mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Fix train preprocessing
This commit is contained in:
parent
c36ef3ecb5
commit
76dda0a0e9
@ -65,16 +65,6 @@ __all__ = ["preprocess"]
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--force",
|
||||
is_flag=True,
|
||||
help=(
|
||||
"If a preprocessed file already exists, this option tells the "
|
||||
"program to overwrite it with the new preprocessed data. Use "
|
||||
"this if you want to re-do the preprocessing even if the files "
|
||||
"already exist."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
@ -97,7 +87,6 @@ def preprocess(
|
||||
base_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
force: bool = False,
|
||||
num_workers: Optional[int] = None,
|
||||
dataset_field: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
@ -149,6 +138,5 @@ def preprocess(
|
||||
dataset,
|
||||
conf,
|
||||
output=output,
|
||||
force=force,
|
||||
max_workers=num_workers,
|
||||
)
|
||||
|
||||
@ -20,18 +20,16 @@ computationally intensive steps during the actual training loop. The module
|
||||
includes utilities for parallel processing using `multiprocessing`.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence
|
||||
from typing import Callable, Dict, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
import torch.utils.data
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.datasets import Dataset
|
||||
@ -41,11 +39,10 @@ from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.utils.arrays import audio_to_xarray
|
||||
from batdetect2.typing.train import PreprocessedExample
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"preprocess_single_annotation",
|
||||
"generate_train_example",
|
||||
"preprocess_dataset",
|
||||
"TrainPreprocessConfig",
|
||||
@ -75,12 +72,16 @@ def preprocess_dataset(
|
||||
dataset: Dataset,
|
||||
config: TrainPreprocessConfig,
|
||||
output: Path,
|
||||
force: bool = False,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
targets = build_targets(config=config.targets)
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
labeller = build_clip_labeler(targets, config=config.labels)
|
||||
labeller = build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.labels,
|
||||
)
|
||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||
|
||||
if not output.exists():
|
||||
@ -93,7 +94,6 @@ def preprocess_dataset(
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
replace=force,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
@ -103,142 +103,60 @@ def generate_train_example(
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
) -> xr.Dataset:
|
||||
"""Generate a complete training example for one annotation.
|
||||
|
||||
This function takes a single `ClipAnnotation`, applies the configured
|
||||
preprocessing (`PreprocessorProtocol`) to get the processed waveform and
|
||||
input spectrogram, applies the configured target generation
|
||||
(`ClipLabeller`) to get the target heatmaps, and packages them all into a
|
||||
single `xr.Dataset`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The annotated clip to process. Contains the reference to the `Clip`
|
||||
(audio segment) and the associated `SoundEventAnnotation` objects.
|
||||
preprocessor : PreprocessorProtocol
|
||||
An initialized preprocessor object responsible for loading/processing
|
||||
audio and computing the input spectrogram.
|
||||
labeller : ClipLabeller
|
||||
An initialized clip labeller function responsible for generating the
|
||||
target heatmaps (detection, class, size) from the `clip_annotation`
|
||||
and the computed spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
An xarray Dataset containing the following data variables:
|
||||
- `audio`: The preprocessed audio waveform (dims: 'audio_time').
|
||||
- `spectrogram`: The computed input spectrogram
|
||||
(dims: 'time', 'frequency').
|
||||
- `detection`: The target detection heatmap
|
||||
(dims: 'time', 'frequency').
|
||||
- `class`: The target class heatmap
|
||||
(dims: 'category', 'time', 'frequency').
|
||||
- `size`: The target size heatmap
|
||||
(dims: 'dimension', 'time', 'frequency').
|
||||
The Dataset also includes metadata in its attributes.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time'
|
||||
within the output Dataset to avoid coordinate conflicts with the
|
||||
spectrogram's 'time' dimension when stored together.
|
||||
- The original `ClipAnnotation` metadata is stored as a JSON string in the
|
||||
Dataset's attributes for provenance.
|
||||
"""
|
||||
wave = audio_loader.load_clip(clip_annotation.clip)
|
||||
|
||||
spectrogram = _spec_to_xr(
|
||||
preprocessor(torch.tensor(wave)),
|
||||
start_time=clip_annotation.clip.start_time,
|
||||
end_time=clip_annotation.clip.end_time,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Generate a complete training example for one annotation."""
|
||||
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
|
||||
spectrogram = preprocessor(wave)
|
||||
heatmaps = labeller(clip_annotation, spectrogram)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
{
|
||||
# NOTE: Need to rename the time dimension to avoid conflicts with
|
||||
# the spectrogram time dimension, otherwise xarray will interpolate
|
||||
# the spectrogram and the heatmaps to the same temporal resolution
|
||||
# as the waveform.
|
||||
"audio": audio_to_xarray(
|
||||
wave,
|
||||
start_time=clip_annotation.clip.start_time,
|
||||
end_time=clip_annotation.clip.end_time,
|
||||
time_axis="audio_time",
|
||||
),
|
||||
"spectrogram": spectrogram,
|
||||
"detection": heatmaps.detection,
|
||||
"class": heatmaps.classes,
|
||||
"size": heatmaps.size,
|
||||
}
|
||||
)
|
||||
|
||||
return dataset.assign_attrs(
|
||||
title=f"Training example for {clip_annotation.uuid}",
|
||||
clip_annotation=clip_annotation.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude_defaults=True,
|
||||
exclude_unset=True,
|
||||
),
|
||||
return dict(
|
||||
audio=wave,
|
||||
spectrogram=spectrogram,
|
||||
detection_heatmap=heatmaps.detection,
|
||||
class_heatmap=heatmaps.classes,
|
||||
size_heatmap=heatmaps.size,
|
||||
)
|
||||
|
||||
|
||||
def _spec_to_xr(
|
||||
spec: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> xr.DataArray:
|
||||
data = spec.numpy()[0, 0]
|
||||
class PreprocessingDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
clips: Dataset,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
):
|
||||
self.clips = clips
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
self.labeller = labeller
|
||||
|
||||
height, width = data.shape
|
||||
def __getitem__(self, idx) -> dict:
|
||||
clip_annotation = self.clips[idx]
|
||||
example = generate_train_example(
|
||||
clip_annotation,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
labeller=self.labeller,
|
||||
)
|
||||
example["idx"] = idx
|
||||
return example
|
||||
|
||||
return xr.DataArray(
|
||||
data=data,
|
||||
dims=[
|
||||
"frequency",
|
||||
"time",
|
||||
],
|
||||
coords={
|
||||
"frequency": np.linspace(
|
||||
min_freq, max_freq, height, endpoint=False
|
||||
),
|
||||
"time": np.linspace(start_time, end_time, width, endpoint=False),
|
||||
},
|
||||
)
|
||||
def __len__(self) -> int:
|
||||
return len(self.clips)
|
||||
|
||||
|
||||
def _save_xr_dataset_to_file(
|
||||
dataset: xr.Dataset,
|
||||
def _save_example_to_file(
|
||||
example: PreprocessedExample,
|
||||
path: data.PathLike,
|
||||
) -> None:
|
||||
"""Save an xarray Dataset to a NetCDF file with compression.
|
||||
|
||||
Internal helper function used by `preprocess_single_annotation`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : xr.Dataset
|
||||
The training example dataset to save.
|
||||
path : PathLike
|
||||
The output file path (e.g., 'output/uuid.nc').
|
||||
"""
|
||||
dataset.to_netcdf(
|
||||
np.savez_compressed(
|
||||
path,
|
||||
encoding={
|
||||
"audio": {"zlib": True},
|
||||
"spectrogram": {"zlib": True},
|
||||
"size": {"zlib": True},
|
||||
"class": {"zlib": True},
|
||||
"detection": {"zlib": True},
|
||||
},
|
||||
audio=example.audio,
|
||||
spectrogram=example.spectrogram,
|
||||
detection_heatmap=example.detection_heatmap,
|
||||
class_heatmap=example.class_heatmap,
|
||||
size_heatmap=example.size_heatmap,
|
||||
clip_annotation=example.clip_annotation,
|
||||
)
|
||||
|
||||
|
||||
@ -254,50 +172,9 @@ def preprocess_annotations(
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess a sequence of ClipAnnotations and save results to disk.
|
||||
|
||||
Generates the full training example (spectrogram, heatmaps, etc.) for each
|
||||
`ClipAnnotation` in the input sequence using the provided `preprocessor`
|
||||
and `labeller`. Saves each example as a separate NetCDF file in the
|
||||
`output_dir`. Utilizes multiprocessing for potentially faster processing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotations : Sequence[data.ClipAnnotation]
|
||||
A sequence (e.g., list) of the clip annotations to preprocess.
|
||||
output_dir : PathLike
|
||||
Path to the directory where the processed NetCDF files will be saved.
|
||||
Will be created if it doesn't exist.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Initialized preprocessor object to generate spectrograms.
|
||||
labeller : ClipLabeller
|
||||
Initialized labeller function to generate target heatmaps.
|
||||
filename_fn : FilenameFn, optional
|
||||
Function to generate the output filename (without extension) for each
|
||||
`ClipAnnotation`. Defaults to using the annotation UUID via
|
||||
`_get_filename`.
|
||||
replace : bool, default=False
|
||||
If True, existing files in `output_dir` with the same generated name
|
||||
will be overwritten. If False (default), existing files are skipped.
|
||||
max_workers : int, optional
|
||||
Maximum number of worker processes to use for parallel processing.
|
||||
If None (default), uses the number of CPUs available (`os.cpu_count()`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
This function does not return anything; its side effect is creating
|
||||
files in the `output_dir`.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If processing fails for any individual annotation when using
|
||||
multiprocessing. The original exception will be attached as the cause.
|
||||
"""
|
||||
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
@ -312,88 +189,33 @@ def preprocess_annotations(
|
||||
max_workers=max_workers or "all available",
|
||||
)
|
||||
|
||||
with Pool(max_workers) as pool:
|
||||
list(
|
||||
tqdm(
|
||||
pool.imap_unordered(
|
||||
partial(
|
||||
preprocess_single_annotation,
|
||||
output_dir=output_dir,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
total=len(clip_annotations),
|
||||
desc="Preprocessing annotations",
|
||||
)
|
||||
if max_workers is None:
|
||||
max_workers = os.cpu_count() or 0
|
||||
|
||||
dataset = PreprocessingDataset(
|
||||
clips=list(clip_annotations),
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
)
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=None,
|
||||
shuffle=False,
|
||||
num_workers=max_workers,
|
||||
)
|
||||
|
||||
for batch in loader:
|
||||
clip_annotation = dataset.clips[batch["idx"]]
|
||||
filename = filename_fn(clip_annotation)
|
||||
path = output_dir / filename
|
||||
example = PreprocessedExample(
|
||||
clip_annotation=clip_annotation,
|
||||
spectrogram=batch["spectrogram"].numpy(),
|
||||
audio=batch["audio"].numpy(),
|
||||
class_heatmap=batch["class_heatmap"].numpy(),
|
||||
size_heatmap=batch["size_heatmap"].numpy(),
|
||||
detection_heatmap=batch["detection_heatmap"].numpy(),
|
||||
)
|
||||
logger.info("Finished preprocessing.")
|
||||
|
||||
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: data.PathLike,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
) -> None:
|
||||
"""Process a single ClipAnnotation and save the result to a file.
|
||||
|
||||
Internal function designed to be called by `preprocess_annotations`, often
|
||||
in parallel worker processes. It generates the training example using
|
||||
`generate_train_example` and saves it using `save_to_file`. Handles
|
||||
file existence checks based on the `replace` flag.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The single annotation to process.
|
||||
output_dir : Path
|
||||
The directory where the output NetCDF file should be saved.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Initialized preprocessor object.
|
||||
labeller : ClipLabeller
|
||||
Initialized labeller function.
|
||||
filename_fn : FilenameFn, default=_get_filename
|
||||
Function to determine the output filename.
|
||||
replace : bool, default=False
|
||||
Whether to overwrite existing output files.
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
filename = filename_fn(clip_annotation)
|
||||
path = output_dir / filename
|
||||
|
||||
if path.is_file() and not replace:
|
||||
logger.debug("Skipping existing file: {path}", path=path)
|
||||
return
|
||||
|
||||
if path.is_file() and replace:
|
||||
logger.debug("Removing existing file: {path}", path=path)
|
||||
path.unlink()
|
||||
|
||||
logger.debug("Processing annotation {uuid}", uuid=clip_annotation.uuid)
|
||||
|
||||
try:
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error(
|
||||
"Failed to process annotation {uuid} to {path}. Error: {error}",
|
||||
uuid=clip_annotation.uuid,
|
||||
path=path,
|
||||
error=error,
|
||||
)
|
||||
return
|
||||
|
||||
_save_xr_dataset_to_file(sample, path)
|
||||
_save_example_to_file(example, path)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
@ -42,6 +43,15 @@ class Heatmaps(NamedTuple):
|
||||
size: torch.Tensor
|
||||
|
||||
|
||||
class PreprocessedExample(NamedTuple):
|
||||
audio: np.ndarray
|
||||
spectrogram: np.ndarray
|
||||
detection_heatmap: np.ndarray
|
||||
class_heatmap: np.ndarray
|
||||
size_heatmap: np.ndarray
|
||||
clip_annotation: data.ClipAnnotation
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||
"""Type alias for the final clip labelling function.
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user