batdetect2/batdetect2/train/preprocess.py
2025-04-04 10:58:54 +01:00

209 lines
6.0 KiB
Python

"""Module for preprocessing data for training."""
import os
from functools import partial
from multiprocessing import Pool
from pathlib import Path
from typing import Callable, Optional, Sequence, Union
import xarray as xr
from pydantic import Field
from soundevent import data
from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig
from batdetect2.preprocess import (
PreprocessingConfig,
compute_spectrogram,
load_clip_audio,
)
from batdetect2.train.labels import LabelConfig, generate_heatmaps
from batdetect2.train.targets import (
TargetConfig,
build_target_encoder,
build_sound_event_filter,
get_class_names,
)
PathLike = Union[Path, str, os.PathLike]
FilenameFn = Callable[[data.ClipAnnotation], str]
__all__ = [
"preprocess_annotations",
"preprocess_single_annotation",
"generate_train_example",
"TrainPreprocessingConfig",
]
class TrainPreprocessingConfig(BaseConfig):
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
target: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def generate_train_example(
clip_annotation: data.ClipAnnotation,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
) -> xr.Dataset:
"""Generate a training example."""
config = TrainPreprocessingConfig(
preprocessing=preprocessing_config or PreprocessingConfig(),
target=target_config or TargetConfig(),
labels=label_config or LabelConfig(),
)
wave = load_clip_audio(
clip_annotation.clip,
config=config.preprocessing.audio,
)
spectrogram = compute_spectrogram(
wave,
config=config.preprocessing.spectrogram,
)
filter_fn = build_sound_event_filter(
include=config.target.include,
exclude=config.target.exclude,
)
selected_events = [
event for event in clip_annotation.sound_events if filter_fn(event)
]
encoder = build_target_encoder(
config.target.classes,
replacement_rules=config.target.replace,
)
class_names = get_class_names(config.target.classes)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
selected_events,
spectrogram,
class_names,
encoder,
target_sigma=config.labels.heatmaps.sigma,
position=config.labels.heatmaps.position,
time_scale=config.labels.heatmaps.time_scale,
frequency_scale=config.labels.heatmaps.frequency_scale,
)
dataset = xr.Dataset(
{
# NOTE: Need to rename the time dimension to avoid conflicts with
# the spectrogram time dimension, otherwise xarray will interpolate
# the spectrogram and the heatmaps to the same temporal resolution
# as the waveform.
"audio": wave.rename({"time": "audio_time"}),
"spectrogram": spectrogram,
"detection": detection_heatmap,
"class": class_heatmap,
"size": size_heatmap,
}
)
return dataset.assign_attrs(
title=f"Training example for {clip_annotation.uuid}",
config=config.model_dump_json(),
clip_annotation=clip_annotation.model_dump_json(
exclude_none=True,
exclude_defaults=True,
exclude_unset=True,
),
)
def save_to_file(
dataset: xr.Dataset,
path: PathLike,
) -> None:
dataset.to_netcdf(
path,
encoding={
"spectrogram": {"zlib": True},
"size": {"zlib": True},
"class": {"zlib": True},
"detection": {"zlib": True},
},
)
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
return f"{clip_annotation.uuid}.nc"
def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: PathLike,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
max_workers: Optional[int] = None,
) -> None:
"""Preprocess annotations and save to disk."""
output_dir = Path(output_dir)
if not output_dir.is_dir():
output_dir.mkdir(parents=True)
with Pool(max_workers) as pool:
list(
tqdm(
pool.imap_unordered(
partial(
preprocess_single_annotation,
output_dir=output_dir,
filename_fn=filename_fn,
replace=replace,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
),
clip_annotations,
),
total=len(clip_annotations),
)
)
def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation,
output_dir: PathLike,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
) -> None:
output_dir = Path(output_dir)
filename = filename_fn(clip_annotation)
path = output_dir / filename
if path.is_file() and not replace:
return
if path.is_file() and replace:
path.unlink()
try:
sample = generate_train_example(
clip_annotation,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
)
except Exception as error:
raise RuntimeError(
f"Failed to process annotation: {clip_annotation.uuid}"
) from error
save_to_file(sample, path)