From 76dda0a0e95bafc51beca57096c81d5656e80ac5 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 25 Aug 2025 14:00:56 +0100 Subject: [PATCH] Fix train preprocessing --- src/batdetect2/cli/preprocess.py | 12 - src/batdetect2/train/preprocess.py | 344 +++++++---------------------- src/batdetect2/typing/train.py | 10 + 3 files changed, 93 insertions(+), 273 deletions(-) diff --git a/src/batdetect2/cli/preprocess.py b/src/batdetect2/cli/preprocess.py index 0ab537c..b539d5d 100644 --- a/src/batdetect2/cli/preprocess.py +++ b/src/batdetect2/cli/preprocess.py @@ -65,16 +65,6 @@ __all__ = ["preprocess"] "top level, you don't need to specify this." ), ) -@click.option( - "--force", - is_flag=True, - help=( - "If a preprocessed file already exists, this option tells the " - "program to overwrite it with the new preprocessed data. Use " - "this if you want to re-do the preprocessing even if the files " - "already exist." - ), -) @click.option( "--num-workers", type=int, @@ -97,7 +87,6 @@ def preprocess( base_dir: Optional[Path] = None, config: Optional[Path] = None, config_field: Optional[str] = None, - force: bool = False, num_workers: Optional[int] = None, dataset_field: Optional[str] = None, verbose: int = 0, @@ -149,6 +138,5 @@ def preprocess( dataset, conf, output=output, - force=force, max_workers=num_workers, ) diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py index a8b5bff..3a26c59 100644 --- a/src/batdetect2/train/preprocess.py +++ b/src/batdetect2/train/preprocess.py @@ -20,18 +20,16 @@ computationally intensive steps during the actual training loop. The module includes utilities for parallel processing using `multiprocessing`. """ -from functools import partial -from multiprocessing import Pool +import os from pathlib import Path -from typing import Callable, Optional, Sequence +from typing import Callable, Dict, Optional, Sequence import numpy as np import torch -import xarray as xr +import torch.utils.data from loguru import logger from pydantic import Field from soundevent import data -from tqdm.auto import tqdm from batdetect2.configs import BaseConfig, load_config from batdetect2.data.datasets import Dataset @@ -41,11 +39,10 @@ from batdetect2.targets import TargetConfig, build_targets from batdetect2.train.labels import LabelConfig, build_clip_labeler from batdetect2.typing import ClipLabeller, PreprocessorProtocol from batdetect2.typing.preprocess import AudioLoader -from batdetect2.utils.arrays import audio_to_xarray +from batdetect2.typing.train import PreprocessedExample __all__ = [ "preprocess_annotations", - "preprocess_single_annotation", "generate_train_example", "preprocess_dataset", "TrainPreprocessConfig", @@ -75,12 +72,16 @@ def preprocess_dataset( dataset: Dataset, config: TrainPreprocessConfig, output: Path, - force: bool = False, max_workers: Optional[int] = None, ) -> None: targets = build_targets(config=config.targets) preprocessor = build_preprocessor(config=config.preprocess) - labeller = build_clip_labeler(targets, config=config.labels) + labeller = build_clip_labeler( + targets, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + config=config.labels, + ) audio_loader = build_audio_loader(config=config.preprocess.audio) if not output.exists(): @@ -93,7 +94,6 @@ def preprocess_dataset( audio_loader=audio_loader, preprocessor=preprocessor, labeller=labeller, - replace=force, max_workers=max_workers, ) @@ -103,142 +103,60 @@ def generate_train_example( audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, labeller: ClipLabeller, -) -> xr.Dataset: - """Generate a complete training example for one annotation. - - This function takes a single `ClipAnnotation`, applies the configured - preprocessing (`PreprocessorProtocol`) to get the processed waveform and - input spectrogram, applies the configured target generation - (`ClipLabeller`) to get the target heatmaps, and packages them all into a - single `xr.Dataset`. - - Parameters - ---------- - clip_annotation : data.ClipAnnotation - The annotated clip to process. Contains the reference to the `Clip` - (audio segment) and the associated `SoundEventAnnotation` objects. - preprocessor : PreprocessorProtocol - An initialized preprocessor object responsible for loading/processing - audio and computing the input spectrogram. - labeller : ClipLabeller - An initialized clip labeller function responsible for generating the - target heatmaps (detection, class, size) from the `clip_annotation` - and the computed spectrogram. - - Returns - ------- - xr.Dataset - An xarray Dataset containing the following data variables: - - `audio`: The preprocessed audio waveform (dims: 'audio_time'). - - `spectrogram`: The computed input spectrogram - (dims: 'time', 'frequency'). - - `detection`: The target detection heatmap - (dims: 'time', 'frequency'). - - `class`: The target class heatmap - (dims: 'category', 'time', 'frequency'). - - `size`: The target size heatmap - (dims: 'dimension', 'time', 'frequency'). - The Dataset also includes metadata in its attributes. - - Notes - ----- - - The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time' - within the output Dataset to avoid coordinate conflicts with the - spectrogram's 'time' dimension when stored together. - - The original `ClipAnnotation` metadata is stored as a JSON string in the - Dataset's attributes for provenance. - """ - wave = audio_loader.load_clip(clip_annotation.clip) - - spectrogram = _spec_to_xr( - preprocessor(torch.tensor(wave)), - start_time=clip_annotation.clip.start_time, - end_time=clip_annotation.clip.end_time, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - ) - +) -> Dict[str, torch.Tensor]: + """Generate a complete training example for one annotation.""" + wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip)) + spectrogram = preprocessor(wave) heatmaps = labeller(clip_annotation, spectrogram) - - dataset = xr.Dataset( - { - # NOTE: Need to rename the time dimension to avoid conflicts with - # the spectrogram time dimension, otherwise xarray will interpolate - # the spectrogram and the heatmaps to the same temporal resolution - # as the waveform. - "audio": audio_to_xarray( - wave, - start_time=clip_annotation.clip.start_time, - end_time=clip_annotation.clip.end_time, - time_axis="audio_time", - ), - "spectrogram": spectrogram, - "detection": heatmaps.detection, - "class": heatmaps.classes, - "size": heatmaps.size, - } - ) - - return dataset.assign_attrs( - title=f"Training example for {clip_annotation.uuid}", - clip_annotation=clip_annotation.model_dump_json( - exclude_none=True, - exclude_defaults=True, - exclude_unset=True, - ), + return dict( + audio=wave, + spectrogram=spectrogram, + detection_heatmap=heatmaps.detection, + class_heatmap=heatmaps.classes, + size_heatmap=heatmaps.size, ) -def _spec_to_xr( - spec: torch.Tensor, - start_time: float, - end_time: float, - min_freq: float, - max_freq: float, -) -> xr.DataArray: - data = spec.numpy()[0, 0] +class PreprocessingDataset(torch.utils.data.Dataset): + def __init__( + self, + clips: Dataset, + audio_loader: AudioLoader, + preprocessor: PreprocessorProtocol, + labeller: ClipLabeller, + ): + self.clips = clips + self.audio_loader = audio_loader + self.preprocessor = preprocessor + self.labeller = labeller - height, width = data.shape + def __getitem__(self, idx) -> dict: + clip_annotation = self.clips[idx] + example = generate_train_example( + clip_annotation, + audio_loader=self.audio_loader, + preprocessor=self.preprocessor, + labeller=self.labeller, + ) + example["idx"] = idx + return example - return xr.DataArray( - data=data, - dims=[ - "frequency", - "time", - ], - coords={ - "frequency": np.linspace( - min_freq, max_freq, height, endpoint=False - ), - "time": np.linspace(start_time, end_time, width, endpoint=False), - }, - ) + def __len__(self) -> int: + return len(self.clips) -def _save_xr_dataset_to_file( - dataset: xr.Dataset, +def _save_example_to_file( + example: PreprocessedExample, path: data.PathLike, ) -> None: - """Save an xarray Dataset to a NetCDF file with compression. - - Internal helper function used by `preprocess_single_annotation`. - - Parameters - ---------- - dataset : xr.Dataset - The training example dataset to save. - path : PathLike - The output file path (e.g., 'output/uuid.nc'). - """ - dataset.to_netcdf( + np.savez_compressed( path, - encoding={ - "audio": {"zlib": True}, - "spectrogram": {"zlib": True}, - "size": {"zlib": True}, - "class": {"zlib": True}, - "detection": {"zlib": True}, - }, + audio=example.audio, + spectrogram=example.spectrogram, + detection_heatmap=example.detection_heatmap, + class_heatmap=example.class_heatmap, + size_heatmap=example.size_heatmap, + clip_annotation=example.clip_annotation, ) @@ -254,50 +172,9 @@ def preprocess_annotations( audio_loader: AudioLoader, labeller: ClipLabeller, filename_fn: FilenameFn = _get_filename, - replace: bool = False, max_workers: Optional[int] = None, ) -> None: - """Preprocess a sequence of ClipAnnotations and save results to disk. - - Generates the full training example (spectrogram, heatmaps, etc.) for each - `ClipAnnotation` in the input sequence using the provided `preprocessor` - and `labeller`. Saves each example as a separate NetCDF file in the - `output_dir`. Utilizes multiprocessing for potentially faster processing. - - Parameters - ---------- - clip_annotations : Sequence[data.ClipAnnotation] - A sequence (e.g., list) of the clip annotations to preprocess. - output_dir : PathLike - Path to the directory where the processed NetCDF files will be saved. - Will be created if it doesn't exist. - preprocessor : PreprocessorProtocol - Initialized preprocessor object to generate spectrograms. - labeller : ClipLabeller - Initialized labeller function to generate target heatmaps. - filename_fn : FilenameFn, optional - Function to generate the output filename (without extension) for each - `ClipAnnotation`. Defaults to using the annotation UUID via - `_get_filename`. - replace : bool, default=False - If True, existing files in `output_dir` with the same generated name - will be overwritten. If False (default), existing files are skipped. - max_workers : int, optional - Maximum number of worker processes to use for parallel processing. - If None (default), uses the number of CPUs available (`os.cpu_count()`). - - Returns - ------- - None - This function does not return anything; its side effect is creating - files in the `output_dir`. - - Raises - ------ - RuntimeError - If processing fails for any individual annotation when using - multiprocessing. The original exception will be attached as the cause. - """ + """Preprocess a sequence of ClipAnnotations and save results to disk.""" output_dir = Path(output_dir) if not output_dir.is_dir(): @@ -312,88 +189,33 @@ def preprocess_annotations( max_workers=max_workers or "all available", ) - with Pool(max_workers) as pool: - list( - tqdm( - pool.imap_unordered( - partial( - preprocess_single_annotation, - output_dir=output_dir, - filename_fn=filename_fn, - replace=replace, - audio_loader=audio_loader, - preprocessor=preprocessor, - labeller=labeller, - ), - clip_annotations, - ), - total=len(clip_annotations), - desc="Preprocessing annotations", - ) + if max_workers is None: + max_workers = os.cpu_count() or 0 + + dataset = PreprocessingDataset( + clips=list(clip_annotations), + audio_loader=audio_loader, + preprocessor=preprocessor, + labeller=labeller, + ) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=max_workers, + ) + + for batch in loader: + clip_annotation = dataset.clips[batch["idx"]] + filename = filename_fn(clip_annotation) + path = output_dir / filename + example = PreprocessedExample( + clip_annotation=clip_annotation, + spectrogram=batch["spectrogram"].numpy(), + audio=batch["audio"].numpy(), + class_heatmap=batch["class_heatmap"].numpy(), + size_heatmap=batch["size_heatmap"].numpy(), + detection_heatmap=batch["detection_heatmap"].numpy(), ) - logger.info("Finished preprocessing.") - - -def preprocess_single_annotation( - clip_annotation: data.ClipAnnotation, - output_dir: data.PathLike, - audio_loader: AudioLoader, - preprocessor: PreprocessorProtocol, - labeller: ClipLabeller, - filename_fn: FilenameFn = _get_filename, - replace: bool = False, -) -> None: - """Process a single ClipAnnotation and save the result to a file. - - Internal function designed to be called by `preprocess_annotations`, often - in parallel worker processes. It generates the training example using - `generate_train_example` and saves it using `save_to_file`. Handles - file existence checks based on the `replace` flag. - - Parameters - ---------- - clip_annotation : data.ClipAnnotation - The single annotation to process. - output_dir : Path - The directory where the output NetCDF file should be saved. - preprocessor : PreprocessorProtocol - Initialized preprocessor object. - labeller : ClipLabeller - Initialized labeller function. - filename_fn : FilenameFn, default=_get_filename - Function to determine the output filename. - replace : bool, default=False - Whether to overwrite existing output files. - """ - output_dir = Path(output_dir) - - filename = filename_fn(clip_annotation) - path = output_dir / filename - - if path.is_file() and not replace: - logger.debug("Skipping existing file: {path}", path=path) - return - - if path.is_file() and replace: - logger.debug("Removing existing file: {path}", path=path) - path.unlink() - - logger.debug("Processing annotation {uuid}", uuid=clip_annotation.uuid) - - try: - sample = generate_train_example( - clip_annotation, - audio_loader=audio_loader, - preprocessor=preprocessor, - labeller=labeller, - ) - except Exception as error: - logger.error( - "Failed to process annotation {uuid} to {path}. Error: {error}", - uuid=clip_annotation.uuid, - path=path, - error=error, - ) - return - - _save_xr_dataset_to_file(sample, path) + _save_example_to_file(example, path) diff --git a/src/batdetect2/typing/train.py b/src/batdetect2/typing/train.py index 7d6079c..79868d9 100644 --- a/src/batdetect2/typing/train.py +++ b/src/batdetect2/typing/train.py @@ -1,5 +1,6 @@ from typing import Callable, NamedTuple, Protocol, Tuple +import numpy as np import torch import xarray as xr from soundevent import data @@ -42,6 +43,15 @@ class Heatmaps(NamedTuple): size: torch.Tensor +class PreprocessedExample(NamedTuple): + audio: np.ndarray + spectrogram: np.ndarray + detection_heatmap: np.ndarray + class_heatmap: np.ndarray + size_heatmap: np.ndarray + clip_annotation: data.ClipAnnotation + + ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps] """Type alias for the final clip labelling function.