From 59bd14bc79d58a7261c519526a20294d6449a7e3 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 23 Apr 2025 23:15:08 +0100 Subject: [PATCH] Added clips for random cliping and augmentations --- batdetect2/modules.py | 153 --------- batdetect2/train/__init__.py | 16 +- batdetect2/train/augmentations.py | 130 ++------ batdetect2/train/clips.py | 184 ++++++++++ batdetect2/train/dataset.py | 85 ++--- batdetect2/train/lightning.py | 71 ++++ batdetect2/train/losses.py | 517 +++++++++++++++++++++++------ batdetect2/train/preprocess.py | 1 + batdetect2/train/types.py | 54 ++- tests/test_targets/test_classes.py | 11 - tests/test_train/test_clips.py | 466 ++++++++++++++++++++++++++ tests/test_train/test_lightning.py | 56 ++++ 12 files changed, 1317 insertions(+), 427 deletions(-) delete mode 100644 batdetect2/modules.py create mode 100644 batdetect2/train/clips.py create mode 100644 batdetect2/train/lightning.py create mode 100644 tests/test_train/test_clips.py create mode 100644 tests/test_train/test_lightning.py diff --git a/batdetect2/modules.py b/batdetect2/modules.py deleted file mode 100644 index c79a6a3..0000000 --- a/batdetect2/modules.py +++ /dev/null @@ -1,153 +0,0 @@ -from pathlib import Path -from typing import Optional - -import lightning as L -import torch -from pydantic import Field -from soundevent import data -from torch.optim.adam import Adam -from torch.optim.lr_scheduler import CosineAnnealingLR - -from batdetect2.configs import BaseConfig -from batdetect2.models import ( - BackboneConfig, - BBoxHead, - ClassifierHead, - ModelOutput, - build_backbone, -) -from batdetect2.postprocess import ( - PostprocessConfig, - postprocess_model_outputs, -) -from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip -from batdetect2.targets import ( - TargetConfig, - build_decoder, - build_target_encoder, - get_class_names, -) -from batdetect2.train import TrainExample, TrainingConfig, compute_loss - -__all__ = [ - "DetectorModel", -] - - -class ModuleConfig(BaseConfig): - train: TrainingConfig = Field(default_factory=TrainingConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) - architecture: BackboneConfig = Field(default_factory=BackboneConfig) - preprocessing: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - postprocessing: PostprocessConfig = Field( - default_factory=PostprocessConfig - ) - - -class DetectorModel(L.LightningModule): - config: ModuleConfig - - def __init__( - self, - config: Optional[ModuleConfig] = None, - ): - super().__init__() - - self.config = config or ModuleConfig() - self.save_hyperparameters() - - self.backbone = build_model_backbone(self.config.architecture) - - self.classifier = ClassifierHead( - num_classes=len(self.config.targets.classes), - in_channels=self.backbone.out_channels, - ) - - self.bbox = BBoxHead(in_channels=self.backbone.out_channels) - - conf = self.config.train.loss.classification - self.class_weights = ( - torch.tensor(conf.class_weights) if conf.class_weights else None - ) - - # Training targets - self.class_names = get_class_names(self.config.targets.classes) - self.encoder = build_target_encoder( - self.config.targets.classes, - replacement_rules=self.config.targets.replace, - ) - self.decoder = build_decoder(self.config.targets.classes) - self.example_input_array = torch.randn([1, 1, 128, 512]) - - def forward(self, spec: torch.Tensor) -> ModelOutput: - features = self.backbone(spec) - detection_probs, classification_probs = self.classifier(features) - size_preds = self.bbox(features) - return ModelOutput( - detection_probs=detection_probs, - size_preds=size_preds, - class_probs=classification_probs, - features=features, - ) - - def training_step(self, batch: TrainExample): - outputs = self.forward(batch.spec) - losses = compute_loss( - batch, - outputs, - conf=self.config.train.loss, - class_weights=self.class_weights, - ) - - self.log("train/loss/total", losses.total, prog_bar=True, logger=True) - self.log("train/loss/detection", losses.total, logger=True) - self.log("train/loss/size", losses.total, logger=True) - self.log("train/loss/classification", losses.total, logger=True) - - return losses.total - - def validation_step(self, batch: TrainExample, batch_idx: int) -> None: - outputs = self.forward(batch.spec) - - losses = compute_loss( - batch, - outputs, - conf=self.config.train.loss, - class_weights=self.class_weights, - ) - - self.log("val/loss/total", losses.total, prog_bar=True, logger=True) - self.log("val/loss/detection", losses.total, logger=True) - self.log("val/loss/size", losses.total, logger=True) - self.log("val/loss/classification", losses.total, logger=True) - - def on_validation_epoch_end(self) -> None: - self.validation_predictions.clear() - - def configure_optimizers(self): - conf = self.config.train.optimizer - optimizer = Adam(self.parameters(), lr=conf.learning_rate) - scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max) - return [optimizer], [scheduler] - - def process_clip( - self, - clip: data.Clip, - audio_dir: Optional[Path] = None, - ) -> data.ClipPrediction: - spec = preprocess_audio_clip( - clip, - config=self.config.preprocessing, - audio_dir=audio_dir, - ) - tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0) - outputs = self.forward(tensor) - return postprocess_model_outputs( - outputs, - clips=[clip], - classes=self.class_names, - decoder=self.decoder, - config=self.config.postprocessing, - )[0] diff --git a/batdetect2/train/__init__.py b/batdetect2/train/__init__.py index 6caba28..be4c1b0 100644 --- a/batdetect2/train/__init__.py +++ b/batdetect2/train/__init__.py @@ -11,18 +11,19 @@ from batdetect2.train.augmentations import ( mask_time, mix_examples, scale_volume, - select_subclip, warp_spectrogram, ) +from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.config import TrainingConfig, load_train_config from batdetect2.train.dataset import ( LabeledDataset, + RandomExampleSource, SubclipConfig, TrainExample, - get_preprocessed_files, + list_preprocessed_files, ) from batdetect2.train.labels import load_label_config -from batdetect2.train.losses import compute_loss +from batdetect2.train.losses import LossFunction, build_loss from batdetect2.train.preprocess import ( generate_train_example, preprocess_annotations, @@ -34,6 +35,8 @@ __all__ = [ "EchoAugmentationConfig", "FrequencyMaskAugmentationConfig", "LabeledDataset", + "LossFunction", + "RandomExampleSource", "SubclipConfig", "TimeMaskAugmentationConfig", "TrainExample", @@ -43,9 +46,11 @@ __all__ = [ "WarpAugmentationConfig", "add_echo", "build_augmentations", - "compute_loss", + "build_clipper", + "build_loss", "generate_train_example", - "get_preprocessed_files", + "list_preprocessed_files", + "load_label_config", "load_train_config", "load_trainer_config", "mask_frequency", @@ -56,5 +61,4 @@ __all__ = [ "select_subclip", "train", "warp_spectrogram", - "load_label_config", ] diff --git a/batdetect2/train/augmentations.py b/batdetect2/train/augmentations.py index 322d4ee..f8d3fd0 100644 --- a/batdetect2/train/augmentations.py +++ b/batdetect2/train/augmentations.py @@ -37,24 +37,24 @@ from batdetect2.train.types import Augmentation from batdetect2.utils.arrays import adjust_width __all__ = [ + "AugmentationConfig", "AugmentationsConfig", - "load_augmentation_config", - "build_augmentations", - "select_subclip", - "mix_examples", - "add_echo", - "scale_volume", - "warp_spectrogram", - "mask_time", - "mask_frequency", - "MixAugmentationConfig", + "DEFAULT_AUGMENTATION_CONFIG", "EchoAugmentationConfig", + "ExampleSource", + "FrequencyMaskAugmentationConfig", + "MixAugmentationConfig", + "TimeMaskAugmentationConfig", "VolumeAugmentationConfig", "WarpAugmentationConfig", - "TimeMaskAugmentationConfig", - "FrequencyMaskAugmentationConfig", - "AugmentationConfig", - "ExampleSource", + "add_echo", + "build_augmentations", + "load_augmentation_config", + "mask_frequency", + "mask_time", + "mix_examples", + "scale_volume", + "warp_spectrogram", ] ExampleSource = Callable[[], xr.Dataset] @@ -64,92 +64,6 @@ Used by the `mix_examples` augmentation to fetch another example to mix with. """ -def select_subclip( - example: xr.Dataset, - start_time: Optional[float] = None, - duration: Optional[float] = None, - width: Optional[int] = None, - random: bool = False, -) -> xr.Dataset: - """Extract a sub-clip (time segment) from a training example dataset. - - Selects a portion of the 'time' dimension from all relevant DataArrays - (`audio`, `spectrogram`, `detection`, `class`, `size`) within the example - Dataset. The segment can be defined by a fixed start time and - duration/width, or a random start time can be chosen. - - Parameters - ---------- - example : xr.Dataset - The input training example containing 'audio', 'spectrogram', and - target heatmaps, all with compatible 'time' (or 'audio_time') - coordinates. - start_time : float, optional - Desired start time (seconds) of the subclip. If None and `random` is - False, starts from the beginning of the example. If None and `random` - is True, a random start time is chosen. - duration : float, optional - Desired duration (seconds) of the subclip. Either `duration` or `width` - must be provided. - width : int, optional - Desired width (number of time bins) of the subclip's - spectrogram/heatmaps. Either `duration` or `width` must be provided. If - both are given, `duration` takes precedence. - random : bool, default=False - If True and `start_time` is None, selects a random start time ensuring - the subclip fits within the original example's duration. - - Returns - ------- - xr.Dataset - A new dataset containing only the selected time segment. Coordinates - are adjusted accordingly. Returns the original example if the requested - subclip cannot be extracted (e.g., duration too long). - - Raises - ------ - ValueError - If neither `duration` nor `width` is provided, or if time coordinates - are missing or invalid. - """ - step = arrays.get_dim_step(example, "time") # type: ignore - start, stop = arrays.get_dim_range(example, "time") # type: ignore - - if width is None: - if duration is None: - raise ValueError("Either duration or width must be provided") - - width = int(np.floor(duration / step)) - - if duration is None: - duration = width * step - - if start_time is None: - if random: - start_time = np.random.uniform(start, max(stop - duration, start)) - else: - start_time = start - - if start_time + duration > stop: - return example - - start_index = arrays.get_coord_index( - example, # type: ignore - "time", - start_time, - ) - - end_index = start_index + width - 1 - - start_time = example.time.values[start_index] - end_time = example.time.values[end_index] - - return example.sel( - time=slice(start_time, end_time), - audio_time=slice(start_time, end_time + step), - ) - - class MixAugmentationConfig(BaseConfig): """Configuration for MixUp augmentation (mixing two examples).""" @@ -878,9 +792,21 @@ def build_augmentation_from_config( ) +DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( + steps=[ + MixAugmentationConfig(), + EchoAugmentationConfig(), + VolumeAugmentationConfig(), + WarpAugmentationConfig(), + TimeMaskAugmentationConfig(), + FrequencyMaskAugmentationConfig(), + ] +) + + def build_augmentations( - config: AugmentationsConfig, preprocessor: PreprocessorProtocol, + config: Optional[AugmentationsConfig] = None, example_source: Optional[ExampleSource] = None, ) -> Augmentation: """Build a composite augmentation pipeline function from configuration. @@ -915,6 +841,8 @@ def build_augmentations( NotImplementedError If an unknown `augmentation_type` is encountered in `config.steps`. """ + config = config or DEFAULT_AUGMENTATION_CONFIG + augmentations = [] for step_config in config.steps: diff --git a/batdetect2/train/clips.py b/batdetect2/train/clips.py new file mode 100644 index 0000000..c1d39e8 --- /dev/null +++ b/batdetect2/train/clips.py @@ -0,0 +1,184 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import xarray as xr +from soundevent import arrays + +from batdetect2.configs import BaseConfig +from batdetect2.train.types import ClipperProtocol + +DEFAULT_TRAIN_CLIP_DURATION = 0.513 +DEFAULT_MAX_EMPTY_CLIP = 0.1 + + +class ClipperConfig(BaseConfig): + duration: float = DEFAULT_TRAIN_CLIP_DURATION + random: bool = True + max_empty: float = DEFAULT_MAX_EMPTY_CLIP + + +class Clipper(ClipperProtocol): + def __init__( + self, + duration: float = 0.5, + max_empty: float = 0.2, + random: bool = True, + ): + self.duration = duration + self.random = random + self.max_empty = max_empty + + def extract_clip( + self, example: xr.Dataset + ) -> Tuple[xr.Dataset, float, float]: + step = arrays.get_dim_step( + example.spectrogram, + dim=arrays.Dimensions.time.value, + ) + duration = ( + arrays.get_dim_width( + example.spectrogram, + dim=arrays.Dimensions.time.value, + ) + + step + ) + + start_time = 0 + if self.random: + start_time = np.random.uniform( + -self.max_empty, + duration - self.duration + self.max_empty, + ) + + subclip = select_subclip( + example, + start=start_time, + span=self.duration, + dim="time", + ) + + return ( + select_subclip( + subclip, + start=start_time, + span=self.duration, + dim="audio_time", + ), + start_time, + start_time + self.duration, + ) + + +def build_clipper(config: Optional[ClipperConfig] = None) -> ClipperProtocol: + config = config or ClipperConfig() + return Clipper( + duration=config.duration, + max_empty=config.max_empty, + random=config.random, + ) + + +def select_subclip( + dataset: xr.Dataset, + span: float, + start: float, + fill_value: float = 0, + dim: str = "time", +) -> xr.Dataset: + width = _compute_expected_width( + dataset, # type: ignore + span, + dim=dim, + ) + + coord = dataset.coords[dim] + + if len(coord) == width: + return dataset + + new_coords, start_pad, end_pad, dim_slice = _extract_coordinate( + coord, start, span + ) + + data_vars = {} + for name, data_array in dataset.data_vars.items(): + if dim not in data_array.dims: + data_vars[name] = data_array + continue + + if width == data_array.sizes[dim]: + data_vars[name] = data_array + continue + + sliced = data_array.isel({dim: dim_slice}).data + + if start_pad > 0 or end_pad > 0: + padding = [ + [0, 0] if other_dim != dim else [start_pad, end_pad] + for other_dim in data_array.dims + ] + sliced = np.pad(sliced, padding, constant_values=fill_value) + + data_vars[name] = xr.DataArray( + data=sliced, + dims=data_array.dims, + coords={**data_array.coords, dim: new_coords}, + attrs=data_array.attrs, + ) + + return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs) + + +def _extract_coordinate( + coord: xr.DataArray, + start: float, + span: float, +) -> Tuple[xr.Variable, int, int, slice]: + step = arrays.get_dim_step(coord, str(coord.name)) + + current_width = len(coord) + expected_width = int(np.floor(span / step)) + + coord_start = float(coord[0]) + offset = start - coord_start + + start_index = int(np.floor(offset / step)) + end_index = start_index + expected_width + + if start_index > current_width: + raise ValueError("Requested span does not overlap with current range") + + if end_index < 0: + raise ValueError("Requested span does not overlap with current range") + + corrected_start = float(start_index * step) + corrected_end = float(end_index * step) + + start_index_offset = max(0, -start_index) + end_index_offset = max(0, end_index - current_width) + + sl = slice( + start_index if start_index >= 0 else None, + end_index if end_index < current_width else None, + ) + + return ( + arrays.create_range_dim( + str(coord.name), + start=corrected_start, + stop=corrected_end, + step=step, + ), + start_index_offset, + end_index_offset, + sl, + ) + + +def _compute_expected_width( + array: Union[xr.DataArray, xr.Dataset], + duration: float, + dim: str, +) -> int: + step = arrays.get_dim_step(array, dim) # type: ignore + return int(np.floor(duration / step)) diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index cc839dd..4f76051 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -1,6 +1,5 @@ -import os from pathlib import Path -from typing import NamedTuple, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple import numpy as np import torch @@ -13,28 +12,14 @@ from batdetect2.configs import BaseConfig from batdetect2.train.augmentations import ( Augmentation, AugmentationsConfig, - select_subclip, ) -from batdetect2.train.preprocess import PreprocessorProtocol -from batdetect2.utils.tensors import adjust_width +from batdetect2.train.types import ClipperProtocol, TrainExample __all__ = [ - "TrainExample", "LabeledDataset", ] -PathLike = Union[Path, str, os.PathLike] - - -class TrainExample(NamedTuple): - spec: torch.Tensor - detection_heatmap: torch.Tensor - class_heatmap: torch.Tensor - size_heatmap: torch.Tensor - idx: torch.Tensor - - class SubclipConfig(BaseConfig): duration: Optional[float] = None width: int = 512 @@ -51,14 +36,12 @@ class DatasetConfig(BaseConfig): class LabeledDataset(Dataset): def __init__( self, - preprocessor: PreprocessorProtocol, - filenames: Sequence[PathLike], - subclip: Optional[SubclipConfig] = None, + filenames: Sequence[data.PathLike], + clipper: ClipperProtocol, augmentation: Optional[Augmentation] = None, ): - self.preprocessor = preprocessor self.filenames = filenames - self.subclip = subclip + self.clipper = clipper self.augmentation = augmentation def __len__(self): @@ -66,14 +49,7 @@ class LabeledDataset(Dataset): def __getitem__(self, idx) -> TrainExample: dataset = self.get_dataset(idx) - - if self.subclip: - dataset = select_subclip( - dataset, - duration=self.subclip.duration, - width=self.subclip.width, - random=self.subclip.random, - ) + dataset, start_time, end_time = self.clipper.extract_clip(dataset) if self.augmentation: dataset = self.augmentation(dataset) @@ -84,37 +60,31 @@ class LabeledDataset(Dataset): class_heatmap=self.to_tensor(dataset["class"]), size_heatmap=self.to_tensor(dataset["size"]), idx=torch.tensor(idx), + start_time=start_time, + end_time=end_time, ) @classmethod def from_directory( cls, - directory: PathLike, - preprocessor: PreprocessorProtocol, + directory: data.PathLike, + clipper: ClipperProtocol, extension: str = ".nc", - subclip: Optional[SubclipConfig] = None, augmentation: Optional[Augmentation] = None, ): return cls( - preprocessor=preprocessor, - filenames=get_preprocessed_files(directory, extension), - subclip=subclip, + filenames=list_preprocessed_files(directory, extension), + clipper=clipper, augmentation=augmentation, ) - def get_random_example(self) -> xr.Dataset: + def get_random_example(self) -> Tuple[xr.Dataset, float, float]: idx = np.random.randint(0, len(self)) dataset = self.get_dataset(idx) - if self.subclip: - dataset = select_subclip( - dataset, - duration=self.subclip.duration, - width=self.subclip.width, - random=self.subclip.random, - ) + dataset, start_time, end_time = self.clipper.extract_clip(dataset) - return dataset + return dataset, start_time, end_time def get_dataset(self, idx) -> xr.Dataset: return xr.open_dataset(self.filenames[idx]) @@ -129,16 +99,23 @@ class LabeledDataset(Dataset): array: xr.DataArray, dtype=np.float32, ) -> torch.Tensor: - tensor = torch.tensor(array.values.astype(dtype)) - - if not self.subclip: - return tensor - - width = self.subclip.width - return adjust_width(tensor, width) + return torch.tensor(array.values.astype(dtype)) -def get_preprocessed_files( - directory: PathLike, extension: str = ".nc" +def list_preprocessed_files( + directory: data.PathLike, extension: str = ".nc" ) -> Sequence[Path]: return list(Path(directory).glob(f"*{extension}")) + + +class RandomExampleSource: + def __init__(self, filenames: List[str], clipper: ClipperProtocol): + self.filenames = filenames + self.clipper = clipper + + def __call__(self): + index = int(np.random.randint(len(self.filenames))) + filename = self.filenames[index] + dataset = xr.open_dataset(filename) + example, _, _ = self.clipper.extract_clip(dataset) + return example diff --git a/batdetect2/train/lightning.py b/batdetect2/train/lightning.py new file mode 100644 index 0000000..096ccd9 --- /dev/null +++ b/batdetect2/train/lightning.py @@ -0,0 +1,71 @@ +import lightning as L +import torch +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR + +from batdetect2.models import ( + DetectionModel, + ModelOutput, +) +from batdetect2.postprocess.types import PostprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import TargetProtocol +from batdetect2.train import TrainExample +from batdetect2.train.types import LossProtocol + +__all__ = [ + "TrainingModule", +] + + +class TrainingModule(L.LightningModule): + def __init__( + self, + detector: DetectionModel, + loss: LossProtocol, + targets: TargetProtocol, + preprocessor: PreprocessorProtocol, + postprocessor: PostprocessorProtocol, + learning_rate: float = 0.001, + t_max: int = 100, + ): + super().__init__() + + self.loss = loss + self.detector = detector + self.preprocessor = preprocessor + self.targets = targets + self.postprocessor = postprocessor + + self.learning_rate = learning_rate + self.t_max = t_max + + self.save_hyperparameters() + + def forward(self, spec: torch.Tensor) -> ModelOutput: + return self.detector(spec) + + def training_step(self, batch: TrainExample): + outputs = self.forward(batch.spec) + losses = self.loss(outputs, batch) + + self.log("train/loss/total", losses.total, prog_bar=True, logger=True) + self.log("train/loss/detection", losses.total, logger=True) + self.log("train/loss/size", losses.total, logger=True) + self.log("train/loss/classification", losses.total, logger=True) + + return losses.total + + def validation_step(self, batch: TrainExample, batch_idx: int) -> None: + outputs = self.forward(batch.spec) + losses = self.loss(outputs, batch) + + self.log("val/loss/total", losses.total, prog_bar=True, logger=True) + self.log("val/loss/detection", losses.total, logger=True) + self.log("val/loss/size", losses.total, logger=True) + self.log("val/loss/classification", losses.total, logger=True) + + def configure_optimizers(self): + optimizer = Adam(self.parameters(), lr=self.learning_rate) + scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) + return [optimizer], [scheduler] diff --git a/batdetect2/train/losses.py b/batdetect2/train/losses.py index 27a132c..b4083bb 100644 --- a/batdetect2/train/losses.py +++ b/batdetect2/train/losses.py @@ -1,115 +1,323 @@ -from typing import NamedTuple, Optional +"""Loss functions and configurations for training BatDetect2 models. +This module defines the loss functions used to train BatDetect2 models, +including individual loss components for different prediction tasks (detection, +classification, size regression) and a main coordinating loss function that +combines them. + +It utilizes common loss types like L1 loss (`BBoxLoss`) for regression and +Focal Loss (`FocalLoss`) for handling class imbalance in dense detection and +classification tasks. Configuration objects (`LossConfig`, etc.) allow for easy +customization of loss parameters and weights via configuration files. + +The primary entry points are: +- `LossFunction`: An `nn.Module` that computes the weighted sum of individual + loss components given model outputs and ground truth targets. +- `build_loss`: A factory function that constructs the `LossFunction` based + on a `LossConfig` object. +- `LossConfig`: The Pydantic model for configuring loss weights and parameters. +""" + +from typing import Optional + +import numpy as np import torch import torch.nn.functional as F from pydantic import Field +from torch import nn from batdetect2.configs import BaseConfig from batdetect2.models.types import ModelOutput from batdetect2.train.dataset import TrainExample +from batdetect2.train.types import Losses, LossProtocol __all__ = [ - "bbox_size_loss", - "compute_loss", - "focal_loss", - "mse_loss", + "BBoxLoss", + "ClassificationLossConfig", + "DetectionLossConfig", + "FocalLoss", + "FocalLossConfig", + "LossConfig", + "LossFunction", + "MSELoss", + "SizeLossConfig", + "build_loss", ] class SizeLossConfig(BaseConfig): + """Configuration for the bounding box size loss component. + + Attributes + ---------- + weight : float, default=0.1 + The weighting factor applied to the size loss when combining it with + other losses (detection, classification) to form the total training + loss. + """ + weight: float = 0.1 -def bbox_size_loss( - pred_size: torch.Tensor, - gt_size: torch.Tensor, -) -> torch.Tensor: +class BBoxLoss(nn.Module): + """Computes L1 loss for bounding box size regression. + + Calculates the Mean Absolute Error (MAE or L1 loss) between the predicted + size dimensions (`pred`) and the ground truth size dimensions (`gt`). + Crucially, the loss is only computed at locations where the ground truth + size heatmap (`gt`) contains non-zero values (i.e., at the reference points + of actual annotated sound events). This prevents the model from being + penalized for size predictions in background regions. + + The loss is summed over all valid locations and normalized by the number + of valid locations. """ - Bounding box size loss. Only compute loss where there is a bounding box. - """ - gt_size_mask = (gt_size > 0).float() - return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / ( - gt_size_mask.sum() + 1e-5 - ) + + def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor: + """Calculate masked L1 loss for size prediction. + + Parameters + ---------- + pred : torch.Tensor + Predicted size tensor, typically shape `(B, 2, H, W)`, where + channels represent scaled width and height. + gt : torch.Tensor + Ground truth size tensor, same shape as `pred`. Non-zero values + indicate locations and target sizes of actual annotations. + + Returns + ------- + torch.Tensor + Scalar tensor representing the calculated masked L1 loss. + """ + gt_size_mask = (gt > 0).float() + masked_pred = pred * gt_size_mask + loss = F.l1_loss(masked_pred, gt, reduction="sum") + num_pos = gt_size_mask.sum() + 1e-5 + return loss / num_pos class FocalLossConfig(BaseConfig): + """Configuration parameters for the Focal Loss function. + + Attributes + ---------- + beta : float, default=4 + Exponent controlling the down-weighting of easy negative examples. + Higher values increase down-weighting (focus more on hard negatives). + alpha : float, default=2 + Exponent controlling the down-weighting based on prediction confidence. + Higher values focus more on misclassified examples (both positive and + negative). + """ + beta: float = 4 alpha: float = 2 -def focal_loss( - pred: torch.Tensor, - gt: torch.Tensor, - weights: Optional[torch.Tensor] = None, - valid_mask: Optional[torch.Tensor] = None, - eps: float = 1e-5, - beta: float = 4, - alpha: float = 2, -) -> torch.Tensor: - """ - Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints - pred (batch x c x h x w) - gt (batch x c x h x w) +class FocalLoss(nn.Module): + """Focal Loss implementation, adapted from CornerNet. + + Addresses class imbalance in dense object detection/classification tasks by + down-weighting the loss contribution from easy examples (both positive and + negative), allowing the model to focus more on hard-to-classify examples. + + Parameters + ---------- + eps : float, default=1e-5 + Small epsilon value added for numerical stability. + beta : float, default=4 + Exponent focusing on hard negative examples (modulates `(1-gt)^beta`). + alpha : float, default=2 + Exponent focusing on misclassified examples (modulates `(1-p)^alpha` + for positives and `p^alpha` for negatives). + class_weights : torch.Tensor, optional + Optional tensor containing weights for each class (applied to positive + loss). Shape should be broadcastable to the channel dimension of the + input tensors. + mask_zero : bool, default=False + If True, ignores loss contributions from spatial locations where the + ground truth `gt` tensor is zero across *all* channels. Useful for + classification heatmaps where some areas might have no assigned class. + + References + ---------- + - Lin, T. Y., et al. "Focal loss for dense object detection." ICCV 2017. + - Law, H., & Deng, J. "CornerNet: Detecting Objects as Paired Keypoints." + ECCV 2018. """ - pos_inds = gt.eq(1).float() - neg_inds = gt.lt(1).float() + def __init__( + self, + eps: float = 1e-5, + beta: float = 4, + alpha: float = 2, + class_weights: Optional[torch.Tensor] = None, + mask_zero: bool = False, + ): + super().__init__() + self.class_weights = class_weights + self.eps = eps + self.beta = beta + self.alpha = alpha + self.mask_zero = mask_zero - pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds - neg_loss = ( - torch.log(1 - pred + eps) - * torch.pow(pred, alpha) - * torch.pow(1 - gt, beta) - * neg_inds - ) + def forward( + self, + pred: torch.Tensor, + gt: torch.Tensor, + ) -> torch.Tensor: + """Compute the Focal Loss. - if weights is not None: - pos_loss = pos_loss * torch.tensor(weights) - # neg_loss = neg_loss*weights + Parameters + ---------- + pred : torch.Tensor + Predicted probabilities or logits (typically sigmoid output for + detection, or softmax/sigmoid for classification). Must be in the + range [0, 1] after potential activation. Shape `(B, C, H, W)`. + gt : torch.Tensor + Ground truth heatmap tensor. Shape `(B, C, H, W)`. Values typically + represent target probabilities (e.g., Gaussian peaks for detection, + one-hot encoding or smoothed labels for classification). For the + adapted CornerNet loss, `gt=1` indicates a positive location, and + values `< 1` indicate negative locations (with potential Gaussian + weighting `(1-gt)^beta` for negatives near positives). - if valid_mask is not None: - pos_loss = pos_loss * valid_mask - neg_loss = neg_loss * valid_mask + Returns + ------- + torch.Tensor + Scalar tensor representing the computed focal loss, normalized by + the number of positive locations. + """ - pos_loss = pos_loss.sum() - neg_loss = neg_loss.sum() + pos_inds = gt.eq(1).float() + neg_inds = gt.lt(1).float() - num_pos = pos_inds.float().sum() - if num_pos == 0: - loss = -neg_loss - else: - loss = -(pos_loss + neg_loss) / num_pos - return loss + pos_loss = ( + torch.log(pred + self.eps) + * torch.pow(1 - pred, self.alpha) + * pos_inds + ) + neg_loss = ( + torch.log(1 - pred + self.eps) + * torch.pow(pred, self.alpha) + * torch.pow(1 - gt, self.beta) + * neg_inds + ) + + if self.class_weights is not None: + pos_loss = pos_loss * torch.tensor(self.class_weights) + + if self.mask_zero: + valid_mask = gt.any(dim=1, keepdim=True).float() + pos_loss = pos_loss * valid_mask + neg_loss = neg_loss * valid_mask + + pos_loss = pos_loss.sum() + neg_loss = neg_loss.sum() + + num_pos = pos_inds.float().sum() + if num_pos == 0: + loss = -neg_loss + else: + loss = -(pos_loss + neg_loss) / num_pos + return loss -def mse_loss( - pred: torch.Tensor, - gt: torch.Tensor, - valid_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: +class MSELoss(nn.Module): + """Mean Squared Error (MSE) Loss module. + + Calculates the mean squared difference between predictions and ground + truth. Optionally masks contributions where the ground truth is zero across + channels. + + Parameters + ---------- + mask_zero : bool, default=False + If True, calculates the loss only over spatial locations (H, W) where + at least one channel in the ground truth `gt` tensor is non-zero. The + loss is then averaged over these valid locations. If False (default), + the standard MSE over all elements is computed. """ - Mean squared error loss. - """ - if valid_mask is None: - op = ((gt - pred) ** 2).mean() - else: - op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum() - return op + + def __init__(self, mask_zero: bool = False): + super().__init__() + self.mask_zero = mask_zero + + def forward( + self, + pred: torch.Tensor, + gt: torch.Tensor, + ) -> torch.Tensor: + """Compute the Mean Squared Error loss. + + Parameters + ---------- + pred : torch.Tensor + Predicted tensor, shape `(B, C, H, W)`. + gt : torch.Tensor + Ground truth tensor, shape `(B, C, H, W)`. + + Returns + ------- + torch.Tensor + Scalar tensor representing the calculated MSE loss. + """ + if not self.mask_zero: + return ((gt - pred) ** 2).mean() + + valid_mask = gt.any(dim=1, keepdim=True).float() + return (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum() class DetectionLossConfig(BaseConfig): + """Configuration for the detection loss component. + + Attributes + ---------- + weight : float, default=1.0 + Weighting factor for the detection loss in the combined total loss. + focal : FocalLossConfig + Configuration for the Focal Loss used for detection. Defaults to + standard Focal Loss parameters (`alpha=2`, `beta=4`). + """ + weight: float = 1.0 focal: FocalLossConfig = Field(default_factory=FocalLossConfig) class ClassificationLossConfig(BaseConfig): + """Configuration for the classification loss component. + + Attributes + ---------- + weight : float, default=2.0 + Weighting factor for the classification loss in the combined total loss. + focal : FocalLossConfig + Configuration for the Focal Loss used for classification. Defaults to + standard Focal Loss parameters (`alpha=2`, `beta=4`). + """ + weight: float = 2.0 focal: FocalLossConfig = Field(default_factory=FocalLossConfig) - class_weights: Optional[list[float]] = None class LossConfig(BaseConfig): + """Aggregated configuration for all loss components. + + Defines the configuration and weighting for detection, size regression, + and classification losses used in the main `LossFunction`. + + Attributes + ---------- + detection : DetectionLossConfig + Configuration for the detection loss (Focal Loss). + size : SizeLossConfig + Configuration for the size regression loss (L1 loss). + classification : ClassificationLossConfig + Configuration for the classification loss (Focal Loss). + """ + detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig) size: SizeLossConfig = Field(default_factory=SizeLossConfig) classification: ClassificationLossConfig = Field( @@ -117,50 +325,157 @@ class LossConfig(BaseConfig): ) -class Losses(NamedTuple): - detection: torch.Tensor - size: torch.Tensor - classification: torch.Tensor - total: torch.Tensor +class LossFunction(nn.Module, LossProtocol): + """Computes the combined training loss for the BatDetect2 model. + + Aggregates individual loss functions for detection, size regression, and + classification tasks. Calculates each component loss based on model outputs + and ground truth targets, applies configured weights, and sums them to get + the final total loss used for optimization. Also returns individual + components for monitoring. + + Parameters + ---------- + size_loss : nn.Module + Instantiated loss module for size regression (e.g., `BBoxLoss`). + detection_loss : nn.Module + Instantiated loss module for detection (e.g., `FocalLoss`). + classification_loss : nn.Module + Instantiated loss module for classification (e.g., `FocalLoss`). + size_weight : float, default=0.1 + Weighting factor for the size loss component. + detection_weight : float, default=1.0 + Weighting factor for the detection loss component. + classification_weight : float, default=2.0 + Weighting factor for the classification loss component. + + Attributes + ---------- + size_loss_fn : nn.Module + detection_loss_fn : nn.Module + classification_loss_fn : nn.Module + size_weight : float + detection_weight : float + classification_weight : float + """ + + def __init__( + self, + size_loss: nn.Module, + detection_loss: nn.Module, + classification_loss: nn.Module, + size_weight: float = 0.1, + detection_weight: float = 1.0, + classification_weight: float = 2.0, + ): + super().__init__() + self.size_loss_fn = size_loss + self.detection_loss_fn = detection_loss + self.classification_loss_fn = classification_loss + + self.size_weight = size_weight + self.detection_weight = detection_weight + self.classification_weight = classification_weight + + def forward( + self, + pred: ModelOutput, + gt: TrainExample, + ) -> Losses: + """Calculate the combined loss and individual components. + + Parameters + ---------- + pred: ModelOutput + A NamedTuple containing the model's prediction tensors for the + batch: `detection_probs`, `size_preds`, `class_probs`. + gt: TrainExample + A structure containing the ground truth targets for the batch, + expected to have attributes like `detection_heatmap`, + `size_heatmap`, and `class_heatmap` (as `torch.Tensor`). + + Returns + ------- + Losses + A NamedTuple containing the scalar loss values for detection, size, + classification, and the total weighted loss. + """ + size_loss = self.size_loss_fn(pred.size_preds, gt.size_heatmap) + detection_loss = self.detection_loss_fn( + pred.detection_probs, + gt.detection_heatmap, + ) + classification_loss = self.classification_loss_fn( + pred.class_probs, + gt.class_heatmap, + ) + total_loss = ( + size_loss * self.size_weight + + classification_loss * self.classification_weight + + detection_loss * self.detection_weight + ) + return Losses( + detection=detection_loss, + size=size_loss, + classification=classification_loss, + total=total_loss, + ) -def compute_loss( - batch: TrainExample, - outputs: ModelOutput, - conf: LossConfig, - class_weights: Optional[torch.Tensor] = None, -) -> Losses: - detection_loss = focal_loss( - outputs.detection_probs, - batch.detection_heatmap, - beta=conf.detection.focal.beta, - alpha=conf.detection.focal.alpha, +def build_loss( + config: Optional[LossConfig] = None, + class_weights: Optional[np.ndarray] = None, +) -> nn.Module: + """Factory function to build the main LossFunction from configuration. + + Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based + on the provided `LossConfig` (or defaults) and optional `class_weights`, + then assembles them into the main `LossFunction` module with the specified + component weights. + + Parameters + ---------- + config : LossConfig, optional + Configuration object defining weights and parameters (e.g., alpha, beta + for Focal Loss) for each loss component. If None, default settings + from `LossConfig` and its nested configs are used. + class_weights : np.ndarray, optional + An array of weights for each specific class, used to adjust the + classification loss (typically Focal Loss). If provided, this overrides + any `class_weights` specified within `config.classification`. If None, + weights from the config (or default of equal weights) are used. + + Returns + ------- + LossFunction + An initialized `LossFunction` module ready for training. + """ + config = config or LossConfig() + + class_weights_tensor = ( + torch.tensor(class_weights) if class_weights else None ) - size_loss = bbox_size_loss( - outputs.size_preds, - batch.size_heatmap, + detection_loss_fn = FocalLoss( + beta=config.detection.focal.beta, + alpha=config.detection.focal.alpha, + mask_zero=False, ) - valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float() - classification_loss = focal_loss( - outputs.class_probs, - batch.class_heatmap, - weights=class_weights, - valid_mask=valid_mask, - beta=conf.classification.focal.beta, - alpha=conf.classification.focal.alpha, + classification_loss_fn = FocalLoss( + beta=config.classification.focal.beta, + alpha=config.classification.focal.alpha, + class_weights=class_weights_tensor, + mask_zero=True, ) - total = ( - detection_loss * conf.detection.weight - + size_loss * conf.size.weight - + classification_loss * conf.classification.weight - ) + size_loss_fn = BBoxLoss() - return Losses( - detection=detection_loss, - size=size_loss, - classification=classification_loss, - total=total, + return LossFunction( + size_loss=size_loss_fn, + classification_loss=classification_loss_fn, + detection_loss=detection_loss_fn, + size_weight=config.size.weight, + detection_weight=config.detection.weight, + classification_weight=config.classification.weight, ) diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index b90770f..14372ae 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -139,6 +139,7 @@ def _save_xr_dataset_to_file( dataset.to_netcdf( path, encoding={ + "audio": {"zlib": True}, "spectrogram": {"zlib": True}, "size": {"zlib": True}, "class": {"zlib": True}, diff --git a/batdetect2/train/types.py b/batdetect2/train/types.py index 5bc071b..1be4e3e 100644 --- a/batdetect2/train/types.py +++ b/batdetect2/train/types.py @@ -1,12 +1,17 @@ -from typing import Callable, NamedTuple +from typing import Callable, NamedTuple, Protocol, Tuple +import torch import xarray as xr from soundevent import data +from batdetect2.models import ModelOutput + __all__ = [ "Heatmaps", "ClipLabeller", "Augmentation", + "LossProtocol", + "TrainExample", ] @@ -46,3 +51,50 @@ steps, and returns the final `Heatmaps` used for model training. Augmentation = Callable[[xr.Dataset], xr.Dataset] +class TrainExample(NamedTuple): + spec: torch.Tensor + detection_heatmap: torch.Tensor + class_heatmap: torch.Tensor + size_heatmap: torch.Tensor + idx: torch.Tensor + start_time: float + end_time: float + + +class Losses(NamedTuple): + """Structure to hold the computed loss values. + + Allows returning individual loss components along with the total weighted + loss for monitoring and analysis during training. + + Attributes + ---------- + detection : torch.Tensor + Scalar tensor representing the calculated detection loss component + (before weighting). + size : torch.Tensor + Scalar tensor representing the calculated size regression loss component + (before weighting). + classification : torch.Tensor + Scalar tensor representing the calculated classification loss component + (before weighting). + total : torch.Tensor + Scalar tensor representing the final combined loss, computed as the + weighted sum of the detection, size, and classification components. + This is the value typically used for backpropagation. + """ + + detection: torch.Tensor + size: torch.Tensor + classification: torch.Tensor + total: torch.Tensor + + +class LossProtocol(Protocol): + def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ... + + +class ClipperProtocol(Protocol): + def extract_clip( + self, example: xr.Dataset + ) -> Tuple[xr.Dataset, float, float]: ... diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py index bda228b..4143c3c 100644 --- a/tests/test_targets/test_classes.py +++ b/tests/test_targets/test_classes.py @@ -24,17 +24,6 @@ from batdetect2.targets.classes import ( from batdetect2.targets.terms import TagInfo, TermRegistry -@pytest.fixture -def sample_term_registry() -> TermRegistry: - """Fixture for a sample TermRegistry.""" - registry = TermRegistry() - registry.add_custom_term("order") - registry.add_custom_term("species") - registry.add_custom_term("call_type") - registry.add_custom_term("quality") - return registry - - @pytest.fixture def sample_annotation( sound_event: data.SoundEvent, diff --git a/tests/test_train/test_clips.py b/tests/test_train/test_clips.py new file mode 100644 index 0000000..95ac37e --- /dev/null +++ b/tests/test_train/test_clips.py @@ -0,0 +1,466 @@ +import numpy as np +import pytest +import xarray as xr + +from batdetect2.train.clips import ( + Clipper, + _compute_expected_width, + select_subclip, +) + +AUDIO_SAMPLERATE = 48000 + +SPEC_SAMPLERATE = 100 +SPEC_FREQS = 64 +CLIP_DURATION = 0.5 + + +CLIP_WIDTH_SPEC = int(np.floor(CLIP_DURATION * SPEC_SAMPLERATE)) +CLIP_WIDTH_AUDIO = int(np.floor(CLIP_DURATION * AUDIO_SAMPLERATE)) +MAX_EMPTY = 0.2 + + +def create_test_dataset( + duration_sec: float, + spec_samplerate: int = SPEC_SAMPLERATE, + audio_samplerate: int = AUDIO_SAMPLERATE, + num_freqs: int = SPEC_FREQS, + start_time: float = 0.0, +) -> xr.Dataset: + """Creates a sample xr.Dataset for testing.""" + time_step = 1 / spec_samplerate + audio_time_step = 1 / audio_samplerate + + times = np.arange(start_time, start_time + duration_sec, step=time_step) + freqs = np.linspace(0, audio_samplerate / 2, num_freqs) + audio_times = np.arange( + start_time, + start_time + duration_sec, + step=audio_time_step, + ) + + num_time_steps = len(times) + num_audio_samples = len(audio_times) + spec_shape = (num_freqs, num_time_steps) + + spectrogram_data = np.arange(num_time_steps).reshape(1, -1) * np.ones( + (num_freqs, 1) + ) + + spectrogram = xr.DataArray( + spectrogram_data.astype(np.float32), + coords=[("frequency", freqs), ("time", times)], + name="spectrogram", + ) + + detection = xr.DataArray( + np.ones(spec_shape, dtype=np.float32) * 0.5, + coords=spectrogram.coords, + name="detection", + ) + + classes = xr.DataArray( + np.ones((3, *spec_shape), dtype=np.float32), + coords=[ + ("category", ["A", "B", "C"]), + ("frequency", freqs), + ("time", times), + ], + name="class", + ) + + size = xr.DataArray( + np.ones((2, *spec_shape), dtype=np.float32), + coords=[ + ("dimension", ["height", "width"]), + ("frequency", freqs), + ("time", times), + ], + name="size", + ) + + audio_data = np.arange(num_audio_samples) + audio = xr.DataArray( + audio_data.astype(np.float32), + coords=[("audio_time", audio_times)], + name="audio", + ) + + metadata = xr.DataArray([1, 2, 3], dims=["other_dim"], name="metadata") + + return xr.Dataset( + { + "audio": audio, + "spectrogram": spectrogram, + "detection": detection, + "class": classes, + "size": size, + "metadata": metadata, + } + ).assign_attrs( + samplerate=audio_samplerate, + spec_samplerate=spec_samplerate, + ) + + +@pytest.fixture +def long_dataset() -> xr.Dataset: + """Dataset longer than the clip duration.""" + return create_test_dataset(duration_sec=2.0) + + +@pytest.fixture +def short_dataset() -> xr.Dataset: + """Dataset shorter than the clip duration.""" + return create_test_dataset(duration_sec=0.3) + + +@pytest.fixture +def exact_dataset() -> xr.Dataset: + """Dataset exactly the clip duration.""" + return create_test_dataset(duration_sec=CLIP_DURATION - 1e-9) + + +@pytest.fixture +def offset_dataset() -> xr.Dataset: + """Dataset starting at a non-zero time.""" + return create_test_dataset(duration_sec=1.0, start_time=0.5) + + +def test_select_subclip_within_bounds(long_dataset): + start_time = 0.5 + subclip = select_subclip( + long_dataset, span=CLIP_DURATION, start=start_time, dim="time" + ) + expected_width = _compute_expected_width( + long_dataset, CLIP_DURATION, "time" + ) + + assert "time" in subclip.dims + assert subclip.dims["time"] == expected_width + assert subclip.spectrogram.dims == ("frequency", "time") + assert subclip.spectrogram.shape == (SPEC_FREQS, expected_width) + assert subclip.detection.shape == (SPEC_FREQS, expected_width) + assert subclip["class"].shape == (3, SPEC_FREQS, expected_width) + assert subclip.size.shape == (2, SPEC_FREQS, expected_width) + assert subclip.time.min() >= start_time + assert ( + subclip.time.max() <= start_time + CLIP_DURATION + 1 / SPEC_SAMPLERATE + ) + + assert "metadata" in subclip + xr.testing.assert_equal(subclip.metadata, long_dataset.metadata) + + +def test_select_subclip_pad_start(long_dataset): + start_time = -0.1 + subclip = select_subclip( + long_dataset, span=CLIP_DURATION, start=start_time, dim="time" + ) + expected_width = _compute_expected_width( + long_dataset, CLIP_DURATION, "time" + ) + step = 1 / SPEC_SAMPLERATE + expected_pad_samples = int(np.floor(abs(start_time) / step)) + + assert subclip.dims["time"] == expected_width + assert subclip.spectrogram.shape[1] == expected_width + + assert np.all( + subclip.spectrogram.isel(time=slice(0, expected_pad_samples)) == 0 + ) + + assert np.any( + subclip.spectrogram.isel(time=slice(expected_pad_samples, None)) != 0 + ) + assert subclip.time.min() >= start_time + assert subclip.time.max() < start_time + CLIP_DURATION + step + + +def test_select_subclip_pad_end(long_dataset): + original_duration = long_dataset.time.max() - long_dataset.time.min() + start_time = original_duration - 0.1 + subclip = select_subclip( + long_dataset, span=CLIP_DURATION, start=start_time, dim="time" + ) + expected_width = _compute_expected_width( + long_dataset, CLIP_DURATION, "time" + ) + step = 1 / SPEC_SAMPLERATE + original_width = long_dataset.dims["time"] + expected_pad_samples = expected_width - ( + original_width - int(np.floor(start_time / step)) + ) + + assert subclip.sizes["time"] == expected_width + assert subclip.spectrogram.shape[1] == expected_width + + assert np.all( + subclip.spectrogram.isel( + time=slice(expected_width - expected_pad_samples, None) + ) + == 0 + ) + + assert np.any( + subclip.spectrogram.isel( + time=slice(0, expected_width - expected_pad_samples) + ) + != 0 + ) + assert subclip.time.min() >= start_time + assert subclip.time.max() < start_time + CLIP_DURATION + step + + +def test_select_subclip_pad_both_short_dataset(short_dataset): + start_time = -0.1 + subclip = select_subclip( + short_dataset, span=CLIP_DURATION, start=start_time, dim="time" + ) + expected_width = _compute_expected_width( + short_dataset, CLIP_DURATION, "time" + ) + step = 1 / SPEC_SAMPLERATE + + assert subclip.dims["time"] == expected_width + assert subclip.spectrogram.shape[1] == expected_width + + assert subclip.spectrogram.coords["time"][0] == pytest.approx( + start_time, + abs=step, + ) + assert subclip.spectrogram.coords["time"][-1] == pytest.approx( + start_time + CLIP_DURATION - step, + abs=2 * step, + ) + + +def test_select_subclip_width_consistency(long_dataset): + expected_width = _compute_expected_width( + long_dataset, CLIP_DURATION, "time" + ) + step = 1 / SPEC_SAMPLERATE + + subclip_aligned = select_subclip( + long_dataset.copy(deep=True), + span=CLIP_DURATION, + start=5 * step, + dim="time", + ) + + subclip_offset = select_subclip( + long_dataset.copy(deep=True), + span=CLIP_DURATION, + start=5.3 * step, + dim="time", + ) + + assert subclip_aligned.sizes["time"] == expected_width + assert subclip_offset.sizes["time"] == expected_width + assert subclip_aligned.spectrogram.shape[1] == expected_width + assert subclip_offset.spectrogram.shape[1] == expected_width + + +def test_select_subclip_different_dimension(long_dataset): + freq_coords = long_dataset.frequency.values + freq_min, freq_max = freq_coords.min(), freq_coords.max() + freq_span = (freq_max - freq_min) / 2 + start_freq = freq_min + freq_span / 2 + + subclip = select_subclip( + long_dataset, span=freq_span, start=start_freq, dim="frequency" + ) + + assert "frequency" in subclip.dims + assert subclip.spectrogram.shape[0] < long_dataset.spectrogram.shape[0] + assert subclip.detection.shape[0] < long_dataset.detection.shape[0] + assert subclip["class"].shape[1] < long_dataset["class"].shape[1] + assert subclip.size.shape[1] < long_dataset.size.shape[1] + + assert subclip.dims["time"] == long_dataset.dims["time"] + assert subclip.spectrogram.shape[1] == long_dataset.spectrogram.shape[1] + + xr.testing.assert_equal(subclip.audio, long_dataset.audio) + assert subclip.dims["audio_time"] == long_dataset.dims["audio_time"] + + +def test_select_subclip_fill_value(short_dataset): + fill_value = -999.0 + subclip = select_subclip( + short_dataset, + span=CLIP_DURATION, + start=0, + dim="time", + fill_value=fill_value, + ) + + expected_width = _compute_expected_width( + short_dataset, + CLIP_DURATION, + "time", + ) + + assert subclip.dims["time"] == expected_width + assert np.all(subclip.spectrogram.sel(time=slice(0.3, None)) == fill_value) + + +def test_select_subclip_no_overlap_raises_error(long_dataset): + original_duration = long_dataset.time.max() - long_dataset.time.min() + + with pytest.raises(ValueError, match="does not overlap"): + select_subclip( + long_dataset, + span=CLIP_DURATION, + start=original_duration + 1.0, + dim="time", + ) + + with pytest.raises(ValueError, match="does not overlap"): + select_subclip( + long_dataset, + span=CLIP_DURATION, + start=-1.0 * CLIP_DURATION - 1.0, + dim="time", + ) + + +def test_clipper_non_random(long_dataset, exact_dataset, short_dataset): + clipper = Clipper(duration=CLIP_DURATION, random=False) + + for ds in [long_dataset, exact_dataset, short_dataset]: + clip, _, _ = clipper.extract_clip(ds) + expected_spec_width = _compute_expected_width( + ds, CLIP_DURATION, "time" + ) + expected_audio_width = _compute_expected_width( + ds, CLIP_DURATION, "audio_time" + ) + + assert clip.dims["time"] == expected_spec_width + assert clip.dims["audio_time"] == expected_audio_width + assert clip.spectrogram.shape[1] == expected_spec_width + assert clip.audio.shape[0] == expected_audio_width + + assert clip.time.min() >= -1 / SPEC_SAMPLERATE + assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE + + time_span = clip.time.max() - clip.time.min() + audio_span = clip.audio_time.max() - clip.audio_time.min() + assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE) + assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE) + + +def test_clipper_random(long_dataset): + seed = 42 + np.random.seed(seed) + clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY) + clip1, _, _ = clipper.extract_clip(long_dataset) + + np.random.seed(seed + 1) + clip2, _, _ = clipper.extract_clip(long_dataset) + + expected_spec_width = _compute_expected_width( + long_dataset, CLIP_DURATION, "time" + ) + expected_audio_width = _compute_expected_width( + long_dataset, CLIP_DURATION, "audio_time" + ) + + for clip in [clip1, clip2]: + assert clip.dims["time"] == expected_spec_width + assert clip.dims["audio_time"] == expected_audio_width + assert clip.spectrogram.shape[1] == expected_spec_width + assert clip.audio.shape[0] == expected_audio_width + + assert not np.isclose(clip1.time.min(), clip2.time.min()) + assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min()) + + for clip in [clip1, clip2]: + time_span = clip.time.max() - clip.time.min() + audio_span = clip.audio_time.max() - clip.audio_time.min() + assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE) + assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE) + + max_start_time = ( + (long_dataset.time.max() - long_dataset.time.min()) + - CLIP_DURATION + + MAX_EMPTY + ) + assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE + assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE + + +def test_clipper_random_max_empty_effect(long_dataset): + """Check that max_empty influences the possible start times.""" + seed = 123 + data_duration = long_dataset.time.max() - long_dataset.time.min() + + np.random.seed(seed) + clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0) + max_start_time0 = data_duration - CLIP_DURATION + start_times0 = [] + + for _ in range(20): + clip, _, _ = clipper0.extract_clip(long_dataset) + start_times0.append(clip.time.min().item()) + + assert all( + st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0 + ) + assert any(st > 0.1 for st in start_times0) + + np.random.seed(seed) + clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2) + max_start_time_pos = data_duration - CLIP_DURATION + 0.2 + start_times_pos = [] + for _ in range(20): + clip, _, _ = clipper_pos.extract_clip(long_dataset) + start_times_pos.append(clip.time.min().item()) + assert all( + st <= max_start_time_pos + 1 / SPEC_SAMPLERATE + for st in start_times_pos + ) + + assert any(st > max_start_time0 + 1e-6 for st in start_times_pos) + + +def test_clipper_short_dataset_random(short_dataset): + clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY) + clip, _, _ = clipper.extract_clip(short_dataset) + + expected_spec_width = _compute_expected_width( + short_dataset, CLIP_DURATION, "time" + ) + expected_audio_width = _compute_expected_width( + short_dataset, CLIP_DURATION, "audio_time" + ) + + assert clip.sizes["time"] == expected_spec_width + assert clip.sizes["audio_time"] == expected_audio_width + assert clip["spectrogram"].shape[1] == expected_spec_width + assert clip["audio"].shape[0] == expected_audio_width + + assert np.any(clip.spectrogram == 0) + assert np.any(clip.audio == 0) + + +def test_clipper_exact_dataset_random(exact_dataset): + clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY) + clip, _, _ = clipper.extract_clip(exact_dataset) + + expected_spec_width = _compute_expected_width( + exact_dataset, CLIP_DURATION, "time" + ) + expected_audio_width = _compute_expected_width( + exact_dataset, CLIP_DURATION, "audio_time" + ) + + assert clip.dims["time"] == expected_spec_width + assert clip.dims["audio_time"] == expected_audio_width + assert clip.spectrogram.shape[1] == expected_spec_width + assert clip.audio.shape[0] == expected_audio_width + + time_span = clip.time.max() - clip.time.min() + audio_span = clip.audio_time.max() - clip.audio_time.min() + assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE) + assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE) diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py new file mode 100644 index 0000000..6ce7b71 --- /dev/null +++ b/tests/test_train/test_lightning.py @@ -0,0 +1,56 @@ +from pathlib import Path + +import lightning as L +import torch +import xarray as xr +from soundevent import data + +from batdetect2.models import build_model +from batdetect2.postprocess import build_postprocessor +from batdetect2.preprocess import build_preprocessor +from batdetect2.targets import build_targets +from batdetect2.train.lightning import TrainingModule +from batdetect2.train.losses import build_loss + + +def build_default_module(): + loss = build_loss() + targets = build_targets() + detector = build_model(num_classes=len(targets.class_names)) + preprocessor = build_preprocessor() + postprocessor = build_postprocessor(targets) + return TrainingModule( + detector=detector, + loss=loss, + targets=targets, + preprocessor=preprocessor, + postprocessor=postprocessor, + ) + + +def test_can_initialize_default_module(): + module = build_default_module() + assert isinstance(module, L.LightningModule) + + +def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip): + module = build_default_module() + trainer = L.Trainer() + path = tmp_path / "example.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(path) + + recovered = TrainingModule.load_from_checkpoint(path) + + spec1 = module.preprocessor.preprocess_clip(clip) + spec2 = recovered.preprocessor.preprocess_clip(clip) + + xr.testing.assert_equal(spec1, spec2) + + input1 = torch.tensor([spec1.values]).unsqueeze(0) + input2 = torch.tensor([spec2.values]).unsqueeze(0) + + output1 = module(input1) + output2 = recovered(input2) + + assert output1 == output2