mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added clips for random cliping and augmentations
This commit is contained in:
parent
2396815c13
commit
59bd14bc79
@ -1,153 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models import (
|
||||
BackboneConfig,
|
||||
BBoxHead,
|
||||
ClassifierHead,
|
||||
ModelOutput,
|
||||
build_backbone,
|
||||
)
|
||||
from batdetect2.postprocess import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
build_target_encoder,
|
||||
get_class_names,
|
||||
)
|
||||
from batdetect2.train import TrainExample, TrainingConfig, compute_loss
|
||||
|
||||
__all__ = [
|
||||
"DetectorModel",
|
||||
]
|
||||
|
||||
|
||||
class ModuleConfig(BaseConfig):
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
architecture: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocessing: PostprocessConfig = Field(
|
||||
default_factory=PostprocessConfig
|
||||
)
|
||||
|
||||
|
||||
class DetectorModel(L.LightningModule):
|
||||
config: ModuleConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ModuleConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config or ModuleConfig()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.backbone = build_model_backbone(self.config.architecture)
|
||||
|
||||
self.classifier = ClassifierHead(
|
||||
num_classes=len(self.config.targets.classes),
|
||||
in_channels=self.backbone.out_channels,
|
||||
)
|
||||
|
||||
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
|
||||
|
||||
conf = self.config.train.loss.classification
|
||||
self.class_weights = (
|
||||
torch.tensor(conf.class_weights) if conf.class_weights else None
|
||||
)
|
||||
|
||||
# Training targets
|
||||
self.class_names = get_class_names(self.config.targets.classes)
|
||||
self.encoder = build_target_encoder(
|
||||
self.config.targets.classes,
|
||||
replacement_rules=self.config.targets.replace,
|
||||
)
|
||||
self.decoder = build_decoder(self.config.targets.classes)
|
||||
self.example_input_array = torch.randn([1, 1, 128, 512])
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
features = self.backbone(spec)
|
||||
detection_probs, classification_probs = self.classifier(features)
|
||||
size_preds = self.bbox(features)
|
||||
return ModelOutput(
|
||||
detection_probs=detection_probs,
|
||||
size_preds=size_preds,
|
||||
class_probs=classification_probs,
|
||||
features=features,
|
||||
)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("train/loss/detection", losses.total, logger=True)
|
||||
self.log("train/loss/size", losses.total, logger=True)
|
||||
self.log("train/loss/classification", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||
outputs = self.forward(batch.spec)
|
||||
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("val/loss/detection", losses.total, logger=True)
|
||||
self.log("val/loss/size", losses.total, logger=True)
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.validation_predictions.clear()
|
||||
|
||||
def configure_optimizers(self):
|
||||
conf = self.config.train.optimizer
|
||||
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def process_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[Path] = None,
|
||||
) -> data.ClipPrediction:
|
||||
spec = preprocess_audio_clip(
|
||||
clip,
|
||||
config=self.config.preprocessing,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||
outputs = self.forward(tensor)
|
||||
return postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
@ -11,18 +11,19 @@ from batdetect2.train.augmentations import (
|
||||
mask_time,
|
||||
mix_examples,
|
||||
scale_volume,
|
||||
select_subclip,
|
||||
warp_spectrogram,
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper, select_subclip
|
||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
SubclipConfig,
|
||||
TrainExample,
|
||||
get_preprocessed_files,
|
||||
list_preprocessed_files,
|
||||
)
|
||||
from batdetect2.train.labels import load_label_config
|
||||
from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.losses import LossFunction, build_loss
|
||||
from batdetect2.train.preprocess import (
|
||||
generate_train_example,
|
||||
preprocess_annotations,
|
||||
@ -34,6 +35,8 @@ __all__ = [
|
||||
"EchoAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"LabeledDataset",
|
||||
"LossFunction",
|
||||
"RandomExampleSource",
|
||||
"SubclipConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"TrainExample",
|
||||
@ -43,9 +46,11 @@ __all__ = [
|
||||
"WarpAugmentationConfig",
|
||||
"add_echo",
|
||||
"build_augmentations",
|
||||
"compute_loss",
|
||||
"build_clipper",
|
||||
"build_loss",
|
||||
"generate_train_example",
|
||||
"get_preprocessed_files",
|
||||
"list_preprocessed_files",
|
||||
"load_label_config",
|
||||
"load_train_config",
|
||||
"load_trainer_config",
|
||||
"mask_frequency",
|
||||
@ -56,5 +61,4 @@ __all__ = [
|
||||
"select_subclip",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
"load_label_config",
|
||||
]
|
||||
|
@ -37,24 +37,24 @@ from batdetect2.train.types import Augmentation
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
|
||||
__all__ = [
|
||||
"AugmentationConfig",
|
||||
"AugmentationsConfig",
|
||||
"load_augmentation_config",
|
||||
"build_augmentations",
|
||||
"select_subclip",
|
||||
"mix_examples",
|
||||
"add_echo",
|
||||
"scale_volume",
|
||||
"warp_spectrogram",
|
||||
"mask_time",
|
||||
"mask_frequency",
|
||||
"MixAugmentationConfig",
|
||||
"DEFAULT_AUGMENTATION_CONFIG",
|
||||
"EchoAugmentationConfig",
|
||||
"ExampleSource",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"MixAugmentationConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"AugmentationConfig",
|
||||
"ExampleSource",
|
||||
"add_echo",
|
||||
"build_augmentations",
|
||||
"load_augmentation_config",
|
||||
"mask_frequency",
|
||||
"mask_time",
|
||||
"mix_examples",
|
||||
"scale_volume",
|
||||
"warp_spectrogram",
|
||||
]
|
||||
|
||||
ExampleSource = Callable[[], xr.Dataset]
|
||||
@ -64,92 +64,6 @@ Used by the `mix_examples` augmentation to fetch another example to mix with.
|
||||
"""
|
||||
|
||||
|
||||
def select_subclip(
|
||||
example: xr.Dataset,
|
||||
start_time: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
width: Optional[int] = None,
|
||||
random: bool = False,
|
||||
) -> xr.Dataset:
|
||||
"""Extract a sub-clip (time segment) from a training example dataset.
|
||||
|
||||
Selects a portion of the 'time' dimension from all relevant DataArrays
|
||||
(`audio`, `spectrogram`, `detection`, `class`, `size`) within the example
|
||||
Dataset. The segment can be defined by a fixed start time and
|
||||
duration/width, or a random start time can be chosen.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
The input training example containing 'audio', 'spectrogram', and
|
||||
target heatmaps, all with compatible 'time' (or 'audio_time')
|
||||
coordinates.
|
||||
start_time : float, optional
|
||||
Desired start time (seconds) of the subclip. If None and `random` is
|
||||
False, starts from the beginning of the example. If None and `random`
|
||||
is True, a random start time is chosen.
|
||||
duration : float, optional
|
||||
Desired duration (seconds) of the subclip. Either `duration` or `width`
|
||||
must be provided.
|
||||
width : int, optional
|
||||
Desired width (number of time bins) of the subclip's
|
||||
spectrogram/heatmaps. Either `duration` or `width` must be provided. If
|
||||
both are given, `duration` takes precedence.
|
||||
random : bool, default=False
|
||||
If True and `start_time` is None, selects a random start time ensuring
|
||||
the subclip fits within the original example's duration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
A new dataset containing only the selected time segment. Coordinates
|
||||
are adjusted accordingly. Returns the original example if the requested
|
||||
subclip cannot be extracted (e.g., duration too long).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If neither `duration` nor `width` is provided, or if time coordinates
|
||||
are missing or invalid.
|
||||
"""
|
||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||
|
||||
if width is None:
|
||||
if duration is None:
|
||||
raise ValueError("Either duration or width must be provided")
|
||||
|
||||
width = int(np.floor(duration / step))
|
||||
|
||||
if duration is None:
|
||||
duration = width * step
|
||||
|
||||
if start_time is None:
|
||||
if random:
|
||||
start_time = np.random.uniform(start, max(stop - duration, start))
|
||||
else:
|
||||
start_time = start
|
||||
|
||||
if start_time + duration > stop:
|
||||
return example
|
||||
|
||||
start_index = arrays.get_coord_index(
|
||||
example, # type: ignore
|
||||
"time",
|
||||
start_time,
|
||||
)
|
||||
|
||||
end_index = start_index + width - 1
|
||||
|
||||
start_time = example.time.values[start_index]
|
||||
end_time = example.time.values[end_index]
|
||||
|
||||
return example.sel(
|
||||
time=slice(start_time, end_time),
|
||||
audio_time=slice(start_time, end_time + step),
|
||||
)
|
||||
|
||||
|
||||
class MixAugmentationConfig(BaseConfig):
|
||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||
|
||||
@ -878,9 +792,21 @@ def build_augmentation_from_config(
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
||||
steps=[
|
||||
MixAugmentationConfig(),
|
||||
EchoAugmentationConfig(),
|
||||
VolumeAugmentationConfig(),
|
||||
WarpAugmentationConfig(),
|
||||
TimeMaskAugmentationConfig(),
|
||||
FrequencyMaskAugmentationConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_augmentations(
|
||||
config: AugmentationsConfig,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[AugmentationsConfig] = None,
|
||||
example_source: Optional[ExampleSource] = None,
|
||||
) -> Augmentation:
|
||||
"""Build a composite augmentation pipeline function from configuration.
|
||||
@ -915,6 +841,8 @@ def build_augmentations(
|
||||
NotImplementedError
|
||||
If an unknown `augmentation_type` is encountered in `config.steps`.
|
||||
"""
|
||||
config = config or DEFAULT_AUGMENTATION_CONFIG
|
||||
|
||||
augmentations = []
|
||||
|
||||
for step_config in config.steps:
|
||||
|
184
batdetect2/train/clips.py
Normal file
184
batdetect2/train/clips.py
Normal file
@ -0,0 +1,184 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import arrays
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.train.types import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
|
||||
class ClipperConfig(BaseConfig):
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
random: bool = True
|
||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||
|
||||
|
||||
class Clipper(ClipperProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
duration: float = 0.5,
|
||||
max_empty: float = 0.2,
|
||||
random: bool = True,
|
||||
):
|
||||
self.duration = duration
|
||||
self.random = random
|
||||
self.max_empty = max_empty
|
||||
|
||||
def extract_clip(
|
||||
self, example: xr.Dataset
|
||||
) -> Tuple[xr.Dataset, float, float]:
|
||||
step = arrays.get_dim_step(
|
||||
example.spectrogram,
|
||||
dim=arrays.Dimensions.time.value,
|
||||
)
|
||||
duration = (
|
||||
arrays.get_dim_width(
|
||||
example.spectrogram,
|
||||
dim=arrays.Dimensions.time.value,
|
||||
)
|
||||
+ step
|
||||
)
|
||||
|
||||
start_time = 0
|
||||
if self.random:
|
||||
start_time = np.random.uniform(
|
||||
-self.max_empty,
|
||||
duration - self.duration + self.max_empty,
|
||||
)
|
||||
|
||||
subclip = select_subclip(
|
||||
example,
|
||||
start=start_time,
|
||||
span=self.duration,
|
||||
dim="time",
|
||||
)
|
||||
|
||||
return (
|
||||
select_subclip(
|
||||
subclip,
|
||||
start=start_time,
|
||||
span=self.duration,
|
||||
dim="audio_time",
|
||||
),
|
||||
start_time,
|
||||
start_time + self.duration,
|
||||
)
|
||||
|
||||
|
||||
def build_clipper(config: Optional[ClipperConfig] = None) -> ClipperProtocol:
|
||||
config = config or ClipperConfig()
|
||||
return Clipper(
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
random=config.random,
|
||||
)
|
||||
|
||||
|
||||
def select_subclip(
|
||||
dataset: xr.Dataset,
|
||||
span: float,
|
||||
start: float,
|
||||
fill_value: float = 0,
|
||||
dim: str = "time",
|
||||
) -> xr.Dataset:
|
||||
width = _compute_expected_width(
|
||||
dataset, # type: ignore
|
||||
span,
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
coord = dataset.coords[dim]
|
||||
|
||||
if len(coord) == width:
|
||||
return dataset
|
||||
|
||||
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
|
||||
coord, start, span
|
||||
)
|
||||
|
||||
data_vars = {}
|
||||
for name, data_array in dataset.data_vars.items():
|
||||
if dim not in data_array.dims:
|
||||
data_vars[name] = data_array
|
||||
continue
|
||||
|
||||
if width == data_array.sizes[dim]:
|
||||
data_vars[name] = data_array
|
||||
continue
|
||||
|
||||
sliced = data_array.isel({dim: dim_slice}).data
|
||||
|
||||
if start_pad > 0 or end_pad > 0:
|
||||
padding = [
|
||||
[0, 0] if other_dim != dim else [start_pad, end_pad]
|
||||
for other_dim in data_array.dims
|
||||
]
|
||||
sliced = np.pad(sliced, padding, constant_values=fill_value)
|
||||
|
||||
data_vars[name] = xr.DataArray(
|
||||
data=sliced,
|
||||
dims=data_array.dims,
|
||||
coords={**data_array.coords, dim: new_coords},
|
||||
attrs=data_array.attrs,
|
||||
)
|
||||
|
||||
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
|
||||
|
||||
|
||||
def _extract_coordinate(
|
||||
coord: xr.DataArray,
|
||||
start: float,
|
||||
span: float,
|
||||
) -> Tuple[xr.Variable, int, int, slice]:
|
||||
step = arrays.get_dim_step(coord, str(coord.name))
|
||||
|
||||
current_width = len(coord)
|
||||
expected_width = int(np.floor(span / step))
|
||||
|
||||
coord_start = float(coord[0])
|
||||
offset = start - coord_start
|
||||
|
||||
start_index = int(np.floor(offset / step))
|
||||
end_index = start_index + expected_width
|
||||
|
||||
if start_index > current_width:
|
||||
raise ValueError("Requested span does not overlap with current range")
|
||||
|
||||
if end_index < 0:
|
||||
raise ValueError("Requested span does not overlap with current range")
|
||||
|
||||
corrected_start = float(start_index * step)
|
||||
corrected_end = float(end_index * step)
|
||||
|
||||
start_index_offset = max(0, -start_index)
|
||||
end_index_offset = max(0, end_index - current_width)
|
||||
|
||||
sl = slice(
|
||||
start_index if start_index >= 0 else None,
|
||||
end_index if end_index < current_width else None,
|
||||
)
|
||||
|
||||
return (
|
||||
arrays.create_range_dim(
|
||||
str(coord.name),
|
||||
start=corrected_start,
|
||||
stop=corrected_end,
|
||||
step=step,
|
||||
),
|
||||
start_index_offset,
|
||||
end_index_offset,
|
||||
sl,
|
||||
)
|
||||
|
||||
|
||||
def _compute_expected_width(
|
||||
array: Union[xr.DataArray, xr.Dataset],
|
||||
duration: float,
|
||||
dim: str,
|
||||
) -> int:
|
||||
step = arrays.get_dim_step(array, dim) # type: ignore
|
||||
return int(np.floor(duration / step))
|
@ -1,6 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional, Sequence, Union
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -13,28 +12,14 @@ from batdetect2.configs import BaseConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
Augmentation,
|
||||
AugmentationsConfig,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.preprocess import PreprocessorProtocol
|
||||
from batdetect2.utils.tensors import adjust_width
|
||||
from batdetect2.train.types import ClipperProtocol, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"TrainExample",
|
||||
"LabeledDataset",
|
||||
]
|
||||
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
|
||||
class TrainExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
|
||||
|
||||
class SubclipConfig(BaseConfig):
|
||||
duration: Optional[float] = None
|
||||
width: int = 512
|
||||
@ -51,14 +36,12 @@ class DatasetConfig(BaseConfig):
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
filenames: Sequence[PathLike],
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
filenames: Sequence[data.PathLike],
|
||||
clipper: ClipperProtocol,
|
||||
augmentation: Optional[Augmentation] = None,
|
||||
):
|
||||
self.preprocessor = preprocessor
|
||||
self.filenames = filenames
|
||||
self.subclip = subclip
|
||||
self.clipper = clipper
|
||||
self.augmentation = augmentation
|
||||
|
||||
def __len__(self):
|
||||
@ -66,14 +49,7 @@ class LabeledDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
dataset = self.get_dataset(idx)
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||
|
||||
if self.augmentation:
|
||||
dataset = self.augmentation(dataset)
|
||||
@ -84,37 +60,31 @@ class LabeledDataset(Dataset):
|
||||
class_heatmap=self.to_tensor(dataset["class"]),
|
||||
size_heatmap=self.to_tensor(dataset["size"]),
|
||||
idx=torch.tensor(idx),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_directory(
|
||||
cls,
|
||||
directory: PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
directory: data.PathLike,
|
||||
clipper: ClipperProtocol,
|
||||
extension: str = ".nc",
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[Augmentation] = None,
|
||||
):
|
||||
return cls(
|
||||
preprocessor=preprocessor,
|
||||
filenames=get_preprocessed_files(directory, extension),
|
||||
subclip=subclip,
|
||||
filenames=list_preprocessed_files(directory, extension),
|
||||
clipper=clipper,
|
||||
augmentation=augmentation,
|
||||
)
|
||||
|
||||
def get_random_example(self) -> xr.Dataset:
|
||||
def get_random_example(self) -> Tuple[xr.Dataset, float, float]:
|
||||
idx = np.random.randint(0, len(self))
|
||||
dataset = self.get_dataset(idx)
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||
|
||||
return dataset
|
||||
return dataset, start_time, end_time
|
||||
|
||||
def get_dataset(self, idx) -> xr.Dataset:
|
||||
return xr.open_dataset(self.filenames[idx])
|
||||
@ -129,16 +99,23 @@ class LabeledDataset(Dataset):
|
||||
array: xr.DataArray,
|
||||
dtype=np.float32,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.tensor(array.values.astype(dtype))
|
||||
|
||||
if not self.subclip:
|
||||
return tensor
|
||||
|
||||
width = self.subclip.width
|
||||
return adjust_width(tensor, width)
|
||||
return torch.tensor(array.values.astype(dtype))
|
||||
|
||||
|
||||
def get_preprocessed_files(
|
||||
directory: PathLike, extension: str = ".nc"
|
||||
def list_preprocessed_files(
|
||||
directory: data.PathLike, extension: str = ".nc"
|
||||
) -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
||||
|
||||
class RandomExampleSource:
|
||||
def __init__(self, filenames: List[str], clipper: ClipperProtocol):
|
||||
self.filenames = filenames
|
||||
self.clipper = clipper
|
||||
|
||||
def __call__(self):
|
||||
index = int(np.random.randint(len(self.filenames)))
|
||||
filename = self.filenames[index]
|
||||
dataset = xr.open_dataset(filename)
|
||||
example, _, _ = self.clipper.extract_clip(dataset)
|
||||
return example
|
||||
|
71
batdetect2/train/lightning.py
Normal file
71
batdetect2/train/lightning.py
Normal file
@ -0,0 +1,71 @@
|
||||
import lightning as L
|
||||
import torch
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import (
|
||||
DetectionModel,
|
||||
ModelOutput,
|
||||
)
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train import TrainExample
|
||||
from batdetect2.train.types import LossProtocol
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
]
|
||||
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
loss: LossProtocol,
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.loss = loss
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.targets = targets
|
||||
self.postprocessor = postprocessor
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
|
||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("train/loss/detection", losses.total, logger=True)
|
||||
self.log("train/loss/size", losses.total, logger=True)
|
||||
self.log("train/loss/classification", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
|
||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("val/loss/detection", losses.total, logger=True)
|
||||
self.log("val/loss/size", losses.total, logger=True)
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||
return [optimizer], [scheduler]
|
@ -1,115 +1,323 @@
|
||||
from typing import NamedTuple, Optional
|
||||
"""Loss functions and configurations for training BatDetect2 models.
|
||||
|
||||
This module defines the loss functions used to train BatDetect2 models,
|
||||
including individual loss components for different prediction tasks (detection,
|
||||
classification, size regression) and a main coordinating loss function that
|
||||
combines them.
|
||||
|
||||
It utilizes common loss types like L1 loss (`BBoxLoss`) for regression and
|
||||
Focal Loss (`FocalLoss`) for handling class imbalance in dense detection and
|
||||
classification tasks. Configuration objects (`LossConfig`, etc.) allow for easy
|
||||
customization of loss parameters and weights via configuration files.
|
||||
|
||||
The primary entry points are:
|
||||
- `LossFunction`: An `nn.Module` that computes the weighted sum of individual
|
||||
loss components given model outputs and ground truth targets.
|
||||
- `build_loss`: A factory function that constructs the `LossFunction` based
|
||||
on a `LossConfig` object.
|
||||
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
from batdetect2.train.types import Losses, LossProtocol
|
||||
|
||||
__all__ = [
|
||||
"bbox_size_loss",
|
||||
"compute_loss",
|
||||
"focal_loss",
|
||||
"mse_loss",
|
||||
"BBoxLoss",
|
||||
"ClassificationLossConfig",
|
||||
"DetectionLossConfig",
|
||||
"FocalLoss",
|
||||
"FocalLossConfig",
|
||||
"LossConfig",
|
||||
"LossFunction",
|
||||
"MSELoss",
|
||||
"SizeLossConfig",
|
||||
"build_loss",
|
||||
]
|
||||
|
||||
|
||||
class SizeLossConfig(BaseConfig):
|
||||
"""Configuration for the bounding box size loss component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
weight : float, default=0.1
|
||||
The weighting factor applied to the size loss when combining it with
|
||||
other losses (detection, classification) to form the total training
|
||||
loss.
|
||||
"""
|
||||
|
||||
weight: float = 0.1
|
||||
|
||||
|
||||
def bbox_size_loss(
|
||||
pred_size: torch.Tensor,
|
||||
gt_size: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
class BBoxLoss(nn.Module):
|
||||
"""Computes L1 loss for bounding box size regression.
|
||||
|
||||
Calculates the Mean Absolute Error (MAE or L1 loss) between the predicted
|
||||
size dimensions (`pred`) and the ground truth size dimensions (`gt`).
|
||||
Crucially, the loss is only computed at locations where the ground truth
|
||||
size heatmap (`gt`) contains non-zero values (i.e., at the reference points
|
||||
of actual annotated sound events). This prevents the model from being
|
||||
penalized for size predictions in background regions.
|
||||
|
||||
The loss is summed over all valid locations and normalized by the number
|
||||
of valid locations.
|
||||
"""
|
||||
Bounding box size loss. Only compute loss where there is a bounding box.
|
||||
"""
|
||||
gt_size_mask = (gt_size > 0).float()
|
||||
return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
|
||||
gt_size_mask.sum() + 1e-5
|
||||
)
|
||||
|
||||
def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate masked L1 loss for size prediction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred : torch.Tensor
|
||||
Predicted size tensor, typically shape `(B, 2, H, W)`, where
|
||||
channels represent scaled width and height.
|
||||
gt : torch.Tensor
|
||||
Ground truth size tensor, same shape as `pred`. Non-zero values
|
||||
indicate locations and target sizes of actual annotations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Scalar tensor representing the calculated masked L1 loss.
|
||||
"""
|
||||
gt_size_mask = (gt > 0).float()
|
||||
masked_pred = pred * gt_size_mask
|
||||
loss = F.l1_loss(masked_pred, gt, reduction="sum")
|
||||
num_pos = gt_size_mask.sum() + 1e-5
|
||||
return loss / num_pos
|
||||
|
||||
|
||||
class FocalLossConfig(BaseConfig):
|
||||
"""Configuration parameters for the Focal Loss function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
beta : float, default=4
|
||||
Exponent controlling the down-weighting of easy negative examples.
|
||||
Higher values increase down-weighting (focus more on hard negatives).
|
||||
alpha : float, default=2
|
||||
Exponent controlling the down-weighting based on prediction confidence.
|
||||
Higher values focus more on misclassified examples (both positive and
|
||||
negative).
|
||||
"""
|
||||
|
||||
beta: float = 4
|
||||
alpha: float = 2
|
||||
|
||||
|
||||
def focal_loss(
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
weights: Optional[torch.Tensor] = None,
|
||||
valid_mask: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5,
|
||||
beta: float = 4,
|
||||
alpha: float = 2,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
||||
pred (batch x c x h x w)
|
||||
gt (batch x c x h x w)
|
||||
class FocalLoss(nn.Module):
|
||||
"""Focal Loss implementation, adapted from CornerNet.
|
||||
|
||||
Addresses class imbalance in dense object detection/classification tasks by
|
||||
down-weighting the loss contribution from easy examples (both positive and
|
||||
negative), allowing the model to focus more on hard-to-classify examples.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
eps : float, default=1e-5
|
||||
Small epsilon value added for numerical stability.
|
||||
beta : float, default=4
|
||||
Exponent focusing on hard negative examples (modulates `(1-gt)^beta`).
|
||||
alpha : float, default=2
|
||||
Exponent focusing on misclassified examples (modulates `(1-p)^alpha`
|
||||
for positives and `p^alpha` for negatives).
|
||||
class_weights : torch.Tensor, optional
|
||||
Optional tensor containing weights for each class (applied to positive
|
||||
loss). Shape should be broadcastable to the channel dimension of the
|
||||
input tensors.
|
||||
mask_zero : bool, default=False
|
||||
If True, ignores loss contributions from spatial locations where the
|
||||
ground truth `gt` tensor is zero across *all* channels. Useful for
|
||||
classification heatmaps where some areas might have no assigned class.
|
||||
|
||||
References
|
||||
----------
|
||||
- Lin, T. Y., et al. "Focal loss for dense object detection." ICCV 2017.
|
||||
- Law, H., & Deng, J. "CornerNet: Detecting Objects as Paired Keypoints."
|
||||
ECCV 2018.
|
||||
"""
|
||||
|
||||
pos_inds = gt.eq(1).float()
|
||||
neg_inds = gt.lt(1).float()
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-5,
|
||||
beta: float = 4,
|
||||
alpha: float = 2,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
mask_zero: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.class_weights = class_weights
|
||||
self.eps = eps
|
||||
self.beta = beta
|
||||
self.alpha = alpha
|
||||
self.mask_zero = mask_zero
|
||||
|
||||
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
|
||||
neg_loss = (
|
||||
torch.log(1 - pred + eps)
|
||||
* torch.pow(pred, alpha)
|
||||
* torch.pow(1 - gt, beta)
|
||||
* neg_inds
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the Focal Loss.
|
||||
|
||||
if weights is not None:
|
||||
pos_loss = pos_loss * torch.tensor(weights)
|
||||
# neg_loss = neg_loss*weights
|
||||
Parameters
|
||||
----------
|
||||
pred : torch.Tensor
|
||||
Predicted probabilities or logits (typically sigmoid output for
|
||||
detection, or softmax/sigmoid for classification). Must be in the
|
||||
range [0, 1] after potential activation. Shape `(B, C, H, W)`.
|
||||
gt : torch.Tensor
|
||||
Ground truth heatmap tensor. Shape `(B, C, H, W)`. Values typically
|
||||
represent target probabilities (e.g., Gaussian peaks for detection,
|
||||
one-hot encoding or smoothed labels for classification). For the
|
||||
adapted CornerNet loss, `gt=1` indicates a positive location, and
|
||||
values `< 1` indicate negative locations (with potential Gaussian
|
||||
weighting `(1-gt)^beta` for negatives near positives).
|
||||
|
||||
if valid_mask is not None:
|
||||
pos_loss = pos_loss * valid_mask
|
||||
neg_loss = neg_loss * valid_mask
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Scalar tensor representing the computed focal loss, normalized by
|
||||
the number of positive locations.
|
||||
"""
|
||||
|
||||
pos_loss = pos_loss.sum()
|
||||
neg_loss = neg_loss.sum()
|
||||
pos_inds = gt.eq(1).float()
|
||||
neg_inds = gt.lt(1).float()
|
||||
|
||||
num_pos = pos_inds.float().sum()
|
||||
if num_pos == 0:
|
||||
loss = -neg_loss
|
||||
else:
|
||||
loss = -(pos_loss + neg_loss) / num_pos
|
||||
return loss
|
||||
pos_loss = (
|
||||
torch.log(pred + self.eps)
|
||||
* torch.pow(1 - pred, self.alpha)
|
||||
* pos_inds
|
||||
)
|
||||
neg_loss = (
|
||||
torch.log(1 - pred + self.eps)
|
||||
* torch.pow(pred, self.alpha)
|
||||
* torch.pow(1 - gt, self.beta)
|
||||
* neg_inds
|
||||
)
|
||||
|
||||
if self.class_weights is not None:
|
||||
pos_loss = pos_loss * torch.tensor(self.class_weights)
|
||||
|
||||
if self.mask_zero:
|
||||
valid_mask = gt.any(dim=1, keepdim=True).float()
|
||||
pos_loss = pos_loss * valid_mask
|
||||
neg_loss = neg_loss * valid_mask
|
||||
|
||||
pos_loss = pos_loss.sum()
|
||||
neg_loss = neg_loss.sum()
|
||||
|
||||
num_pos = pos_inds.float().sum()
|
||||
if num_pos == 0:
|
||||
loss = -neg_loss
|
||||
else:
|
||||
loss = -(pos_loss + neg_loss) / num_pos
|
||||
return loss
|
||||
|
||||
|
||||
def mse_loss(
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
valid_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
class MSELoss(nn.Module):
|
||||
"""Mean Squared Error (MSE) Loss module.
|
||||
|
||||
Calculates the mean squared difference between predictions and ground
|
||||
truth. Optionally masks contributions where the ground truth is zero across
|
||||
channels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mask_zero : bool, default=False
|
||||
If True, calculates the loss only over spatial locations (H, W) where
|
||||
at least one channel in the ground truth `gt` tensor is non-zero. The
|
||||
loss is then averaged over these valid locations. If False (default),
|
||||
the standard MSE over all elements is computed.
|
||||
"""
|
||||
Mean squared error loss.
|
||||
"""
|
||||
if valid_mask is None:
|
||||
op = ((gt - pred) ** 2).mean()
|
||||
else:
|
||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||
return op
|
||||
|
||||
def __init__(self, mask_zero: bool = False):
|
||||
super().__init__()
|
||||
self.mask_zero = mask_zero
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the Mean Squared Error loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred : torch.Tensor
|
||||
Predicted tensor, shape `(B, C, H, W)`.
|
||||
gt : torch.Tensor
|
||||
Ground truth tensor, shape `(B, C, H, W)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Scalar tensor representing the calculated MSE loss.
|
||||
"""
|
||||
if not self.mask_zero:
|
||||
return ((gt - pred) ** 2).mean()
|
||||
|
||||
valid_mask = gt.any(dim=1, keepdim=True).float()
|
||||
return (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||
|
||||
|
||||
class DetectionLossConfig(BaseConfig):
|
||||
"""Configuration for the detection loss component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
weight : float, default=1.0
|
||||
Weighting factor for the detection loss in the combined total loss.
|
||||
focal : FocalLossConfig
|
||||
Configuration for the Focal Loss used for detection. Defaults to
|
||||
standard Focal Loss parameters (`alpha=2`, `beta=4`).
|
||||
"""
|
||||
|
||||
weight: float = 1.0
|
||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||
|
||||
|
||||
class ClassificationLossConfig(BaseConfig):
|
||||
"""Configuration for the classification loss component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
weight : float, default=2.0
|
||||
Weighting factor for the classification loss in the combined total loss.
|
||||
focal : FocalLossConfig
|
||||
Configuration for the Focal Loss used for classification. Defaults to
|
||||
standard Focal Loss parameters (`alpha=2`, `beta=4`).
|
||||
"""
|
||||
|
||||
weight: float = 2.0
|
||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||
class_weights: Optional[list[float]] = None
|
||||
|
||||
|
||||
class LossConfig(BaseConfig):
|
||||
"""Aggregated configuration for all loss components.
|
||||
|
||||
Defines the configuration and weighting for detection, size regression,
|
||||
and classification losses used in the main `LossFunction`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : DetectionLossConfig
|
||||
Configuration for the detection loss (Focal Loss).
|
||||
size : SizeLossConfig
|
||||
Configuration for the size regression loss (L1 loss).
|
||||
classification : ClassificationLossConfig
|
||||
Configuration for the classification loss (Focal Loss).
|
||||
"""
|
||||
|
||||
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
|
||||
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
|
||||
classification: ClassificationLossConfig = Field(
|
||||
@ -117,50 +325,157 @@ class LossConfig(BaseConfig):
|
||||
)
|
||||
|
||||
|
||||
class Losses(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
size: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
total: torch.Tensor
|
||||
class LossFunction(nn.Module, LossProtocol):
|
||||
"""Computes the combined training loss for the BatDetect2 model.
|
||||
|
||||
Aggregates individual loss functions for detection, size regression, and
|
||||
classification tasks. Calculates each component loss based on model outputs
|
||||
and ground truth targets, applies configured weights, and sums them to get
|
||||
the final total loss used for optimization. Also returns individual
|
||||
components for monitoring.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size_loss : nn.Module
|
||||
Instantiated loss module for size regression (e.g., `BBoxLoss`).
|
||||
detection_loss : nn.Module
|
||||
Instantiated loss module for detection (e.g., `FocalLoss`).
|
||||
classification_loss : nn.Module
|
||||
Instantiated loss module for classification (e.g., `FocalLoss`).
|
||||
size_weight : float, default=0.1
|
||||
Weighting factor for the size loss component.
|
||||
detection_weight : float, default=1.0
|
||||
Weighting factor for the detection loss component.
|
||||
classification_weight : float, default=2.0
|
||||
Weighting factor for the classification loss component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
size_loss_fn : nn.Module
|
||||
detection_loss_fn : nn.Module
|
||||
classification_loss_fn : nn.Module
|
||||
size_weight : float
|
||||
detection_weight : float
|
||||
classification_weight : float
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size_loss: nn.Module,
|
||||
detection_loss: nn.Module,
|
||||
classification_loss: nn.Module,
|
||||
size_weight: float = 0.1,
|
||||
detection_weight: float = 1.0,
|
||||
classification_weight: float = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.size_loss_fn = size_loss
|
||||
self.detection_loss_fn = detection_loss
|
||||
self.classification_loss_fn = classification_loss
|
||||
|
||||
self.size_weight = size_weight
|
||||
self.detection_weight = detection_weight
|
||||
self.classification_weight = classification_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: ModelOutput,
|
||||
gt: TrainExample,
|
||||
) -> Losses:
|
||||
"""Calculate the combined loss and individual components.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred: ModelOutput
|
||||
A NamedTuple containing the model's prediction tensors for the
|
||||
batch: `detection_probs`, `size_preds`, `class_probs`.
|
||||
gt: TrainExample
|
||||
A structure containing the ground truth targets for the batch,
|
||||
expected to have attributes like `detection_heatmap`,
|
||||
`size_heatmap`, and `class_heatmap` (as `torch.Tensor`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Losses
|
||||
A NamedTuple containing the scalar loss values for detection, size,
|
||||
classification, and the total weighted loss.
|
||||
"""
|
||||
size_loss = self.size_loss_fn(pred.size_preds, gt.size_heatmap)
|
||||
detection_loss = self.detection_loss_fn(
|
||||
pred.detection_probs,
|
||||
gt.detection_heatmap,
|
||||
)
|
||||
classification_loss = self.classification_loss_fn(
|
||||
pred.class_probs,
|
||||
gt.class_heatmap,
|
||||
)
|
||||
total_loss = (
|
||||
size_loss * self.size_weight
|
||||
+ classification_loss * self.classification_weight
|
||||
+ detection_loss * self.detection_weight
|
||||
)
|
||||
return Losses(
|
||||
detection=detection_loss,
|
||||
size=size_loss,
|
||||
classification=classification_loss,
|
||||
total=total_loss,
|
||||
)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
batch: TrainExample,
|
||||
outputs: ModelOutput,
|
||||
conf: LossConfig,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
) -> Losses:
|
||||
detection_loss = focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
beta=conf.detection.focal.beta,
|
||||
alpha=conf.detection.focal.alpha,
|
||||
def build_loss(
|
||||
config: Optional[LossConfig] = None,
|
||||
class_weights: Optional[np.ndarray] = None,
|
||||
) -> nn.Module:
|
||||
"""Factory function to build the main LossFunction from configuration.
|
||||
|
||||
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based
|
||||
on the provided `LossConfig` (or defaults) and optional `class_weights`,
|
||||
then assembles them into the main `LossFunction` module with the specified
|
||||
component weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : LossConfig, optional
|
||||
Configuration object defining weights and parameters (e.g., alpha, beta
|
||||
for Focal Loss) for each loss component. If None, default settings
|
||||
from `LossConfig` and its nested configs are used.
|
||||
class_weights : np.ndarray, optional
|
||||
An array of weights for each specific class, used to adjust the
|
||||
classification loss (typically Focal Loss). If provided, this overrides
|
||||
any `class_weights` specified within `config.classification`. If None,
|
||||
weights from the config (or default of equal weights) are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LossFunction
|
||||
An initialized `LossFunction` module ready for training.
|
||||
"""
|
||||
config = config or LossConfig()
|
||||
|
||||
class_weights_tensor = (
|
||||
torch.tensor(class_weights) if class_weights else None
|
||||
)
|
||||
|
||||
size_loss = bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
detection_loss_fn = FocalLoss(
|
||||
beta=config.detection.focal.beta,
|
||||
alpha=config.detection.focal.alpha,
|
||||
mask_zero=False,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
weights=class_weights,
|
||||
valid_mask=valid_mask,
|
||||
beta=conf.classification.focal.beta,
|
||||
alpha=conf.classification.focal.alpha,
|
||||
classification_loss_fn = FocalLoss(
|
||||
beta=config.classification.focal.beta,
|
||||
alpha=config.classification.focal.alpha,
|
||||
class_weights=class_weights_tensor,
|
||||
mask_zero=True,
|
||||
)
|
||||
|
||||
total = (
|
||||
detection_loss * conf.detection.weight
|
||||
+ size_loss * conf.size.weight
|
||||
+ classification_loss * conf.classification.weight
|
||||
)
|
||||
size_loss_fn = BBoxLoss()
|
||||
|
||||
return Losses(
|
||||
detection=detection_loss,
|
||||
size=size_loss,
|
||||
classification=classification_loss,
|
||||
total=total,
|
||||
return LossFunction(
|
||||
size_loss=size_loss_fn,
|
||||
classification_loss=classification_loss_fn,
|
||||
detection_loss=detection_loss_fn,
|
||||
size_weight=config.size.weight,
|
||||
detection_weight=config.detection.weight,
|
||||
classification_weight=config.classification.weight,
|
||||
)
|
||||
|
@ -139,6 +139,7 @@ def _save_xr_dataset_to_file(
|
||||
dataset.to_netcdf(
|
||||
path,
|
||||
encoding={
|
||||
"audio": {"zlib": True},
|
||||
"spectrogram": {"zlib": True},
|
||||
"size": {"zlib": True},
|
||||
"class": {"zlib": True},
|
||||
|
@ -1,12 +1,17 @@
|
||||
from typing import Callable, NamedTuple
|
||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Heatmaps",
|
||||
"ClipLabeller",
|
||||
"Augmentation",
|
||||
"LossProtocol",
|
||||
"TrainExample",
|
||||
]
|
||||
|
||||
|
||||
@ -46,3 +51,50 @@ steps, and returns the final `Heatmaps` used for model training.
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
|
||||
class TrainExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: float
|
||||
end_time: float
|
||||
|
||||
|
||||
class Losses(NamedTuple):
|
||||
"""Structure to hold the computed loss values.
|
||||
|
||||
Allows returning individual loss components along with the total weighted
|
||||
loss for monitoring and analysis during training.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : torch.Tensor
|
||||
Scalar tensor representing the calculated detection loss component
|
||||
(before weighting).
|
||||
size : torch.Tensor
|
||||
Scalar tensor representing the calculated size regression loss component
|
||||
(before weighting).
|
||||
classification : torch.Tensor
|
||||
Scalar tensor representing the calculated classification loss component
|
||||
(before weighting).
|
||||
total : torch.Tensor
|
||||
Scalar tensor representing the final combined loss, computed as the
|
||||
weighted sum of the detection, size, and classification components.
|
||||
This is the value typically used for backpropagation.
|
||||
"""
|
||||
|
||||
detection: torch.Tensor
|
||||
size: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
total: torch.Tensor
|
||||
|
||||
|
||||
class LossProtocol(Protocol):
|
||||
def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ...
|
||||
|
||||
|
||||
class ClipperProtocol(Protocol):
|
||||
def extract_clip(
|
||||
self, example: xr.Dataset
|
||||
) -> Tuple[xr.Dataset, float, float]: ...
|
||||
|
@ -24,17 +24,6 @@ from batdetect2.targets.classes import (
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_term_registry() -> TermRegistry:
|
||||
"""Fixture for a sample TermRegistry."""
|
||||
registry = TermRegistry()
|
||||
registry.add_custom_term("order")
|
||||
registry.add_custom_term("species")
|
||||
registry.add_custom_term("call_type")
|
||||
registry.add_custom_term("quality")
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotation(
|
||||
sound_event: data.SoundEvent,
|
||||
|
466
tests/test_train/test_clips.py
Normal file
466
tests/test_train/test_clips.py
Normal file
@ -0,0 +1,466 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
|
||||
from batdetect2.train.clips import (
|
||||
Clipper,
|
||||
_compute_expected_width,
|
||||
select_subclip,
|
||||
)
|
||||
|
||||
AUDIO_SAMPLERATE = 48000
|
||||
|
||||
SPEC_SAMPLERATE = 100
|
||||
SPEC_FREQS = 64
|
||||
CLIP_DURATION = 0.5
|
||||
|
||||
|
||||
CLIP_WIDTH_SPEC = int(np.floor(CLIP_DURATION * SPEC_SAMPLERATE))
|
||||
CLIP_WIDTH_AUDIO = int(np.floor(CLIP_DURATION * AUDIO_SAMPLERATE))
|
||||
MAX_EMPTY = 0.2
|
||||
|
||||
|
||||
def create_test_dataset(
|
||||
duration_sec: float,
|
||||
spec_samplerate: int = SPEC_SAMPLERATE,
|
||||
audio_samplerate: int = AUDIO_SAMPLERATE,
|
||||
num_freqs: int = SPEC_FREQS,
|
||||
start_time: float = 0.0,
|
||||
) -> xr.Dataset:
|
||||
"""Creates a sample xr.Dataset for testing."""
|
||||
time_step = 1 / spec_samplerate
|
||||
audio_time_step = 1 / audio_samplerate
|
||||
|
||||
times = np.arange(start_time, start_time + duration_sec, step=time_step)
|
||||
freqs = np.linspace(0, audio_samplerate / 2, num_freqs)
|
||||
audio_times = np.arange(
|
||||
start_time,
|
||||
start_time + duration_sec,
|
||||
step=audio_time_step,
|
||||
)
|
||||
|
||||
num_time_steps = len(times)
|
||||
num_audio_samples = len(audio_times)
|
||||
spec_shape = (num_freqs, num_time_steps)
|
||||
|
||||
spectrogram_data = np.arange(num_time_steps).reshape(1, -1) * np.ones(
|
||||
(num_freqs, 1)
|
||||
)
|
||||
|
||||
spectrogram = xr.DataArray(
|
||||
spectrogram_data.astype(np.float32),
|
||||
coords=[("frequency", freqs), ("time", times)],
|
||||
name="spectrogram",
|
||||
)
|
||||
|
||||
detection = xr.DataArray(
|
||||
np.ones(spec_shape, dtype=np.float32) * 0.5,
|
||||
coords=spectrogram.coords,
|
||||
name="detection",
|
||||
)
|
||||
|
||||
classes = xr.DataArray(
|
||||
np.ones((3, *spec_shape), dtype=np.float32),
|
||||
coords=[
|
||||
("category", ["A", "B", "C"]),
|
||||
("frequency", freqs),
|
||||
("time", times),
|
||||
],
|
||||
name="class",
|
||||
)
|
||||
|
||||
size = xr.DataArray(
|
||||
np.ones((2, *spec_shape), dtype=np.float32),
|
||||
coords=[
|
||||
("dimension", ["height", "width"]),
|
||||
("frequency", freqs),
|
||||
("time", times),
|
||||
],
|
||||
name="size",
|
||||
)
|
||||
|
||||
audio_data = np.arange(num_audio_samples)
|
||||
audio = xr.DataArray(
|
||||
audio_data.astype(np.float32),
|
||||
coords=[("audio_time", audio_times)],
|
||||
name="audio",
|
||||
)
|
||||
|
||||
metadata = xr.DataArray([1, 2, 3], dims=["other_dim"], name="metadata")
|
||||
|
||||
return xr.Dataset(
|
||||
{
|
||||
"audio": audio,
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection,
|
||||
"class": classes,
|
||||
"size": size,
|
||||
"metadata": metadata,
|
||||
}
|
||||
).assign_attrs(
|
||||
samplerate=audio_samplerate,
|
||||
spec_samplerate=spec_samplerate,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_dataset() -> xr.Dataset:
|
||||
"""Dataset longer than the clip duration."""
|
||||
return create_test_dataset(duration_sec=2.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_dataset() -> xr.Dataset:
|
||||
"""Dataset shorter than the clip duration."""
|
||||
return create_test_dataset(duration_sec=0.3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def exact_dataset() -> xr.Dataset:
|
||||
"""Dataset exactly the clip duration."""
|
||||
return create_test_dataset(duration_sec=CLIP_DURATION - 1e-9)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def offset_dataset() -> xr.Dataset:
|
||||
"""Dataset starting at a non-zero time."""
|
||||
return create_test_dataset(duration_sec=1.0, start_time=0.5)
|
||||
|
||||
|
||||
def test_select_subclip_within_bounds(long_dataset):
|
||||
start_time = 0.5
|
||||
subclip = select_subclip(
|
||||
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
||||
)
|
||||
expected_width = _compute_expected_width(
|
||||
long_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
|
||||
assert "time" in subclip.dims
|
||||
assert subclip.dims["time"] == expected_width
|
||||
assert subclip.spectrogram.dims == ("frequency", "time")
|
||||
assert subclip.spectrogram.shape == (SPEC_FREQS, expected_width)
|
||||
assert subclip.detection.shape == (SPEC_FREQS, expected_width)
|
||||
assert subclip["class"].shape == (3, SPEC_FREQS, expected_width)
|
||||
assert subclip.size.shape == (2, SPEC_FREQS, expected_width)
|
||||
assert subclip.time.min() >= start_time
|
||||
assert (
|
||||
subclip.time.max() <= start_time + CLIP_DURATION + 1 / SPEC_SAMPLERATE
|
||||
)
|
||||
|
||||
assert "metadata" in subclip
|
||||
xr.testing.assert_equal(subclip.metadata, long_dataset.metadata)
|
||||
|
||||
|
||||
def test_select_subclip_pad_start(long_dataset):
|
||||
start_time = -0.1
|
||||
subclip = select_subclip(
|
||||
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
||||
)
|
||||
expected_width = _compute_expected_width(
|
||||
long_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
step = 1 / SPEC_SAMPLERATE
|
||||
expected_pad_samples = int(np.floor(abs(start_time) / step))
|
||||
|
||||
assert subclip.dims["time"] == expected_width
|
||||
assert subclip.spectrogram.shape[1] == expected_width
|
||||
|
||||
assert np.all(
|
||||
subclip.spectrogram.isel(time=slice(0, expected_pad_samples)) == 0
|
||||
)
|
||||
|
||||
assert np.any(
|
||||
subclip.spectrogram.isel(time=slice(expected_pad_samples, None)) != 0
|
||||
)
|
||||
assert subclip.time.min() >= start_time
|
||||
assert subclip.time.max() < start_time + CLIP_DURATION + step
|
||||
|
||||
|
||||
def test_select_subclip_pad_end(long_dataset):
|
||||
original_duration = long_dataset.time.max() - long_dataset.time.min()
|
||||
start_time = original_duration - 0.1
|
||||
subclip = select_subclip(
|
||||
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
||||
)
|
||||
expected_width = _compute_expected_width(
|
||||
long_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
step = 1 / SPEC_SAMPLERATE
|
||||
original_width = long_dataset.dims["time"]
|
||||
expected_pad_samples = expected_width - (
|
||||
original_width - int(np.floor(start_time / step))
|
||||
)
|
||||
|
||||
assert subclip.sizes["time"] == expected_width
|
||||
assert subclip.spectrogram.shape[1] == expected_width
|
||||
|
||||
assert np.all(
|
||||
subclip.spectrogram.isel(
|
||||
time=slice(expected_width - expected_pad_samples, None)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
assert np.any(
|
||||
subclip.spectrogram.isel(
|
||||
time=slice(0, expected_width - expected_pad_samples)
|
||||
)
|
||||
!= 0
|
||||
)
|
||||
assert subclip.time.min() >= start_time
|
||||
assert subclip.time.max() < start_time + CLIP_DURATION + step
|
||||
|
||||
|
||||
def test_select_subclip_pad_both_short_dataset(short_dataset):
|
||||
start_time = -0.1
|
||||
subclip = select_subclip(
|
||||
short_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
||||
)
|
||||
expected_width = _compute_expected_width(
|
||||
short_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
step = 1 / SPEC_SAMPLERATE
|
||||
|
||||
assert subclip.dims["time"] == expected_width
|
||||
assert subclip.spectrogram.shape[1] == expected_width
|
||||
|
||||
assert subclip.spectrogram.coords["time"][0] == pytest.approx(
|
||||
start_time,
|
||||
abs=step,
|
||||
)
|
||||
assert subclip.spectrogram.coords["time"][-1] == pytest.approx(
|
||||
start_time + CLIP_DURATION - step,
|
||||
abs=2 * step,
|
||||
)
|
||||
|
||||
|
||||
def test_select_subclip_width_consistency(long_dataset):
|
||||
expected_width = _compute_expected_width(
|
||||
long_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
step = 1 / SPEC_SAMPLERATE
|
||||
|
||||
subclip_aligned = select_subclip(
|
||||
long_dataset.copy(deep=True),
|
||||
span=CLIP_DURATION,
|
||||
start=5 * step,
|
||||
dim="time",
|
||||
)
|
||||
|
||||
subclip_offset = select_subclip(
|
||||
long_dataset.copy(deep=True),
|
||||
span=CLIP_DURATION,
|
||||
start=5.3 * step,
|
||||
dim="time",
|
||||
)
|
||||
|
||||
assert subclip_aligned.sizes["time"] == expected_width
|
||||
assert subclip_offset.sizes["time"] == expected_width
|
||||
assert subclip_aligned.spectrogram.shape[1] == expected_width
|
||||
assert subclip_offset.spectrogram.shape[1] == expected_width
|
||||
|
||||
|
||||
def test_select_subclip_different_dimension(long_dataset):
|
||||
freq_coords = long_dataset.frequency.values
|
||||
freq_min, freq_max = freq_coords.min(), freq_coords.max()
|
||||
freq_span = (freq_max - freq_min) / 2
|
||||
start_freq = freq_min + freq_span / 2
|
||||
|
||||
subclip = select_subclip(
|
||||
long_dataset, span=freq_span, start=start_freq, dim="frequency"
|
||||
)
|
||||
|
||||
assert "frequency" in subclip.dims
|
||||
assert subclip.spectrogram.shape[0] < long_dataset.spectrogram.shape[0]
|
||||
assert subclip.detection.shape[0] < long_dataset.detection.shape[0]
|
||||
assert subclip["class"].shape[1] < long_dataset["class"].shape[1]
|
||||
assert subclip.size.shape[1] < long_dataset.size.shape[1]
|
||||
|
||||
assert subclip.dims["time"] == long_dataset.dims["time"]
|
||||
assert subclip.spectrogram.shape[1] == long_dataset.spectrogram.shape[1]
|
||||
|
||||
xr.testing.assert_equal(subclip.audio, long_dataset.audio)
|
||||
assert subclip.dims["audio_time"] == long_dataset.dims["audio_time"]
|
||||
|
||||
|
||||
def test_select_subclip_fill_value(short_dataset):
|
||||
fill_value = -999.0
|
||||
subclip = select_subclip(
|
||||
short_dataset,
|
||||
span=CLIP_DURATION,
|
||||
start=0,
|
||||
dim="time",
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
expected_width = _compute_expected_width(
|
||||
short_dataset,
|
||||
CLIP_DURATION,
|
||||
"time",
|
||||
)
|
||||
|
||||
assert subclip.dims["time"] == expected_width
|
||||
assert np.all(subclip.spectrogram.sel(time=slice(0.3, None)) == fill_value)
|
||||
|
||||
|
||||
def test_select_subclip_no_overlap_raises_error(long_dataset):
|
||||
original_duration = long_dataset.time.max() - long_dataset.time.min()
|
||||
|
||||
with pytest.raises(ValueError, match="does not overlap"):
|
||||
select_subclip(
|
||||
long_dataset,
|
||||
span=CLIP_DURATION,
|
||||
start=original_duration + 1.0,
|
||||
dim="time",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not overlap"):
|
||||
select_subclip(
|
||||
long_dataset,
|
||||
span=CLIP_DURATION,
|
||||
start=-1.0 * CLIP_DURATION - 1.0,
|
||||
dim="time",
|
||||
)
|
||||
|
||||
|
||||
def test_clipper_non_random(long_dataset, exact_dataset, short_dataset):
|
||||
clipper = Clipper(duration=CLIP_DURATION, random=False)
|
||||
|
||||
for ds in [long_dataset, exact_dataset, short_dataset]:
|
||||
clip, _, _ = clipper.extract_clip(ds)
|
||||
expected_spec_width = _compute_expected_width(
|
||||
ds, CLIP_DURATION, "time"
|
||||
)
|
||||
expected_audio_width = _compute_expected_width(
|
||||
ds, CLIP_DURATION, "audio_time"
|
||||
)
|
||||
|
||||
assert clip.dims["time"] == expected_spec_width
|
||||
assert clip.dims["audio_time"] == expected_audio_width
|
||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||
assert clip.audio.shape[0] == expected_audio_width
|
||||
|
||||
assert clip.time.min() >= -1 / SPEC_SAMPLERATE
|
||||
assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE
|
||||
|
||||
time_span = clip.time.max() - clip.time.min()
|
||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
||||
|
||||
|
||||
def test_clipper_random(long_dataset):
|
||||
seed = 42
|
||||
np.random.seed(seed)
|
||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||
clip1, _, _ = clipper.extract_clip(long_dataset)
|
||||
|
||||
np.random.seed(seed + 1)
|
||||
clip2, _, _ = clipper.extract_clip(long_dataset)
|
||||
|
||||
expected_spec_width = _compute_expected_width(
|
||||
long_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
expected_audio_width = _compute_expected_width(
|
||||
long_dataset, CLIP_DURATION, "audio_time"
|
||||
)
|
||||
|
||||
for clip in [clip1, clip2]:
|
||||
assert clip.dims["time"] == expected_spec_width
|
||||
assert clip.dims["audio_time"] == expected_audio_width
|
||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||
assert clip.audio.shape[0] == expected_audio_width
|
||||
|
||||
assert not np.isclose(clip1.time.min(), clip2.time.min())
|
||||
assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min())
|
||||
|
||||
for clip in [clip1, clip2]:
|
||||
time_span = clip.time.max() - clip.time.min()
|
||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
||||
|
||||
max_start_time = (
|
||||
(long_dataset.time.max() - long_dataset.time.min())
|
||||
- CLIP_DURATION
|
||||
+ MAX_EMPTY
|
||||
)
|
||||
assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
|
||||
assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE
|
||||
|
||||
|
||||
def test_clipper_random_max_empty_effect(long_dataset):
|
||||
"""Check that max_empty influences the possible start times."""
|
||||
seed = 123
|
||||
data_duration = long_dataset.time.max() - long_dataset.time.min()
|
||||
|
||||
np.random.seed(seed)
|
||||
clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0)
|
||||
max_start_time0 = data_duration - CLIP_DURATION
|
||||
start_times0 = []
|
||||
|
||||
for _ in range(20):
|
||||
clip, _, _ = clipper0.extract_clip(long_dataset)
|
||||
start_times0.append(clip.time.min().item())
|
||||
|
||||
assert all(
|
||||
st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0
|
||||
)
|
||||
assert any(st > 0.1 for st in start_times0)
|
||||
|
||||
np.random.seed(seed)
|
||||
clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2)
|
||||
max_start_time_pos = data_duration - CLIP_DURATION + 0.2
|
||||
start_times_pos = []
|
||||
for _ in range(20):
|
||||
clip, _, _ = clipper_pos.extract_clip(long_dataset)
|
||||
start_times_pos.append(clip.time.min().item())
|
||||
assert all(
|
||||
st <= max_start_time_pos + 1 / SPEC_SAMPLERATE
|
||||
for st in start_times_pos
|
||||
)
|
||||
|
||||
assert any(st > max_start_time0 + 1e-6 for st in start_times_pos)
|
||||
|
||||
|
||||
def test_clipper_short_dataset_random(short_dataset):
|
||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||
clip, _, _ = clipper.extract_clip(short_dataset)
|
||||
|
||||
expected_spec_width = _compute_expected_width(
|
||||
short_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
expected_audio_width = _compute_expected_width(
|
||||
short_dataset, CLIP_DURATION, "audio_time"
|
||||
)
|
||||
|
||||
assert clip.sizes["time"] == expected_spec_width
|
||||
assert clip.sizes["audio_time"] == expected_audio_width
|
||||
assert clip["spectrogram"].shape[1] == expected_spec_width
|
||||
assert clip["audio"].shape[0] == expected_audio_width
|
||||
|
||||
assert np.any(clip.spectrogram == 0)
|
||||
assert np.any(clip.audio == 0)
|
||||
|
||||
|
||||
def test_clipper_exact_dataset_random(exact_dataset):
|
||||
clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY)
|
||||
clip, _, _ = clipper.extract_clip(exact_dataset)
|
||||
|
||||
expected_spec_width = _compute_expected_width(
|
||||
exact_dataset, CLIP_DURATION, "time"
|
||||
)
|
||||
expected_audio_width = _compute_expected_width(
|
||||
exact_dataset, CLIP_DURATION, "audio_time"
|
||||
)
|
||||
|
||||
assert clip.dims["time"] == expected_spec_width
|
||||
assert clip.dims["audio_time"] == expected_audio_width
|
||||
assert clip.spectrogram.shape[1] == expected_spec_width
|
||||
assert clip.audio.shape[0] == expected_audio_width
|
||||
|
||||
time_span = clip.time.max() - clip.time.min()
|
||||
audio_span = clip.audio_time.max() - clip.audio_time.min()
|
||||
assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE)
|
||||
assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE)
|
56
tests/test_train/test_lightning.py
Normal file
56
tests/test_train/test_lightning.py
Normal file
@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
|
||||
def build_default_module():
|
||||
loss = build_loss()
|
||||
targets = build_targets()
|
||||
detector = build_model(num_classes=len(targets.class_names))
|
||||
preprocessor = build_preprocessor()
|
||||
postprocessor = build_postprocessor(targets)
|
||||
return TrainingModule(
|
||||
detector=detector,
|
||||
loss=loss,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
|
||||
def test_can_initialize_default_module():
|
||||
module = build_default_module()
|
||||
assert isinstance(module, L.LightningModule)
|
||||
|
||||
|
||||
def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
|
||||
module = build_default_module()
|
||||
trainer = L.Trainer()
|
||||
path = tmp_path / "example.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
|
||||
spec1 = module.preprocessor.preprocess_clip(clip)
|
||||
spec2 = recovered.preprocessor.preprocess_clip(clip)
|
||||
|
||||
xr.testing.assert_equal(spec1, spec2)
|
||||
|
||||
input1 = torch.tensor([spec1.values]).unsqueeze(0)
|
||||
input2 = torch.tensor([spec2.values]).unsqueeze(0)
|
||||
|
||||
output1 = module(input1)
|
||||
output2 = recovered(input2)
|
||||
|
||||
assert output1 == output2
|
Loading…
Reference in New Issue
Block a user