mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added clips for random cliping and augmentations
This commit is contained in:
parent
2396815c13
commit
59bd14bc79
@ -1,153 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import lightning as L
|
|
||||||
import torch
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from torch.optim.adam import Adam
|
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
|
||||||
from batdetect2.models import (
|
|
||||||
BackboneConfig,
|
|
||||||
BBoxHead,
|
|
||||||
ClassifierHead,
|
|
||||||
ModelOutput,
|
|
||||||
build_backbone,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess import (
|
|
||||||
PostprocessConfig,
|
|
||||||
postprocess_model_outputs,
|
|
||||||
)
|
|
||||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
|
||||||
from batdetect2.targets import (
|
|
||||||
TargetConfig,
|
|
||||||
build_decoder,
|
|
||||||
build_target_encoder,
|
|
||||||
get_class_names,
|
|
||||||
)
|
|
||||||
from batdetect2.train import TrainExample, TrainingConfig, compute_loss
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"DetectorModel",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ModuleConfig(BaseConfig):
|
|
||||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
architecture: BackboneConfig = Field(default_factory=BackboneConfig)
|
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocessing: PostprocessConfig = Field(
|
|
||||||
default_factory=PostprocessConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectorModel(L.LightningModule):
|
|
||||||
config: ModuleConfig
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: Optional[ModuleConfig] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.config = config or ModuleConfig()
|
|
||||||
self.save_hyperparameters()
|
|
||||||
|
|
||||||
self.backbone = build_model_backbone(self.config.architecture)
|
|
||||||
|
|
||||||
self.classifier = ClassifierHead(
|
|
||||||
num_classes=len(self.config.targets.classes),
|
|
||||||
in_channels=self.backbone.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
|
|
||||||
|
|
||||||
conf = self.config.train.loss.classification
|
|
||||||
self.class_weights = (
|
|
||||||
torch.tensor(conf.class_weights) if conf.class_weights else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training targets
|
|
||||||
self.class_names = get_class_names(self.config.targets.classes)
|
|
||||||
self.encoder = build_target_encoder(
|
|
||||||
self.config.targets.classes,
|
|
||||||
replacement_rules=self.config.targets.replace,
|
|
||||||
)
|
|
||||||
self.decoder = build_decoder(self.config.targets.classes)
|
|
||||||
self.example_input_array = torch.randn([1, 1, 128, 512])
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
|
||||||
features = self.backbone(spec)
|
|
||||||
detection_probs, classification_probs = self.classifier(features)
|
|
||||||
size_preds = self.bbox(features)
|
|
||||||
return ModelOutput(
|
|
||||||
detection_probs=detection_probs,
|
|
||||||
size_preds=size_preds,
|
|
||||||
class_probs=classification_probs,
|
|
||||||
features=features,
|
|
||||||
)
|
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample):
|
|
||||||
outputs = self.forward(batch.spec)
|
|
||||||
losses = compute_loss(
|
|
||||||
batch,
|
|
||||||
outputs,
|
|
||||||
conf=self.config.train.loss,
|
|
||||||
class_weights=self.class_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
|
||||||
self.log("train/loss/detection", losses.total, logger=True)
|
|
||||||
self.log("train/loss/size", losses.total, logger=True)
|
|
||||||
self.log("train/loss/classification", losses.total, logger=True)
|
|
||||||
|
|
||||||
return losses.total
|
|
||||||
|
|
||||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
|
||||||
outputs = self.forward(batch.spec)
|
|
||||||
|
|
||||||
losses = compute_loss(
|
|
||||||
batch,
|
|
||||||
outputs,
|
|
||||||
conf=self.config.train.loss,
|
|
||||||
class_weights=self.class_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
|
||||||
self.log("val/loss/detection", losses.total, logger=True)
|
|
||||||
self.log("val/loss/size", losses.total, logger=True)
|
|
||||||
self.log("val/loss/classification", losses.total, logger=True)
|
|
||||||
|
|
||||||
def on_validation_epoch_end(self) -> None:
|
|
||||||
self.validation_predictions.clear()
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
conf = self.config.train.optimizer
|
|
||||||
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
|
||||||
return [optimizer], [scheduler]
|
|
||||||
|
|
||||||
def process_clip(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: Optional[Path] = None,
|
|
||||||
) -> data.ClipPrediction:
|
|
||||||
spec = preprocess_audio_clip(
|
|
||||||
clip,
|
|
||||||
config=self.config.preprocessing,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
|
||||||
outputs = self.forward(tensor)
|
|
||||||
return postprocess_model_outputs(
|
|
||||||
outputs,
|
|
||||||
clips=[clip],
|
|
||||||
classes=self.class_names,
|
|
||||||
decoder=self.decoder,
|
|
||||||
config=self.config.postprocessing,
|
|
||||||
)[0]
|
|
@ -11,18 +11,19 @@ from batdetect2.train.augmentations import (
|
|||||||
mask_time,
|
mask_time,
|
||||||
mix_examples,
|
mix_examples,
|
||||||
scale_volume,
|
scale_volume,
|
||||||
select_subclip,
|
|
||||||
warp_spectrogram,
|
warp_spectrogram,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.clips import build_clipper, select_subclip
|
||||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
|
RandomExampleSource,
|
||||||
SubclipConfig,
|
SubclipConfig,
|
||||||
TrainExample,
|
TrainExample,
|
||||||
get_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import load_label_config
|
from batdetect2.train.labels import load_label_config
|
||||||
from batdetect2.train.losses import compute_loss
|
from batdetect2.train.losses import LossFunction, build_loss
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import (
|
||||||
generate_train_example,
|
generate_train_example,
|
||||||
preprocess_annotations,
|
preprocess_annotations,
|
||||||
@ -34,6 +35,8 @@ __all__ = [
|
|||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
|
"LossFunction",
|
||||||
|
"RandomExampleSource",
|
||||||
"SubclipConfig",
|
"SubclipConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
@ -43,9 +46,11 @@ __all__ = [
|
|||||||
"WarpAugmentationConfig",
|
"WarpAugmentationConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
"build_augmentations",
|
"build_augmentations",
|
||||||
"compute_loss",
|
"build_clipper",
|
||||||
|
"build_loss",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"get_preprocessed_files",
|
"list_preprocessed_files",
|
||||||
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"load_trainer_config",
|
"load_trainer_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
@ -56,5 +61,4 @@ __all__ = [
|
|||||||
"select_subclip",
|
"select_subclip",
|
||||||
"train",
|
"train",
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
"load_label_config",
|
|
||||||
]
|
]
|
||||||
|
@ -37,24 +37,24 @@ from batdetect2.train.types import Augmentation
|
|||||||
from batdetect2.utils.arrays import adjust_width
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AugmentationConfig",
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
"load_augmentation_config",
|
"DEFAULT_AUGMENTATION_CONFIG",
|
||||||
"build_augmentations",
|
|
||||||
"select_subclip",
|
|
||||||
"mix_examples",
|
|
||||||
"add_echo",
|
|
||||||
"scale_volume",
|
|
||||||
"warp_spectrogram",
|
|
||||||
"mask_time",
|
|
||||||
"mask_frequency",
|
|
||||||
"MixAugmentationConfig",
|
|
||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
|
"ExampleSource",
|
||||||
|
"FrequencyMaskAugmentationConfig",
|
||||||
|
"MixAugmentationConfig",
|
||||||
|
"TimeMaskAugmentationConfig",
|
||||||
"VolumeAugmentationConfig",
|
"VolumeAugmentationConfig",
|
||||||
"WarpAugmentationConfig",
|
"WarpAugmentationConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"add_echo",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"build_augmentations",
|
||||||
"AugmentationConfig",
|
"load_augmentation_config",
|
||||||
"ExampleSource",
|
"mask_frequency",
|
||||||
|
"mask_time",
|
||||||
|
"mix_examples",
|
||||||
|
"scale_volume",
|
||||||
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
|
||||||
ExampleSource = Callable[[], xr.Dataset]
|
ExampleSource = Callable[[], xr.Dataset]
|
||||||
@ -64,92 +64,6 @@ Used by the `mix_examples` augmentation to fetch another example to mix with.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def select_subclip(
|
|
||||||
example: xr.Dataset,
|
|
||||||
start_time: Optional[float] = None,
|
|
||||||
duration: Optional[float] = None,
|
|
||||||
width: Optional[int] = None,
|
|
||||||
random: bool = False,
|
|
||||||
) -> xr.Dataset:
|
|
||||||
"""Extract a sub-clip (time segment) from a training example dataset.
|
|
||||||
|
|
||||||
Selects a portion of the 'time' dimension from all relevant DataArrays
|
|
||||||
(`audio`, `spectrogram`, `detection`, `class`, `size`) within the example
|
|
||||||
Dataset. The segment can be defined by a fixed start time and
|
|
||||||
duration/width, or a random start time can be chosen.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
example : xr.Dataset
|
|
||||||
The input training example containing 'audio', 'spectrogram', and
|
|
||||||
target heatmaps, all with compatible 'time' (or 'audio_time')
|
|
||||||
coordinates.
|
|
||||||
start_time : float, optional
|
|
||||||
Desired start time (seconds) of the subclip. If None and `random` is
|
|
||||||
False, starts from the beginning of the example. If None and `random`
|
|
||||||
is True, a random start time is chosen.
|
|
||||||
duration : float, optional
|
|
||||||
Desired duration (seconds) of the subclip. Either `duration` or `width`
|
|
||||||
must be provided.
|
|
||||||
width : int, optional
|
|
||||||
Desired width (number of time bins) of the subclip's
|
|
||||||
spectrogram/heatmaps. Either `duration` or `width` must be provided. If
|
|
||||||
both are given, `duration` takes precedence.
|
|
||||||
random : bool, default=False
|
|
||||||
If True and `start_time` is None, selects a random start time ensuring
|
|
||||||
the subclip fits within the original example's duration.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.Dataset
|
|
||||||
A new dataset containing only the selected time segment. Coordinates
|
|
||||||
are adjusted accordingly. Returns the original example if the requested
|
|
||||||
subclip cannot be extracted (e.g., duration too long).
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If neither `duration` nor `width` is provided, or if time coordinates
|
|
||||||
are missing or invalid.
|
|
||||||
"""
|
|
||||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
|
||||||
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
|
||||||
|
|
||||||
if width is None:
|
|
||||||
if duration is None:
|
|
||||||
raise ValueError("Either duration or width must be provided")
|
|
||||||
|
|
||||||
width = int(np.floor(duration / step))
|
|
||||||
|
|
||||||
if duration is None:
|
|
||||||
duration = width * step
|
|
||||||
|
|
||||||
if start_time is None:
|
|
||||||
if random:
|
|
||||||
start_time = np.random.uniform(start, max(stop - duration, start))
|
|
||||||
else:
|
|
||||||
start_time = start
|
|
||||||
|
|
||||||
if start_time + duration > stop:
|
|
||||||
return example
|
|
||||||
|
|
||||||
start_index = arrays.get_coord_index(
|
|
||||||
example, # type: ignore
|
|
||||||
"time",
|
|
||||||
start_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
end_index = start_index + width - 1
|
|
||||||
|
|
||||||
start_time = example.time.values[start_index]
|
|
||||||
end_time = example.time.values[end_index]
|
|
||||||
|
|
||||||
return example.sel(
|
|
||||||
time=slice(start_time, end_time),
|
|
||||||
audio_time=slice(start_time, end_time + step),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MixAugmentationConfig(BaseConfig):
|
class MixAugmentationConfig(BaseConfig):
|
||||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||||
|
|
||||||
@ -878,9 +792,21 @@ def build_augmentation_from_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
||||||
|
steps=[
|
||||||
|
MixAugmentationConfig(),
|
||||||
|
EchoAugmentationConfig(),
|
||||||
|
VolumeAugmentationConfig(),
|
||||||
|
WarpAugmentationConfig(),
|
||||||
|
TimeMaskAugmentationConfig(),
|
||||||
|
FrequencyMaskAugmentationConfig(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_augmentations(
|
def build_augmentations(
|
||||||
config: AugmentationsConfig,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
|
config: Optional[AugmentationsConfig] = None,
|
||||||
example_source: Optional[ExampleSource] = None,
|
example_source: Optional[ExampleSource] = None,
|
||||||
) -> Augmentation:
|
) -> Augmentation:
|
||||||
"""Build a composite augmentation pipeline function from configuration.
|
"""Build a composite augmentation pipeline function from configuration.
|
||||||
@ -915,6 +841,8 @@ def build_augmentations(
|
|||||||
NotImplementedError
|
NotImplementedError
|
||||||
If an unknown `augmentation_type` is encountered in `config.steps`.
|
If an unknown `augmentation_type` is encountered in `config.steps`.
|
||||||
"""
|
"""
|
||||||
|
config = config or DEFAULT_AUGMENTATION_CONFIG
|
||||||
|
|
||||||
augmentations = []
|
augmentations = []
|
||||||
|
|
||||||
for step_config in config.steps:
|
for step_config in config.steps:
|
||||||
|
184
batdetect2/train/clips.py
Normal file
184
batdetect2/train/clips.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import arrays
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.train.types import ClipperProtocol
|
||||||
|
|
||||||
|
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||||
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class ClipperConfig(BaseConfig):
|
||||||
|
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||||
|
random: bool = True
|
||||||
|
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||||
|
|
||||||
|
|
||||||
|
class Clipper(ClipperProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
duration: float = 0.5,
|
||||||
|
max_empty: float = 0.2,
|
||||||
|
random: bool = True,
|
||||||
|
):
|
||||||
|
self.duration = duration
|
||||||
|
self.random = random
|
||||||
|
self.max_empty = max_empty
|
||||||
|
|
||||||
|
def extract_clip(
|
||||||
|
self, example: xr.Dataset
|
||||||
|
) -> Tuple[xr.Dataset, float, float]:
|
||||||
|
step = arrays.get_dim_step(
|
||||||
|
example.spectrogram,
|
||||||
|
dim=arrays.Dimensions.time.value,
|
||||||
|
)
|
||||||
|
duration = (
|
||||||
|
arrays.get_dim_width(
|
||||||
|
example.spectrogram,
|
||||||
|
dim=arrays.Dimensions.time.value,
|
||||||
|
)
|
||||||
|
+ step
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = 0
|
||||||
|
if self.random:
|
||||||
|
start_time = np.random.uniform(
|
||||||
|
-self.max_empty,
|
||||||
|
duration - self.duration + self.max_empty,
|
||||||
|
)
|
||||||
|
|
||||||
|
subclip = select_subclip(
|
||||||
|
example,
|
||||||
|
start=start_time,
|
||||||
|
span=self.duration,
|
||||||
|
dim="time",
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
select_subclip(
|
||||||
|
subclip,
|
||||||
|
start=start_time,
|
||||||
|
span=self.duration,
|
||||||
|
dim="audio_time",
|
||||||
|
),
|
||||||
|
start_time,
|
||||||
|
start_time + self.duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_clipper(config: Optional[ClipperConfig] = None) -> ClipperProtocol:
|
||||||
|
config = config or ClipperConfig()
|
||||||
|
return Clipper(
|
||||||
|
duration=config.duration,
|
||||||
|
max_empty=config.max_empty,
|
||||||
|
random=config.random,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def select_subclip(
|
||||||
|
dataset: xr.Dataset,
|
||||||
|
span: float,
|
||||||
|
start: float,
|
||||||
|
fill_value: float = 0,
|
||||||
|
dim: str = "time",
|
||||||
|
) -> xr.Dataset:
|
||||||
|
width = _compute_expected_width(
|
||||||
|
dataset, # type: ignore
|
||||||
|
span,
|
||||||
|
dim=dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
coord = dataset.coords[dim]
|
||||||
|
|
||||||
|
if len(coord) == width:
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
|
||||||
|
coord, start, span
|
||||||
|
)
|
||||||
|
|
||||||
|
data_vars = {}
|
||||||
|
for name, data_array in dataset.data_vars.items():
|
||||||
|
if dim not in data_array.dims:
|
||||||
|
data_vars[name] = data_array
|
||||||
|
continue
|
||||||
|
|
||||||
|
if width == data_array.sizes[dim]:
|
||||||
|
data_vars[name] = data_array
|
||||||
|
continue
|
||||||
|
|
||||||
|
sliced = data_array.isel({dim: dim_slice}).data
|
||||||
|
|
||||||
|
if start_pad > 0 or end_pad > 0:
|
||||||
|
padding = [
|
||||||
|
[0, 0] if other_dim != dim else [start_pad, end_pad]
|
||||||
|
for other_dim in data_array.dims
|
||||||
|
]
|
||||||
|
sliced = np.pad(sliced, padding, constant_values=fill_value)
|
||||||
|
|
||||||
|
data_vars[name] = xr.DataArray(
|
||||||
|
data=sliced,
|
||||||
|
dims=data_array.dims,
|
||||||
|
coords={**data_array.coords, dim: new_coords},
|
||||||
|
attrs=data_array.attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_coordinate(
|
||||||
|
coord: xr.DataArray,
|
||||||
|
start: float,
|
||||||
|
span: float,
|
||||||
|
) -> Tuple[xr.Variable, int, int, slice]:
|
||||||
|
step = arrays.get_dim_step(coord, str(coord.name))
|
||||||
|
|
||||||
|
current_width = len(coord)
|
||||||
|
expected_width = int(np.floor(span / step))
|
||||||
|
|
||||||
|
coord_start = float(coord[0])
|
||||||
|
offset = start - coord_start
|
||||||
|
|
||||||
|
start_index = int(np.floor(offset / step))
|
||||||
|
end_index = start_index + expected_width
|
||||||
|
|
||||||
|
if start_index > current_width:
|
||||||
|
raise ValueError("Requested span does not overlap with current range")
|
||||||
|
|
||||||
|
if end_index < 0:
|
||||||
|
raise ValueError("Requested span does not overlap with current range")
|
||||||
|
|
||||||
|
corrected_start = float(start_index * step)
|
||||||
|
corrected_end = float(end_index * step)
|
||||||
|
|
||||||
|
start_index_offset = max(0, -start_index)
|
||||||
|
end_index_offset = max(0, end_index - current_width)
|
||||||
|
|
||||||
|
sl = slice(
|
||||||
|
start_index if start_index >= 0 else None,
|
||||||
|
end_index if end_index < current_width else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
arrays.create_range_dim(
|
||||||
|
str(coord.name),
|
||||||
|
start=corrected_start,
|
||||||
|
stop=corrected_end,
|
||||||
|
step=step,
|
||||||
|
),
|
||||||
|
start_index_offset,
|
||||||
|
end_index_offset,
|
||||||
|
sl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_expected_width(
|
||||||
|
array: Union[xr.DataArray, xr.Dataset],
|
||||||
|
duration: float,
|
||||||
|
dim: str,
|
||||||
|
) -> int:
|
||||||
|
step = arrays.get_dim_step(array, dim) # type: ignore
|
||||||
|
return int(np.floor(duration / step))
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import NamedTuple, Optional, Sequence, Union
|
from typing import List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -13,28 +12,14 @@ from batdetect2.configs import BaseConfig
|
|||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
Augmentation,
|
Augmentation,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
select_subclip,
|
|
||||||
)
|
)
|
||||||
from batdetect2.train.preprocess import PreprocessorProtocol
|
from batdetect2.train.types import ClipperProtocol, TrainExample
|
||||||
from batdetect2.utils.tensors import adjust_width
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainExample",
|
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
|
||||||
|
|
||||||
|
|
||||||
class TrainExample(NamedTuple):
|
|
||||||
spec: torch.Tensor
|
|
||||||
detection_heatmap: torch.Tensor
|
|
||||||
class_heatmap: torch.Tensor
|
|
||||||
size_heatmap: torch.Tensor
|
|
||||||
idx: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class SubclipConfig(BaseConfig):
|
class SubclipConfig(BaseConfig):
|
||||||
duration: Optional[float] = None
|
duration: Optional[float] = None
|
||||||
width: int = 512
|
width: int = 512
|
||||||
@ -51,14 +36,12 @@ class DatasetConfig(BaseConfig):
|
|||||||
class LabeledDataset(Dataset):
|
class LabeledDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
preprocessor: PreprocessorProtocol,
|
filenames: Sequence[data.PathLike],
|
||||||
filenames: Sequence[PathLike],
|
clipper: ClipperProtocol,
|
||||||
subclip: Optional[SubclipConfig] = None,
|
|
||||||
augmentation: Optional[Augmentation] = None,
|
augmentation: Optional[Augmentation] = None,
|
||||||
):
|
):
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.subclip = subclip
|
self.clipper = clipper
|
||||||
self.augmentation = augmentation
|
self.augmentation = augmentation
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -66,14 +49,7 @@ class LabeledDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
dataset = self.get_dataset(idx)
|
dataset = self.get_dataset(idx)
|
||||||
|
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||||
if self.subclip:
|
|
||||||
dataset = select_subclip(
|
|
||||||
dataset,
|
|
||||||
duration=self.subclip.duration,
|
|
||||||
width=self.subclip.width,
|
|
||||||
random=self.subclip.random,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.augmentation:
|
if self.augmentation:
|
||||||
dataset = self.augmentation(dataset)
|
dataset = self.augmentation(dataset)
|
||||||
@ -84,37 +60,31 @@ class LabeledDataset(Dataset):
|
|||||||
class_heatmap=self.to_tensor(dataset["class"]),
|
class_heatmap=self.to_tensor(dataset["class"]),
|
||||||
size_heatmap=self.to_tensor(dataset["size"]),
|
size_heatmap=self.to_tensor(dataset["size"]),
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_directory(
|
def from_directory(
|
||||||
cls,
|
cls,
|
||||||
directory: PathLike,
|
directory: data.PathLike,
|
||||||
preprocessor: PreprocessorProtocol,
|
clipper: ClipperProtocol,
|
||||||
extension: str = ".nc",
|
extension: str = ".nc",
|
||||||
subclip: Optional[SubclipConfig] = None,
|
|
||||||
augmentation: Optional[Augmentation] = None,
|
augmentation: Optional[Augmentation] = None,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
preprocessor=preprocessor,
|
filenames=list_preprocessed_files(directory, extension),
|
||||||
filenames=get_preprocessed_files(directory, extension),
|
clipper=clipper,
|
||||||
subclip=subclip,
|
|
||||||
augmentation=augmentation,
|
augmentation=augmentation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_random_example(self) -> xr.Dataset:
|
def get_random_example(self) -> Tuple[xr.Dataset, float, float]:
|
||||||
idx = np.random.randint(0, len(self))
|
idx = np.random.randint(0, len(self))
|
||||||
dataset = self.get_dataset(idx)
|
dataset = self.get_dataset(idx)
|
||||||
|
|
||||||
if self.subclip:
|
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||||
dataset = select_subclip(
|
|
||||||
dataset,
|
|
||||||
duration=self.subclip.duration,
|
|
||||||
width=self.subclip.width,
|
|
||||||
random=self.subclip.random,
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset
|
return dataset, start_time, end_time
|
||||||
|
|
||||||
def get_dataset(self, idx) -> xr.Dataset:
|
def get_dataset(self, idx) -> xr.Dataset:
|
||||||
return xr.open_dataset(self.filenames[idx])
|
return xr.open_dataset(self.filenames[idx])
|
||||||
@ -129,16 +99,23 @@ class LabeledDataset(Dataset):
|
|||||||
array: xr.DataArray,
|
array: xr.DataArray,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
tensor = torch.tensor(array.values.astype(dtype))
|
return torch.tensor(array.values.astype(dtype))
|
||||||
|
|
||||||
if not self.subclip:
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
width = self.subclip.width
|
|
||||||
return adjust_width(tensor, width)
|
|
||||||
|
|
||||||
|
|
||||||
def get_preprocessed_files(
|
def list_preprocessed_files(
|
||||||
directory: PathLike, extension: str = ".nc"
|
directory: data.PathLike, extension: str = ".nc"
|
||||||
) -> Sequence[Path]:
|
) -> Sequence[Path]:
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomExampleSource:
|
||||||
|
def __init__(self, filenames: List[str], clipper: ClipperProtocol):
|
||||||
|
self.filenames = filenames
|
||||||
|
self.clipper = clipper
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
index = int(np.random.randint(len(self.filenames)))
|
||||||
|
filename = self.filenames[index]
|
||||||
|
dataset = xr.open_dataset(filename)
|
||||||
|
example, _, _ = self.clipper.extract_clip(dataset)
|
||||||
|
return example
|
||||||
|
71
batdetect2/train/lightning.py
Normal file
71
batdetect2/train/lightning.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import lightning as L
|
||||||
|
import torch
|
||||||
|
from torch.optim.adam import Adam
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
|
from batdetect2.models import (
|
||||||
|
DetectionModel,
|
||||||
|
ModelOutput,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
from batdetect2.train import TrainExample
|
||||||
|
from batdetect2.train.types import LossProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TrainingModule",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingModule(L.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
detector: DetectionModel,
|
||||||
|
loss: LossProtocol,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
postprocessor: PostprocessorProtocol,
|
||||||
|
learning_rate: float = 0.001,
|
||||||
|
t_max: int = 100,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.loss = loss
|
||||||
|
self.detector = detector
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.targets = targets
|
||||||
|
self.postprocessor = postprocessor
|
||||||
|
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.t_max = t_max
|
||||||
|
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
|
return self.detector(spec)
|
||||||
|
|
||||||
|
def training_step(self, batch: TrainExample):
|
||||||
|
outputs = self.forward(batch.spec)
|
||||||
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
|
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||||
|
self.log("train/loss/detection", losses.total, logger=True)
|
||||||
|
self.log("train/loss/size", losses.total, logger=True)
|
||||||
|
self.log("train/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
|
return losses.total
|
||||||
|
|
||||||
|
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||||
|
outputs = self.forward(batch.spec)
|
||||||
|
losses = self.loss(outputs, batch)
|
||||||
|
|
||||||
|
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||||
|
self.log("val/loss/detection", losses.total, logger=True)
|
||||||
|
self.log("val/loss/size", losses.total, logger=True)
|
||||||
|
self.log("val/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||||
|
return [optimizer], [scheduler]
|
@ -1,115 +1,323 @@
|
|||||||
from typing import NamedTuple, Optional
|
"""Loss functions and configurations for training BatDetect2 models.
|
||||||
|
|
||||||
|
This module defines the loss functions used to train BatDetect2 models,
|
||||||
|
including individual loss components for different prediction tasks (detection,
|
||||||
|
classification, size regression) and a main coordinating loss function that
|
||||||
|
combines them.
|
||||||
|
|
||||||
|
It utilizes common loss types like L1 loss (`BBoxLoss`) for regression and
|
||||||
|
Focal Loss (`FocalLoss`) for handling class imbalance in dense detection and
|
||||||
|
classification tasks. Configuration objects (`LossConfig`, etc.) allow for easy
|
||||||
|
customization of loss parameters and weights via configuration files.
|
||||||
|
|
||||||
|
The primary entry points are:
|
||||||
|
- `LossFunction`: An `nn.Module` that computes the weighted sum of individual
|
||||||
|
loss components given model outputs and ground truth targets.
|
||||||
|
- `build_loss`: A factory function that constructs the `LossFunction` based
|
||||||
|
on a `LossConfig` object.
|
||||||
|
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.train.dataset import TrainExample
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
from batdetect2.train.types import Losses, LossProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"bbox_size_loss",
|
"BBoxLoss",
|
||||||
"compute_loss",
|
"ClassificationLossConfig",
|
||||||
"focal_loss",
|
"DetectionLossConfig",
|
||||||
"mse_loss",
|
"FocalLoss",
|
||||||
|
"FocalLossConfig",
|
||||||
|
"LossConfig",
|
||||||
|
"LossFunction",
|
||||||
|
"MSELoss",
|
||||||
|
"SizeLossConfig",
|
||||||
|
"build_loss",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class SizeLossConfig(BaseConfig):
|
class SizeLossConfig(BaseConfig):
|
||||||
|
"""Configuration for the bounding box size loss component.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
weight : float, default=0.1
|
||||||
|
The weighting factor applied to the size loss when combining it with
|
||||||
|
other losses (detection, classification) to form the total training
|
||||||
|
loss.
|
||||||
|
"""
|
||||||
|
|
||||||
weight: float = 0.1
|
weight: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
def bbox_size_loss(
|
class BBoxLoss(nn.Module):
|
||||||
pred_size: torch.Tensor,
|
"""Computes L1 loss for bounding box size regression.
|
||||||
gt_size: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
Calculates the Mean Absolute Error (MAE or L1 loss) between the predicted
|
||||||
|
size dimensions (`pred`) and the ground truth size dimensions (`gt`).
|
||||||
|
Crucially, the loss is only computed at locations where the ground truth
|
||||||
|
size heatmap (`gt`) contains non-zero values (i.e., at the reference points
|
||||||
|
of actual annotated sound events). This prevents the model from being
|
||||||
|
penalized for size predictions in background regions.
|
||||||
|
|
||||||
|
The loss is summed over all valid locations and normalized by the number
|
||||||
|
of valid locations.
|
||||||
"""
|
"""
|
||||||
Bounding box size loss. Only compute loss where there is a bounding box.
|
|
||||||
"""
|
def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
|
||||||
gt_size_mask = (gt_size > 0).float()
|
"""Calculate masked L1 loss for size prediction.
|
||||||
return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
|
|
||||||
gt_size_mask.sum() + 1e-5
|
Parameters
|
||||||
)
|
----------
|
||||||
|
pred : torch.Tensor
|
||||||
|
Predicted size tensor, typically shape `(B, 2, H, W)`, where
|
||||||
|
channels represent scaled width and height.
|
||||||
|
gt : torch.Tensor
|
||||||
|
Ground truth size tensor, same shape as `pred`. Non-zero values
|
||||||
|
indicate locations and target sizes of actual annotations.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Scalar tensor representing the calculated masked L1 loss.
|
||||||
|
"""
|
||||||
|
gt_size_mask = (gt > 0).float()
|
||||||
|
masked_pred = pred * gt_size_mask
|
||||||
|
loss = F.l1_loss(masked_pred, gt, reduction="sum")
|
||||||
|
num_pos = gt_size_mask.sum() + 1e-5
|
||||||
|
return loss / num_pos
|
||||||
|
|
||||||
|
|
||||||
class FocalLossConfig(BaseConfig):
|
class FocalLossConfig(BaseConfig):
|
||||||
|
"""Configuration parameters for the Focal Loss function.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
beta : float, default=4
|
||||||
|
Exponent controlling the down-weighting of easy negative examples.
|
||||||
|
Higher values increase down-weighting (focus more on hard negatives).
|
||||||
|
alpha : float, default=2
|
||||||
|
Exponent controlling the down-weighting based on prediction confidence.
|
||||||
|
Higher values focus more on misclassified examples (both positive and
|
||||||
|
negative).
|
||||||
|
"""
|
||||||
|
|
||||||
beta: float = 4
|
beta: float = 4
|
||||||
alpha: float = 2
|
alpha: float = 2
|
||||||
|
|
||||||
|
|
||||||
def focal_loss(
|
class FocalLoss(nn.Module):
|
||||||
pred: torch.Tensor,
|
"""Focal Loss implementation, adapted from CornerNet.
|
||||||
gt: torch.Tensor,
|
|
||||||
weights: Optional[torch.Tensor] = None,
|
Addresses class imbalance in dense object detection/classification tasks by
|
||||||
valid_mask: Optional[torch.Tensor] = None,
|
down-weighting the loss contribution from easy examples (both positive and
|
||||||
eps: float = 1e-5,
|
negative), allowing the model to focus more on hard-to-classify examples.
|
||||||
beta: float = 4,
|
|
||||||
alpha: float = 2,
|
Parameters
|
||||||
) -> torch.Tensor:
|
----------
|
||||||
"""
|
eps : float, default=1e-5
|
||||||
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
Small epsilon value added for numerical stability.
|
||||||
pred (batch x c x h x w)
|
beta : float, default=4
|
||||||
gt (batch x c x h x w)
|
Exponent focusing on hard negative examples (modulates `(1-gt)^beta`).
|
||||||
|
alpha : float, default=2
|
||||||
|
Exponent focusing on misclassified examples (modulates `(1-p)^alpha`
|
||||||
|
for positives and `p^alpha` for negatives).
|
||||||
|
class_weights : torch.Tensor, optional
|
||||||
|
Optional tensor containing weights for each class (applied to positive
|
||||||
|
loss). Shape should be broadcastable to the channel dimension of the
|
||||||
|
input tensors.
|
||||||
|
mask_zero : bool, default=False
|
||||||
|
If True, ignores loss contributions from spatial locations where the
|
||||||
|
ground truth `gt` tensor is zero across *all* channels. Useful for
|
||||||
|
classification heatmaps where some areas might have no assigned class.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
- Lin, T. Y., et al. "Focal loss for dense object detection." ICCV 2017.
|
||||||
|
- Law, H., & Deng, J. "CornerNet: Detecting Objects as Paired Keypoints."
|
||||||
|
ECCV 2018.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pos_inds = gt.eq(1).float()
|
def __init__(
|
||||||
neg_inds = gt.lt(1).float()
|
self,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
beta: float = 4,
|
||||||
|
alpha: float = 2,
|
||||||
|
class_weights: Optional[torch.Tensor] = None,
|
||||||
|
mask_zero: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.class_weights = class_weights
|
||||||
|
self.eps = eps
|
||||||
|
self.beta = beta
|
||||||
|
self.alpha = alpha
|
||||||
|
self.mask_zero = mask_zero
|
||||||
|
|
||||||
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
|
def forward(
|
||||||
neg_loss = (
|
self,
|
||||||
torch.log(1 - pred + eps)
|
pred: torch.Tensor,
|
||||||
* torch.pow(pred, alpha)
|
gt: torch.Tensor,
|
||||||
* torch.pow(1 - gt, beta)
|
) -> torch.Tensor:
|
||||||
* neg_inds
|
"""Compute the Focal Loss.
|
||||||
)
|
|
||||||
|
|
||||||
if weights is not None:
|
Parameters
|
||||||
pos_loss = pos_loss * torch.tensor(weights)
|
----------
|
||||||
# neg_loss = neg_loss*weights
|
pred : torch.Tensor
|
||||||
|
Predicted probabilities or logits (typically sigmoid output for
|
||||||
|
detection, or softmax/sigmoid for classification). Must be in the
|
||||||
|
range [0, 1] after potential activation. Shape `(B, C, H, W)`.
|
||||||
|
gt : torch.Tensor
|
||||||
|
Ground truth heatmap tensor. Shape `(B, C, H, W)`. Values typically
|
||||||
|
represent target probabilities (e.g., Gaussian peaks for detection,
|
||||||
|
one-hot encoding or smoothed labels for classification). For the
|
||||||
|
adapted CornerNet loss, `gt=1` indicates a positive location, and
|
||||||
|
values `< 1` indicate negative locations (with potential Gaussian
|
||||||
|
weighting `(1-gt)^beta` for negatives near positives).
|
||||||
|
|
||||||
if valid_mask is not None:
|
Returns
|
||||||
pos_loss = pos_loss * valid_mask
|
-------
|
||||||
neg_loss = neg_loss * valid_mask
|
torch.Tensor
|
||||||
|
Scalar tensor representing the computed focal loss, normalized by
|
||||||
|
the number of positive locations.
|
||||||
|
"""
|
||||||
|
|
||||||
pos_loss = pos_loss.sum()
|
pos_inds = gt.eq(1).float()
|
||||||
neg_loss = neg_loss.sum()
|
neg_inds = gt.lt(1).float()
|
||||||
|
|
||||||
num_pos = pos_inds.float().sum()
|
pos_loss = (
|
||||||
if num_pos == 0:
|
torch.log(pred + self.eps)
|
||||||
loss = -neg_loss
|
* torch.pow(1 - pred, self.alpha)
|
||||||
else:
|
* pos_inds
|
||||||
loss = -(pos_loss + neg_loss) / num_pos
|
)
|
||||||
return loss
|
neg_loss = (
|
||||||
|
torch.log(1 - pred + self.eps)
|
||||||
|
* torch.pow(pred, self.alpha)
|
||||||
|
* torch.pow(1 - gt, self.beta)
|
||||||
|
* neg_inds
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.class_weights is not None:
|
||||||
|
pos_loss = pos_loss * torch.tensor(self.class_weights)
|
||||||
|
|
||||||
|
if self.mask_zero:
|
||||||
|
valid_mask = gt.any(dim=1, keepdim=True).float()
|
||||||
|
pos_loss = pos_loss * valid_mask
|
||||||
|
neg_loss = neg_loss * valid_mask
|
||||||
|
|
||||||
|
pos_loss = pos_loss.sum()
|
||||||
|
neg_loss = neg_loss.sum()
|
||||||
|
|
||||||
|
num_pos = pos_inds.float().sum()
|
||||||
|
if num_pos == 0:
|
||||||
|
loss = -neg_loss
|
||||||
|
else:
|
||||||
|
loss = -(pos_loss + neg_loss) / num_pos
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def mse_loss(
|
class MSELoss(nn.Module):
|
||||||
pred: torch.Tensor,
|
"""Mean Squared Error (MSE) Loss module.
|
||||||
gt: torch.Tensor,
|
|
||||||
valid_mask: Optional[torch.Tensor] = None,
|
Calculates the mean squared difference between predictions and ground
|
||||||
) -> torch.Tensor:
|
truth. Optionally masks contributions where the ground truth is zero across
|
||||||
|
channels.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
mask_zero : bool, default=False
|
||||||
|
If True, calculates the loss only over spatial locations (H, W) where
|
||||||
|
at least one channel in the ground truth `gt` tensor is non-zero. The
|
||||||
|
loss is then averaged over these valid locations. If False (default),
|
||||||
|
the standard MSE over all elements is computed.
|
||||||
"""
|
"""
|
||||||
Mean squared error loss.
|
|
||||||
"""
|
def __init__(self, mask_zero: bool = False):
|
||||||
if valid_mask is None:
|
super().__init__()
|
||||||
op = ((gt - pred) ** 2).mean()
|
self.mask_zero = mask_zero
|
||||||
else:
|
|
||||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
def forward(
|
||||||
return op
|
self,
|
||||||
|
pred: torch.Tensor,
|
||||||
|
gt: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute the Mean Squared Error loss.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pred : torch.Tensor
|
||||||
|
Predicted tensor, shape `(B, C, H, W)`.
|
||||||
|
gt : torch.Tensor
|
||||||
|
Ground truth tensor, shape `(B, C, H, W)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Scalar tensor representing the calculated MSE loss.
|
||||||
|
"""
|
||||||
|
if not self.mask_zero:
|
||||||
|
return ((gt - pred) ** 2).mean()
|
||||||
|
|
||||||
|
valid_mask = gt.any(dim=1, keepdim=True).float()
|
||||||
|
return (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||||
|
|
||||||
|
|
||||||
class DetectionLossConfig(BaseConfig):
|
class DetectionLossConfig(BaseConfig):
|
||||||
|
"""Configuration for the detection loss component.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
weight : float, default=1.0
|
||||||
|
Weighting factor for the detection loss in the combined total loss.
|
||||||
|
focal : FocalLossConfig
|
||||||
|
Configuration for the Focal Loss used for detection. Defaults to
|
||||||
|
standard Focal Loss parameters (`alpha=2`, `beta=4`).
|
||||||
|
"""
|
||||||
|
|
||||||
weight: float = 1.0
|
weight: float = 1.0
|
||||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationLossConfig(BaseConfig):
|
class ClassificationLossConfig(BaseConfig):
|
||||||
|
"""Configuration for the classification loss component.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
weight : float, default=2.0
|
||||||
|
Weighting factor for the classification loss in the combined total loss.
|
||||||
|
focal : FocalLossConfig
|
||||||
|
Configuration for the Focal Loss used for classification. Defaults to
|
||||||
|
standard Focal Loss parameters (`alpha=2`, `beta=4`).
|
||||||
|
"""
|
||||||
|
|
||||||
weight: float = 2.0
|
weight: float = 2.0
|
||||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||||
class_weights: Optional[list[float]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class LossConfig(BaseConfig):
|
class LossConfig(BaseConfig):
|
||||||
|
"""Aggregated configuration for all loss components.
|
||||||
|
|
||||||
|
Defines the configuration and weighting for detection, size regression,
|
||||||
|
and classification losses used in the main `LossFunction`.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
detection : DetectionLossConfig
|
||||||
|
Configuration for the detection loss (Focal Loss).
|
||||||
|
size : SizeLossConfig
|
||||||
|
Configuration for the size regression loss (L1 loss).
|
||||||
|
classification : ClassificationLossConfig
|
||||||
|
Configuration for the classification loss (Focal Loss).
|
||||||
|
"""
|
||||||
|
|
||||||
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
|
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
|
||||||
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
|
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
|
||||||
classification: ClassificationLossConfig = Field(
|
classification: ClassificationLossConfig = Field(
|
||||||
@ -117,50 +325,157 @@ class LossConfig(BaseConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Losses(NamedTuple):
|
class LossFunction(nn.Module, LossProtocol):
|
||||||
detection: torch.Tensor
|
"""Computes the combined training loss for the BatDetect2 model.
|
||||||
size: torch.Tensor
|
|
||||||
classification: torch.Tensor
|
Aggregates individual loss functions for detection, size regression, and
|
||||||
total: torch.Tensor
|
classification tasks. Calculates each component loss based on model outputs
|
||||||
|
and ground truth targets, applies configured weights, and sums them to get
|
||||||
|
the final total loss used for optimization. Also returns individual
|
||||||
|
components for monitoring.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
size_loss : nn.Module
|
||||||
|
Instantiated loss module for size regression (e.g., `BBoxLoss`).
|
||||||
|
detection_loss : nn.Module
|
||||||
|
Instantiated loss module for detection (e.g., `FocalLoss`).
|
||||||
|
classification_loss : nn.Module
|
||||||
|
Instantiated loss module for classification (e.g., `FocalLoss`).
|
||||||
|
size_weight : float, default=0.1
|
||||||
|
Weighting factor for the size loss component.
|
||||||
|
detection_weight : float, default=1.0
|
||||||
|
Weighting factor for the detection loss component.
|
||||||
|
classification_weight : float, default=2.0
|
||||||
|
Weighting factor for the classification loss component.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
size_loss_fn : nn.Module
|
||||||
|
detection_loss_fn : nn.Module
|
||||||
|
classification_loss_fn : nn.Module
|
||||||
|
size_weight : float
|
||||||
|
detection_weight : float
|
||||||
|
classification_weight : float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size_loss: nn.Module,
|
||||||
|
detection_loss: nn.Module,
|
||||||
|
classification_loss: nn.Module,
|
||||||
|
size_weight: float = 0.1,
|
||||||
|
detection_weight: float = 1.0,
|
||||||
|
classification_weight: float = 2.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.size_loss_fn = size_loss
|
||||||
|
self.detection_loss_fn = detection_loss
|
||||||
|
self.classification_loss_fn = classification_loss
|
||||||
|
|
||||||
|
self.size_weight = size_weight
|
||||||
|
self.detection_weight = detection_weight
|
||||||
|
self.classification_weight = classification_weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pred: ModelOutput,
|
||||||
|
gt: TrainExample,
|
||||||
|
) -> Losses:
|
||||||
|
"""Calculate the combined loss and individual components.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pred: ModelOutput
|
||||||
|
A NamedTuple containing the model's prediction tensors for the
|
||||||
|
batch: `detection_probs`, `size_preds`, `class_probs`.
|
||||||
|
gt: TrainExample
|
||||||
|
A structure containing the ground truth targets for the batch,
|
||||||
|
expected to have attributes like `detection_heatmap`,
|
||||||
|
`size_heatmap`, and `class_heatmap` (as `torch.Tensor`).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Losses
|
||||||
|
A NamedTuple containing the scalar loss values for detection, size,
|
||||||
|
classification, and the total weighted loss.
|
||||||
|
"""
|
||||||
|
size_loss = self.size_loss_fn(pred.size_preds, gt.size_heatmap)
|
||||||
|
detection_loss = self.detection_loss_fn(
|
||||||
|
pred.detection_probs,
|
||||||
|
gt.detection_heatmap,
|
||||||
|
)
|
||||||
|
classification_loss = self.classification_loss_fn(
|
||||||
|
pred.class_probs,
|
||||||
|
gt.class_heatmap,
|
||||||
|
)
|
||||||
|
total_loss = (
|
||||||
|
size_loss * self.size_weight
|
||||||
|
+ classification_loss * self.classification_weight
|
||||||
|
+ detection_loss * self.detection_weight
|
||||||
|
)
|
||||||
|
return Losses(
|
||||||
|
detection=detection_loss,
|
||||||
|
size=size_loss,
|
||||||
|
classification=classification_loss,
|
||||||
|
total=total_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def build_loss(
|
||||||
batch: TrainExample,
|
config: Optional[LossConfig] = None,
|
||||||
outputs: ModelOutput,
|
class_weights: Optional[np.ndarray] = None,
|
||||||
conf: LossConfig,
|
) -> nn.Module:
|
||||||
class_weights: Optional[torch.Tensor] = None,
|
"""Factory function to build the main LossFunction from configuration.
|
||||||
) -> Losses:
|
|
||||||
detection_loss = focal_loss(
|
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based
|
||||||
outputs.detection_probs,
|
on the provided `LossConfig` (or defaults) and optional `class_weights`,
|
||||||
batch.detection_heatmap,
|
then assembles them into the main `LossFunction` module with the specified
|
||||||
beta=conf.detection.focal.beta,
|
component weights.
|
||||||
alpha=conf.detection.focal.alpha,
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : LossConfig, optional
|
||||||
|
Configuration object defining weights and parameters (e.g., alpha, beta
|
||||||
|
for Focal Loss) for each loss component. If None, default settings
|
||||||
|
from `LossConfig` and its nested configs are used.
|
||||||
|
class_weights : np.ndarray, optional
|
||||||
|
An array of weights for each specific class, used to adjust the
|
||||||
|
classification loss (typically Focal Loss). If provided, this overrides
|
||||||
|
any `class_weights` specified within `config.classification`. If None,
|
||||||
|
weights from the config (or default of equal weights) are used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LossFunction
|
||||||
|
An initialized `LossFunction` module ready for training.
|
||||||
|
"""
|
||||||
|
config = config or LossConfig()
|
||||||
|
|
||||||
|
class_weights_tensor = (
|
||||||
|
torch.tensor(class_weights) if class_weights else None
|
||||||
)
|
)
|
||||||
|
|
||||||
size_loss = bbox_size_loss(
|
detection_loss_fn = FocalLoss(
|
||||||
outputs.size_preds,
|
beta=config.detection.focal.beta,
|
||||||
batch.size_heatmap,
|
alpha=config.detection.focal.alpha,
|
||||||
|
mask_zero=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
classification_loss_fn = FocalLoss(
|
||||||
classification_loss = focal_loss(
|
beta=config.classification.focal.beta,
|
||||||
outputs.class_probs,
|
alpha=config.classification.focal.alpha,
|
||||||
batch.class_heatmap,
|
class_weights=class_weights_tensor,
|
||||||
weights=class_weights,
|
mask_zero=True,
|
||||||
valid_mask=valid_mask,
|
|
||||||
beta=conf.classification.focal.beta,
|
|
||||||
alpha=conf.classification.focal.alpha,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
total = (
|
size_loss_fn = BBoxLoss()
|
||||||
detection_loss * conf.detection.weight
|
|
||||||
+ size_loss * conf.size.weight
|
|
||||||
+ classification_loss * conf.classification.weight
|
|
||||||
)
|
|
||||||
|
|
||||||
return Losses(
|
return LossFunction(
|
||||||
detection=detection_loss,
|
size_loss=size_loss_fn,
|
||||||
size=size_loss,
|
classification_loss=classification_loss_fn,
|
||||||
classification=classification_loss,
|
detection_loss=detection_loss_fn,
|
||||||
total=total,
|
size_weight=config.size.weight,
|
||||||
|
detection_weight=config.detection.weight,
|
||||||
|
classification_weight=config.classification.weight,
|
||||||
)
|
)
|
||||||
|
@ -139,6 +139,7 @@ def _save_xr_dataset_to_file(
|
|||||||
dataset.to_netcdf(
|
dataset.to_netcdf(
|
||||||
path,
|
path,
|
||||||
encoding={
|
encoding={
|
||||||
|
"audio": {"zlib": True},
|
||||||
"spectrogram": {"zlib": True},
|
"spectrogram": {"zlib": True},
|
||||||
"size": {"zlib": True},
|
"size": {"zlib": True},
|
||||||
"class": {"zlib": True},
|
"class": {"zlib": True},
|
||||||
|
@ -1,12 +1,17 @@
|
|||||||
from typing import Callable, NamedTuple
|
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.models import ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Heatmaps",
|
"Heatmaps",
|
||||||
"ClipLabeller",
|
"ClipLabeller",
|
||||||
"Augmentation",
|
"Augmentation",
|
||||||
|
"LossProtocol",
|
||||||
|
"TrainExample",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -46,3 +51,50 @@ steps, and returns the final `Heatmaps` used for model training.
|
|||||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||||
|
|
||||||
|
|
||||||
|
class TrainExample(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
idx: torch.Tensor
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
|
||||||
|
|
||||||
|
class Losses(NamedTuple):
|
||||||
|
"""Structure to hold the computed loss values.
|
||||||
|
|
||||||
|
Allows returning individual loss components along with the total weighted
|
||||||
|
loss for monitoring and analysis during training.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
detection : torch.Tensor
|
||||||
|
Scalar tensor representing the calculated detection loss component
|
||||||
|
(before weighting).
|
||||||
|
size : torch.Tensor
|
||||||
|
Scalar tensor representing the calculated size regression loss component
|
||||||
|
(before weighting).
|
||||||
|
classification : torch.Tensor
|
||||||
|
Scalar tensor representing the calculated classification loss component
|
||||||
|
(before weighting).
|
||||||
|
total : torch.Tensor
|
||||||
|
Scalar tensor representing the final combined loss, computed as the
|
||||||
|
weighted sum of the detection, size, and classification components.
|
||||||
|
This is the value typically used for backpropagation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
detection: torch.Tensor
|
||||||
|
size: torch.Tensor
|
||||||
|
classification: torch.Tensor
|
||||||
|
total: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class LossProtocol(Protocol):
|
||||||
|
def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ClipperProtocol(Protocol):
|
||||||
|
def extract_clip(
|
||||||
|
self, example: xr.Dataset
|
||||||
|
) -> Tuple[xr.Dataset, float, float]: ...
|
||||||
|
@ -24,17 +24,6 @@ from batdetect2.targets.classes import (
|
|||||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_term_registry() -> TermRegistry:
|
|
||||||
"""Fixture for a sample TermRegistry."""
|
|
||||||
registry = TermRegistry()
|
|
||||||
registry.add_custom_term("order")
|
|
||||||
registry.add_custom_term("species")
|
|
||||||
registry.add_custom_term("call_type")
|
|
||||||
registry.add_custom_term("quality")
|
|
||||||
return registry
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_annotation(
|
def sample_annotation(
|
||||||
sound_event: data.SoundEvent,
|
sound_event: data.SoundEvent,
|
||||||
|
466
tests/test_train/test_clips.py
Normal file
466
tests/test_train/test_clips.py
Normal file
@ -0,0 +1,466 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
|
from batdetect2.train.clips import (
|
||||||
|
Clipper,
|
||||||
|
_compute_expected_width,
|
||||||
|
select_subclip,
|
||||||
|
)
|
||||||
|
|
||||||
|
AUDIO_SAMPLERATE = 48000
|
||||||
|
|
||||||
|
SPEC_SAMPLERATE = 100
|
||||||
|
SPEC_FREQS = 64
|
||||||
|
CLIP_DURATION = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_WIDTH_SPEC = int(np.floor(CLIP_DURATION * SPEC_SAMPLERATE))
|
||||||
|
CLIP_WIDTH_AUDIO = int(np.floor(CLIP_DURATION * AUDIO_SAMPLERATE))
|
||||||
|
MAX_EMPTY = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_dataset(
|
||||||
|
duration_sec: float,
|
||||||
|
spec_samplerate: int = SPEC_SAMPLERATE,
|
||||||
|
audio_samplerate: int = AUDIO_SAMPLERATE,
|
||||||
|
num_freqs: int = SPEC_FREQS,
|
||||||
|
start_time: float = 0.0,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Creates a sample xr.Dataset for testing."""
|
||||||
|
time_step = 1 / spec_samplerate
|
||||||
|
audio_time_step = 1 / audio_samplerate
|
||||||
|
|
||||||
|
times = np.arange(start_time, start_time + duration_sec, step=time_step)
|
||||||
|
freqs = np.linspace(0, audio_samplerate / 2, num_freqs)
|
||||||
|
audio_times = np.arange(
|
||||||
|
start_time,
|
||||||
|
start_time + duration_sec,
|
||||||
|
step=audio_time_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_time_steps = len(times)
|
||||||
|
num_audio_samples = len(audio_times)
|
||||||
|
spec_shape = (num_freqs, num_time_steps)
|
||||||
|
|
||||||
|
spectrogram_data = np.arange(num_time_steps).reshape(1, -1) * np.ones(
|
||||||
|
(num_freqs, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
spectrogram = xr.DataArray(
|
||||||
|
spectrogram_data.astype(np.float32),
|
||||||
|
coords=[("frequency", freqs), ("time", times)],
|
||||||
|
name="spectrogram",
|
||||||
|
)
|
||||||
|
|
||||||
|
detection = xr.DataArray(
|
||||||
|
np.ones(spec_shape, dtype=np.float32) * 0.5,
|
||||||
|
coords=spectrogram.coords,
|
||||||
|
name="detection",
|
||||||
|
)
|
||||||
|
|
||||||
|
classes = xr.DataArray(
|
||||||
|
np.ones((3, *spec_shape), dtype=np.float32),
|
||||||
|
coords=[
|
||||||
|
("category", ["A", "B", "C"]),
|
||||||
|
("frequency", freqs),
|
||||||
|
("time", times),
|
||||||
|
],
|
||||||
|
name="class",
|
||||||
|
)
|
||||||
|
|
||||||
|
size = xr.DataArray(
|
||||||
|
np.ones((2, *spec_shape), dtype=np.float32),
|
||||||
|
coords=[
|
||||||
|
("dimension", ["height", "width"]),
|
||||||
|
("frequency", freqs),
|
||||||
|
("time", times),
|
||||||
|
],
|
||||||
|
name="size",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_data = np.arange(num_audio_samples)
|
||||||
|
audio = xr.DataArray(
|
||||||
|
audio_data.astype(np.float32),
|
||||||
|
coords=[("audio_time", audio_times)],
|
||||||
|
name="audio",
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = xr.DataArray([1, 2, 3], dims=["other_dim"], name="metadata")
|
||||||
|
|
||||||
|
return xr.Dataset(
|
||||||
|
{
|
||||||
|
"audio": audio,
|
||||||
|
"spectrogram": spectrogram,
|
||||||
|
"detection": detection,
|
||||||
|
"class": classes,
|
||||||
|
"size": size,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
).assign_attrs(
|
||||||
|
samplerate=audio_samplerate,
|
||||||
|
spec_samplerate=spec_samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def long_dataset() -> xr.Dataset:
|
||||||
|
"""Dataset longer than the clip duration."""
|
||||||
|
return create_test_dataset(duration_sec=2.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def short_dataset() -> xr.Dataset:
|
||||||
|
"""Dataset shorter than the clip duration."""
|
||||||
|
return create_test_dataset(duration_sec=0.3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def exact_dataset() -> xr.Dataset:
|
||||||
|
"""Dataset exactly the clip duration."""
|
||||||
|
return create_test_dataset(duration_sec=CLIP_DURATION - 1e-9)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def offset_dataset() -> xr.Dataset:
|
||||||
|
"""Dataset starting at a non-zero time."""
|
||||||
|
return create_test_dataset(duration_sec=1.0, start_time=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_subclip_within_bounds(long_dataset):
|
||||||
|
start_time = 0.5
|
||||||
|
subclip = select_subclip(
|
||||||
|
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
||||||
|
)
|
||||||
|
expected_width = _compute_expected_width(
|
||||||
|
long_dataset, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "time" in subclip.dims
|
||||||
|
assert subclip.dims["time"] == expected_width
|
||||||
|
assert subclip.spectrogram.dims == ("frequency", "time")
|
||||||
|
assert subclip.spectrogram.shape == (SPEC_FREQS, expected_width)
|
||||||
|
assert subclip.detection.shape == (SPEC_FREQS, expected_width)
|
||||||
|
assert subclip["class"].shape == (3, SPEC_FREQS, expected_width)
|
||||||
|
assert subclip.size.shape == (2, SPEC_FREQS, expected_width)
|
||||||
|
assert subclip.time.min() >= start_time
|
||||||
|
assert (
|
||||||
|
subclip.time.max() <= start_time + CLIP_DURATION + 1 / SPEC_SAMPLERATE
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "metadata" in subclip
|
||||||
|
xr.testing.assert_equal(subclip.metadata, long_dataset.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_subclip_pad_start(long_dataset):
|
||||||
|
start_time = -0.1
|
||||||
|
subclip = select_subclip(
|
||||||
|
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
||||||
|
)
|
||||||
|
expected_width = _compute_expected_width(
|
||||||
|
long_dataset, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
step = 1 / SPEC_SAMPLERATE
|
||||||
|
expected_pad_samples = int(np.floor(abs(start_time) / step))
|
||||||
|
|
||||||
|
assert subclip.dims["time"] == expected_width
|
||||||
|
assert subclip.spectrogram.shape[1] == expected_width
|
||||||
|
|
||||||
|
assert np.all(
|
||||||
|
subclip.spectrogram.isel(time=slice(0, expected_pad_samples)) == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.any(
|
||||||
|
subclip.spectrogram.isel(time=slice(expected_pad_samples, None)) != 0
|
||||||
|
)
|
||||||
|
assert subclip.time.min() >= start_time
|
||||||
|
assert subclip.time.max() < start_time + CLIP_DURATION + step
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_subclip_pad_end(long_dataset):
|
||||||
|
original_duration = long_dataset.time.max() - long_dataset.time.min()
|
||||||
|
start_time = original_duration - 0.1
|
||||||
|
subclip = select_subclip(
|
||||||
|
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
||||||
|
)
|
||||||
|
expected_width = _compute_expected_width(
|
||||||
|
long_dataset, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
step = 1 / SPEC_SAMPLERATE
|
||||||
|
original_width = long_dataset.dims["time"]
|
||||||
|
expected_pad_samples = expected_width - (
|
||||||
|
original_width - int(np.floor(start_time / step))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert subclip.sizes["time"] == expected_width
|
||||||
|
assert subclip.spectrogram.shape[1] == expected_width
|
||||||
|
|
||||||
|
assert np.all(
|
||||||
|
subclip.spectrogram.isel(
|
||||||
|
time=slice(expected_width - expected_pad_samples, None)
|
||||||
|
)
|
||||||
|
== 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.any(
|
||||||
|
subclip.spectrogram.isel(
|
||||||
|
time=slice(0, expected_width - expected_pad_samples)
|
||||||
|
)
|
||||||
|
!= 0
|
||||||
|
)
|
||||||
|
assert subclip.time.min() >= start_time
|
||||||
|
assert subclip.time.max() < start_time + CLIP_DURATION + step
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_subclip_pad_both_short_dataset(short_dataset):
|
||||||
|
start_time = -0.1
|
||||||
|
subclip = select_subclip(
|
||||||
|
short_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
||||||
|
)
|
||||||
|
expected_width = _compute_expected_width(
|
||||||
|
short_dataset, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
step = 1 / SPEC_SAMPLERATE
|
||||||
|
|
||||||
|
assert subclip.dims["time"] == expected_width
|
||||||
|
assert subclip.spectrogram.shape[1] == expected_width
|
||||||
|
|
||||||
|
assert subclip.spectrogram.coords["time"][0] == pytest.approx(
|
||||||
|
start_time,
|
||||||
|
abs=step,
|
||||||
|
)
|
||||||
|
assert subclip.spectrogram.coords["time"][-1] == pytest.approx(
|
||||||
|
start_time + CLIP_DURATION - step,
|
||||||
|
abs=2 * step,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_subclip_width_consistency(long_dataset):
|
||||||
|
expected_width = _compute_expected_width(
|
||||||
|
long_dataset, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
step = 1 / SPEC_SAMPLERATE
|
||||||
|
|
||||||
|
subclip_aligned = select_subclip(
|
||||||
|
long_dataset.copy(deep=True),
|
||||||
|
span=CLIP_DURATION,
|
||||||
|
start=5 * step,
|
||||||
|
dim="time",
|
||||||
|
)
|
||||||
|
|
||||||
|
subclip_offset = select_subclip(
|
||||||
|
long_dataset.copy(deep=True),
|
||||||
|
span=CLIP_DURATION,
|
||||||
|
start=5.3 * step,
|
||||||
|
dim="time",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert subclip_aligned.sizes["time"] == expected_width
|
||||||
|
assert subclip_offset.sizes["time"] == expected_width
|
||||||
|
assert subclip_aligned.spectrogram.shape[1] == expected_width
|
||||||
|
assert subclip_offset.spectrogram.shape[1] == expected_width
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_subclip_different_dimension(long_dataset):
|
||||||
|
freq_coords = long_dataset.frequency.values
|
||||||
|
freq_min, freq_max = freq_coords.min(), freq_coords.max()
|
||||||
|
freq_span = (freq_max - freq_min) / 2
|
||||||
|
start_freq = freq_min + freq_span / 2
|
||||||
|
|
||||||
|
subclip = select_subclip(
|
||||||
|
long_dataset, span=freq_span, start=start_freq, dim="frequency"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "frequency" in subclip.dims
|
||||||
|
assert subclip.spectrogram.shape[0] < long_dataset.spectrogram.shape[0]
|
||||||
|
assert subclip.detection.shape[0] < long_dataset.detection.shape[0]
|
||||||
|
assert subclip["class"].shape[1] < long_dataset["class"].shape[1]
|
||||||
|
assert subclip.size.shape[1] < long_dataset.size.shape[1]
|
||||||
|
|
||||||
|
assert subclip.dims["time"] == long_dataset.dims["time"]
|
||||||
|
assert subclip.spectrogram.shape[1] == long_dataset.spectrogram.shape[1]
|
||||||
|
|
||||||
|
xr.testing.assert_equal(subclip.audio, long_dataset.audio)
|
||||||
|
assert subclip.dims["audio_time"] == long_dataset.dims["audio_time"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_subclip_fill_value(short_dataset):
|
||||||
|
fill_value = -999.0
|
||||||
|
subclip = select_subclip(
|
||||||
|
short_dataset,
|
||||||
|
span=CLIP_DURATION,
|
||||||
|
start=0,
|
||||||
|
dim="time",
|
||||||
|
fill_value=fill_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_width = _compute_expected_width(
|
||||||
|
short_dataset,
|
||||||
|
CLIP_DURATION,
|
||||||
|
"time",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert subclip.dims["time"] == expected_width
|
||||||
|
assert np.all(subclip.spectrogram.sel(time=slice(0.3, None)) == fill_value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_subclip_no_overlap_raises_error(long_dataset):
|
||||||
|
original_duration = long_dataset.time.max() - long_dataset.time.min()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="does not overlap"):
|
||||||
|
select_subclip(
|
||||||
|
long_dataset,
|
||||||
|
span=CLIP_DURATION,
|
||||||
|
start=original_duration + 1.0,
|
||||||
|
dim="time",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="does not overlap"):
|
||||||
|
select_subclip(
|
||||||
|
long_dataset,
|
||||||
|
span=CLIP_DURATION,
|
||||||
|
start=-1.0 * CLIP_DURATION - 1.0,
|
||||||
|
dim="time",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clipper_non_random(long_dataset, exact_dataset, short_dataset):
|
||||||
|
clipper = Clipper(duration=CLIP_DURATION, random=False)
|
||||||
|
|
||||||
|
for ds in [long_dataset, exact_dataset, short_dataset]:
|
||||||
|
clip, _, _ = clipper.extract_clip(ds)
|
||||||
|
expected_spec_width = _compute_expected_width(
|
||||||
|
ds, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
expected_audio_width = _compute_expected_width(
|
||||||
|
ds, CLIP_DURATION, "audio_time"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert clip.dims["time"] == expected_spec_width
|
||||||
|
assert clip.dims["audio_time"] == expected_audio_width
|
||||||
|
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||||
|
assert clip.audio.shape[0] == expected_audio_width
|
||||||
|
|
||||||
|
assert clip.time.min() >= -1 / SPEC_SAMPLERATE
|
||||||
|
assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE
|
||||||
|
|
||||||
|
time_span = clip.time.max() - clip.time.min()
|
||||||
|
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||||
|
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||||
|
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clipper_random(long_dataset):
|
||||||
|
seed = 42
|
||||||
|
np.random.seed(seed)
|
||||||
|
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||||
|
clip1, _, _ = clipper.extract_clip(long_dataset)
|
||||||
|
|
||||||
|
np.random.seed(seed + 1)
|
||||||
|
clip2, _, _ = clipper.extract_clip(long_dataset)
|
||||||
|
|
||||||
|
expected_spec_width = _compute_expected_width(
|
||||||
|
long_dataset, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
expected_audio_width = _compute_expected_width(
|
||||||
|
long_dataset, CLIP_DURATION, "audio_time"
|
||||||
|
)
|
||||||
|
|
||||||
|
for clip in [clip1, clip2]:
|
||||||
|
assert clip.dims["time"] == expected_spec_width
|
||||||
|
assert clip.dims["audio_time"] == expected_audio_width
|
||||||
|
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||||
|
assert clip.audio.shape[0] == expected_audio_width
|
||||||
|
|
||||||
|
assert not np.isclose(clip1.time.min(), clip2.time.min())
|
||||||
|
assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min())
|
||||||
|
|
||||||
|
for clip in [clip1, clip2]:
|
||||||
|
time_span = clip.time.max() - clip.time.min()
|
||||||
|
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||||
|
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||||
|
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
||||||
|
|
||||||
|
max_start_time = (
|
||||||
|
(long_dataset.time.max() - long_dataset.time.min())
|
||||||
|
- CLIP_DURATION
|
||||||
|
+ MAX_EMPTY
|
||||||
|
)
|
||||||
|
assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
|
||||||
|
assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
|
||||||
|
|
||||||
|
|
||||||
|
def test_clipper_random_max_empty_effect(long_dataset):
|
||||||
|
"""Check that max_empty influences the possible start times."""
|
||||||
|
seed = 123
|
||||||
|
data_duration = long_dataset.time.max() - long_dataset.time.min()
|
||||||
|
|
||||||
|
np.random.seed(seed)
|
||||||
|
clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0)
|
||||||
|
max_start_time0 = data_duration - CLIP_DURATION
|
||||||
|
start_times0 = []
|
||||||
|
|
||||||
|
for _ in range(20):
|
||||||
|
clip, _, _ = clipper0.extract_clip(long_dataset)
|
||||||
|
start_times0.append(clip.time.min().item())
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0
|
||||||
|
)
|
||||||
|
assert any(st > 0.1 for st in start_times0)
|
||||||
|
|
||||||
|
np.random.seed(seed)
|
||||||
|
clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2)
|
||||||
|
max_start_time_pos = data_duration - CLIP_DURATION + 0.2
|
||||||
|
start_times_pos = []
|
||||||
|
for _ in range(20):
|
||||||
|
clip, _, _ = clipper_pos.extract_clip(long_dataset)
|
||||||
|
start_times_pos.append(clip.time.min().item())
|
||||||
|
assert all(
|
||||||
|
st <= max_start_time_pos + 1 / SPEC_SAMPLERATE
|
||||||
|
for st in start_times_pos
|
||||||
|
)
|
||||||
|
|
||||||
|
assert any(st > max_start_time0 + 1e-6 for st in start_times_pos)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clipper_short_dataset_random(short_dataset):
|
||||||
|
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||||
|
clip, _, _ = clipper.extract_clip(short_dataset)
|
||||||
|
|
||||||
|
expected_spec_width = _compute_expected_width(
|
||||||
|
short_dataset, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
expected_audio_width = _compute_expected_width(
|
||||||
|
short_dataset, CLIP_DURATION, "audio_time"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert clip.sizes["time"] == expected_spec_width
|
||||||
|
assert clip.sizes["audio_time"] == expected_audio_width
|
||||||
|
assert clip["spectrogram"].shape[1] == expected_spec_width
|
||||||
|
assert clip["audio"].shape[0] == expected_audio_width
|
||||||
|
|
||||||
|
assert np.any(clip.spectrogram == 0)
|
||||||
|
assert np.any(clip.audio == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clipper_exact_dataset_random(exact_dataset):
|
||||||
|
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||||
|
clip, _, _ = clipper.extract_clip(exact_dataset)
|
||||||
|
|
||||||
|
expected_spec_width = _compute_expected_width(
|
||||||
|
exact_dataset, CLIP_DURATION, "time"
|
||||||
|
)
|
||||||
|
expected_audio_width = _compute_expected_width(
|
||||||
|
exact_dataset, CLIP_DURATION, "audio_time"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert clip.dims["time"] == expected_spec_width
|
||||||
|
assert clip.dims["audio_time"] == expected_audio_width
|
||||||
|
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||||
|
assert clip.audio.shape[0] == expected_audio_width
|
||||||
|
|
||||||
|
time_span = clip.time.max() - clip.time.min()
|
||||||
|
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||||
|
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||||
|
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
56
tests/test_train/test_lightning.py
Normal file
56
tests/test_train/test_lightning.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import lightning as L
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.models import build_model
|
||||||
|
from batdetect2.postprocess import build_postprocessor
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.train.lightning import TrainingModule
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
|
|
||||||
|
def build_default_module():
|
||||||
|
loss = build_loss()
|
||||||
|
targets = build_targets()
|
||||||
|
detector = build_model(num_classes=len(targets.class_names))
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
postprocessor = build_postprocessor(targets)
|
||||||
|
return TrainingModule(
|
||||||
|
detector=detector,
|
||||||
|
loss=loss,
|
||||||
|
targets=targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_initialize_default_module():
|
||||||
|
module = build_default_module()
|
||||||
|
assert isinstance(module, L.LightningModule)
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
|
||||||
|
module = build_default_module()
|
||||||
|
trainer = L.Trainer()
|
||||||
|
path = tmp_path / "example.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
|
||||||
|
spec1 = module.preprocessor.preprocess_clip(clip)
|
||||||
|
spec2 = recovered.preprocessor.preprocess_clip(clip)
|
||||||
|
|
||||||
|
xr.testing.assert_equal(spec1, spec2)
|
||||||
|
|
||||||
|
input1 = torch.tensor([spec1.values]).unsqueeze(0)
|
||||||
|
input2 = torch.tensor([spec2.values]).unsqueeze(0)
|
||||||
|
|
||||||
|
output1 = module(input1)
|
||||||
|
output2 = recovered(input2)
|
||||||
|
|
||||||
|
assert output1 == output2
|
Loading…
Reference in New Issue
Block a user