mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Fix train preprocessing
This commit is contained in:
parent
c36ef3ecb5
commit
76dda0a0e9
@ -65,16 +65,6 @@ __all__ = ["preprocess"]
|
|||||||
"top level, you don't need to specify this."
|
"top level, you don't need to specify this."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
|
||||||
"--force",
|
|
||||||
is_flag=True,
|
|
||||||
help=(
|
|
||||||
"If a preprocessed file already exists, this option tells the "
|
|
||||||
"program to overwrite it with the new preprocessed data. Use "
|
|
||||||
"this if you want to re-do the preprocessing even if the files "
|
|
||||||
"already exist."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -97,7 +87,6 @@ def preprocess(
|
|||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
force: bool = False,
|
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
dataset_field: Optional[str] = None,
|
dataset_field: Optional[str] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
@ -149,6 +138,5 @@ def preprocess(
|
|||||||
dataset,
|
dataset,
|
||||||
conf,
|
conf,
|
||||||
output=output,
|
output=output,
|
||||||
force=force,
|
|
||||||
max_workers=num_workers,
|
max_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -20,18 +20,16 @@ computationally intensive steps during the actual training loop. The module
|
|||||||
includes utilities for parallel processing using `multiprocessing`.
|
includes utilities for parallel processing using `multiprocessing`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
import os
|
||||||
from multiprocessing import Pool
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Sequence
|
from typing import Callable, Dict, Optional, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import torch.utils.data
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
@ -41,11 +39,10 @@ from batdetect2.targets import TargetConfig, build_targets
|
|||||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
from batdetect2.utils.arrays import audio_to_xarray
|
from batdetect2.typing.train import PreprocessedExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
"preprocess_single_annotation",
|
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"preprocess_dataset",
|
"preprocess_dataset",
|
||||||
"TrainPreprocessConfig",
|
"TrainPreprocessConfig",
|
||||||
@ -75,12 +72,16 @@ def preprocess_dataset(
|
|||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
config: TrainPreprocessConfig,
|
config: TrainPreprocessConfig,
|
||||||
output: Path,
|
output: Path,
|
||||||
force: bool = False,
|
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
labeller = build_clip_labeler(targets, config=config.labels)
|
labeller = build_clip_labeler(
|
||||||
|
targets,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
config=config.labels,
|
||||||
|
)
|
||||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||||
|
|
||||||
if not output.exists():
|
if not output.exists():
|
||||||
@ -93,7 +94,6 @@ def preprocess_dataset(
|
|||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
replace=force,
|
|
||||||
max_workers=max_workers,
|
max_workers=max_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,142 +103,60 @@ def generate_train_example(
|
|||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
) -> xr.Dataset:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Generate a complete training example for one annotation.
|
"""Generate a complete training example for one annotation."""
|
||||||
|
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
|
||||||
This function takes a single `ClipAnnotation`, applies the configured
|
spectrogram = preprocessor(wave)
|
||||||
preprocessing (`PreprocessorProtocol`) to get the processed waveform and
|
|
||||||
input spectrogram, applies the configured target generation
|
|
||||||
(`ClipLabeller`) to get the target heatmaps, and packages them all into a
|
|
||||||
single `xr.Dataset`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip_annotation : data.ClipAnnotation
|
|
||||||
The annotated clip to process. Contains the reference to the `Clip`
|
|
||||||
(audio segment) and the associated `SoundEventAnnotation` objects.
|
|
||||||
preprocessor : PreprocessorProtocol
|
|
||||||
An initialized preprocessor object responsible for loading/processing
|
|
||||||
audio and computing the input spectrogram.
|
|
||||||
labeller : ClipLabeller
|
|
||||||
An initialized clip labeller function responsible for generating the
|
|
||||||
target heatmaps (detection, class, size) from the `clip_annotation`
|
|
||||||
and the computed spectrogram.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.Dataset
|
|
||||||
An xarray Dataset containing the following data variables:
|
|
||||||
- `audio`: The preprocessed audio waveform (dims: 'audio_time').
|
|
||||||
- `spectrogram`: The computed input spectrogram
|
|
||||||
(dims: 'time', 'frequency').
|
|
||||||
- `detection`: The target detection heatmap
|
|
||||||
(dims: 'time', 'frequency').
|
|
||||||
- `class`: The target class heatmap
|
|
||||||
(dims: 'category', 'time', 'frequency').
|
|
||||||
- `size`: The target size heatmap
|
|
||||||
(dims: 'dimension', 'time', 'frequency').
|
|
||||||
The Dataset also includes metadata in its attributes.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
- The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time'
|
|
||||||
within the output Dataset to avoid coordinate conflicts with the
|
|
||||||
spectrogram's 'time' dimension when stored together.
|
|
||||||
- The original `ClipAnnotation` metadata is stored as a JSON string in the
|
|
||||||
Dataset's attributes for provenance.
|
|
||||||
"""
|
|
||||||
wave = audio_loader.load_clip(clip_annotation.clip)
|
|
||||||
|
|
||||||
spectrogram = _spec_to_xr(
|
|
||||||
preprocessor(torch.tensor(wave)),
|
|
||||||
start_time=clip_annotation.clip.start_time,
|
|
||||||
end_time=clip_annotation.clip.end_time,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
heatmaps = labeller(clip_annotation, spectrogram)
|
heatmaps = labeller(clip_annotation, spectrogram)
|
||||||
|
return dict(
|
||||||
dataset = xr.Dataset(
|
audio=wave,
|
||||||
{
|
spectrogram=spectrogram,
|
||||||
# NOTE: Need to rename the time dimension to avoid conflicts with
|
detection_heatmap=heatmaps.detection,
|
||||||
# the spectrogram time dimension, otherwise xarray will interpolate
|
class_heatmap=heatmaps.classes,
|
||||||
# the spectrogram and the heatmaps to the same temporal resolution
|
size_heatmap=heatmaps.size,
|
||||||
# as the waveform.
|
|
||||||
"audio": audio_to_xarray(
|
|
||||||
wave,
|
|
||||||
start_time=clip_annotation.clip.start_time,
|
|
||||||
end_time=clip_annotation.clip.end_time,
|
|
||||||
time_axis="audio_time",
|
|
||||||
),
|
|
||||||
"spectrogram": spectrogram,
|
|
||||||
"detection": heatmaps.detection,
|
|
||||||
"class": heatmaps.classes,
|
|
||||||
"size": heatmaps.size,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset.assign_attrs(
|
|
||||||
title=f"Training example for {clip_annotation.uuid}",
|
|
||||||
clip_annotation=clip_annotation.model_dump_json(
|
|
||||||
exclude_none=True,
|
|
||||||
exclude_defaults=True,
|
|
||||||
exclude_unset=True,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _spec_to_xr(
|
class PreprocessingDataset(torch.utils.data.Dataset):
|
||||||
spec: torch.Tensor,
|
def __init__(
|
||||||
start_time: float,
|
self,
|
||||||
end_time: float,
|
clips: Dataset,
|
||||||
min_freq: float,
|
audio_loader: AudioLoader,
|
||||||
max_freq: float,
|
preprocessor: PreprocessorProtocol,
|
||||||
) -> xr.DataArray:
|
labeller: ClipLabeller,
|
||||||
data = spec.numpy()[0, 0]
|
):
|
||||||
|
self.clips = clips
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.labeller = labeller
|
||||||
|
|
||||||
height, width = data.shape
|
def __getitem__(self, idx) -> dict:
|
||||||
|
clip_annotation = self.clips[idx]
|
||||||
|
example = generate_train_example(
|
||||||
|
clip_annotation,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
labeller=self.labeller,
|
||||||
|
)
|
||||||
|
example["idx"] = idx
|
||||||
|
return example
|
||||||
|
|
||||||
return xr.DataArray(
|
def __len__(self) -> int:
|
||||||
data=data,
|
return len(self.clips)
|
||||||
dims=[
|
|
||||||
"frequency",
|
|
||||||
"time",
|
|
||||||
],
|
|
||||||
coords={
|
|
||||||
"frequency": np.linspace(
|
|
||||||
min_freq, max_freq, height, endpoint=False
|
|
||||||
),
|
|
||||||
"time": np.linspace(start_time, end_time, width, endpoint=False),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _save_xr_dataset_to_file(
|
def _save_example_to_file(
|
||||||
dataset: xr.Dataset,
|
example: PreprocessedExample,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save an xarray Dataset to a NetCDF file with compression.
|
np.savez_compressed(
|
||||||
|
|
||||||
Internal helper function used by `preprocess_single_annotation`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
dataset : xr.Dataset
|
|
||||||
The training example dataset to save.
|
|
||||||
path : PathLike
|
|
||||||
The output file path (e.g., 'output/uuid.nc').
|
|
||||||
"""
|
|
||||||
dataset.to_netcdf(
|
|
||||||
path,
|
path,
|
||||||
encoding={
|
audio=example.audio,
|
||||||
"audio": {"zlib": True},
|
spectrogram=example.spectrogram,
|
||||||
"spectrogram": {"zlib": True},
|
detection_heatmap=example.detection_heatmap,
|
||||||
"size": {"zlib": True},
|
class_heatmap=example.class_heatmap,
|
||||||
"class": {"zlib": True},
|
size_heatmap=example.size_heatmap,
|
||||||
"detection": {"zlib": True},
|
clip_annotation=example.clip_annotation,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -254,50 +172,9 @@ def preprocess_annotations(
|
|||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Preprocess a sequence of ClipAnnotations and save results to disk.
|
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
|
||||||
|
|
||||||
Generates the full training example (spectrogram, heatmaps, etc.) for each
|
|
||||||
`ClipAnnotation` in the input sequence using the provided `preprocessor`
|
|
||||||
and `labeller`. Saves each example as a separate NetCDF file in the
|
|
||||||
`output_dir`. Utilizes multiprocessing for potentially faster processing.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip_annotations : Sequence[data.ClipAnnotation]
|
|
||||||
A sequence (e.g., list) of the clip annotations to preprocess.
|
|
||||||
output_dir : PathLike
|
|
||||||
Path to the directory where the processed NetCDF files will be saved.
|
|
||||||
Will be created if it doesn't exist.
|
|
||||||
preprocessor : PreprocessorProtocol
|
|
||||||
Initialized preprocessor object to generate spectrograms.
|
|
||||||
labeller : ClipLabeller
|
|
||||||
Initialized labeller function to generate target heatmaps.
|
|
||||||
filename_fn : FilenameFn, optional
|
|
||||||
Function to generate the output filename (without extension) for each
|
|
||||||
`ClipAnnotation`. Defaults to using the annotation UUID via
|
|
||||||
`_get_filename`.
|
|
||||||
replace : bool, default=False
|
|
||||||
If True, existing files in `output_dir` with the same generated name
|
|
||||||
will be overwritten. If False (default), existing files are skipped.
|
|
||||||
max_workers : int, optional
|
|
||||||
Maximum number of worker processes to use for parallel processing.
|
|
||||||
If None (default), uses the number of CPUs available (`os.cpu_count()`).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
None
|
|
||||||
This function does not return anything; its side effect is creating
|
|
||||||
files in the `output_dir`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
RuntimeError
|
|
||||||
If processing fails for any individual annotation when using
|
|
||||||
multiprocessing. The original exception will be attached as the cause.
|
|
||||||
"""
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
if not output_dir.is_dir():
|
if not output_dir.is_dir():
|
||||||
@ -312,88 +189,33 @@ def preprocess_annotations(
|
|||||||
max_workers=max_workers or "all available",
|
max_workers=max_workers or "all available",
|
||||||
)
|
)
|
||||||
|
|
||||||
with Pool(max_workers) as pool:
|
if max_workers is None:
|
||||||
list(
|
max_workers = os.cpu_count() or 0
|
||||||
tqdm(
|
|
||||||
pool.imap_unordered(
|
dataset = PreprocessingDataset(
|
||||||
partial(
|
clips=list(clip_annotations),
|
||||||
preprocess_single_annotation,
|
audio_loader=audio_loader,
|
||||||
output_dir=output_dir,
|
preprocessor=preprocessor,
|
||||||
filename_fn=filename_fn,
|
labeller=labeller,
|
||||||
replace=replace,
|
)
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
loader = torch.utils.data.DataLoader(
|
||||||
labeller=labeller,
|
dataset,
|
||||||
),
|
batch_size=None,
|
||||||
clip_annotations,
|
shuffle=False,
|
||||||
),
|
num_workers=max_workers,
|
||||||
total=len(clip_annotations),
|
)
|
||||||
desc="Preprocessing annotations",
|
|
||||||
)
|
for batch in loader:
|
||||||
|
clip_annotation = dataset.clips[batch["idx"]]
|
||||||
|
filename = filename_fn(clip_annotation)
|
||||||
|
path = output_dir / filename
|
||||||
|
example = PreprocessedExample(
|
||||||
|
clip_annotation=clip_annotation,
|
||||||
|
spectrogram=batch["spectrogram"].numpy(),
|
||||||
|
audio=batch["audio"].numpy(),
|
||||||
|
class_heatmap=batch["class_heatmap"].numpy(),
|
||||||
|
size_heatmap=batch["size_heatmap"].numpy(),
|
||||||
|
detection_heatmap=batch["detection_heatmap"].numpy(),
|
||||||
)
|
)
|
||||||
logger.info("Finished preprocessing.")
|
_save_example_to_file(example, path)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_single_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
output_dir: data.PathLike,
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
labeller: ClipLabeller,
|
|
||||||
filename_fn: FilenameFn = _get_filename,
|
|
||||||
replace: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Process a single ClipAnnotation and save the result to a file.
|
|
||||||
|
|
||||||
Internal function designed to be called by `preprocess_annotations`, often
|
|
||||||
in parallel worker processes. It generates the training example using
|
|
||||||
`generate_train_example` and saves it using `save_to_file`. Handles
|
|
||||||
file existence checks based on the `replace` flag.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip_annotation : data.ClipAnnotation
|
|
||||||
The single annotation to process.
|
|
||||||
output_dir : Path
|
|
||||||
The directory where the output NetCDF file should be saved.
|
|
||||||
preprocessor : PreprocessorProtocol
|
|
||||||
Initialized preprocessor object.
|
|
||||||
labeller : ClipLabeller
|
|
||||||
Initialized labeller function.
|
|
||||||
filename_fn : FilenameFn, default=_get_filename
|
|
||||||
Function to determine the output filename.
|
|
||||||
replace : bool, default=False
|
|
||||||
Whether to overwrite existing output files.
|
|
||||||
"""
|
|
||||||
output_dir = Path(output_dir)
|
|
||||||
|
|
||||||
filename = filename_fn(clip_annotation)
|
|
||||||
path = output_dir / filename
|
|
||||||
|
|
||||||
if path.is_file() and not replace:
|
|
||||||
logger.debug("Skipping existing file: {path}", path=path)
|
|
||||||
return
|
|
||||||
|
|
||||||
if path.is_file() and replace:
|
|
||||||
logger.debug("Removing existing file: {path}", path=path)
|
|
||||||
path.unlink()
|
|
||||||
|
|
||||||
logger.debug("Processing annotation {uuid}", uuid=clip_annotation.uuid)
|
|
||||||
|
|
||||||
try:
|
|
||||||
sample = generate_train_example(
|
|
||||||
clip_annotation,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
labeller=labeller,
|
|
||||||
)
|
|
||||||
except Exception as error:
|
|
||||||
logger.error(
|
|
||||||
"Failed to process annotation {uuid} to {path}. Error: {error}",
|
|
||||||
uuid=clip_annotation.uuid,
|
|
||||||
path=path,
|
|
||||||
error=error,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
_save_xr_dataset_to_file(sample, path)
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -42,6 +43,15 @@ class Heatmaps(NamedTuple):
|
|||||||
size: torch.Tensor
|
size: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessedExample(NamedTuple):
|
||||||
|
audio: np.ndarray
|
||||||
|
spectrogram: np.ndarray
|
||||||
|
detection_heatmap: np.ndarray
|
||||||
|
class_heatmap: np.ndarray
|
||||||
|
size_heatmap: np.ndarray
|
||||||
|
clip_annotation: data.ClipAnnotation
|
||||||
|
|
||||||
|
|
||||||
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||||
"""Type alias for the final clip labelling function.
|
"""Type alias for the final clip labelling function.
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user