mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Starting to create dataset builders
This commit is contained in:
parent
9cf159efff
commit
f6cdd4e87e
@ -1,5 +1,6 @@
|
|||||||
"""Compatibility functions between old and new data structures."""
|
"""Compatibility functions between old and new data structures."""
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -17,7 +18,7 @@ PathLike = Union[Path, str, os.PathLike]
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_to_annotation_group",
|
"convert_to_annotation_group",
|
||||||
"load_annotation_project",
|
"load_annotation_project_from_dir",
|
||||||
]
|
]
|
||||||
|
|
||||||
SPECIES_TAG_KEY = "species"
|
SPECIES_TAG_KEY = "species"
|
||||||
@ -298,7 +299,38 @@ def list_file_annotations(path: PathLike) -> List[Path]:
|
|||||||
return [file for file in path.glob("*.json")]
|
return [file for file in path.glob("*.json")]
|
||||||
|
|
||||||
|
|
||||||
def load_annotation_project(
|
def load_annotation_project_from_file(
|
||||||
|
path: PathLike,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
audio_dir: Optional[PathLike] = None,
|
||||||
|
) -> data.AnnotationProject:
|
||||||
|
old_annotations = json.loads(Path(path).read_text())
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for ann in old_annotations:
|
||||||
|
try:
|
||||||
|
ann = FileAnnotation.model_validate(ann)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||||
|
tasks.append(file_annotation_to_annotation_task(ann, clip))
|
||||||
|
|
||||||
|
return data.AnnotationProject(
|
||||||
|
name=name or str(path),
|
||||||
|
clip_annotations=annotations,
|
||||||
|
tasks=tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_annotation_project_from_dir(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
audio_dir: Optional[PathLike] = None,
|
audio_dir: Optional[PathLike] = None,
|
||||||
|
@ -1,5 +1,26 @@
|
|||||||
|
from typing import Optional, Type, TypeVar
|
||||||
|
|
||||||
|
import yaml
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(
|
||||||
|
path: PathLike,
|
||||||
|
schema: Type[T],
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> T:
|
||||||
|
with open(path, "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
if field:
|
||||||
|
config = config[field]
|
||||||
|
|
||||||
|
return schema.model_validate(config)
|
||||||
|
87
batdetect2/data.py
Normal file
87
batdetect2/data.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Literal, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data, io
|
||||||
|
|
||||||
|
from batdetect2.compat.data import (
|
||||||
|
load_annotation_project_from_dir,
|
||||||
|
load_annotation_project_from_file,
|
||||||
|
)
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2AnnotationFiles(BaseConfig):
|
||||||
|
format: Literal["batdetect2"] = "batdetect2"
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2AnnotationFile(BaseConfig):
|
||||||
|
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||||
|
path: Path
|
||||||
|
|
||||||
|
|
||||||
|
class AOEFAnnotationFile(BaseConfig):
|
||||||
|
format: Literal["aoef"] = "aoef"
|
||||||
|
annotations_file: Path
|
||||||
|
|
||||||
|
|
||||||
|
AnnotationFormats = Union[
|
||||||
|
BatDetect2AnnotationFiles,
|
||||||
|
BatDetect2AnnotationFile,
|
||||||
|
AOEFAnnotationFile,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetInfo(BaseConfig):
|
||||||
|
name: str
|
||||||
|
audio_dir: Path
|
||||||
|
annotations: AnnotationFormats = Field(discriminator="format")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsConfig(BaseConfig):
|
||||||
|
train: List[DatasetInfo] = Field(default_factory=list)
|
||||||
|
test: List[DatasetInfo] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(info: DatasetInfo) -> data.AnnotationProject:
|
||||||
|
if info.annotations.format == "batdetect2":
|
||||||
|
return load_annotation_project_from_dir(
|
||||||
|
info.annotations.path,
|
||||||
|
name=info.name,
|
||||||
|
audio_dir=info.audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if info.annotations.format == "batdetect2_file":
|
||||||
|
return load_annotation_project_from_file(
|
||||||
|
info.annotations.path,
|
||||||
|
name=info.name,
|
||||||
|
audio_dir=info.audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if info.annotations.format == "aoef":
|
||||||
|
return io.load( # type: ignore
|
||||||
|
info.annotations.annotations_file,
|
||||||
|
audio_dir=info.audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unknown annotation format: {info.annotations.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_datasets(
|
||||||
|
config: DatasetsConfig,
|
||||||
|
) -> Tuple[List[data.ClipAnnotation], List[data.ClipAnnotation]]:
|
||||||
|
test_annotations = []
|
||||||
|
train_annotations = []
|
||||||
|
|
||||||
|
for dataset in config.train:
|
||||||
|
project = load_dataset(dataset)
|
||||||
|
train_annotations.extend(project.clip_annotations)
|
||||||
|
|
||||||
|
for dataset in config.test:
|
||||||
|
project = load_dataset(dataset)
|
||||||
|
test_annotations.extend(project.clip_annotations)
|
||||||
|
|
||||||
|
return train_annotations, test_annotations
|
@ -1,33 +0,0 @@
|
|||||||
from typing import Callable, Generic, Iterable, List, TypeVar
|
|
||||||
|
|
||||||
from soundevent import data
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ClipDataset",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
E = TypeVar("E")
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDataset(Dataset, Generic[E]):
|
|
||||||
clips: List[data.Clip]
|
|
||||||
|
|
||||||
transform: Callable[[data.Clip], E]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
clips: Iterable[data.Clip],
|
|
||||||
transform: Callable[[data.Clip], E],
|
|
||||||
name: str = "ClipDataset",
|
|
||||||
):
|
|
||||||
self.clips = list(clips)
|
|
||||||
self.transform = transform
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.clips)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> E:
|
|
||||||
return self.transform(self.clips[idx])
|
|
@ -49,6 +49,7 @@ class PreprocessingConfig(BaseModel):
|
|||||||
def preprocess_audio_clip(
|
def preprocess_audio_clip(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
config: Optional[PreprocessingConfig] = None,
|
config: Optional[PreprocessingConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
"""Preprocesses audio clip to generate spectrogram.
|
"""Preprocesses audio clip to generate spectrogram.
|
||||||
|
|
||||||
@ -66,5 +67,5 @@ def preprocess_audio_clip(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
config = config or PreprocessingConfig()
|
config = config or PreprocessingConfig()
|
||||||
wav = load_clip_audio(clip, config=config.audio)
|
wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
|
||||||
return compute_spectrogram(wav, config=config.spectrogram)
|
return compute_spectrogram(wav, config=config.spectrogram)
|
||||||
|
@ -30,15 +30,22 @@ class AudioConfig(BaseConfig):
|
|||||||
def load_file_audio(
|
def load_file_audio(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
config: Optional[AudioConfig] = None,
|
config: Optional[AudioConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
recording = data.Recording.from_file(path)
|
recording = data.Recording.from_file(path)
|
||||||
return load_recording_audio(recording, config=config, dtype=dtype)
|
return load_recording_audio(
|
||||||
|
recording,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_recording_audio(
|
def load_recording_audio(
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
config: Optional[AudioConfig] = None,
|
config: Optional[AudioConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
@ -46,17 +53,25 @@ def load_recording_audio(
|
|||||||
start_time=0,
|
start_time=0,
|
||||||
end_time=recording.duration,
|
end_time=recording.duration,
|
||||||
)
|
)
|
||||||
return load_clip_audio(clip, config=config, dtype=dtype)
|
return load_clip_audio(
|
||||||
|
clip,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_clip_audio(
|
def load_clip_audio(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
config: Optional[AudioConfig] = None,
|
config: Optional[AudioConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
config = config or AudioConfig()
|
config = config or AudioConfig()
|
||||||
|
|
||||||
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
wav = (
|
||||||
|
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
|
||||||
|
)
|
||||||
|
|
||||||
if config.duration is not None:
|
if config.duration is not None:
|
||||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||||
|
@ -28,10 +28,6 @@ class TrainExample(NamedTuple):
|
|||||||
idx: torch.Tensor
|
idx: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset(Dataset):
|
class LabeledDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -92,3 +88,7 @@ class LabeledDataset(Dataset):
|
|||||||
return data.ClipAnnotation.model_validate_json(
|
return data.ClipAnnotation.model_validate_json(
|
||||||
self.get_dataset(idx).attrs["clip_annotation"]
|
self.get_dataset(idx).attrs["clip_annotation"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||||
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
@ -26,6 +26,8 @@ dependencies = [
|
|||||||
"onnx>=1.16.0",
|
"onnx>=1.16.0",
|
||||||
"lightning[extra]>=2.2.2",
|
"lightning[extra]>=2.2.2",
|
||||||
"tensorboard>=2.16.2",
|
"tensorboard>=2.16.2",
|
||||||
|
"omegaconf>=2.3.0",
|
||||||
|
"pyyaml>=6.0.2",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.9,<3.13"
|
requires-python = ">=3.9,<3.13"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -5,7 +5,7 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from batdetect2.compat.data import load_annotation_project
|
from batdetect2.compat.data import load_annotation_project_from_dir
|
||||||
from batdetect2.compat.params import get_training_preprocessing_config
|
from batdetect2.compat.params import get_training_preprocessing_config
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ def test_can_generate_similar_training_inputs(
|
|||||||
size_mask = dataset["size_mask"]
|
size_mask = dataset["size_mask"]
|
||||||
class_mask = dataset["class_mask"]
|
class_mask = dataset["class_mask"]
|
||||||
|
|
||||||
project = load_annotation_project(
|
project = load_annotation_project_from_dir(
|
||||||
example_anns_dir,
|
example_anns_dir,
|
||||||
audio_dir=example_audio_dir,
|
audio_dir=example_audio_dir,
|
||||||
)
|
)
|
||||||
|
4
uv.lock
generated
4
uv.lock
generated
@ -198,9 +198,11 @@ dependencies = [
|
|||||||
{ name = "matplotlib" },
|
{ name = "matplotlib" },
|
||||||
{ name = "netcdf4" },
|
{ name = "netcdf4" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
|
{ name = "omegaconf" },
|
||||||
{ name = "onnx" },
|
{ name = "onnx" },
|
||||||
{ name = "pandas" },
|
{ name = "pandas" },
|
||||||
{ name = "pytorch-lightning" },
|
{ name = "pytorch-lightning" },
|
||||||
|
{ name = "pyyaml" },
|
||||||
{ name = "scikit-learn" },
|
{ name = "scikit-learn" },
|
||||||
{ name = "scipy" },
|
{ name = "scipy" },
|
||||||
{ name = "soundevent", extra = ["audio", "geometry", "plot"] },
|
{ name = "soundevent", extra = ["audio", "geometry", "plot"] },
|
||||||
@ -231,9 +233,11 @@ requires-dist = [
|
|||||||
{ name = "matplotlib", specifier = ">=3.7.1" },
|
{ name = "matplotlib", specifier = ">=3.7.1" },
|
||||||
{ name = "netcdf4", specifier = ">=1.6.5" },
|
{ name = "netcdf4", specifier = ">=1.6.5" },
|
||||||
{ name = "numpy", specifier = ">=1.23.5" },
|
{ name = "numpy", specifier = ">=1.23.5" },
|
||||||
|
{ name = "omegaconf", specifier = ">=2.3.0" },
|
||||||
{ name = "onnx", specifier = ">=1.16.0" },
|
{ name = "onnx", specifier = ">=1.16.0" },
|
||||||
{ name = "pandas", specifier = ">=1.5.3" },
|
{ name = "pandas", specifier = ">=1.5.3" },
|
||||||
{ name = "pytorch-lightning", specifier = ">=2.2.2" },
|
{ name = "pytorch-lightning", specifier = ">=2.2.2" },
|
||||||
|
{ name = "pyyaml", specifier = ">=6.0.2" },
|
||||||
{ name = "scikit-learn", specifier = ">=1.2.2" },
|
{ name = "scikit-learn", specifier = ">=1.2.2" },
|
||||||
{ name = "scipy", specifier = ">=1.10.1" },
|
{ name = "scipy", specifier = ">=1.10.1" },
|
||||||
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" },
|
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" },
|
||||||
|
Loading…
Reference in New Issue
Block a user