Fix train preprocessing

This commit is contained in:
mbsantiago 2025-08-25 14:00:56 +01:00
parent c36ef3ecb5
commit 76dda0a0e9
3 changed files with 93 additions and 273 deletions

View File

@ -65,16 +65,6 @@ __all__ = ["preprocess"]
"top level, you don't need to specify this."
),
)
@click.option(
"--force",
is_flag=True,
help=(
"If a preprocessed file already exists, this option tells the "
"program to overwrite it with the new preprocessed data. Use "
"this if you want to re-do the preprocessing even if the files "
"already exist."
),
)
@click.option(
"--num-workers",
type=int,
@ -97,7 +87,6 @@ def preprocess(
base_dir: Optional[Path] = None,
config: Optional[Path] = None,
config_field: Optional[str] = None,
force: bool = False,
num_workers: Optional[int] = None,
dataset_field: Optional[str] = None,
verbose: int = 0,
@ -149,6 +138,5 @@ def preprocess(
dataset,
conf,
output=output,
force=force,
max_workers=num_workers,
)

View File

@ -20,18 +20,16 @@ computationally intensive steps during the actual training loop. The module
includes utilities for parallel processing using `multiprocessing`.
"""
from functools import partial
from multiprocessing import Pool
import os
from pathlib import Path
from typing import Callable, Optional, Sequence
from typing import Callable, Dict, Optional, Sequence
import numpy as np
import torch
import xarray as xr
import torch.utils.data
from loguru import logger
from pydantic import Field
from soundevent import data
from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
@ -41,11 +39,10 @@ from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.utils.arrays import audio_to_xarray
from batdetect2.typing.train import PreprocessedExample
__all__ = [
"preprocess_annotations",
"preprocess_single_annotation",
"generate_train_example",
"preprocess_dataset",
"TrainPreprocessConfig",
@ -75,12 +72,16 @@ def preprocess_dataset(
dataset: Dataset,
config: TrainPreprocessConfig,
output: Path,
force: bool = False,
max_workers: Optional[int] = None,
) -> None:
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
labeller = build_clip_labeler(targets, config=config.labels)
labeller = build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.labels,
)
audio_loader = build_audio_loader(config=config.preprocess.audio)
if not output.exists():
@ -93,7 +94,6 @@ def preprocess_dataset(
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
max_workers=max_workers,
)
@ -103,142 +103,60 @@ def generate_train_example(
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
) -> xr.Dataset:
"""Generate a complete training example for one annotation.
This function takes a single `ClipAnnotation`, applies the configured
preprocessing (`PreprocessorProtocol`) to get the processed waveform and
input spectrogram, applies the configured target generation
(`ClipLabeller`) to get the target heatmaps, and packages them all into a
single `xr.Dataset`.
Parameters
----------
clip_annotation : data.ClipAnnotation
The annotated clip to process. Contains the reference to the `Clip`
(audio segment) and the associated `SoundEventAnnotation` objects.
preprocessor : PreprocessorProtocol
An initialized preprocessor object responsible for loading/processing
audio and computing the input spectrogram.
labeller : ClipLabeller
An initialized clip labeller function responsible for generating the
target heatmaps (detection, class, size) from the `clip_annotation`
and the computed spectrogram.
Returns
-------
xr.Dataset
An xarray Dataset containing the following data variables:
- `audio`: The preprocessed audio waveform (dims: 'audio_time').
- `spectrogram`: The computed input spectrogram
(dims: 'time', 'frequency').
- `detection`: The target detection heatmap
(dims: 'time', 'frequency').
- `class`: The target class heatmap
(dims: 'category', 'time', 'frequency').
- `size`: The target size heatmap
(dims: 'dimension', 'time', 'frequency').
The Dataset also includes metadata in its attributes.
Notes
-----
- The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time'
within the output Dataset to avoid coordinate conflicts with the
spectrogram's 'time' dimension when stored together.
- The original `ClipAnnotation` metadata is stored as a JSON string in the
Dataset's attributes for provenance.
"""
wave = audio_loader.load_clip(clip_annotation.clip)
spectrogram = _spec_to_xr(
preprocessor(torch.tensor(wave)),
start_time=clip_annotation.clip.start_time,
end_time=clip_annotation.clip.end_time,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
) -> Dict[str, torch.Tensor]:
"""Generate a complete training example for one annotation."""
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
spectrogram = preprocessor(wave)
heatmaps = labeller(clip_annotation, spectrogram)
dataset = xr.Dataset(
{
# NOTE: Need to rename the time dimension to avoid conflicts with
# the spectrogram time dimension, otherwise xarray will interpolate
# the spectrogram and the heatmaps to the same temporal resolution
# as the waveform.
"audio": audio_to_xarray(
wave,
start_time=clip_annotation.clip.start_time,
end_time=clip_annotation.clip.end_time,
time_axis="audio_time",
),
"spectrogram": spectrogram,
"detection": heatmaps.detection,
"class": heatmaps.classes,
"size": heatmaps.size,
}
)
return dataset.assign_attrs(
title=f"Training example for {clip_annotation.uuid}",
clip_annotation=clip_annotation.model_dump_json(
exclude_none=True,
exclude_defaults=True,
exclude_unset=True,
),
return dict(
audio=wave,
spectrogram=spectrogram,
detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size,
)
def _spec_to_xr(
spec: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> xr.DataArray:
data = spec.numpy()[0, 0]
class PreprocessingDataset(torch.utils.data.Dataset):
def __init__(
self,
clips: Dataset,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
):
self.clips = clips
self.audio_loader = audio_loader
self.preprocessor = preprocessor
self.labeller = labeller
height, width = data.shape
def __getitem__(self, idx) -> dict:
clip_annotation = self.clips[idx]
example = generate_train_example(
clip_annotation,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
labeller=self.labeller,
)
example["idx"] = idx
return example
return xr.DataArray(
data=data,
dims=[
"frequency",
"time",
],
coords={
"frequency": np.linspace(
min_freq, max_freq, height, endpoint=False
),
"time": np.linspace(start_time, end_time, width, endpoint=False),
},
)
def __len__(self) -> int:
return len(self.clips)
def _save_xr_dataset_to_file(
dataset: xr.Dataset,
def _save_example_to_file(
example: PreprocessedExample,
path: data.PathLike,
) -> None:
"""Save an xarray Dataset to a NetCDF file with compression.
Internal helper function used by `preprocess_single_annotation`.
Parameters
----------
dataset : xr.Dataset
The training example dataset to save.
path : PathLike
The output file path (e.g., 'output/uuid.nc').
"""
dataset.to_netcdf(
np.savez_compressed(
path,
encoding={
"audio": {"zlib": True},
"spectrogram": {"zlib": True},
"size": {"zlib": True},
"class": {"zlib": True},
"detection": {"zlib": True},
},
audio=example.audio,
spectrogram=example.spectrogram,
detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap,
clip_annotation=example.clip_annotation,
)
@ -254,50 +172,9 @@ def preprocess_annotations(
audio_loader: AudioLoader,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
max_workers: Optional[int] = None,
) -> None:
"""Preprocess a sequence of ClipAnnotations and save results to disk.
Generates the full training example (spectrogram, heatmaps, etc.) for each
`ClipAnnotation` in the input sequence using the provided `preprocessor`
and `labeller`. Saves each example as a separate NetCDF file in the
`output_dir`. Utilizes multiprocessing for potentially faster processing.
Parameters
----------
clip_annotations : Sequence[data.ClipAnnotation]
A sequence (e.g., list) of the clip annotations to preprocess.
output_dir : PathLike
Path to the directory where the processed NetCDF files will be saved.
Will be created if it doesn't exist.
preprocessor : PreprocessorProtocol
Initialized preprocessor object to generate spectrograms.
labeller : ClipLabeller
Initialized labeller function to generate target heatmaps.
filename_fn : FilenameFn, optional
Function to generate the output filename (without extension) for each
`ClipAnnotation`. Defaults to using the annotation UUID via
`_get_filename`.
replace : bool, default=False
If True, existing files in `output_dir` with the same generated name
will be overwritten. If False (default), existing files are skipped.
max_workers : int, optional
Maximum number of worker processes to use for parallel processing.
If None (default), uses the number of CPUs available (`os.cpu_count()`).
Returns
-------
None
This function does not return anything; its side effect is creating
files in the `output_dir`.
Raises
------
RuntimeError
If processing fails for any individual annotation when using
multiprocessing. The original exception will be attached as the cause.
"""
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
output_dir = Path(output_dir)
if not output_dir.is_dir():
@ -312,88 +189,33 @@ def preprocess_annotations(
max_workers=max_workers or "all available",
)
with Pool(max_workers) as pool:
list(
tqdm(
pool.imap_unordered(
partial(
preprocess_single_annotation,
output_dir=output_dir,
filename_fn=filename_fn,
replace=replace,
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
),
clip_annotations,
),
total=len(clip_annotations),
desc="Preprocessing annotations",
)
if max_workers is None:
max_workers = os.cpu_count() or 0
dataset = PreprocessingDataset(
clips=list(clip_annotations),
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
shuffle=False,
num_workers=max_workers,
)
for batch in loader:
clip_annotation = dataset.clips[batch["idx"]]
filename = filename_fn(clip_annotation)
path = output_dir / filename
example = PreprocessedExample(
clip_annotation=clip_annotation,
spectrogram=batch["spectrogram"].numpy(),
audio=batch["audio"].numpy(),
class_heatmap=batch["class_heatmap"].numpy(),
size_heatmap=batch["size_heatmap"].numpy(),
detection_heatmap=batch["detection_heatmap"].numpy(),
)
logger.info("Finished preprocessing.")
def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation,
output_dir: data.PathLike,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
) -> None:
"""Process a single ClipAnnotation and save the result to a file.
Internal function designed to be called by `preprocess_annotations`, often
in parallel worker processes. It generates the training example using
`generate_train_example` and saves it using `save_to_file`. Handles
file existence checks based on the `replace` flag.
Parameters
----------
clip_annotation : data.ClipAnnotation
The single annotation to process.
output_dir : Path
The directory where the output NetCDF file should be saved.
preprocessor : PreprocessorProtocol
Initialized preprocessor object.
labeller : ClipLabeller
Initialized labeller function.
filename_fn : FilenameFn, default=_get_filename
Function to determine the output filename.
replace : bool, default=False
Whether to overwrite existing output files.
"""
output_dir = Path(output_dir)
filename = filename_fn(clip_annotation)
path = output_dir / filename
if path.is_file() and not replace:
logger.debug("Skipping existing file: {path}", path=path)
return
if path.is_file() and replace:
logger.debug("Removing existing file: {path}", path=path)
path.unlink()
logger.debug("Processing annotation {uuid}", uuid=clip_annotation.uuid)
try:
sample = generate_train_example(
clip_annotation,
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
)
except Exception as error:
logger.error(
"Failed to process annotation {uuid} to {path}. Error: {error}",
uuid=clip_annotation.uuid,
path=path,
error=error,
)
return
_save_xr_dataset_to_file(sample, path)
_save_example_to_file(example, path)

View File

@ -1,5 +1,6 @@
from typing import Callable, NamedTuple, Protocol, Tuple
import numpy as np
import torch
import xarray as xr
from soundevent import data
@ -42,6 +43,15 @@ class Heatmaps(NamedTuple):
size: torch.Tensor
class PreprocessedExample(NamedTuple):
audio: np.ndarray
spectrogram: np.ndarray
detection_heatmap: np.ndarray
class_heatmap: np.ndarray
size_heatmap: np.ndarray
clip_annotation: data.ClipAnnotation
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
"""Type alias for the final clip labelling function.