diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py index 276c5f4..ca4665b 100644 --- a/src/batdetect2/plotting/clip_annotations.py +++ b/src/batdetect2/plotting/clip_annotations.py @@ -17,8 +17,6 @@ def plot_clip_annotation( figsize: Optional[Tuple[int, int]] = None, ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, - add_colorbar: bool = False, - add_labels: bool = False, add_points: bool = False, cmap: str = "gray", alpha: float = 1, @@ -31,8 +29,6 @@ def plot_clip_annotation( figsize=figsize, ax=ax, audio_dir=audio_dir, - add_colorbar=add_colorbar, - add_labels=add_labels, spec_cmap=cmap, ) diff --git a/src/batdetect2/plotting/clip_predictions.py b/src/batdetect2/plotting/clip_predictions.py index 994f984..86f4833 100644 --- a/src/batdetect2/plotting/clip_predictions.py +++ b/src/batdetect2/plotting/clip_predictions.py @@ -21,8 +21,6 @@ def plot_clip_prediction( figsize: Optional[Tuple[int, int]] = None, ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, - add_colorbar: bool = False, - add_labels: bool = False, add_legend: bool = False, spec_cmap: str = "gray", linewidth: float = 1, @@ -34,8 +32,6 @@ def plot_clip_prediction( figsize=figsize, ax=ax, audio_dir=audio_dir, - add_colorbar=add_colorbar, - add_labels=add_labels, spec_cmap=spec_cmap, ) diff --git a/src/batdetect2/plotting/clips.py b/src/batdetect2/plotting/clips.py index df1fb16..28a806f 100644 --- a/src/batdetect2/plotting/clips.py +++ b/src/batdetect2/plotting/clips.py @@ -1,13 +1,13 @@ from typing import Optional, Tuple import matplotlib.pyplot as plt +import torch from matplotlib.axes import Axes from soundevent import data -from batdetect2.preprocess import ( - PreprocessorProtocol, - get_default_preprocessor, -) +from batdetect2.plotting.common import plot_spectrogram +from batdetect2.preprocess import build_audio_loader, get_default_preprocessor +from batdetect2.typing import AudioLoader, PreprocessorProtocol __all__ = [ "plot_clip", @@ -16,12 +16,11 @@ __all__ = [ def plot_clip( clip: data.Clip, + audio_loader: Optional[AudioLoader] = None, preprocessor: Optional[PreprocessorProtocol] = None, figsize: Optional[Tuple[int, int]] = None, ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, - add_colorbar: bool = False, - add_labels: bool = False, spec_cmap: str = "gray", ) -> Axes: if ax is None: @@ -30,15 +29,16 @@ def plot_clip( if preprocessor is None: preprocessor = get_default_preprocessor() - spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir) + if audio_loader is None: + audio_loader = build_audio_loader() - spec.plot( # type: ignore + wav = torch.tensor(audio_loader.load_clip(clip, audio_dir=audio_dir)) + spec = preprocessor(wav) + + plot_spectrogram( + spec, ax=ax, - add_colorbar=add_colorbar, cmap=spec_cmap, - add_labels=add_labels, - vmin=spec.min().item(), - vmax=spec.max().item(), ) return ax diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index 7d5b19c..63561ce 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -40,8 +40,6 @@ def plot_matches( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, color_mapper: Optional[TagColorMapper] = None, - add_colorbar: bool = False, - add_labels: bool = False, add_points: bool = False, fill: bool = False, spec_cmap: str = "gray", @@ -59,8 +57,6 @@ def plot_matches( ax=ax, figsize=figsize, audio_dir=audio_dir, - add_colorbar=add_colorbar, - add_labels=add_labels, spec_cmap=spec_cmap, ) @@ -128,8 +124,6 @@ def plot_false_positive_match( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, duration: float = DEFAULT_DURATION, - add_colorbar: bool = False, - add_labels: bool = False, add_points: bool = False, fill: bool = False, spec_cmap: str = "gray", @@ -160,8 +154,6 @@ def plot_false_positive_match( figsize=figsize, ax=ax, audio_dir=audio_dir, - add_colorbar=add_colorbar, - add_labels=add_labels, spec_cmap=spec_cmap, ) @@ -196,8 +188,6 @@ def plot_false_negative_match( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, duration: float = DEFAULT_DURATION, - add_colorbar: bool = False, - add_labels: bool = False, add_points: bool = False, fill: bool = False, spec_cmap: str = "gray", @@ -226,8 +216,6 @@ def plot_false_negative_match( figsize=figsize, ax=ax, audio_dir=audio_dir, - add_colorbar=add_colorbar, - add_labels=add_labels, spec_cmap=spec_cmap, ) @@ -262,8 +250,6 @@ def plot_true_positive_match( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, duration: float = DEFAULT_DURATION, - add_colorbar: bool = False, - add_labels: bool = False, add_points: bool = False, fill: bool = False, spec_cmap: str = "gray", @@ -294,8 +280,6 @@ def plot_true_positive_match( figsize=figsize, ax=ax, audio_dir=audio_dir, - add_colorbar=add_colorbar, - add_labels=add_labels, spec_cmap=spec_cmap, ) @@ -343,8 +327,6 @@ def plot_cross_trigger_match( ax: Optional[Axes] = None, audio_dir: Optional[data.PathLike] = None, duration: float = DEFAULT_DURATION, - add_colorbar: bool = False, - add_labels: bool = False, add_points: bool = False, fill: bool = False, spec_cmap: str = "gray", @@ -375,8 +357,6 @@ def plot_cross_trigger_match( figsize=figsize, ax=ax, audio_dir=audio_dir, - add_colorbar=add_colorbar, - add_labels=add_labels, spec_cmap=spec_cmap, ) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index f3b2830..c7862c8 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -67,7 +67,7 @@ class ValidationMetrics(Callback): for class_name, fig in plot_example_gallery( matches, - preprocessor=pl_module.preprocessor, + preprocessor=pl_module.model.preprocessor, n_examples=4, ): plotter( @@ -94,7 +94,7 @@ class ValidationMetrics(Callback): ) -> None: matches = _match_all_collected_examples( self._matches, - pl_module.targets, + pl_module.model.targets, config=self.match_config, ) @@ -127,8 +127,8 @@ class ValidationMetrics(Callback): batch, outputs, dataset=self.get_dataset(trainer), - postprocessor=pl_module.postprocessor, - targets=pl_module.targets, + postprocessor=pl_module.model.postprocessor, + targets=pl_module.model.targets, ) ) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 928e687..58ecfa5 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -73,12 +73,12 @@ class LabeledDataset(Dataset): return load_preprocessed_example(self.filenames[idx]) def get_clip_annotation(self, idx) -> data.ClipAnnotation: - item = np.load(self.filenames[idx]) - return item["clip_annotation"] + item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+") + return item["clip_annotation"].tolist() def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample: - item = np.load(path) + item = np.load(path, mmap_mode="r+") return PreprocessedExample( audio=torch.tensor(item["audio"]), spectrogram=torch.tensor(item["spectrogram"]), diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index b2fb261..1e42651 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -12,6 +12,8 @@ __all__ = [ class TrainingModule(L.LightningModule): + model: Model + def __init__( self, model: Model, diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py index fdc0c70..1a21a77 100644 --- a/src/batdetect2/train/preprocess.py +++ b/src/batdetect2/train/preprocess.py @@ -1,24 +1,4 @@ -"""Preprocesses datasets for BatDetect2 model training. - -This module provides functions to take a collection of annotated audio clips -(`soundevent.data.ClipAnnotation`) and process them into the final format -required for training a BatDetect2 model. This typically involves: - -1. Loading the relevant audio segment for each annotation using a configured - `PreprocessorProtocol`. -2. Generating the corresponding input spectrogram using the - `PreprocessorProtocol`. -3. Generating the target heatmaps (detection, classification, size) using a - configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic). -4. Packaging the input spectrogram, target heatmaps, and potentially the - processed audio waveform into an `xarray.Dataset`. -5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF) - in an output directory. - -This offline preprocessing is often preferred for large datasets as it avoids -computationally intensive steps during the actual training loop. The module -includes utilities for parallel processing using `multiprocessing`. -""" +"""Preprocesses datasets for BatDetect2 model training.""" import os from pathlib import Path diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index a6fda60..20f37b1 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from typing import List, Optional +import torch from lightning import Trainer from lightning.pytorch.callbacks import Callback, ModelCheckpoint from loguru import logger @@ -25,7 +26,12 @@ from batdetect2.train.dataset import ( from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger from batdetect2.train.losses import build_loss -from batdetect2.typing import PreprocessorProtocol, TargetProtocol +from batdetect2.typing import ( + PreprocessorProtocol, + TargetProtocol, + TrainExample, +) +from batdetect2.utils.arrays import adjust_width __all__ = [ "build_train_dataset", @@ -53,7 +59,7 @@ def train( else: module = build_training_module(config) - trainer = build_trainer(config, targets=module.targets) + trainer = build_trainer(config, targets=module.model.targets) train_dataloader = build_train_loader( train_examples, @@ -160,6 +166,7 @@ def build_train_loader( batch_size=loader_conf.batch_size, shuffle=loader_conf.shuffle, num_workers=num_workers, + collate_fn=_collate_fn, ) @@ -188,6 +195,28 @@ def build_val_loader( batch_size=loader_conf.batch_size, shuffle=loader_conf.shuffle, num_workers=num_workers, + collate_fn=_collate_fn, + ) + + +def _collate_fn(batch: List[TrainExample]) -> TrainExample: + max_width = max(item.spec.shape[-1] for item in batch) + return TrainExample( + spec=torch.stack( + [adjust_width(item.spec, max_width) for item in batch] + ), + detection_heatmap=torch.stack( + [adjust_width(item.detection_heatmap, max_width) for item in batch] + ), + size_heatmap=torch.stack( + [adjust_width(item.size_heatmap, max_width) for item in batch] + ), + class_heatmap=torch.stack( + [adjust_width(item.class_heatmap, max_width) for item in batch] + ), + idx=torch.stack([item.idx for item in batch]), + start_time=torch.stack([item.start_time for item in batch]), + end_time=torch.stack([item.end_time for item in batch]), )