mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
209 lines
6.0 KiB
Python
209 lines
6.0 KiB
Python
"""Module for preprocessing data for training."""
|
|
|
|
import os
|
|
from functools import partial
|
|
from multiprocessing import Pool
|
|
from pathlib import Path
|
|
from typing import Callable, Optional, Sequence, Union
|
|
|
|
import xarray as xr
|
|
from pydantic import Field
|
|
from soundevent import data
|
|
from tqdm.auto import tqdm
|
|
|
|
from batdetect2.configs import BaseConfig
|
|
from batdetect2.preprocess import (
|
|
PreprocessingConfig,
|
|
compute_spectrogram,
|
|
load_clip_audio,
|
|
)
|
|
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
|
from batdetect2.train.targets import (
|
|
TargetConfig,
|
|
build_target_encoder,
|
|
build_sound_event_filter,
|
|
get_class_names,
|
|
)
|
|
|
|
PathLike = Union[Path, str, os.PathLike]
|
|
FilenameFn = Callable[[data.ClipAnnotation], str]
|
|
|
|
__all__ = [
|
|
"preprocess_annotations",
|
|
"preprocess_single_annotation",
|
|
"generate_train_example",
|
|
"TrainPreprocessingConfig",
|
|
]
|
|
|
|
|
|
class TrainPreprocessingConfig(BaseConfig):
|
|
preprocessing: PreprocessingConfig = Field(
|
|
default_factory=PreprocessingConfig
|
|
)
|
|
target: TargetConfig = Field(default_factory=TargetConfig)
|
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
|
|
|
|
|
def generate_train_example(
|
|
clip_annotation: data.ClipAnnotation,
|
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
|
target_config: Optional[TargetConfig] = None,
|
|
label_config: Optional[LabelConfig] = None,
|
|
) -> xr.Dataset:
|
|
"""Generate a training example."""
|
|
config = TrainPreprocessingConfig(
|
|
preprocessing=preprocessing_config or PreprocessingConfig(),
|
|
target=target_config or TargetConfig(),
|
|
labels=label_config or LabelConfig(),
|
|
)
|
|
|
|
wave = load_clip_audio(
|
|
clip_annotation.clip,
|
|
config=config.preprocessing.audio,
|
|
)
|
|
|
|
spectrogram = compute_spectrogram(
|
|
wave,
|
|
config=config.preprocessing.spectrogram,
|
|
)
|
|
|
|
filter_fn = build_sound_event_filter(
|
|
include=config.target.include,
|
|
exclude=config.target.exclude,
|
|
)
|
|
|
|
selected_events = [
|
|
event for event in clip_annotation.sound_events if filter_fn(event)
|
|
]
|
|
|
|
encoder = build_target_encoder(
|
|
config.target.classes,
|
|
replacement_rules=config.target.replace,
|
|
)
|
|
class_names = get_class_names(config.target.classes)
|
|
|
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
|
selected_events,
|
|
spectrogram,
|
|
class_names,
|
|
encoder,
|
|
target_sigma=config.labels.heatmaps.sigma,
|
|
position=config.labels.heatmaps.position,
|
|
time_scale=config.labels.heatmaps.time_scale,
|
|
frequency_scale=config.labels.heatmaps.frequency_scale,
|
|
)
|
|
|
|
dataset = xr.Dataset(
|
|
{
|
|
# NOTE: Need to rename the time dimension to avoid conflicts with
|
|
# the spectrogram time dimension, otherwise xarray will interpolate
|
|
# the spectrogram and the heatmaps to the same temporal resolution
|
|
# as the waveform.
|
|
"audio": wave.rename({"time": "audio_time"}),
|
|
"spectrogram": spectrogram,
|
|
"detection": detection_heatmap,
|
|
"class": class_heatmap,
|
|
"size": size_heatmap,
|
|
}
|
|
)
|
|
|
|
return dataset.assign_attrs(
|
|
title=f"Training example for {clip_annotation.uuid}",
|
|
config=config.model_dump_json(),
|
|
clip_annotation=clip_annotation.model_dump_json(
|
|
exclude_none=True,
|
|
exclude_defaults=True,
|
|
exclude_unset=True,
|
|
),
|
|
)
|
|
|
|
|
|
def save_to_file(
|
|
dataset: xr.Dataset,
|
|
path: PathLike,
|
|
) -> None:
|
|
dataset.to_netcdf(
|
|
path,
|
|
encoding={
|
|
"spectrogram": {"zlib": True},
|
|
"size": {"zlib": True},
|
|
"class": {"zlib": True},
|
|
"detection": {"zlib": True},
|
|
},
|
|
)
|
|
|
|
|
|
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
|
return f"{clip_annotation.uuid}.nc"
|
|
|
|
|
|
def preprocess_annotations(
|
|
clip_annotations: Sequence[data.ClipAnnotation],
|
|
output_dir: PathLike,
|
|
filename_fn: FilenameFn = _get_filename,
|
|
replace: bool = False,
|
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
|
target_config: Optional[TargetConfig] = None,
|
|
label_config: Optional[LabelConfig] = None,
|
|
max_workers: Optional[int] = None,
|
|
) -> None:
|
|
"""Preprocess annotations and save to disk."""
|
|
output_dir = Path(output_dir)
|
|
|
|
if not output_dir.is_dir():
|
|
output_dir.mkdir(parents=True)
|
|
|
|
with Pool(max_workers) as pool:
|
|
list(
|
|
tqdm(
|
|
pool.imap_unordered(
|
|
partial(
|
|
preprocess_single_annotation,
|
|
output_dir=output_dir,
|
|
filename_fn=filename_fn,
|
|
replace=replace,
|
|
preprocessing_config=preprocessing_config,
|
|
target_config=target_config,
|
|
label_config=label_config,
|
|
),
|
|
clip_annotations,
|
|
),
|
|
total=len(clip_annotations),
|
|
)
|
|
)
|
|
|
|
|
|
def preprocess_single_annotation(
|
|
clip_annotation: data.ClipAnnotation,
|
|
output_dir: PathLike,
|
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
|
target_config: Optional[TargetConfig] = None,
|
|
label_config: Optional[LabelConfig] = None,
|
|
filename_fn: FilenameFn = _get_filename,
|
|
replace: bool = False,
|
|
) -> None:
|
|
output_dir = Path(output_dir)
|
|
|
|
filename = filename_fn(clip_annotation)
|
|
path = output_dir / filename
|
|
|
|
if path.is_file() and not replace:
|
|
return
|
|
|
|
if path.is_file() and replace:
|
|
path.unlink()
|
|
|
|
try:
|
|
sample = generate_train_example(
|
|
clip_annotation,
|
|
preprocessing_config=preprocessing_config,
|
|
target_config=target_config,
|
|
label_config=label_config,
|
|
)
|
|
except Exception as error:
|
|
raise RuntimeError(
|
|
f"Failed to process annotation: {clip_annotation.uuid}"
|
|
) from error
|
|
|
|
save_to_file(sample, path)
|