Added clips for random cliping and augmentations

This commit is contained in:
mbsantiago 2025-04-23 23:15:08 +01:00
parent 2396815c13
commit 59bd14bc79
12 changed files with 1317 additions and 427 deletions

View File

@ -1,153 +0,0 @@
from pathlib import Path
from typing import Optional
import lightning as L
import torch
from pydantic import Field
from soundevent import data
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.configs import BaseConfig
from batdetect2.models import (
BackboneConfig,
BBoxHead,
ClassifierHead,
ModelOutput,
build_backbone,
)
from batdetect2.postprocess import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
from batdetect2.targets import (
TargetConfig,
build_decoder,
build_target_encoder,
get_class_names,
)
from batdetect2.train import TrainExample, TrainingConfig, compute_loss
__all__ = [
"DetectorModel",
]
class ModuleConfig(BaseConfig):
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
architecture: BackboneConfig = Field(default_factory=BackboneConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocessing: PostprocessConfig = Field(
default_factory=PostprocessConfig
)
class DetectorModel(L.LightningModule):
config: ModuleConfig
def __init__(
self,
config: Optional[ModuleConfig] = None,
):
super().__init__()
self.config = config or ModuleConfig()
self.save_hyperparameters()
self.backbone = build_model_backbone(self.config.architecture)
self.classifier = ClassifierHead(
num_classes=len(self.config.targets.classes),
in_channels=self.backbone.out_channels,
)
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
conf = self.config.train.loss.classification
self.class_weights = (
torch.tensor(conf.class_weights) if conf.class_weights else None
)
# Training targets
self.class_names = get_class_names(self.config.targets.classes)
self.encoder = build_target_encoder(
self.config.targets.classes,
replacement_rules=self.config.targets.replace,
)
self.decoder = build_decoder(self.config.targets.classes)
self.example_input_array = torch.randn([1, 1, 128, 512])
def forward(self, spec: torch.Tensor) -> ModelOutput:
features = self.backbone(spec)
detection_probs, classification_probs = self.classifier(features)
size_preds = self.bbox(features)
return ModelOutput(
detection_probs=detection_probs,
size_preds=size_preds,
class_probs=classification_probs,
features=features,
)
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
losses = compute_loss(
batch,
outputs,
conf=self.config.train.loss,
class_weights=self.class_weights,
)
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
self.log("train/loss/detection", losses.total, logger=True)
self.log("train/loss/size", losses.total, logger=True)
self.log("train/loss/classification", losses.total, logger=True)
return losses.total
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
outputs = self.forward(batch.spec)
losses = compute_loss(
batch,
outputs,
conf=self.config.train.loss,
class_weights=self.class_weights,
)
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
self.log("val/loss/detection", losses.total, logger=True)
self.log("val/loss/size", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True)
def on_validation_epoch_end(self) -> None:
self.validation_predictions.clear()
def configure_optimizers(self):
conf = self.config.train.optimizer
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
return [optimizer], [scheduler]
def process_clip(
self,
clip: data.Clip,
audio_dir: Optional[Path] = None,
) -> data.ClipPrediction:
spec = preprocess_audio_clip(
clip,
config=self.config.preprocessing,
audio_dir=audio_dir,
)
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
outputs = self.forward(tensor)
return postprocess_model_outputs(
outputs,
clips=[clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]

View File

@ -11,18 +11,19 @@ from batdetect2.train.augmentations import (
mask_time,
mix_examples,
scale_volume,
select_subclip,
warp_spectrogram,
)
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
SubclipConfig,
TrainExample,
get_preprocessed_files,
list_preprocessed_files,
)
from batdetect2.train.labels import load_label_config
from batdetect2.train.losses import compute_loss
from batdetect2.train.losses import LossFunction, build_loss
from batdetect2.train.preprocess import (
generate_train_example,
preprocess_annotations,
@ -34,6 +35,8 @@ __all__ = [
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"LabeledDataset",
"LossFunction",
"RandomExampleSource",
"SubclipConfig",
"TimeMaskAugmentationConfig",
"TrainExample",
@ -43,9 +46,11 @@ __all__ = [
"WarpAugmentationConfig",
"add_echo",
"build_augmentations",
"compute_loss",
"build_clipper",
"build_loss",
"generate_train_example",
"get_preprocessed_files",
"list_preprocessed_files",
"load_label_config",
"load_train_config",
"load_trainer_config",
"mask_frequency",
@ -56,5 +61,4 @@ __all__ = [
"select_subclip",
"train",
"warp_spectrogram",
"load_label_config",
]

View File

@ -37,24 +37,24 @@ from batdetect2.train.types import Augmentation
from batdetect2.utils.arrays import adjust_width
__all__ = [
"AugmentationConfig",
"AugmentationsConfig",
"load_augmentation_config",
"build_augmentations",
"select_subclip",
"mix_examples",
"add_echo",
"scale_volume",
"warp_spectrogram",
"mask_time",
"mask_frequency",
"MixAugmentationConfig",
"DEFAULT_AUGMENTATION_CONFIG",
"EchoAugmentationConfig",
"ExampleSource",
"FrequencyMaskAugmentationConfig",
"MixAugmentationConfig",
"TimeMaskAugmentationConfig",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"TimeMaskAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"AugmentationConfig",
"ExampleSource",
"add_echo",
"build_augmentations",
"load_augmentation_config",
"mask_frequency",
"mask_time",
"mix_examples",
"scale_volume",
"warp_spectrogram",
]
ExampleSource = Callable[[], xr.Dataset]
@ -64,92 +64,6 @@ Used by the `mix_examples` augmentation to fetch another example to mix with.
"""
def select_subclip(
example: xr.Dataset,
start_time: Optional[float] = None,
duration: Optional[float] = None,
width: Optional[int] = None,
random: bool = False,
) -> xr.Dataset:
"""Extract a sub-clip (time segment) from a training example dataset.
Selects a portion of the 'time' dimension from all relevant DataArrays
(`audio`, `spectrogram`, `detection`, `class`, `size`) within the example
Dataset. The segment can be defined by a fixed start time and
duration/width, or a random start time can be chosen.
Parameters
----------
example : xr.Dataset
The input training example containing 'audio', 'spectrogram', and
target heatmaps, all with compatible 'time' (or 'audio_time')
coordinates.
start_time : float, optional
Desired start time (seconds) of the subclip. If None and `random` is
False, starts from the beginning of the example. If None and `random`
is True, a random start time is chosen.
duration : float, optional
Desired duration (seconds) of the subclip. Either `duration` or `width`
must be provided.
width : int, optional
Desired width (number of time bins) of the subclip's
spectrogram/heatmaps. Either `duration` or `width` must be provided. If
both are given, `duration` takes precedence.
random : bool, default=False
If True and `start_time` is None, selects a random start time ensuring
the subclip fits within the original example's duration.
Returns
-------
xr.Dataset
A new dataset containing only the selected time segment. Coordinates
are adjusted accordingly. Returns the original example if the requested
subclip cannot be extracted (e.g., duration too long).
Raises
------
ValueError
If neither `duration` nor `width` is provided, or if time coordinates
are missing or invalid.
"""
step = arrays.get_dim_step(example, "time") # type: ignore
start, stop = arrays.get_dim_range(example, "time") # type: ignore
if width is None:
if duration is None:
raise ValueError("Either duration or width must be provided")
width = int(np.floor(duration / step))
if duration is None:
duration = width * step
if start_time is None:
if random:
start_time = np.random.uniform(start, max(stop - duration, start))
else:
start_time = start
if start_time + duration > stop:
return example
start_index = arrays.get_coord_index(
example, # type: ignore
"time",
start_time,
)
end_index = start_index + width - 1
start_time = example.time.values[start_index]
end_time = example.time.values[end_index]
return example.sel(
time=slice(start_time, end_time),
audio_time=slice(start_time, end_time + step),
)
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
@ -878,9 +792,21 @@ def build_augmentation_from_config(
)
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
steps=[
MixAugmentationConfig(),
EchoAugmentationConfig(),
VolumeAugmentationConfig(),
WarpAugmentationConfig(),
TimeMaskAugmentationConfig(),
FrequencyMaskAugmentationConfig(),
]
)
def build_augmentations(
config: AugmentationsConfig,
preprocessor: PreprocessorProtocol,
config: Optional[AugmentationsConfig] = None,
example_source: Optional[ExampleSource] = None,
) -> Augmentation:
"""Build a composite augmentation pipeline function from configuration.
@ -915,6 +841,8 @@ def build_augmentations(
NotImplementedError
If an unknown `augmentation_type` is encountered in `config.steps`.
"""
config = config or DEFAULT_AUGMENTATION_CONFIG
augmentations = []
for step_config in config.steps:

184
batdetect2/train/clips.py Normal file
View File

@ -0,0 +1,184 @@
from typing import Optional, Tuple, Union
import numpy as np
import xarray as xr
from soundevent import arrays
from batdetect2.configs import BaseConfig
from batdetect2.train.types import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_MAX_EMPTY_CLIP = 0.1
class ClipperConfig(BaseConfig):
duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
class Clipper(ClipperProtocol):
def __init__(
self,
duration: float = 0.5,
max_empty: float = 0.2,
random: bool = True,
):
self.duration = duration
self.random = random
self.max_empty = max_empty
def extract_clip(
self, example: xr.Dataset
) -> Tuple[xr.Dataset, float, float]:
step = arrays.get_dim_step(
example.spectrogram,
dim=arrays.Dimensions.time.value,
)
duration = (
arrays.get_dim_width(
example.spectrogram,
dim=arrays.Dimensions.time.value,
)
+ step
)
start_time = 0
if self.random:
start_time = np.random.uniform(
-self.max_empty,
duration - self.duration + self.max_empty,
)
subclip = select_subclip(
example,
start=start_time,
span=self.duration,
dim="time",
)
return (
select_subclip(
subclip,
start=start_time,
span=self.duration,
dim="audio_time",
),
start_time,
start_time + self.duration,
)
def build_clipper(config: Optional[ClipperConfig] = None) -> ClipperProtocol:
config = config or ClipperConfig()
return Clipper(
duration=config.duration,
max_empty=config.max_empty,
random=config.random,
)
def select_subclip(
dataset: xr.Dataset,
span: float,
start: float,
fill_value: float = 0,
dim: str = "time",
) -> xr.Dataset:
width = _compute_expected_width(
dataset, # type: ignore
span,
dim=dim,
)
coord = dataset.coords[dim]
if len(coord) == width:
return dataset
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
coord, start, span
)
data_vars = {}
for name, data_array in dataset.data_vars.items():
if dim not in data_array.dims:
data_vars[name] = data_array
continue
if width == data_array.sizes[dim]:
data_vars[name] = data_array
continue
sliced = data_array.isel({dim: dim_slice}).data
if start_pad > 0 or end_pad > 0:
padding = [
[0, 0] if other_dim != dim else [start_pad, end_pad]
for other_dim in data_array.dims
]
sliced = np.pad(sliced, padding, constant_values=fill_value)
data_vars[name] = xr.DataArray(
data=sliced,
dims=data_array.dims,
coords={**data_array.coords, dim: new_coords},
attrs=data_array.attrs,
)
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
def _extract_coordinate(
coord: xr.DataArray,
start: float,
span: float,
) -> Tuple[xr.Variable, int, int, slice]:
step = arrays.get_dim_step(coord, str(coord.name))
current_width = len(coord)
expected_width = int(np.floor(span / step))
coord_start = float(coord[0])
offset = start - coord_start
start_index = int(np.floor(offset / step))
end_index = start_index + expected_width
if start_index > current_width:
raise ValueError("Requested span does not overlap with current range")
if end_index < 0:
raise ValueError("Requested span does not overlap with current range")
corrected_start = float(start_index * step)
corrected_end = float(end_index * step)
start_index_offset = max(0, -start_index)
end_index_offset = max(0, end_index - current_width)
sl = slice(
start_index if start_index >= 0 else None,
end_index if end_index < current_width else None,
)
return (
arrays.create_range_dim(
str(coord.name),
start=corrected_start,
stop=corrected_end,
step=step,
),
start_index_offset,
end_index_offset,
sl,
)
def _compute_expected_width(
array: Union[xr.DataArray, xr.Dataset],
duration: float,
dim: str,
) -> int:
step = arrays.get_dim_step(array, dim) # type: ignore
return int(np.floor(duration / step))

View File

@ -1,6 +1,5 @@
import os
from pathlib import Path
from typing import NamedTuple, Optional, Sequence, Union
from typing import List, Optional, Sequence, Tuple
import numpy as np
import torch
@ -13,28 +12,14 @@ from batdetect2.configs import BaseConfig
from batdetect2.train.augmentations import (
Augmentation,
AugmentationsConfig,
select_subclip,
)
from batdetect2.train.preprocess import PreprocessorProtocol
from batdetect2.utils.tensors import adjust_width
from batdetect2.train.types import ClipperProtocol, TrainExample
__all__ = [
"TrainExample",
"LabeledDataset",
]
PathLike = Union[Path, str, os.PathLike]
class TrainExample(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
idx: torch.Tensor
class SubclipConfig(BaseConfig):
duration: Optional[float] = None
width: int = 512
@ -51,14 +36,12 @@ class DatasetConfig(BaseConfig):
class LabeledDataset(Dataset):
def __init__(
self,
preprocessor: PreprocessorProtocol,
filenames: Sequence[PathLike],
subclip: Optional[SubclipConfig] = None,
filenames: Sequence[data.PathLike],
clipper: ClipperProtocol,
augmentation: Optional[Augmentation] = None,
):
self.preprocessor = preprocessor
self.filenames = filenames
self.subclip = subclip
self.clipper = clipper
self.augmentation = augmentation
def __len__(self):
@ -66,14 +49,7 @@ class LabeledDataset(Dataset):
def __getitem__(self, idx) -> TrainExample:
dataset = self.get_dataset(idx)
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.subclip.duration,
width=self.subclip.width,
random=self.subclip.random,
)
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
if self.augmentation:
dataset = self.augmentation(dataset)
@ -84,37 +60,31 @@ class LabeledDataset(Dataset):
class_heatmap=self.to_tensor(dataset["class"]),
size_heatmap=self.to_tensor(dataset["size"]),
idx=torch.tensor(idx),
start_time=start_time,
end_time=end_time,
)
@classmethod
def from_directory(
cls,
directory: PathLike,
preprocessor: PreprocessorProtocol,
directory: data.PathLike,
clipper: ClipperProtocol,
extension: str = ".nc",
subclip: Optional[SubclipConfig] = None,
augmentation: Optional[Augmentation] = None,
):
return cls(
preprocessor=preprocessor,
filenames=get_preprocessed_files(directory, extension),
subclip=subclip,
filenames=list_preprocessed_files(directory, extension),
clipper=clipper,
augmentation=augmentation,
)
def get_random_example(self) -> xr.Dataset:
def get_random_example(self) -> Tuple[xr.Dataset, float, float]:
idx = np.random.randint(0, len(self))
dataset = self.get_dataset(idx)
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.subclip.duration,
width=self.subclip.width,
random=self.subclip.random,
)
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
return dataset
return dataset, start_time, end_time
def get_dataset(self, idx) -> xr.Dataset:
return xr.open_dataset(self.filenames[idx])
@ -129,16 +99,23 @@ class LabeledDataset(Dataset):
array: xr.DataArray,
dtype=np.float32,
) -> torch.Tensor:
tensor = torch.tensor(array.values.astype(dtype))
if not self.subclip:
return tensor
width = self.subclip.width
return adjust_width(tensor, width)
return torch.tensor(array.values.astype(dtype))
def get_preprocessed_files(
directory: PathLike, extension: str = ".nc"
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc"
) -> Sequence[Path]:
return list(Path(directory).glob(f"*{extension}"))
class RandomExampleSource:
def __init__(self, filenames: List[str], clipper: ClipperProtocol):
self.filenames = filenames
self.clipper = clipper
def __call__(self):
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
dataset = xr.open_dataset(filename)
example, _, _ = self.clipper.extract_clip(dataset)
return example

View File

@ -0,0 +1,71 @@
import lightning as L
import torch
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import (
DetectionModel,
ModelOutput,
)
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainExample
from batdetect2.train.types import LossProtocol
__all__ = [
"TrainingModule",
]
class TrainingModule(L.LightningModule):
def __init__(
self,
detector: DetectionModel,
loss: LossProtocol,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001,
t_max: int = 100,
):
super().__init__()
self.loss = loss
self.detector = detector
self.preprocessor = preprocessor
self.targets = targets
self.postprocessor = postprocessor
self.learning_rate = learning_rate
self.t_max = t_max
self.save_hyperparameters()
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
self.log("train/loss/detection", losses.total, logger=True)
self.log("train/loss/size", losses.total, logger=True)
self.log("train/loss/classification", losses.total, logger=True)
return losses.total
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
self.log("val/loss/detection", losses.total, logger=True)
self.log("val/loss/size", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True)
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
return [optimizer], [scheduler]

View File

@ -1,74 +1,215 @@
from typing import NamedTuple, Optional
"""Loss functions and configurations for training BatDetect2 models.
This module defines the loss functions used to train BatDetect2 models,
including individual loss components for different prediction tasks (detection,
classification, size regression) and a main coordinating loss function that
combines them.
It utilizes common loss types like L1 loss (`BBoxLoss`) for regression and
Focal Loss (`FocalLoss`) for handling class imbalance in dense detection and
classification tasks. Configuration objects (`LossConfig`, etc.) allow for easy
customization of loss parameters and weights via configuration files.
The primary entry points are:
- `LossFunction`: An `nn.Module` that computes the weighted sum of individual
loss components given model outputs and ground truth targets.
- `build_loss`: A factory function that constructs the `LossFunction` based
on a `LossConfig` object.
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
"""
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.types import ModelOutput
from batdetect2.train.dataset import TrainExample
from batdetect2.train.types import Losses, LossProtocol
__all__ = [
"bbox_size_loss",
"compute_loss",
"focal_loss",
"mse_loss",
"BBoxLoss",
"ClassificationLossConfig",
"DetectionLossConfig",
"FocalLoss",
"FocalLossConfig",
"LossConfig",
"LossFunction",
"MSELoss",
"SizeLossConfig",
"build_loss",
]
class SizeLossConfig(BaseConfig):
"""Configuration for the bounding box size loss component.
Attributes
----------
weight : float, default=0.1
The weighting factor applied to the size loss when combining it with
other losses (detection, classification) to form the total training
loss.
"""
weight: float = 0.1
def bbox_size_loss(
pred_size: torch.Tensor,
gt_size: torch.Tensor,
) -> torch.Tensor:
class BBoxLoss(nn.Module):
"""Computes L1 loss for bounding box size regression.
Calculates the Mean Absolute Error (MAE or L1 loss) between the predicted
size dimensions (`pred`) and the ground truth size dimensions (`gt`).
Crucially, the loss is only computed at locations where the ground truth
size heatmap (`gt`) contains non-zero values (i.e., at the reference points
of actual annotated sound events). This prevents the model from being
penalized for size predictions in background regions.
The loss is summed over all valid locations and normalized by the number
of valid locations.
"""
Bounding box size loss. Only compute loss where there is a bounding box.
def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
"""Calculate masked L1 loss for size prediction.
Parameters
----------
pred : torch.Tensor
Predicted size tensor, typically shape `(B, 2, H, W)`, where
channels represent scaled width and height.
gt : torch.Tensor
Ground truth size tensor, same shape as `pred`. Non-zero values
indicate locations and target sizes of actual annotations.
Returns
-------
torch.Tensor
Scalar tensor representing the calculated masked L1 loss.
"""
gt_size_mask = (gt_size > 0).float()
return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
gt_size_mask.sum() + 1e-5
)
gt_size_mask = (gt > 0).float()
masked_pred = pred * gt_size_mask
loss = F.l1_loss(masked_pred, gt, reduction="sum")
num_pos = gt_size_mask.sum() + 1e-5
return loss / num_pos
class FocalLossConfig(BaseConfig):
"""Configuration parameters for the Focal Loss function.
Attributes
----------
beta : float, default=4
Exponent controlling the down-weighting of easy negative examples.
Higher values increase down-weighting (focus more on hard negatives).
alpha : float, default=2
Exponent controlling the down-weighting based on prediction confidence.
Higher values focus more on misclassified examples (both positive and
negative).
"""
beta: float = 4
alpha: float = 2
def focal_loss(
pred: torch.Tensor,
gt: torch.Tensor,
weights: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
class FocalLoss(nn.Module):
"""Focal Loss implementation, adapted from CornerNet.
Addresses class imbalance in dense object detection/classification tasks by
down-weighting the loss contribution from easy examples (both positive and
negative), allowing the model to focus more on hard-to-classify examples.
Parameters
----------
eps : float, default=1e-5
Small epsilon value added for numerical stability.
beta : float, default=4
Exponent focusing on hard negative examples (modulates `(1-gt)^beta`).
alpha : float, default=2
Exponent focusing on misclassified examples (modulates `(1-p)^alpha`
for positives and `p^alpha` for negatives).
class_weights : torch.Tensor, optional
Optional tensor containing weights for each class (applied to positive
loss). Shape should be broadcastable to the channel dimension of the
input tensors.
mask_zero : bool, default=False
If True, ignores loss contributions from spatial locations where the
ground truth `gt` tensor is zero across *all* channels. Useful for
classification heatmaps where some areas might have no assigned class.
References
----------
- Lin, T. Y., et al. "Focal loss for dense object detection." ICCV 2017.
- Law, H., & Deng, J. "CornerNet: Detecting Objects as Paired Keypoints."
ECCV 2018.
"""
def __init__(
self,
eps: float = 1e-5,
beta: float = 4,
alpha: float = 2,
class_weights: Optional[torch.Tensor] = None,
mask_zero: bool = False,
):
super().__init__()
self.class_weights = class_weights
self.eps = eps
self.beta = beta
self.alpha = alpha
self.mask_zero = mask_zero
def forward(
self,
pred: torch.Tensor,
gt: torch.Tensor,
) -> torch.Tensor:
"""
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
pred (batch x c x h x w)
gt (batch x c x h x w)
"""Compute the Focal Loss.
Parameters
----------
pred : torch.Tensor
Predicted probabilities or logits (typically sigmoid output for
detection, or softmax/sigmoid for classification). Must be in the
range [0, 1] after potential activation. Shape `(B, C, H, W)`.
gt : torch.Tensor
Ground truth heatmap tensor. Shape `(B, C, H, W)`. Values typically
represent target probabilities (e.g., Gaussian peaks for detection,
one-hot encoding or smoothed labels for classification). For the
adapted CornerNet loss, `gt=1` indicates a positive location, and
values `< 1` indicate negative locations (with potential Gaussian
weighting `(1-gt)^beta` for negatives near positives).
Returns
-------
torch.Tensor
Scalar tensor representing the computed focal loss, normalized by
the number of positive locations.
"""
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
pos_loss = (
torch.log(pred + self.eps)
* torch.pow(1 - pred, self.alpha)
* pos_inds
)
neg_loss = (
torch.log(1 - pred + eps)
* torch.pow(pred, alpha)
* torch.pow(1 - gt, beta)
torch.log(1 - pred + self.eps)
* torch.pow(pred, self.alpha)
* torch.pow(1 - gt, self.beta)
* neg_inds
)
if weights is not None:
pos_loss = pos_loss * torch.tensor(weights)
# neg_loss = neg_loss*weights
if self.class_weights is not None:
pos_loss = pos_loss * torch.tensor(self.class_weights)
if valid_mask is not None:
if self.mask_zero:
valid_mask = gt.any(dim=1, keepdim=True).float()
pos_loss = pos_loss * valid_mask
neg_loss = neg_loss * valid_mask
@ -83,33 +224,100 @@ def focal_loss(
return loss
def mse_loss(
class MSELoss(nn.Module):
"""Mean Squared Error (MSE) Loss module.
Calculates the mean squared difference between predictions and ground
truth. Optionally masks contributions where the ground truth is zero across
channels.
Parameters
----------
mask_zero : bool, default=False
If True, calculates the loss only over spatial locations (H, W) where
at least one channel in the ground truth `gt` tensor is non-zero. The
loss is then averaged over these valid locations. If False (default),
the standard MSE over all elements is computed.
"""
def __init__(self, mask_zero: bool = False):
super().__init__()
self.mask_zero = mask_zero
def forward(
self,
pred: torch.Tensor,
gt: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute the Mean Squared Error loss.
Parameters
----------
pred : torch.Tensor
Predicted tensor, shape `(B, C, H, W)`.
gt : torch.Tensor
Ground truth tensor, shape `(B, C, H, W)`.
Returns
-------
torch.Tensor
Scalar tensor representing the calculated MSE loss.
"""
Mean squared error loss.
"""
if valid_mask is None:
op = ((gt - pred) ** 2).mean()
else:
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
return op
if not self.mask_zero:
return ((gt - pred) ** 2).mean()
valid_mask = gt.any(dim=1, keepdim=True).float()
return (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
class DetectionLossConfig(BaseConfig):
"""Configuration for the detection loss component.
Attributes
----------
weight : float, default=1.0
Weighting factor for the detection loss in the combined total loss.
focal : FocalLossConfig
Configuration for the Focal Loss used for detection. Defaults to
standard Focal Loss parameters (`alpha=2`, `beta=4`).
"""
weight: float = 1.0
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
class ClassificationLossConfig(BaseConfig):
"""Configuration for the classification loss component.
Attributes
----------
weight : float, default=2.0
Weighting factor for the classification loss in the combined total loss.
focal : FocalLossConfig
Configuration for the Focal Loss used for classification. Defaults to
standard Focal Loss parameters (`alpha=2`, `beta=4`).
"""
weight: float = 2.0
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
class_weights: Optional[list[float]] = None
class LossConfig(BaseConfig):
"""Aggregated configuration for all loss components.
Defines the configuration and weighting for detection, size regression,
and classification losses used in the main `LossFunction`.
Attributes
----------
detection : DetectionLossConfig
Configuration for the detection loss (Focal Loss).
size : SizeLossConfig
Configuration for the size regression loss (L1 loss).
classification : ClassificationLossConfig
Configuration for the classification loss (Focal Loss).
"""
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
classification: ClassificationLossConfig = Field(
@ -117,50 +325,157 @@ class LossConfig(BaseConfig):
)
class Losses(NamedTuple):
detection: torch.Tensor
size: torch.Tensor
classification: torch.Tensor
total: torch.Tensor
class LossFunction(nn.Module, LossProtocol):
"""Computes the combined training loss for the BatDetect2 model.
Aggregates individual loss functions for detection, size regression, and
classification tasks. Calculates each component loss based on model outputs
and ground truth targets, applies configured weights, and sums them to get
the final total loss used for optimization. Also returns individual
components for monitoring.
def compute_loss(
batch: TrainExample,
outputs: ModelOutput,
conf: LossConfig,
class_weights: Optional[torch.Tensor] = None,
Parameters
----------
size_loss : nn.Module
Instantiated loss module for size regression (e.g., `BBoxLoss`).
detection_loss : nn.Module
Instantiated loss module for detection (e.g., `FocalLoss`).
classification_loss : nn.Module
Instantiated loss module for classification (e.g., `FocalLoss`).
size_weight : float, default=0.1
Weighting factor for the size loss component.
detection_weight : float, default=1.0
Weighting factor for the detection loss component.
classification_weight : float, default=2.0
Weighting factor for the classification loss component.
Attributes
----------
size_loss_fn : nn.Module
detection_loss_fn : nn.Module
classification_loss_fn : nn.Module
size_weight : float
detection_weight : float
classification_weight : float
"""
def __init__(
self,
size_loss: nn.Module,
detection_loss: nn.Module,
classification_loss: nn.Module,
size_weight: float = 0.1,
detection_weight: float = 1.0,
classification_weight: float = 2.0,
):
super().__init__()
self.size_loss_fn = size_loss
self.detection_loss_fn = detection_loss
self.classification_loss_fn = classification_loss
self.size_weight = size_weight
self.detection_weight = detection_weight
self.classification_weight = classification_weight
def forward(
self,
pred: ModelOutput,
gt: TrainExample,
) -> Losses:
detection_loss = focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
beta=conf.detection.focal.beta,
alpha=conf.detection.focal.alpha,
)
"""Calculate the combined loss and individual components.
size_loss = bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
Parameters
----------
pred: ModelOutput
A NamedTuple containing the model's prediction tensors for the
batch: `detection_probs`, `size_preds`, `class_probs`.
gt: TrainExample
A structure containing the ground truth targets for the batch,
expected to have attributes like `detection_heatmap`,
`size_heatmap`, and `class_heatmap` (as `torch.Tensor`).
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = focal_loss(
outputs.class_probs,
batch.class_heatmap,
weights=class_weights,
valid_mask=valid_mask,
beta=conf.classification.focal.beta,
alpha=conf.classification.focal.alpha,
Returns
-------
Losses
A NamedTuple containing the scalar loss values for detection, size,
classification, and the total weighted loss.
"""
size_loss = self.size_loss_fn(pred.size_preds, gt.size_heatmap)
detection_loss = self.detection_loss_fn(
pred.detection_probs,
gt.detection_heatmap,
)
total = (
detection_loss * conf.detection.weight
+ size_loss * conf.size.weight
+ classification_loss * conf.classification.weight
classification_loss = self.classification_loss_fn(
pred.class_probs,
gt.class_heatmap,
)
total_loss = (
size_loss * self.size_weight
+ classification_loss * self.classification_weight
+ detection_loss * self.detection_weight
)
return Losses(
detection=detection_loss,
size=size_loss,
classification=classification_loss,
total=total,
total=total_loss,
)
def build_loss(
config: Optional[LossConfig] = None,
class_weights: Optional[np.ndarray] = None,
) -> nn.Module:
"""Factory function to build the main LossFunction from configuration.
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based
on the provided `LossConfig` (or defaults) and optional `class_weights`,
then assembles them into the main `LossFunction` module with the specified
component weights.
Parameters
----------
config : LossConfig, optional
Configuration object defining weights and parameters (e.g., alpha, beta
for Focal Loss) for each loss component. If None, default settings
from `LossConfig` and its nested configs are used.
class_weights : np.ndarray, optional
An array of weights for each specific class, used to adjust the
classification loss (typically Focal Loss). If provided, this overrides
any `class_weights` specified within `config.classification`. If None,
weights from the config (or default of equal weights) are used.
Returns
-------
LossFunction
An initialized `LossFunction` module ready for training.
"""
config = config or LossConfig()
class_weights_tensor = (
torch.tensor(class_weights) if class_weights else None
)
detection_loss_fn = FocalLoss(
beta=config.detection.focal.beta,
alpha=config.detection.focal.alpha,
mask_zero=False,
)
classification_loss_fn = FocalLoss(
beta=config.classification.focal.beta,
alpha=config.classification.focal.alpha,
class_weights=class_weights_tensor,
mask_zero=True,
)
size_loss_fn = BBoxLoss()
return LossFunction(
size_loss=size_loss_fn,
classification_loss=classification_loss_fn,
detection_loss=detection_loss_fn,
size_weight=config.size.weight,
detection_weight=config.detection.weight,
classification_weight=config.classification.weight,
)

View File

@ -139,6 +139,7 @@ def _save_xr_dataset_to_file(
dataset.to_netcdf(
path,
encoding={
"audio": {"zlib": True},
"spectrogram": {"zlib": True},
"size": {"zlib": True},
"class": {"zlib": True},

View File

@ -1,12 +1,17 @@
from typing import Callable, NamedTuple
from typing import Callable, NamedTuple, Protocol, Tuple
import torch
import xarray as xr
from soundevent import data
from batdetect2.models import ModelOutput
__all__ = [
"Heatmaps",
"ClipLabeller",
"Augmentation",
"LossProtocol",
"TrainExample",
]
@ -46,3 +51,50 @@ steps, and returns the final `Heatmaps` used for model training.
Augmentation = Callable[[xr.Dataset], xr.Dataset]
class TrainExample(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
idx: torch.Tensor
start_time: float
end_time: float
class Losses(NamedTuple):
"""Structure to hold the computed loss values.
Allows returning individual loss components along with the total weighted
loss for monitoring and analysis during training.
Attributes
----------
detection : torch.Tensor
Scalar tensor representing the calculated detection loss component
(before weighting).
size : torch.Tensor
Scalar tensor representing the calculated size regression loss component
(before weighting).
classification : torch.Tensor
Scalar tensor representing the calculated classification loss component
(before weighting).
total : torch.Tensor
Scalar tensor representing the final combined loss, computed as the
weighted sum of the detection, size, and classification components.
This is the value typically used for backpropagation.
"""
detection: torch.Tensor
size: torch.Tensor
classification: torch.Tensor
total: torch.Tensor
class LossProtocol(Protocol):
def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ...
class ClipperProtocol(Protocol):
def extract_clip(
self, example: xr.Dataset
) -> Tuple[xr.Dataset, float, float]: ...

View File

@ -24,17 +24,6 @@ from batdetect2.targets.classes import (
from batdetect2.targets.terms import TagInfo, TermRegistry
@pytest.fixture
def sample_term_registry() -> TermRegistry:
"""Fixture for a sample TermRegistry."""
registry = TermRegistry()
registry.add_custom_term("order")
registry.add_custom_term("species")
registry.add_custom_term("call_type")
registry.add_custom_term("quality")
return registry
@pytest.fixture
def sample_annotation(
sound_event: data.SoundEvent,

View File

@ -0,0 +1,466 @@
import numpy as np
import pytest
import xarray as xr
from batdetect2.train.clips import (
Clipper,
_compute_expected_width,
select_subclip,
)
AUDIO_SAMPLERATE = 48000
SPEC_SAMPLERATE = 100
SPEC_FREQS = 64
CLIP_DURATION = 0.5
CLIP_WIDTH_SPEC = int(np.floor(CLIP_DURATION * SPEC_SAMPLERATE))
CLIP_WIDTH_AUDIO = int(np.floor(CLIP_DURATION * AUDIO_SAMPLERATE))
MAX_EMPTY = 0.2
def create_test_dataset(
duration_sec: float,
spec_samplerate: int = SPEC_SAMPLERATE,
audio_samplerate: int = AUDIO_SAMPLERATE,
num_freqs: int = SPEC_FREQS,
start_time: float = 0.0,
) -> xr.Dataset:
"""Creates a sample xr.Dataset for testing."""
time_step = 1 / spec_samplerate
audio_time_step = 1 / audio_samplerate
times = np.arange(start_time, start_time + duration_sec, step=time_step)
freqs = np.linspace(0, audio_samplerate / 2, num_freqs)
audio_times = np.arange(
start_time,
start_time + duration_sec,
step=audio_time_step,
)
num_time_steps = len(times)
num_audio_samples = len(audio_times)
spec_shape = (num_freqs, num_time_steps)
spectrogram_data = np.arange(num_time_steps).reshape(1, -1) * np.ones(
(num_freqs, 1)
)
spectrogram = xr.DataArray(
spectrogram_data.astype(np.float32),
coords=[("frequency", freqs), ("time", times)],
name="spectrogram",
)
detection = xr.DataArray(
np.ones(spec_shape, dtype=np.float32) * 0.5,
coords=spectrogram.coords,
name="detection",
)
classes = xr.DataArray(
np.ones((3, *spec_shape), dtype=np.float32),
coords=[
("category", ["A", "B", "C"]),
("frequency", freqs),
("time", times),
],
name="class",
)
size = xr.DataArray(
np.ones((2, *spec_shape), dtype=np.float32),
coords=[
("dimension", ["height", "width"]),
("frequency", freqs),
("time", times),
],
name="size",
)
audio_data = np.arange(num_audio_samples)
audio = xr.DataArray(
audio_data.astype(np.float32),
coords=[("audio_time", audio_times)],
name="audio",
)
metadata = xr.DataArray([1, 2, 3], dims=["other_dim"], name="metadata")
return xr.Dataset(
{
"audio": audio,
"spectrogram": spectrogram,
"detection": detection,
"class": classes,
"size": size,
"metadata": metadata,
}
).assign_attrs(
samplerate=audio_samplerate,
spec_samplerate=spec_samplerate,
)
@pytest.fixture
def long_dataset() -> xr.Dataset:
"""Dataset longer than the clip duration."""
return create_test_dataset(duration_sec=2.0)
@pytest.fixture
def short_dataset() -> xr.Dataset:
"""Dataset shorter than the clip duration."""
return create_test_dataset(duration_sec=0.3)
@pytest.fixture
def exact_dataset() -> xr.Dataset:
"""Dataset exactly the clip duration."""
return create_test_dataset(duration_sec=CLIP_DURATION - 1e-9)
@pytest.fixture
def offset_dataset() -> xr.Dataset:
"""Dataset starting at a non-zero time."""
return create_test_dataset(duration_sec=1.0, start_time=0.5)
def test_select_subclip_within_bounds(long_dataset):
start_time = 0.5
subclip = select_subclip(
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
)
expected_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "time"
)
assert "time" in subclip.dims
assert subclip.dims["time"] == expected_width
assert subclip.spectrogram.dims == ("frequency", "time")
assert subclip.spectrogram.shape == (SPEC_FREQS, expected_width)
assert subclip.detection.shape == (SPEC_FREQS, expected_width)
assert subclip["class"].shape == (3, SPEC_FREQS, expected_width)
assert subclip.size.shape == (2, SPEC_FREQS, expected_width)
assert subclip.time.min() >= start_time
assert (
subclip.time.max() <= start_time + CLIP_DURATION + 1 / SPEC_SAMPLERATE
)
assert "metadata" in subclip
xr.testing.assert_equal(subclip.metadata, long_dataset.metadata)
def test_select_subclip_pad_start(long_dataset):
start_time = -0.1
subclip = select_subclip(
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
)
expected_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "time"
)
step = 1 / SPEC_SAMPLERATE
expected_pad_samples = int(np.floor(abs(start_time) / step))
assert subclip.dims["time"] == expected_width
assert subclip.spectrogram.shape[1] == expected_width
assert np.all(
subclip.spectrogram.isel(time=slice(0, expected_pad_samples)) == 0
)
assert np.any(
subclip.spectrogram.isel(time=slice(expected_pad_samples, None)) != 0
)
assert subclip.time.min() >= start_time
assert subclip.time.max() < start_time + CLIP_DURATION + step
def test_select_subclip_pad_end(long_dataset):
original_duration = long_dataset.time.max() - long_dataset.time.min()
start_time = original_duration - 0.1
subclip = select_subclip(
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
)
expected_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "time"
)
step = 1 / SPEC_SAMPLERATE
original_width = long_dataset.dims["time"]
expected_pad_samples = expected_width - (
original_width - int(np.floor(start_time / step))
)
assert subclip.sizes["time"] == expected_width
assert subclip.spectrogram.shape[1] == expected_width
assert np.all(
subclip.spectrogram.isel(
time=slice(expected_width - expected_pad_samples, None)
)
== 0
)
assert np.any(
subclip.spectrogram.isel(
time=slice(0, expected_width - expected_pad_samples)
)
!= 0
)
assert subclip.time.min() >= start_time
assert subclip.time.max() < start_time + CLIP_DURATION + step
def test_select_subclip_pad_both_short_dataset(short_dataset):
start_time = -0.1
subclip = select_subclip(
short_dataset, span=CLIP_DURATION, start=start_time, dim="time"
)
expected_width = _compute_expected_width(
short_dataset, CLIP_DURATION, "time"
)
step = 1 / SPEC_SAMPLERATE
assert subclip.dims["time"] == expected_width
assert subclip.spectrogram.shape[1] == expected_width
assert subclip.spectrogram.coords["time"][0] == pytest.approx(
start_time,
abs=step,
)
assert subclip.spectrogram.coords["time"][-1] == pytest.approx(
start_time + CLIP_DURATION - step,
abs=2 * step,
)
def test_select_subclip_width_consistency(long_dataset):
expected_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "time"
)
step = 1 / SPEC_SAMPLERATE
subclip_aligned = select_subclip(
long_dataset.copy(deep=True),
span=CLIP_DURATION,
start=5 * step,
dim="time",
)
subclip_offset = select_subclip(
long_dataset.copy(deep=True),
span=CLIP_DURATION,
start=5.3 * step,
dim="time",
)
assert subclip_aligned.sizes["time"] == expected_width
assert subclip_offset.sizes["time"] == expected_width
assert subclip_aligned.spectrogram.shape[1] == expected_width
assert subclip_offset.spectrogram.shape[1] == expected_width
def test_select_subclip_different_dimension(long_dataset):
freq_coords = long_dataset.frequency.values
freq_min, freq_max = freq_coords.min(), freq_coords.max()
freq_span = (freq_max - freq_min) / 2
start_freq = freq_min + freq_span / 2
subclip = select_subclip(
long_dataset, span=freq_span, start=start_freq, dim="frequency"
)
assert "frequency" in subclip.dims
assert subclip.spectrogram.shape[0] < long_dataset.spectrogram.shape[0]
assert subclip.detection.shape[0] < long_dataset.detection.shape[0]
assert subclip["class"].shape[1] < long_dataset["class"].shape[1]
assert subclip.size.shape[1] < long_dataset.size.shape[1]
assert subclip.dims["time"] == long_dataset.dims["time"]
assert subclip.spectrogram.shape[1] == long_dataset.spectrogram.shape[1]
xr.testing.assert_equal(subclip.audio, long_dataset.audio)
assert subclip.dims["audio_time"] == long_dataset.dims["audio_time"]
def test_select_subclip_fill_value(short_dataset):
fill_value = -999.0
subclip = select_subclip(
short_dataset,
span=CLIP_DURATION,
start=0,
dim="time",
fill_value=fill_value,
)
expected_width = _compute_expected_width(
short_dataset,
CLIP_DURATION,
"time",
)
assert subclip.dims["time"] == expected_width
assert np.all(subclip.spectrogram.sel(time=slice(0.3, None)) == fill_value)
def test_select_subclip_no_overlap_raises_error(long_dataset):
original_duration = long_dataset.time.max() - long_dataset.time.min()
with pytest.raises(ValueError, match="does not overlap"):
select_subclip(
long_dataset,
span=CLIP_DURATION,
start=original_duration + 1.0,
dim="time",
)
with pytest.raises(ValueError, match="does not overlap"):
select_subclip(
long_dataset,
span=CLIP_DURATION,
start=-1.0 * CLIP_DURATION - 1.0,
dim="time",
)
def test_clipper_non_random(long_dataset, exact_dataset, short_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=False)
for ds in [long_dataset, exact_dataset, short_dataset]:
clip, _, _ = clipper.extract_clip(ds)
expected_spec_width = _compute_expected_width(
ds, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
ds, CLIP_DURATION, "audio_time"
)
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
assert clip.time.min() >= -1 / SPEC_SAMPLERATE
assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
def test_clipper_random(long_dataset):
seed = 42
np.random.seed(seed)
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip1, _, _ = clipper.extract_clip(long_dataset)
np.random.seed(seed + 1)
clip2, _, _ = clipper.extract_clip(long_dataset)
expected_spec_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
long_dataset, CLIP_DURATION, "audio_time"
)
for clip in [clip1, clip2]:
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
assert not np.isclose(clip1.time.min(), clip2.time.min())
assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min())
for clip in [clip1, clip2]:
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
max_start_time = (
(long_dataset.time.max() - long_dataset.time.min())
- CLIP_DURATION
+ MAX_EMPTY
)
assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
def test_clipper_random_max_empty_effect(long_dataset):
"""Check that max_empty influences the possible start times."""
seed = 123
data_duration = long_dataset.time.max() - long_dataset.time.min()
np.random.seed(seed)
clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0)
max_start_time0 = data_duration - CLIP_DURATION
start_times0 = []
for _ in range(20):
clip, _, _ = clipper0.extract_clip(long_dataset)
start_times0.append(clip.time.min().item())
assert all(
st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0
)
assert any(st > 0.1 for st in start_times0)
np.random.seed(seed)
clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2)
max_start_time_pos = data_duration - CLIP_DURATION + 0.2
start_times_pos = []
for _ in range(20):
clip, _, _ = clipper_pos.extract_clip(long_dataset)
start_times_pos.append(clip.time.min().item())
assert all(
st <= max_start_time_pos + 1 / SPEC_SAMPLERATE
for st in start_times_pos
)
assert any(st > max_start_time0 + 1e-6 for st in start_times_pos)
def test_clipper_short_dataset_random(short_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip, _, _ = clipper.extract_clip(short_dataset)
expected_spec_width = _compute_expected_width(
short_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
short_dataset, CLIP_DURATION, "audio_time"
)
assert clip.sizes["time"] == expected_spec_width
assert clip.sizes["audio_time"] == expected_audio_width
assert clip["spectrogram"].shape[1] == expected_spec_width
assert clip["audio"].shape[0] == expected_audio_width
assert np.any(clip.spectrogram == 0)
assert np.any(clip.audio == 0)
def test_clipper_exact_dataset_random(exact_dataset):
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
clip, _, _ = clipper.extract_clip(exact_dataset)
expected_spec_width = _compute_expected_width(
exact_dataset, CLIP_DURATION, "time"
)
expected_audio_width = _compute_expected_width(
exact_dataset, CLIP_DURATION, "audio_time"
)
assert clip.dims["time"] == expected_spec_width
assert clip.dims["audio_time"] == expected_audio_width
assert clip.spectrogram.shape[1] == expected_spec_width
assert clip.audio.shape[0] == expected_audio_width
time_span = clip.time.max() - clip.time.min()
audio_span = clip.audio_time.max() - clip.audio_time.min()
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)

View File

@ -0,0 +1,56 @@
from pathlib import Path
import lightning as L
import torch
import xarray as xr
from soundevent import data
from batdetect2.models import build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import build_loss
def build_default_module():
loss = build_loss()
targets = build_targets()
detector = build_model(num_classes=len(targets.class_names))
preprocessor = build_preprocessor()
postprocessor = build_postprocessor(targets)
return TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
def test_can_initialize_default_module():
module = build_default_module()
assert isinstance(module, L.LightningModule)
def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
module = build_default_module()
trainer = L.Trainer()
path = tmp_path / "example.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
recovered = TrainingModule.load_from_checkpoint(path)
spec1 = module.preprocessor.preprocess_clip(clip)
spec2 = recovered.preprocessor.preprocess_clip(clip)
xr.testing.assert_equal(spec1, spec2)
input1 = torch.tensor([spec1.values]).unsqueeze(0)
input2 = torch.tensor([spec2.values]).unsqueeze(0)
output1 = module(input1)
output2 = recovered(input2)
assert output1 == output2