From 9cf159efff3499e7f0ec3803e342e22c1737f940 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 19 Nov 2024 19:34:54 +0000 Subject: [PATCH] Reworking model creation --- batdetect2/compat/params.py | 13 +- batdetect2/models/__init__.py | 45 +- .../{feature_extractors.py => backbones.py} | 30 +- batdetect2/models/blocks.py | 139 ++- batdetect2/models/detectors.py | 142 --- batdetect2/models/heads.py | 51 + batdetect2/models/typing.py | 7 +- batdetect2/{models => }/post_process.py | 13 +- batdetect2/preprocess/__init__.py | 10 +- batdetect2/preprocess/spectrogram.py | 48 +- batdetect2/train/audio_dataloader.py | 941 ------------------ batdetect2/train/augmentations.py | 439 +++++--- batdetect2/train/dataset.py | 87 +- batdetect2/train/labels.py | 17 +- batdetect2/train/light.py | 56 -- batdetect2/train/losses.py | 86 +- batdetect2/train/modules.py | 98 ++ batdetect2/train/preprocess.py | 11 +- batdetect2/train/targets.py | 133 ++- tests/conftest.py | 2 +- tests/test_migration/test_preprocessing.py | 7 +- tests/test_train/__init__.py | 0 tests/test_train/test_augmentations.py | 70 ++ .../{test_data => test_train}/test_labels.py | 6 +- 24 files changed, 967 insertions(+), 1484 deletions(-) rename batdetect2/models/{feature_extractors.py => backbones.py} (93%) delete mode 100644 batdetect2/models/detectors.py create mode 100644 batdetect2/models/heads.py rename batdetect2/{models => }/post_process.py (96%) delete mode 100644 batdetect2/train/audio_dataloader.py delete mode 100644 batdetect2/train/light.py create mode 100644 batdetect2/train/modules.py create mode 100644 tests/test_train/__init__.py create mode 100644 tests/test_train/test_augmentations.py rename tests/{test_data => test_train}/test_labels.py (96%) diff --git a/batdetect2/compat/params.py b/batdetect2/compat/params.py index acb811f..65ce3e9 100644 --- a/batdetect2/compat/params.py +++ b/batdetect2/compat/params.py @@ -1,10 +1,13 @@ from batdetect2.preprocess import ( + AmplitudeScaleConfig, AudioConfig, FFTConfig, FrequencyConfig, - PcenConfig, + LogScaleConfig, + PcenScaleConfig, PreprocessingConfig, ResampleConfig, + Scales, SpecSizeConfig, SpectrogramConfig, ) @@ -17,12 +20,12 @@ from batdetect2.train.preprocess import ( ) -def get_spectrogram_scale(scale: str): +def get_spectrogram_scale(scale: str) -> Scales: if scale == "pcen": - return PcenConfig() + return PcenScaleConfig() if scale == "log": - return "log" - return None + return LogScaleConfig() + return AmplitudeScaleConfig() def get_preprocessing_config(params: dict) -> PreprocessingConfig: diff --git a/batdetect2/models/__init__.py b/batdetect2/models/__init__.py index ef37e70..be06c42 100644 --- a/batdetect2/models/__init__.py +++ b/batdetect2/models/__init__.py @@ -1,11 +1,54 @@ -from batdetect2.models.feature_extractors import ( +from enum import Enum + +from batdetect2.configs import BaseConfig +from batdetect2.models.backbones import ( Net2DFast, Net2DFastNoAttn, Net2DFastNoCoordConv, ) +from batdetect2.models.heads import BBoxHead, ClassifierHead +from batdetect2.models.typing import BackboneModel __all__ = [ + "get_backbone", "Net2DFast", "Net2DFastNoAttn", "Net2DFastNoCoordConv", + "ModelType", + "BBoxHead", + "ClassifierHead", ] + + +class ModelType(str, Enum): + Net2DFast = "Net2DFast" + Net2DFastNoAttn = "Net2DFastNoAttn" + Net2DFastNoCoordConv = "Net2DFastNoCoordConv" + + +class ModelConfig(BaseConfig): + name: ModelType = ModelType.Net2DFast + num_features: int = 128 + + +def get_backbone( + config: ModelConfig, + input_height: int = 128, +) -> BackboneModel: + if config.name == ModelType.Net2DFast: + return Net2DFast( + input_height=input_height, + num_features=config.num_features, + ) + elif config.name == ModelType.Net2DFastNoAttn: + return Net2DFastNoAttn( + num_features=config.num_features, + input_height=input_height, + ) + elif config.name == ModelType.Net2DFastNoCoordConv: + return Net2DFastNoCoordConv( + num_features=config.num_features, + input_height=input_height, + ) + else: + raise ValueError(f"Unknown model type: {config.name}") diff --git a/batdetect2/models/feature_extractors.py b/batdetect2/models/backbones.py similarity index 93% rename from batdetect2/models/feature_extractors.py rename to batdetect2/models/backbones.py index 4c437b9..0aa266a 100644 --- a/batdetect2/models/feature_extractors.py +++ b/batdetect2/models/backbones.py @@ -12,7 +12,7 @@ from batdetect2.models.blocks import ( ConvBlockUpStandard, SelfAttention, ) -from batdetect2.models.typing import FeatureExtractorModel +from batdetect2.models.typing import BackboneModel __all__ = [ "Net2DFast", @@ -21,7 +21,7 @@ __all__ = [ ] -class Net2DFast(FeatureExtractorModel): +class Net2DFast(BackboneModel): def __init__( self, num_features: int, @@ -37,7 +37,7 @@ class Net2DFast(FeatureExtractorModel): 1, self.num_features // 4, self.input_height, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) @@ -45,7 +45,7 @@ class Net2DFast(FeatureExtractorModel): self.num_features // 4, self.num_features // 2, self.input_height // 2, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) @@ -53,7 +53,7 @@ class Net2DFast(FeatureExtractorModel): self.num_features // 2, self.num_features, self.input_height // 4, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) @@ -100,6 +100,8 @@ class Net2DFast(FeatureExtractorModel): ) self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) + self.out_channels = self.num_features // 4 + def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]: h, w = spec.shape[2:] h_pad = (32 - h % 32) % 32 @@ -135,7 +137,7 @@ class Net2DFast(FeatureExtractorModel): return F.relu_(self.conv_op_bn(self.conv_op(x))) -class Net2DFastNoAttn(FeatureExtractorModel): +class Net2DFastNoAttn(BackboneModel): def __init__( self, num_features: int, @@ -151,7 +153,7 @@ class Net2DFastNoAttn(FeatureExtractorModel): 1, self.num_features // 4, self.input_height, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) @@ -159,7 +161,7 @@ class Net2DFastNoAttn(FeatureExtractorModel): self.num_features // 4, self.num_features // 2, self.input_height // 2, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) @@ -167,7 +169,7 @@ class Net2DFastNoAttn(FeatureExtractorModel): self.num_features // 2, self.num_features, self.input_height // 4, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) @@ -210,6 +212,7 @@ class Net2DFastNoAttn(FeatureExtractorModel): padding=1, ) self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) + self.out_channels = self.num_features // 4 def forward(self, spec: torch.Tensor) -> torch.Tensor: x1 = self.conv_dn_0(spec) @@ -227,7 +230,7 @@ class Net2DFastNoAttn(FeatureExtractorModel): return F.relu_(self.conv_op_bn(self.conv_op(x))) -class Net2DFastNoCoordConv(FeatureExtractorModel): +class Net2DFastNoCoordConv(BackboneModel): def __init__( self, num_features: int, @@ -242,21 +245,21 @@ class Net2DFastNoCoordConv(FeatureExtractorModel): self.conv_dn_0 = ConvBlockDownStandard( 1, self.num_features // 4, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) self.conv_dn_1 = ConvBlockDownStandard( self.num_features // 4, self.num_features // 2, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) self.conv_dn_2 = ConvBlockDownStandard( self.num_features // 2, self.num_features, - k_size=3, + kernel_size=3, pad_size=1, stride=1, ) @@ -301,6 +304,7 @@ class Net2DFastNoCoordConv(FeatureExtractorModel): padding=1, ) self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) + self.out_channels = self.num_features // 4 def forward(self, spec: torch.Tensor) -> torch.Tensor: x1 = self.conv_dn_0(spec) diff --git a/batdetect2/models/blocks.py b/batdetect2/models/blocks.py index d58bdb1..080b0fa 100644 --- a/batdetect2/models/blocks.py +++ b/batdetect2/models/blocks.py @@ -10,6 +10,8 @@ import torch import torch.nn.functional as F from torch import nn +from batdetect2.configs import BaseConfig + __all__ = [ "SelfAttention", "ConvBlockDownCoordF", @@ -19,22 +21,33 @@ __all__ = [ ] +class SelfAttentionConfig(BaseConfig): + temperature: float = 1.0 + input_channels: int = 128 + attention_channels: int = 128 + + class SelfAttention(nn.Module): """Self-Attention module. This module implements self-attention mechanism. """ - def __init__(self, ip_dim: int, att_dim: int): + def __init__( + self, + in_channels: int, + attention_channels: int, + temperature: float = 1.0, + ): super().__init__() - # Note, does not encode position information (absolute or realtive) - self.temperature = 1.0 - self.att_dim = att_dim - self.key_fun = nn.Linear(ip_dim, att_dim) - self.val_fun = nn.Linear(ip_dim, att_dim) - self.que_fun = nn.Linear(ip_dim, att_dim) - self.pro_fun = nn.Linear(att_dim, ip_dim) + # Note, does not encode position information (absolute or relative) + self.temperature = temperature + self.att_dim = attention_channels + self.key_fun = nn.Linear(in_channels, attention_channels) + self.value_fun = nn.Linear(in_channels, attention_channels) + self.query_fun = nn.Linear(in_channels, attention_channels) + self.pro_fun = nn.Linear(attention_channels, in_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.squeeze(2).permute(0, 2, 1) @@ -43,11 +56,11 @@ class SelfAttention(nn.Module): x, self.key_fun.weight.T ) + self.key_fun.bias.unsqueeze(0).unsqueeze(0) query = torch.matmul( - x, self.que_fun.weight.T - ) + self.que_fun.bias.unsqueeze(0).unsqueeze(0) + x, self.query_fun.weight.T + ) + self.query_fun.bias.unsqueeze(0).unsqueeze(0) value = torch.matmul( - x, self.val_fun.weight.T - ) + self.val_fun.bias.unsqueeze(0).unsqueeze(0) + x, self.value_fun.weight.T + ) + self.value_fun.bias.unsqueeze(0).unsqueeze(0) kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / ( self.temperature * self.att_dim @@ -63,6 +76,15 @@ class SelfAttention(nn.Module): return op +class ConvBlockDownCoordFConfig(BaseConfig): + in_channels: int + out_channels: int + input_height: int + kernel_size: int = 3 + pad_size: int = 1 + stride: int = 1 + + class ConvBlockDownCoordF(nn.Module): """Convolutional Block with Downsampling and Coord Feature. @@ -72,27 +94,27 @@ class ConvBlockDownCoordF(nn.Module): def __init__( self, - in_chn: int, - out_chn: int, - ip_height: int, - k_size: int = 3, + in_channels: int, + out_channels: int, + input_height: int, + kernel_size: int = 3, pad_size: int = 1, stride: int = 1, ): super().__init__() self.coords = nn.Parameter( - torch.linspace(-1, 1, ip_height)[None, None, ..., None], + torch.linspace(-1, 1, input_height)[None, None, ..., None], requires_grad=False, ) self.conv = nn.Conv2d( - in_chn + 1, - out_chn, - kernel_size=k_size, + in_channels + 1, + out_channels, + kernel_size=kernel_size, padding=pad_size, stride=stride, ) - self.conv_bn = nn.BatchNorm2d(out_chn) + self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3]) @@ -102,6 +124,14 @@ class ConvBlockDownCoordF(nn.Module): return x +class ConvBlockDownStandardConfig(BaseConfig): + in_channels: int + out_channels: int + kernel_size: int = 3 + pad_size: int = 1 + stride: int = 1 + + class ConvBlockDownStandard(nn.Module): """Convolutional Block with Downsampling. @@ -110,21 +140,21 @@ class ConvBlockDownStandard(nn.Module): def __init__( self, - in_chn, - out_chn, - k_size=3, - pad_size=1, - stride=1, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + pad_size: int = 1, + stride: int = 1, ): super(ConvBlockDownStandard, self).__init__() self.conv = nn.Conv2d( - in_chn, - out_chn, - kernel_size=k_size, + in_channels, + out_channels, + kernel_size=kernel_size, padding=pad_size, stride=stride, ) - self.conv_bn = nn.BatchNorm2d(out_chn) + self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x): x = F.max_pool2d(self.conv(x), 2, 2) @@ -132,6 +162,16 @@ class ConvBlockDownStandard(nn.Module): return x +class ConvBlockUpFConfig(BaseConfig): + inp_channels: int + out_channels: int + input_height: int + kernel_size: int = 3 + pad_size: int = 1 + up_mode: str = "bilinear" + up_scale: Tuple[int, int] = (2, 2) + + class ConvBlockUpF(nn.Module): """Convolutional Block with Upsampling and Coord Feature. @@ -141,10 +181,10 @@ class ConvBlockUpF(nn.Module): def __init__( self, - in_chn: int, - out_chn: int, - ip_height: int, - k_size: int = 3, + in_channels: int, + out_channels: int, + input_height: int, + kernel_size: int = 3, pad_size: int = 1, up_mode: str = "bilinear", up_scale: Tuple[int, int] = (2, 2), @@ -154,15 +194,18 @@ class ConvBlockUpF(nn.Module): self.up_scale = up_scale self.up_mode = up_mode self.coords = nn.Parameter( - torch.linspace(-1, 1, ip_height * up_scale[0])[ + torch.linspace(-1, 1, input_height * up_scale[0])[ None, None, ..., None ], requires_grad=False, ) self.conv = nn.Conv2d( - in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size + in_channels + 1, + out_channels, + kernel_size=kernel_size, + padding=pad_size, ) - self.conv_bn = nn.BatchNorm2d(out_chn) + self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: op = F.interpolate( @@ -181,6 +224,15 @@ class ConvBlockUpF(nn.Module): return op +class ConvBlockUpStandardConfig(BaseConfig): + in_channels: int + out_channels: int + kernel_size: int = 3 + pad_size: int = 1 + up_mode: str = "bilinear" + up_scale: Tuple[int, int] = (2, 2) + + class ConvBlockUpStandard(nn.Module): """Convolutional Block with Upsampling. @@ -189,9 +241,9 @@ class ConvBlockUpStandard(nn.Module): def __init__( self, - in_chn: int, - out_chn: int, - k_size: int = 3, + in_channels: int, + out_channels: int, + kernel_size: int = 3, pad_size: int = 1, up_mode: str = "bilinear", up_scale: Tuple[int, int] = (2, 2), @@ -200,9 +252,12 @@ class ConvBlockUpStandard(nn.Module): self.up_scale = up_scale self.up_mode = up_mode self.conv = nn.Conv2d( - in_chn, out_chn, kernel_size=k_size, padding=pad_size + in_channels, + out_channels, + kernel_size=kernel_size, + padding=pad_size, ) - self.conv_bn = nn.BatchNorm2d(out_chn) + self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: op = F.interpolate( diff --git a/batdetect2/models/detectors.py b/batdetect2/models/detectors.py deleted file mode 100644 index 0d5595d..0000000 --- a/batdetect2/models/detectors.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Optional, Type - -import pytorch_lightning as L -import torch -import xarray as xr -from soundevent import data -from torch import nn, optim - -from batdetect2.data.labels import ClassMapper -from batdetect2.data.preprocessing import ( - PreprocessingConfig, - preprocess, -) -from batdetect2.models.feature_extractors import Net2DFast -from batdetect2.models.post_process import ( - PostprocessConfig, - postprocess_model_outputs, -) -from batdetect2.models.typing import FeatureExtractorModel, ModelOutput -from batdetect2.train import losses -from batdetect2.train.dataset import TrainExample - - -class DetectorModel(L.LightningModule): - def __init__( - self, - class_mapper: ClassMapper, - feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast, - learning_rate: float = 1e-3, - input_height: int = 128, - num_features: int = 32, - preprocessing_config: Optional[PreprocessingConfig] = None, - postprocessing_config: Optional[PostprocessConfig] = None, - ): - super().__init__() - - preprocessing_config = preprocessing_config or PreprocessingConfig() - postprocessing_config = postprocessing_config or PostprocessConfig() - - self.save_hyperparameters() - - self.preprocessing_config = preprocessing_config - self.postprocessing_config = postprocessing_config - self.class_mapper = class_mapper - self.learning_rate = learning_rate - self.input_height = input_height - self.num_features = num_features - self.num_classes = class_mapper.num_classes - - self.feature_extractor = feature_extractor_class( - input_height=input_height, - num_features=num_features, - ) - - self.classifier = nn.Conv2d( - self.feature_extractor.num_features // 4, - self.num_classes + 1, - kernel_size=1, - padding=0, - ) - - self.bbox = nn.Conv2d( - self.feature_extractor.num_features // 4, - 2, - kernel_size=1, - padding=0, - ) - - def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore - features = self.feature_extractor(spec) - classification_logits = self.classifier(features) - classification_probs = torch.softmax(classification_logits, dim=1) - detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True) - return ModelOutput( - detection_probs=detection_probs, - size_preds=self.bbox(features), - class_probs=classification_probs[:, :-1], - features=features, - ) - - def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray: - return preprocess( - clip, - config=self.preprocessing_config, - ) - - def compute_clip_features(self, clip: data.Clip) -> torch.Tensor: - spectrogram = self.compute_spectrogram(clip) - return self.feature_extractor( - torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0) - ) - - def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction: - spectrogram = self.compute_spectrogram(clip) - spec_tensor = ( - torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0) - ) - outputs = self(spec_tensor) - return postprocess_model_outputs( - outputs, - [clip], - class_mapper=self.class_mapper, - config=self.postprocessing_config, - )[0] - - def compute_loss( - self, - outputs: ModelOutput, - batch: TrainExample, - ) -> torch.Tensor: - detection_loss = losses.focal_loss( - outputs.detection_probs, - batch.detection_heatmap, - ) - - size_loss = losses.bbox_size_loss( - outputs.size_preds, - batch.size_heatmap, - ) - - valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float() - classification_loss = losses.focal_loss( - outputs.class_probs, - batch.class_heatmap, - valid_mask=valid_mask, - ) - - return detection_loss + size_loss + classification_loss - - def training_step( # type: ignore - self, - batch: TrainExample, - ): - outputs = self.forward(batch.spec) - loss = self.compute_loss(outputs, batch) - self.log("train_loss", loss) - return loss - - def configure_optimizers(self): - optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100) - return [optimizer], [scheduler] diff --git a/batdetect2/models/heads.py b/batdetect2/models/heads.py new file mode 100644 index 0000000..d4281f4 --- /dev/null +++ b/batdetect2/models/heads.py @@ -0,0 +1,51 @@ +from typing import NamedTuple + +import torch +from torch import nn + +__all__ = ["ClassifierHead"] + + +class Output(NamedTuple): + detection: torch.Tensor + classification: torch.Tensor + + +class ClassifierHead(nn.Module): + def __init__(self, num_classes: int, in_channels: int): + super().__init__() + + self.num_classes = num_classes + self.in_channels = in_channels + self.classifier = nn.Conv2d( + self.in_channels, + # Add one to account for the background class + self.num_classes + 1, + kernel_size=1, + padding=0, + ) + + def forward(self, features: torch.Tensor) -> Output: + logits = self.classifier(features) + probs = torch.softmax(logits, dim=1) + detection_probs = probs[:, :-1].sum(dim=1, keepdim=True) + return Output( + detection=detection_probs, + classification=probs[:, :-1], + ) + + +class BBoxHead(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.bbox = nn.Conv2d( + self.feature_extractor.out_channels, + 2, + kernel_size=1, + padding=0, + ) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + return self.bbox(features) diff --git a/batdetect2/models/typing.py b/batdetect2/models/typing.py index 7a0d0ee..c39229d 100644 --- a/batdetect2/models/typing.py +++ b/batdetect2/models/typing.py @@ -6,7 +6,7 @@ import torch.nn as nn __all__ = [ "ModelOutput", - "FeatureExtractorModel", + "BackboneModel", ] @@ -41,13 +41,16 @@ class ModelOutput(NamedTuple): """Tensor with intermediate features.""" -class FeatureExtractorModel(ABC, nn.Module): +class BackboneModel(ABC, nn.Module): input_height: int """Height of the input spectrogram.""" num_features: int """Dimension of the feature tensor.""" + out_channels: int + """Number of output channels of the feature extractor.""" + @abstractmethod def forward(self, spec: torch.Tensor) -> torch.Tensor: """Forward pass of the encoder model.""" diff --git a/batdetect2/models/post_process.py b/batdetect2/post_process.py similarity index 96% rename from batdetect2/models/post_process.py rename to batdetect2/post_process.py index 97c6052..c39994c 100644 --- a/batdetect2/models/post_process.py +++ b/batdetect2/post_process.py @@ -8,7 +8,6 @@ from pydantic import BaseModel, Field from soundevent import data from torch import nn -from batdetect2.data.labels import ClassMapper from batdetect2.models.typing import ModelOutput __all__ = [ @@ -37,7 +36,8 @@ TagFunction = Callable[[int], List[data.Tag]] def postprocess_model_outputs( outputs: ModelOutput, clips: List[data.Clip], - class_mapper: ClassMapper, + classes: List[str], + decoder: Callable[[str], List[data.Tag]], config: PostprocessConfig, ) -> List[data.ClipPrediction]: """Postprocesses model outputs to generate clip predictions. @@ -108,7 +108,8 @@ def postprocess_model_outputs( size_preds, class_probs, features, - class_mapper=class_mapper, + classes=classes, + decoder=decoder, min_freq=config.min_freq, max_freq=config.max_freq, detection_threshold=config.detection_threshold, @@ -132,7 +133,8 @@ def compute_sound_events_from_outputs( size_preds: torch.Tensor, class_probs: torch.Tensor, features: torch.Tensor, - class_mapper: ClassMapper, + classes: List[str], + decoder: Callable[[str], List[data.Tag]], min_freq: int = 10000, max_freq: int = 120000, detection_threshold: float = DETECTION_THRESHOLD, @@ -181,7 +183,8 @@ def compute_sound_events_from_outputs( predicted_tags: List[data.PredictedTag] = [] for label_id, class_score in enumerate(class_prob): - corresponding_tags = class_mapper.inverse_transform(label_id) + class_name = classes[label_id] + corresponding_tags = decoder(class_name) predicted_tags.extend( [ data.PredictedTag( diff --git a/batdetect2/preprocess/__init__.py b/batdetect2/preprocess/__init__.py index 4c9560e..a9c2ee5 100644 --- a/batdetect2/preprocess/__init__.py +++ b/batdetect2/preprocess/__init__.py @@ -12,9 +12,12 @@ from batdetect2.preprocess.audio import ( load_clip_audio, ) from batdetect2.preprocess.spectrogram import ( + AmplitudeScaleConfig, FFTConfig, FrequencyConfig, - PcenConfig, + LogScaleConfig, + PcenScaleConfig, + Scales, SpecSizeConfig, SpectrogramConfig, compute_spectrogram, @@ -26,7 +29,10 @@ __all__ = [ "SpectrogramConfig", "FFTConfig", "FrequencyConfig", - "PcenConfig", + "PcenScaleConfig", + "LogScaleConfig", + "AmplitudeScaleConfig", + "Scales", "SpecSizeConfig", "PreprocessingConfig", "preprocess_audio_clip", diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py index 2026785..6a619f8 100644 --- a/batdetect2/preprocess/spectrogram.py +++ b/batdetect2/preprocess/spectrogram.py @@ -23,7 +23,18 @@ class FrequencyConfig(BaseConfig): min_freq: int = Field(default=10_000, gt=0) -class PcenConfig(BaseConfig): +class SpecSizeConfig(BaseConfig): + height: int = 256 + resize_factor: Optional[float] = 0.5 + divide_factor: Optional[int] = 32 + + +class LogScaleConfig(BaseConfig): + name: Literal["log"] = "log" + + +class PcenScaleConfig(BaseConfig): + name: Literal["pcen"] = "pcen" time_constant: float = 0.4 hop_length: int = 512 gain: float = 0.98 @@ -31,19 +42,21 @@ class PcenConfig(BaseConfig): power: float = 0.5 -class SpecSizeConfig(BaseConfig): - height: int = 256 - resize_factor: Optional[float] = 0.5 - divide_factor: Optional[int] = 32 +class AmplitudeScaleConfig(BaseConfig): + name: Literal["amplitude"] = "amplitude" + + +Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig] class SpectrogramConfig(BaseConfig): fft: FFTConfig = Field(default_factory=FFTConfig) frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) - scale: Union[Literal["log"], None, PcenConfig] = Field( - default_factory=PcenConfig + scale: Scales = Field( + default_factory=PcenScaleConfig, + discriminator="name", ) - size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) + size: SpecSizeConfig = Field(default_factory=SpecSizeConfig) denoise: bool = True max_scale: bool = False @@ -55,7 +68,7 @@ def compute_spectrogram( ) -> xr.DataArray: config = config or SpectrogramConfig() - if config.size and config.size.divide_factor: + if config.size.divide_factor: # Need to pad the audio to make sure the spectrogram has a # width compatible with the divide factor wav = pad_audio( @@ -84,12 +97,11 @@ def compute_spectrogram( if config.denoise: spec = denoise_spectrogram(spec) - if config.size: - spec = resize_spectrogram( - spec, - height=config.size.height, - resize_factor=config.size.resize_factor, - ) + spec = resize_spectrogram( + spec, + height=config.size.height, + resize_factor=config.size.resize_factor, + ) if config.max_scale: spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) @@ -180,13 +192,13 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray: def scale_spectrogram( spec: xr.DataArray, - scale: Union[Literal["log"], None, PcenConfig], + scale: Scales, dtype: DTypeLike = np.float32, ) -> xr.DataArray: - if scale == "log": + if scale.name == "log": return scale_log(spec, dtype=dtype) - if isinstance(scale, PcenConfig): + if scale.name == "pcen": return scale_pcen( spec, time_constant=scale.time_constant, diff --git a/batdetect2/train/audio_dataloader.py b/batdetect2/train/audio_dataloader.py deleted file mode 100644 index 22c9070..0000000 --- a/batdetect2/train/audio_dataloader.py +++ /dev/null @@ -1,941 +0,0 @@ -"""Functions and dataloaders for training and testing the model.""" - -import copy -from typing import List, Optional, Tuple - -import librosa -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.data -import torchaudio - -import batdetect2.utils.audio_utils as au -from batdetect2.types import ( - Annotation, - AudioLoaderAnnotationGroup, - AudioLoaderParameters, - FileAnnotation, -) - - -def generate_gt_heatmaps( - spec_op_shape: Tuple[int, int], - sampling_rate: float, - ann: AudioLoaderAnnotationGroup, - class_names: List[str], - fft_win_length: float, - fft_overlap: float, - max_freq: float, - min_freq: float, - resize_factor: float, - target_sigma: float, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]: - """Generate ground truth heatmaps from annotations. - - Parameters - ---------- - spec_op_shape : Tuple[int, int] - Shape of the input spectrogram. - sampling_rate : int - Sampling rate of the input audio in Hz. - ann : AnnotationGroup - Dictionary containing the annotation information. - params : HeatmapParameters - Parameters controlling the generation of the heatmaps. - - Returns - ------- - y_2d_det : np.ndarray - 2D heatmap of the presence of an event. - y_2d_size : np.ndarray - 2D heatmap of the size of the bounding box associated to event. - y_2d_classes : np.ndarray - 3D array containing the ground-truth class probabilities for each - pixel. - ann_aug : AnnotationGroup - A dictionary containing the annotation information of the - annotations that are within the input spectrogram, augmented with - the x and y indices of their pixel location in the input spectrogram. - """ - # spec may be resized on input into the network - num_classes = len(class_names) - op_height = spec_op_shape[0] - op_width = spec_op_shape[1] - freq_per_bin = (max_freq - min_freq) / op_height - - # start and end times - x_pos_start = au.time_to_x_coords( - ann["start_times"], - sampling_rate, - fft_win_length, - fft_overlap, - ) - x_pos_start = (resize_factor * x_pos_start).astype(np.int32) - x_pos_end = au.time_to_x_coords( - ann["end_times"], - sampling_rate, - fft_win_length, - fft_overlap, - ) - x_pos_end = (resize_factor * x_pos_end).astype(np.int32) - - # location on y axis i.e. frequency - y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin - y_pos_low = (op_height - y_pos_low).astype(np.int32) - y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin - y_pos_high = (op_height - y_pos_high).astype(np.int32) - bb_widths = x_pos_end - x_pos_start - bb_heights = y_pos_low - y_pos_high - - # Only include annotations that are within the input spectrogram - valid_inds = np.where( - (x_pos_start >= 0) - & (x_pos_start < op_width) - & (y_pos_low >= 0) - & (y_pos_low < (op_height - 1)) - )[0] - - ann_aug: AudioLoaderAnnotationGroup = { - **ann, - "start_times": ann["start_times"][valid_inds], - "end_times": ann["end_times"][valid_inds], - "high_freqs": ann["high_freqs"][valid_inds], - "low_freqs": ann["low_freqs"][valid_inds], - "class_ids": ann["class_ids"][valid_inds], - "individual_ids": ann["individual_ids"][valid_inds], - "x_inds": x_pos_start[valid_inds], - "y_inds": y_pos_low[valid_inds], - } - - # if the number of calls is only 1, then it is unique - # TODO would be better if we found these unique calls at the merging stage - if len(ann_aug["individual_ids"]) == 1: - ann_aug["individual_ids"][0] = 0 - - y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32) - y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32) - - # num classes and "background" class - y_2d_classes: np.ndarray = np.zeros( - (num_classes + 1, op_height, op_width), dtype=np.float32 - ) - - # create 2D ground truth heatmaps - for ii in valid_inds: - draw_gaussian( - y_2d_det[0, :], - (x_pos_start[ii], y_pos_low[ii]), - target_sigma, - ) - y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii] - y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii] - - cls_id = ann["class_ids"][ii] - if cls_id > -1: - draw_gaussian( - y_2d_classes[cls_id, :], - (x_pos_start[ii], y_pos_low[ii]), - target_sigma, - ) - - # be careful as this will have a 1.0 places where we have event but - # dont know gt class this will be masked in training anyway - y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0) - y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...] - y_2d_classes[np.isnan(y_2d_classes)] = 0.0 - - return y_2d_det, y_2d_size, y_2d_classes, ann_aug - - -def draw_gaussian( - heatmap: np.ndarray, - center: Tuple[int, int], - sigmax: float, - sigmay: Optional[float] = None, -) -> bool: - """Draw a 2D gaussian into the heatmap. - - If the gaussian center is outside the heatmap, then the gaussian is not - drawn. - - Parameters - ---------- - heatmap : np.ndarray - The heatmap to draw into. Should be of shape (height, width). - center : Tuple[int, int] - The center of the gaussian in (x, y) format. - sigmax : float - The standard deviation of the gaussian in the x direction. - sigmay : Optional[float], optional - The standard deviation of the gaussian in the y direction. If None, - then sigmay = sigmax, by default None. - - Returns - ------- - bool - True if the gaussian was drawn, False if it was not (because - the center was outside the heatmap). - - - """ - # center is (x, y) - # this edits the heatmap inplace - - if sigmay is None: - sigmay = sigmax - tmp_size = np.maximum(sigmax, sigmay) * 3 - mu_x = int(center[0] + 0.5) - mu_y = int(center[1] + 0.5) - w, h = heatmap.shape[0], heatmap.shape[1] - ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] - br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] - - if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0: - return False - - size = 2 * tmp_size + 1 - x = np.arange(0, size, 1, np.float32) - y = x[:, np.newaxis] - x0 = y0 = size // 2 - # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) - g = np.exp( - -((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2) - ) - g_x = max(0, -ul[0]), min(br[0], h) - ul[0] - g_y = max(0, -ul[1]), min(br[1], w) - ul[1] - img_x = max(0, ul[0]), min(br[0], h) - img_y = max(0, ul[1]), min(br[1], w) - heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum( - heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]], - g[g_y[0] : g_y[1], g_x[0] : g_x[1]], - ) - return True - - -def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray: - """Pad array with -1s.""" - return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1)) - - -def warp_spec_aug( - spec: torch.Tensor, - ann: AudioLoaderAnnotationGroup, - stretch_squeeze_delta: float, -) -> torch.Tensor: - """Warp spectrogram by randomly stretching and squeezing. - - Parameters - ---------- - spec: torch.Tensor - Spectrogram to warp. - ann: AnnotationGroup - Annotation group for the spectrogram. Must be provided to sync - the start and stop times with the spectrogram after warping. - stretch_squeeze_delta: float - Maximum amount to stretch or squeeze the spectrogram. - - Returns - ------- - torch.Tensor - Warped spectrogram. - - Notes - ----- - This function modifies the annotation group in place. - """ - # Augment spectrogram by randomly stretch and squeezing - # NOTE this also changes the start and stop time in place - - delta = stretch_squeeze_delta - op_size = (spec.shape[1], spec.shape[2]) - resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0 - resize_amt = int(spec.shape[2] * resize_fract_r) - - if resize_amt >= spec.shape[2]: - spec_r = torch.cat( - ( - spec, - torch.zeros( - (1, spec.shape[1], resize_amt - spec.shape[2]), - dtype=spec.dtype, - ), - ), - dim=2, - ) - else: - spec_r = spec[:, :, :resize_amt] - - # Resize the spectrogram - spec = F.interpolate( - spec_r.unsqueeze(0), - size=op_size, - mode="bilinear", - align_corners=False, - ).squeeze(0) - - # Update the start and stop times - ann["start_times"] *= 1.0 / resize_fract_r - ann["end_times"] *= 1.0 / resize_fract_r - - return spec - - -def mask_time_aug( - spec: torch.Tensor, - mask_max_time_perc: float, -) -> torch.Tensor: - """Mask out random blocks of time. - - Will randomly mask out a block of time in the spectrogram. The block - will be between 0.0 and `mask_max_time_perc` of the total time. - A random number of blocks will be masked out between 1 and 3. - - Parameters - ---------- - spec: torch.Tensor - Spectrogram to mask. - mask_max_time_perc: float - Maximum percentage of time to mask out. - - Returns - ------- - torch.Tensor - Spectrogram with masked out time blocks. - - Notes - ----- - This function is based on the implementation in:: - - SpecAugment: A Simple Data Augmentation Method for Automatic Speech - Recognition - """ - fm = torchaudio.transforms.TimeMasking( - int(spec.shape[1] * mask_max_time_perc) - ) - for _ in range(np.random.randint(1, 4)): - spec = fm(spec) - return spec - - -def mask_freq_aug( - spec: torch.Tensor, - mask_max_freq_perc: float, -) -> torch.Tensor: - """Mask out random blocks of frequency. - - Will randomly mask out a block of frequency in the spectrogram. The block - will be between 0.0 and `mask_max_freq_perc` of the total frequency. - A random number of blocks will be masked out between 1 and 3. - - Parameters - ---------- - spec: torch.Tensor - Spectrogram to mask. - mask_max_freq_perc: float - Maximum percentage of frequency to mask out. - - Returns - ------- - torch.Tensor - Spectrogram with masked out frequency blocks. - - Notes - ----- - This function is based on the implementation in:: - - SpecAugment: A Simple Data Augmentation Method for Automatic Speech - Recognition - """ - fm = torchaudio.transforms.FrequencyMasking( - int(spec.shape[1] * mask_max_freq_perc) - ) - for _ in range(np.random.randint(1, 4)): - spec = fm(spec) - return spec - - -def scale_vol_aug( - spec: torch.Tensor, - spec_amp_scaling: float, -) -> torch.Tensor: - """Scale the volume of the spectrogram. - - Parameters - ---------- - spec: torch.Tensor - Spectrogram to scale. - spec_amp_scaling: float - Maximum scaling factor. - - Returns - ------- - torch.Tensor - """ - return spec * np.random.random() * spec_amp_scaling - - -def echo_aug( - audio: np.ndarray, - sampling_rate: float, - echo_max_delay: float, -) -> np.ndarray: - """Add echo to audio. - - Parameters - ---------- - audio: np.ndarray - Audio to add echo to. - sampling_rate: float - Sampling rate of the audio. - echo_max_delay: float - Maximum delay of the echo in seconds. - - Returns - ------- - np.ndarray - Audio with echo added. - """ - sample_offset = ( - int(echo_max_delay * np.random.random() * sampling_rate) + 1 - ) - # NOTE: This seems to be wrong, as the echo should be added to the - # end of the audio, not the beginning. - audio[:-sample_offset] += np.random.random() * audio[sample_offset:] - return audio - - -def resample_aug( - audio: np.ndarray, - sampling_rate: float, - fft_win_length: float, - fft_overlap: float, - resize_factor: float, - spec_divide_factor: float, - spec_train_width: int, - aug_sampling_rates: List[int], -) -> Tuple[np.ndarray, float, float]: - """Resample audio augmentation. - - Will resample the audio to a random sampling rate from the list of - sampling rates in `aug_sampling_rates`. - - Parameters - ---------- - audio: np.ndarray - Audio to resample. - sampling_rate: float - Original sampling rate of the audio. - fft_win_length: float - Length of the FFT window in seconds. - fft_overlap: float - Amount of overlap between FFT windows. - resize_factor: float - Factor to resize the spectrogram by. - spec_divide_factor: float - Factor to divide the spectrogram by. - spec_train_width: int - Width of the spectrogram. - aug_sampling_rates: List[int] - List of sampling rates to resample to. - - Returns - ------- - audio : np.ndarray - Resampled audio. - sampling_rate : float - New sampling rate. - duration : float - Duration of the audio in seconds. - """ - sampling_rate_old = sampling_rate - sampling_rate = np.random.choice(aug_sampling_rates) - audio = librosa.resample( - audio, - orig_sr=sampling_rate_old, - target_sr=sampling_rate, - res_type="polyphase", - ) - - audio = au.pad_audio( - audio, - sampling_rate, - fft_win_length, - fft_overlap, - resize_factor, - spec_divide_factor, - spec_train_width, - ) - duration = audio.shape[0] / float(sampling_rate) - return audio, sampling_rate, duration - - -def resample_audio( - num_samples: int, - sampling_rate: float, - audio2: np.ndarray, - sampling_rate2: float, -) -> Tuple[np.ndarray, float]: - """Resample audio. - - Parameters - ---------- - num_samples: int - Expected number of samples for the output audio. - sampling_rate: float - Original sampling rate of the audio. - audio2: np.ndarray - Audio to resample. - sampling_rate2: float - Target sampling rate of the audio. - - Returns - ------- - audio2 : np.ndarray - Resampled audio. - sampling_rate2 : float - New sampling rate. - """ - # resample to target sampling rate - if sampling_rate != sampling_rate2: - audio2 = librosa.resample( - audio2, - orig_sr=sampling_rate2, - target_sr=sampling_rate, - res_type="polyphase", - ) - sampling_rate2 = sampling_rate - - # pad or trim to the correct length - if audio2.shape[0] < num_samples: - audio2 = np.hstack( - ( - audio2, - np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype), - ) - ) - elif audio2.shape[0] > num_samples: - audio2 = audio2[:num_samples] - - return audio2, sampling_rate2 - - -def combine_audio_aug( - audio: np.ndarray, - sampling_rate: float, - ann: AudioLoaderAnnotationGroup, - audio2: np.ndarray, - sampling_rate2: float, - ann2: AudioLoaderAnnotationGroup, -) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]: - """Combine two audio files. - - Will combine two audio files by resampling them to the same sampling rate - and then combining them with a random weight. The annotations will be - combined by taking the union of the two sets of annotations. - - Parameters - ---------- - audio: np.ndarray - First Audio to combine. - sampling_rate: int - Sampling rate of the first audio. - ann: AnnotationGroup - Annotations for the first audio. - audio2: np.ndarray - Second Audio to combine. - sampling_rate2: int - Sampling rate of the second audio. - ann2: AnnotationGroup - Annotations for the second audio. - - Returns - ------- - audio : np.ndarray - Combined audio. - ann : AnnotationGroup - Combined annotations. - """ - # resample so they are the same - audio2, sampling_rate2 = resample_audio( - audio.shape[0], - sampling_rate, - audio2, - sampling_rate2, - ) - - # # set mean and std to be the same - # audio2 = (audio2 - audio2.mean()) - # audio2 = (audio2/audio2.std())*audio.std() - # audio2 = audio2 + audio.mean() - - if ( - ann.get("annotated", False) - and (ann2.get("annotated", False)) - and (sampling_rate2 == sampling_rate) - and (audio.shape[0] == audio2.shape[0]) - ): - comb_weight = 0.3 + np.random.random() * 0.4 - audio = comb_weight * audio + (1 - comb_weight) * audio2 - inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"]))) - for kk in ann.keys(): - # when combining calls from different files, assume they come - # from different individuals - if kk == "individual_ids": - if (ann[kk] > -1).sum() > 0: - ann2[kk][ann2[kk] > -1] += ( - np.max(ann[kk][ann[kk] > -1]) + 1 - ) - - if (kk != "class_id_file") and (kk != "annotated"): - ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds] - - return audio, ann - - -def _prepare_annotation( - annotation: Annotation, - class_names: List[str], -) -> Annotation: - try: - class_id = class_names.index(annotation["class"]) - except ValueError: - class_id = -1 - - ann: Annotation = { - **annotation, - "class_id": class_id, - } - - if "individual" in ann: - ann["individual"] = int(ann["individual"]) # type: ignore - - return ann - - -def _prepare_file_annotation( - annotation: FileAnnotation, - class_names: List[str], - classes_to_ignore: List[str], -) -> AudioLoaderAnnotationGroup: - annotations = [ - _prepare_annotation(ann, class_names) - for ann in annotation["annotation"] - if ann["class"] not in classes_to_ignore - ] - - try: - class_id_file = class_names.index(annotation["class_name"]) - except ValueError: - class_id_file = -1 - - ret: AudioLoaderAnnotationGroup = { - "id": annotation["id"], - "annotated": annotation["annotated"], - "duration": annotation["duration"], - "issues": annotation["issues"], - "time_exp": annotation["time_exp"], - "class_name": annotation["class_name"], - "notes": annotation["notes"], - "annotation": annotations, - "start_times": np.array([ann["start_time"] for ann in annotations]), - "end_times": np.array([ann["end_time"] for ann in annotations]), - "high_freqs": np.array([ann["high_freq"] for ann in annotations]), - "low_freqs": np.array([ann["low_freq"] for ann in annotations]), - "class_ids": np.array( - [ann.get("class_id", -1) for ann in annotations] - ), - "individual_ids": np.array([ann["individual"] for ann in annotations]), - "class_id_file": class_id_file, - } - - return ret - - -class AudioLoader(torch.utils.data.Dataset): - """Main AudioLoader for training and testing.""" - - def __init__( - self, - data_anns_ip: List[FileAnnotation], - params: AudioLoaderParameters, - dataset_name: Optional[str] = None, - is_train: bool = False, - return_spec_for_viz: bool = False, - ): - self.is_train = is_train - self.params = params - self.return_spec_for_viz = return_spec_for_viz - self.data_anns: List[AudioLoaderAnnotationGroup] = [ - _prepare_file_annotation( - ann, - params["class_names"], - params["classes_to_ignore"], - ) - for ann in data_anns_ip - ] - - ann_cnt = [len(aa["annotation"]) for aa in self.data_anns] - self.max_num_anns = 2 * np.max( - ann_cnt - ) # x2 because we may be combining files during training - - print("\n") - if dataset_name is not None: - print("Dataset : " + dataset_name) - if self.is_train: - print("Split type : train") - else: - print("Split type : test") - print("Num files : " + str(len(self.data_anns))) - print("Num calls : " + str(np.sum(ann_cnt))) - - def get_file_and_anns( - self, - index: Optional[int] = None, - ) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]: - """Get an audio file and its annotations. - - Parameters - ---------- - index : int, optional - Index of the file to be loaded. If None, a random file is chosen. - - Returns - ------- - audio_raw : np.ndarray - Loaded audio file. - sampling_rate : float - Sampling rate of the audio file. - duration : float - Duration of the audio file in seconds. - ann : AnnotationGroup - AnnotationGroup object containing the annotations for the audio file. - """ - # if no file specified, choose random one - if index is None: - index = np.random.randint(0, len(self.data_anns)) - - audio_file = self.data_anns[index]["file_path"] - sampling_rate, audio_raw = au.load_audio( - audio_file, - self.data_anns[index]["time_exp"], - self.params["target_samp_rate"], - self.params["scale_raw_audio"], - ) - - # copy annotation - ann = copy.deepcopy(self.data_anns[index]) - # ann["annotated"] = self.data_anns[index]["annotated"] - # ann["class_id_file"] = self.data_anns[index]["class_id_file"] - # keys = [ - # "start_times", - # "end_times", - # "high_freqs", - # "low_freqs", - # "class_ids", - # "individual_ids", - # ] - # for kk in keys: - # ann[kk] = self.data_anns[index][kk].copy() - - # if train then grab a random crop - if self.is_train: - nfft = int(self.params["fft_win_length"] * sampling_rate) - noverlap = int(self.params["fft_overlap"] * nfft) - length_samples = ( - self.params["spec_train_width"] * (nfft - noverlap) + noverlap - ) - - if audio_raw.shape[0] - length_samples > 0: - sample_crop = np.random.randint( - audio_raw.shape[0] - length_samples - ) - else: - sample_crop = 0 - audio_raw = audio_raw[sample_crop : sample_crop + length_samples] - ann["start_times"] = ann["start_times"] - sample_crop / float( - sampling_rate - ) - ann["end_times"] = ann["end_times"] - sample_crop / float( - sampling_rate - ) - - # pad audio - if self.is_train: - op_spec_target_size = self.params["spec_train_width"] - else: - op_spec_target_size = None - audio_raw = au.pad_audio( - audio_raw, - sampling_rate, - self.params["fft_win_length"], - self.params["fft_overlap"], - self.params["resize_factor"], - self.params["spec_divide_factor"], - op_spec_target_size, - ) - duration = audio_raw.shape[0] / float(sampling_rate) - - # sort based on time - inds = np.argsort(ann["start_times"]) - for kk in ann.keys(): - if (kk != "class_id_file") and (kk != "annotated"): - ann[kk] = ann[kk][inds] - - return audio_raw, sampling_rate, duration, ann - - def __getitem__(self, index): - """Get an item from the dataset.""" - # load audio file - audio, sampling_rate, duration, ann = self.get_file_and_anns(index) - - # augment on raw audio - if self.is_train and self.params["augment_at_train"]: - # augment - combine with random audio file - if ( - self.params["augment_at_train_combine"] - and np.random.random() < self.params["aug_prob"] - ): - ( - audio2, - sampling_rate2, - _, - ann2, - ) = self.get_file_and_anns() - audio, ann = combine_audio_aug( - audio, sampling_rate, ann, audio2, sampling_rate2, ann2 - ) - - # simulate echo by adding delayed copy of the file - if np.random.random() < self.params["aug_prob"]: - audio = echo_aug( - audio, - sampling_rate, - echo_max_delay=self.params["echo_max_delay"], - ) - - # resample the audio - # if np.random.random() < self.params["aug_prob"]: - # audio, sampling_rate, duration = resample_aug( - # audio, sampling_rate, self.params - # ) - - # create spectrogram - spec, _ = au.generate_spectrogram( - audio, - sampling_rate, - params=dict( - fft_win_length=self.params["fft_win_length"], - fft_overlap=self.params["fft_overlap"], - max_freq=self.params["max_freq"], - min_freq=self.params["min_freq"], - spec_scale=self.params["spec_scale"], - denoise_spec_avg=self.params["denoise_spec_avg"], - max_scale_spec=self.params["max_scale_spec"], - ), - ) - rsf = self.params["resize_factor"] - spec_op_shape = ( - int(self.params["spec_height"] * rsf), - int(spec.shape[1] * rsf), - ) - - # resize the spec - spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0) - spec = F.interpolate( - spec, - size=spec_op_shape, - mode="bilinear", - align_corners=False, - ).squeeze(0) - - # augment spectrogram - if self.is_train and self.params["augment_at_train"]: - if np.random.random() < self.params["aug_prob"]: - spec = scale_vol_aug( - spec, - spec_amp_scaling=self.params["spec_amp_scaling"], - ) - - if np.random.random() < self.params["aug_prob"]: - spec = warp_spec_aug( - spec, - ann, - stretch_squeeze_delta=self.params["stretch_squeeze_delta"], - ) - - if np.random.random() < self.params["aug_prob"]: - spec = mask_time_aug( - spec, - mask_max_time_perc=self.params["mask_max_time_perc"], - ) - - if np.random.random() < self.params["aug_prob"]: - spec = mask_freq_aug( - spec, - mask_max_freq_perc=self.params["mask_max_freq_perc"], - ) - - outputs = {} - outputs["spec"] = spec - if self.return_spec_for_viz: - outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze( - 0 - ) - - # create ground truth heatmaps - ( - outputs["y_2d_det"], - outputs["y_2d_size"], - outputs["y_2d_classes"], - ann_aug, - ) = generate_gt_heatmaps( - spec_op_shape, - sampling_rate, - ann, - class_names=self.params["class_names"], - fft_win_length=self.params["fft_win_length"], - fft_overlap=self.params["fft_overlap"], - max_freq=self.params["max_freq"], - min_freq=self.params["min_freq"], - resize_factor=self.params["resize_factor"], - target_sigma=self.params["target_sigma"], - ) - - # hack to get around requirement that all vectors are the same length - # in the output batch - pad_size = self.max_num_anns - len(ann_aug["individual_ids"]) - outputs["is_valid"] = pad_aray( - np.ones(len(ann_aug["individual_ids"])), pad_size - ) - keys = [ - "class_ids", - "individual_ids", - "x_inds", - "y_inds", - "start_times", - "end_times", - "low_freqs", - "high_freqs", - ] - for kk in keys: - outputs[kk] = pad_aray(ann_aug[kk], pad_size) - - # convert to pytorch - for kk in outputs.keys(): - if type(outputs[kk]) != torch.Tensor: - outputs[kk] = torch.from_numpy(outputs[kk]) - - # scalars - outputs["class_id_file"] = ann["class_id_file"] - outputs["annotated"] = ann["annotated"] - outputs["duration"] = duration - outputs["sampling_rate"] = sampling_rate - outputs["file_id"] = index - - return outputs - - def __len__(self): - """Denotes the total number of samples.""" - return len(self.data_anns) diff --git a/batdetect2/train/augmentations.py b/batdetect2/train/augmentations.py index 159d1a9..1fc2611 100644 --- a/batdetect2/train/augmentations.py +++ b/batdetect2/train/augmentations.py @@ -1,133 +1,161 @@ -from functools import wraps -from typing import Callable, List, Optional +from typing import Callable, Optional, Union import numpy as np import xarray as xr +from pydantic import Field +from soundevent import arrays +from soundevent.arrays import operations as ops + +from batdetect2.configs import BaseConfig +from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram Augmentation = Callable[[xr.Dataset], xr.Dataset] -AUGMENTATION_PROBABILITY = 0.2 -MAX_DELAY = 0.005 -STRETCH_SQUEEZE_DELTA = 0.04 -MASK_MAX_TIME_PERC: float = 0.05 -MASK_MAX_FREQ_PERC: float = 0.10 +class AugmentationConfig(BaseConfig): + enable: bool = True + probability: float = 0.2 -def maybe_apply( - augmentation: Callable, - prob: float = AUGMENTATION_PROBABILITY, -) -> Callable: - """Apply an augmentation with a given probability.""" - - @wraps(augmentation) - def _augmentation(x): - if np.random.rand() > prob: - return x - return augmentation(x) - - return _augmentation +class SubclipConfig(BaseConfig): + enable: bool = True + duration: Optional[float] = None def select_random_subclip( - train_example: xr.Dataset, + example: xr.Dataset, + start_time: Optional[float] = None, duration: Optional[float] = None, - proportion: float = 0.9, + width: Optional[int] = None, ) -> xr.Dataset: """Select a random subclip from a clip.""" + step = arrays.get_dim_step(example, "time") # type: ignore - time_coords = train_example.coords["time"] + if width is None: + if duration is None: + raise ValueError("Either duration or width must be provided") - start_time = time_coords.attrs.get("min", time_coords.min()) - end_time = time_coords.attrs.get("max", time_coords.max()) + width = int(np.floor(duration / step)) if duration is None: - duration = (end_time - start_time) * proportion + duration = width * step - start_time = np.random.uniform(start_time, end_time - duration) - return train_example.sel(time=slice(start_time, start_time + duration)) + if start_time is None: + start, stop = arrays.get_dim_range(example, "time") # type: ignore + start_time = np.random.uniform(start, stop - duration) + + start_index = arrays.get_coord_index( + example, # type: ignore + "time", + start_time, + ) + end_index = start_index + width - 1 + start_time = example.time.values[start_index] + end_time = example.time.values[end_index] + + return example.sel( + time=slice(start_time, end_time), + audio_time=slice(start_time, end_time), + ) -def combine_audio( - audio1: xr.DataArray, - audio2: xr.DataArray, - alpha: Optional[float] = None, - min_alpha: float = 0.3, - max_alpha: float = 0.7, -) -> xr.DataArray: +class MixAugmentationConfig(AugmentationConfig): + min_weight: float = 0.3 + max_weight: float = 0.7 + + +def mix_examples( + example: xr.Dataset, + other: xr.Dataset, + weight: Optional[float] = None, + min_weight: float = 0.3, + max_weight: float = 0.7, + config: Optional[PreprocessingConfig] = None, +) -> xr.Dataset: """Combine two audio clips.""" + config = config or PreprocessingConfig() - if alpha is None: - alpha = np.random.uniform(min_alpha, max_alpha) + if weight is None: + weight = np.random.uniform(min_weight, max_weight) - return alpha * audio1 + (1 - alpha) * audio2.data + audio2 = other["audio"].values + audio1 = ops.adjust_dim_width(example["audio"], "audio_time", len(audio2)) + combined = weight * audio1 + (1 - weight) * audio2 + + spec = compute_spectrogram( + combined.rename({"audio_time": "time"}), + config=config.spectrogram, + ) + + detection_heatmap = xr.apply_ufunc( + np.maximum, + example["detection"], + other["detection"].values, + ) + + class_heatmap = xr.apply_ufunc( + np.maximum, + example["class"], + other["class"].values, + ) + + size_heatmap = example["size"] + other["size"].values + + return xr.Dataset( + { + "audio": combined, + "spectrogram": spec, + "detection": detection_heatmap, + "class": class_heatmap, + "size": size_heatmap, + } + ) -# def random_mix( -# audio: xr.DataArray, -# clip: data.ClipAnnotation, -# provider: Optional[ClipProvider] = None, -# alpha: Optional[float] = None, -# min_alpha: float = 0.3, -# max_alpha: float = 0.7, -# join_annotations: bool = True, -# ) -> Tuple[xr.DataArray, data.ClipAnnotation]: -# """Mix two audio clips.""" -# if provider is None: -# raise ValueError("No audio provider given.") -# -# try: -# other_audio, other_clip = provider(clip) -# except (StopIteration, ValueError): -# raise ValueError("No more audio sources available.") -# -# new_audio = combine_audio( -# audio, -# other_audio, -# alpha=alpha, -# min_alpha=min_alpha, -# max_alpha=max_alpha, -# ) -# -# if join_annotations: -# clip = clip.model_copy( -# update=dict( -# sound_events=clip.sound_events + other_clip.sound_events, -# ) -# ) -# -# return new_audio, clip +class EchoAugmentationConfig(AugmentationConfig): + max_delay: float = 0.005 + min_weight: float = 0.0 + max_weight: float = 1.0 def add_echo( - train_example: xr.Dataset, + example: xr.Dataset, delay: Optional[float] = None, - alpha: Optional[float] = None, - min_alpha: float = 0.0, - max_alpha: float = 1.0, - max_delay: float = MAX_DELAY, + weight: Optional[float] = None, + min_weight: float = 0.1, + max_weight: float = 1.0, + max_delay: float = 0.005, + config: Optional[PreprocessingConfig] = None, ) -> xr.Dataset: """Add a delay to the audio.""" + config = config or PreprocessingConfig() + if delay is None: delay = np.random.uniform(0, max_delay) - if alpha is None: - alpha = np.random.uniform(min_alpha, max_alpha) + if weight is None: + weight = np.random.uniform(min_weight, max_weight) - spec = train_example["spectrogram"] + audio = example["audio"] + step = arrays.get_dim_step(audio, "audio_time") + audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0) + audio = audio + weight * audio_delay - time_coords = spec.coords["time"] - start_time = time_coords.attrs["min"] - end_time = time_coords.attrs["max"] - step = (end_time - start_time) / time_coords.size + spectrogram = compute_spectrogram( + audio.rename({"audio_time": "time"}), + config=config.spectrogram, + ) - spec_delay = spec.shift(time=int(delay / step), fill_value=0) + return example.assign(audio=audio, spectrogram=spectrogram) - return train_example.assign(spectrogram=spec + alpha * spec_delay) + +class VolumeAugmentationConfig(AugmentationConfig): + min_scaling: float = 0.0 + max_scaling: float = 2.0 def scale_volume( - train_example: xr.Dataset, + example: xr.Dataset, factor: Optional[float] = None, max_scaling: float = 2, min_scaling: float = 0, @@ -136,106 +164,235 @@ def scale_volume( if factor is None: factor = np.random.uniform(min_scaling, max_scaling) - return train_example.assign( - spectrogram=train_example["spectrogram"] * factor - ) + return example.assign(spectrogram=example["spectrogram"] * factor) + + +class WarpAugmentationConfig(AugmentationConfig): + delta: float = 0.04 def warp_spectrogram( - train_example: xr.Dataset, + example: xr.Dataset, factor: Optional[float] = None, - delta: float = STRETCH_SQUEEZE_DELTA, + delta: float = 0.04, ) -> xr.Dataset: """Warp a spectrogram.""" if factor is None: factor = np.random.uniform(1 - delta, 1 + delta) - time_coords = train_example.coords["time"] - start_time = time_coords.attrs["min"] - end_time = time_coords.attrs["max"] + start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore duration = end_time - start_time new_time = np.linspace( start_time, start_time + duration * factor, - train_example.time.size, + example.time.size, ) - return train_example.interp(time=new_time) + spectrogram = ( + example["spectrogram"] + .interp( + coords={"time": new_time}, + method="linear", + kwargs=dict( + fill_value=0, + ), + ) + .clip(min=0) + ) + + detection = example["detection"].interp( + time=new_time, + method="nearest", + kwargs=dict( + fill_value=0, + ), + ) + + classification = example["class"].interp( + time=new_time, + method="nearest", + kwargs=dict( + fill_value=0, + ), + ) + + size = example["size"].interp( + time=new_time, + method="nearest", + kwargs=dict( + fill_value=0, + ), + ) + + return example.assign( + { + "time": new_time, + "spectrogram": spectrogram, + "detection": detection, + "class": classification, + "size": size, + } + ) def mask_axis( - train_example: xr.Dataset, + array: xr.DataArray, dim: str, start: float, end: float, - mask_all: bool = False, - mask_value: float = 0, -) -> xr.Dataset: - if dim not in train_example.dims: + mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean, +) -> xr.DataArray: + if dim not in array.dims: raise ValueError(f"Axis {dim} not found in array") - coord = train_example.coords[dim] + coord = array.coords[dim] condition = (coord < start) | (coord > end) - if mask_all: - return train_example.where(condition, other=mask_value) + if callable(mask_value): + mask_value = mask_value(array) - return train_example.assign( - spectrogram=train_example.spectrogram.where( - condition, other=mask_value - ) - ) + return array.where(condition, other=mask_value) + + +class TimeMaskAugmentationConfig(AugmentationConfig): + max_perc: float = 0.05 + max_masks: int = 3 def mask_time( - train_example: xr.Dataset, - max_time_mask: float = MASK_MAX_TIME_PERC, - max_num_masks: int = 3, + example: xr.Dataset, + max_perc: float = 0.05, + max_mask: int = 3, ) -> xr.Dataset: """Mask a random section of the time axis.""" + num_masks = np.random.randint(1, max_mask + 1) + start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore - num_masks = np.random.randint(1, max_num_masks + 1) - - time_coord = train_example.coords["time"] - start_time = time_coord.attrs.get("min", time_coord.min()) - end_time = time_coord.attrs.get("max", time_coord.max()) - + spectrogram = example["spectrogram"] for _ in range(num_masks): - mask_size = np.random.uniform(0, max_time_mask) + mask_size = np.random.uniform(0, max_perc) * (end_time - start_time) start = np.random.uniform(start_time, end_time - mask_size) end = start + mask_size - train_example = mask_axis(train_example, "time", start, end) + spectrogram = mask_axis(spectrogram, "time", start, end) - return train_example + return example.assign(spectrogram=spectrogram) + + +class FrequencyMaskAugmentationConfig(AugmentationConfig): + max_perc: float = 0.10 + max_masks: int = 3 def mask_frequency( - train_example: xr.Dataset, - max_freq_mask: float = MASK_MAX_FREQ_PERC, - max_num_masks: int = 3, + example: xr.Dataset, + max_perc: float = 0.10, + max_masks: int = 3, ) -> xr.Dataset: """Mask a random section of the frequency axis.""" + num_masks = np.random.randint(1, max_masks + 1) + min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore - num_masks = np.random.randint(1, max_num_masks + 1) - - freq_coord = train_example.coords["frequency"] - min_freq = float(freq_coord.min()) - max_freq = float(freq_coord.max()) - + spectrogram = example["spectrogram"] for _ in range(num_masks): - mask_size = np.random.uniform(0, max_freq_mask) + mask_size = np.random.uniform(0, max_perc) * (max_freq - min_freq) start = np.random.uniform(min_freq, max_freq - mask_size) end = start + mask_size - train_example = mask_axis(train_example, "frequency", start, end) + spectrogram = mask_axis(spectrogram, "frequency", start, end) - return train_example + return example.assign(spectrogram=spectrogram) -AUGMENTATIONS: List[Augmentation] = [ - select_random_subclip, - add_echo, - scale_volume, - mask_time, - mask_frequency, -] +class AugmentationsConfig(BaseConfig): + subclip: SubclipConfig = Field(default_factory=SubclipConfig) + mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig) + echo: EchoAugmentationConfig = Field( + default_factory=EchoAugmentationConfig + ) + volume: VolumeAugmentationConfig = Field( + default_factory=VolumeAugmentationConfig + ) + warp: WarpAugmentationConfig = Field( + default_factory=WarpAugmentationConfig + ) + time_mask: TimeMaskAugmentationConfig = Field( + default_factory=TimeMaskAugmentationConfig + ) + frequency_mask: FrequencyMaskAugmentationConfig = Field( + default_factory=FrequencyMaskAugmentationConfig + ) + + +def should_apply(config: AugmentationConfig) -> bool: + if not config.enable: + return False + + return np.random.uniform() < config.probability + + +def augment_example( + example: xr.Dataset, + config: AugmentationsConfig, + preprocessing_config: Optional[PreprocessingConfig] = None, + others: Optional[Callable[[], xr.Dataset]] = None, +) -> xr.Dataset: + if config.subclip.enable: + example = select_random_subclip( + example, + duration=config.subclip.duration, + ) + + if should_apply(config.mix) and others is not None: + other = others() + + if config.subclip.enable: + other = select_random_subclip( + other, + duration=config.subclip.duration, + ) + + example = mix_examples( + example, + other, + min_weight=config.mix.min_weight, + max_weight=config.mix.max_weight, + config=preprocessing_config, + ) + + if should_apply(config.echo): + example = add_echo( + example, + max_delay=config.echo.max_delay, + min_weight=config.echo.min_weight, + max_weight=config.echo.max_weight, + config=preprocessing_config, + ) + + if should_apply(config.volume): + example = scale_volume( + example, + max_scaling=config.volume.max_scaling, + min_scaling=config.volume.min_scaling, + ) + + if should_apply(config.warp): + example = warp_spectrogram( + example, + delta=config.warp.delta, + ) + + if should_apply(config.time_mask): + example = mask_time( + example, + max_perc=config.time_mask.max_perc, + max_mask=config.time_mask.max_masks, + ) + + if should_apply(config.frequency_mask): + example = mask_frequency( + example, + max_perc=config.frequency_mask.max_perc, + max_masks=config.frequency_mask.max_masks, + ) + + return example diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index 19d1dc2..2c7b04c 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -1,12 +1,14 @@ import os from pathlib import Path -from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union +from typing import NamedTuple, Optional, Sequence, Union +import numpy as np import torch import xarray as xr from soundevent import data from torch.utils.data import Dataset +from batdetect2.train.augmentations import AugmentationsConfig, augment_example from batdetect2.train.preprocess import PreprocessingConfig __all__ = [ @@ -34,21 +36,44 @@ class LabeledDataset(Dataset): def __init__( self, filenames: Sequence[PathLike], - transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None, + augment: bool = False, + preprocessing_config: Optional[PreprocessingConfig] = None, + augmentation_config: Optional[AugmentationsConfig] = None, ): self.filenames = filenames - self.transform = transform + self.augment = augment + self.preprocessing_config = ( + preprocessing_config or PreprocessingConfig() + ) + self.agumentation_config = augmentation_config or AugmentationsConfig() def __len__(self): return len(self.filenames) def __getitem__(self, idx) -> TrainExample: - data = self.load(idx) + dataset = self.get_dataset(idx) + + if self.augment: + dataset = augment_example( + dataset, + self.agumentation_config, + preprocessing_config=self.preprocessing_config, + others=self.get_random_example, + ) + return TrainExample( - spec=data["spectrogram"], - detection_heatmap=data["detection"], - class_heatmap=data["class"], - size_heatmap=data["size"], + spec=torch.tensor( + dataset["spectrogram"].values.astype(np.float32) + ).unsqueeze(0), + detection_heatmap=torch.tensor( + dataset["detection"].values.astype(np.float32) + ), + class_heatmap=torch.tensor( + dataset["class"].values.astype(np.float32) + ), + size_heatmap=torch.tensor( + dataset["size"].values.astype(np.float32) + ), idx=torch.tensor(idx), ) @@ -56,44 +81,14 @@ class LabeledDataset(Dataset): def from_directory(cls, directory: PathLike, extension: str = ".nc"): return cls(get_files(directory, extension)) - def load(self, idx) -> Dict[str, torch.Tensor]: - dataset = self.get_dataset(idx) - return { - "spectrogram": torch.tensor( - dataset["spectrogram"].values - ).unsqueeze(0), - "detection": torch.tensor(dataset["detection"].values), - "class": torch.tensor(dataset["class"].values), - "size": torch.tensor(dataset["size"].values), - } + def get_random_example(self) -> xr.Dataset: + idx = np.random.randint(0, len(self)) + return self.get_dataset(idx) - def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset: - if self.transform is not None: - return self.transform(dataset) - - return dataset - - def get_dataset(self, idx): + def get_dataset(self, idx) -> xr.Dataset: return xr.open_dataset(self.filenames[idx]) - def get_spectrogram(self, idx): - return xr.open_dataset(self.filenames[idx])["spectrogram"] - - def get_detection_mask(self, idx): - return xr.open_dataset(self.filenames[idx])["detection"] - - def get_class_mask(self, idx): - return xr.open_dataset(self.filenames[idx])["class"] - - def get_size_mask(self, idx): - return xr.open_dataset(self.filenames[idx])["size"] - - def get_clip_annotation(self, idx): - filename = self.filenames[idx] - dataset = xr.open_dataset(filename) - clip_annotation = dataset.attrs["clip_annotation"] - return data.ClipAnnotation.model_validate_json(clip_annotation) - - def get_preprocessing_configuration(self, idx): - config = xr.open_dataset(self.filenames[idx]).attrs["configuration"] - return PreprocessingConfig.model_validate_json(config) + def get_clip_annotation(self, idx) -> data.ClipAnnotation: + return data.ClipAnnotation.model_validate_json( + self.get_dataset(idx).attrs["clip_annotation"] + ) diff --git a/batdetect2/train/labels.py b/batdetect2/train/labels.py index a1dc340..2a38892 100644 --- a/batdetect2/train/labels.py +++ b/batdetect2/train/labels.py @@ -1,18 +1,14 @@ -from typing import Sequence, Tuple +from typing import Callable, List, Optional, Sequence, Tuple import numpy as np import xarray as xr from scipy.ndimage import gaussian_filter from soundevent import arrays, data, geometry from soundevent.geometry.operations import Positions -from soundevent.types import ClassMapper from batdetect2.configs import BaseConfig -__all__ = [ - "ClassMapper", - "generate_heatmaps", -] +__all__ = ["generate_heatmaps"] class HeatmapsConfig(BaseConfig): @@ -25,7 +21,8 @@ class HeatmapsConfig(BaseConfig): def generate_heatmaps( sound_events: Sequence[data.SoundEventAnnotation], spec: xr.DataArray, - class_mapper: ClassMapper, + class_names: List[str], + encoder: Callable[[data.SoundEventAnnotation], Optional[str]], target_sigma: float = 3.0, position: Positions = "bottom-left", time_scale: float = 1000.0, @@ -42,10 +39,10 @@ def generate_heatmaps( # Initialize heatmaps detection_heatmap = xr.zeros_like(spec, dtype=dtype) class_heatmap = xr.DataArray( - data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype), + data=np.zeros((len(class_names), *spec.shape), dtype=dtype), dims=["category", *spec.dims], coords={ - "category": [*class_mapper.class_labels], + "category": [*class_names], **spec.coords, }, ) @@ -92,7 +89,7 @@ def generate_heatmaps( ) # Get the class name of the sound event - class_name = class_mapper.encode(sound_event_annotation) + class_name = encoder(sound_event_annotation) if class_name is None: # If the label is None skip the sound event diff --git a/batdetect2/train/light.py b/batdetect2/train/light.py deleted file mode 100644 index 2cbc047..0000000 --- a/batdetect2/train/light.py +++ /dev/null @@ -1,56 +0,0 @@ -import pytorch_lightning as L -from torch import Tensor, optim - -from batdetect2.models.typing import DetectionModel, ModelOutput -from batdetect2.train import losses - -from batdetect2.train.dataset import TrainExample - - -__all__ = [ - "LitDetectorModel", -] - - -class LitDetectorModel(L.LightningModule): - model: DetectionModel - - def __init__(self, model: DetectionModel, learning_rate: float = 1e-3): - super().__init__() - self.model = model - self.learning_rate = learning_rate - - def compute_loss( - self, - outputs: ModelOutput, - batch: TrainExample, - ) -> Tensor: - detection_loss = losses.focal_loss( - outputs.detection_probs, - batch.detection_heatmap, - ) - - size_loss = losses.bbox_size_loss( - outputs.size_preds, - batch.size_heatmap, - ) - - valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float() - classification_loss = losses.focal_loss( - outputs.class_probs, - batch.class_heatmap, - valid_mask=valid_mask, - ) - - return detection_loss + size_loss + classification_loss - - def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore - outputs: ModelOutput = self.model(batch.spec) - loss = self.compute_loss(outputs, batch) - self.log("train_loss", loss) - return loss - - def configure_optimizers(self): - optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100) - return [optimizer], [scheduler] diff --git a/batdetect2/train/losses.py b/batdetect2/train/losses.py index 22a7eea..e88fbfb 100644 --- a/batdetect2/train/losses.py +++ b/batdetect2/train/losses.py @@ -1,7 +1,16 @@ -from typing import Optional +from typing import NamedTuple, Optional import torch import torch.nn.functional as F +from pydantic import Field + +from batdetect2.configs import BaseConfig +from batdetect2.models.typing import ModelOutput +from batdetect2.train.dataset import TrainExample + + +class SizeLossConfig(BaseConfig): + weight: float = 0.1 def bbox_size_loss( @@ -17,6 +26,11 @@ def bbox_size_loss( ) +class FocalLossConfig(BaseConfig): + beta: float = 4 + alpha: float = 2 + + def focal_loss( pred: torch.Tensor, gt: torch.Tensor, @@ -44,7 +58,7 @@ def focal_loss( ) if weights is not None: - pos_loss = pos_loss * weights + pos_loss = pos_loss * torch.tensor(weights) # neg_loss = neg_loss*weights if valid_mask is not None: @@ -75,3 +89,71 @@ def mse_loss( else: op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum() return op + + +class DetectionLossConfig(BaseConfig): + weight: float = 1.0 + focal: FocalLossConfig = Field(default_factory=FocalLossConfig) + + +class ClassificationLossConfig(BaseConfig): + weight: float = 2.0 + focal: FocalLossConfig = Field(default_factory=FocalLossConfig) + class_weights: Optional[list[float]] = None + + +class LossConfig(BaseConfig): + detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig) + size: SizeLossConfig = Field(default_factory=SizeLossConfig) + classification: ClassificationLossConfig = Field( + default_factory=ClassificationLossConfig + ) + + +class Losses(NamedTuple): + detection: torch.Tensor + size: torch.Tensor + classification: torch.Tensor + total: torch.Tensor + + +def compute_loss( + batch: TrainExample, + outputs: ModelOutput, + conf: LossConfig, + class_weights: Optional[torch.Tensor] = None, +) -> Losses: + detection_loss = focal_loss( + outputs.detection_probs, + batch.detection_heatmap, + beta=conf.detection.focal.beta, + alpha=conf.detection.focal.alpha, + ) + + size_loss = bbox_size_loss( + outputs.size_preds, + batch.size_heatmap, + ) + + valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float() + classification_loss = focal_loss( + outputs.class_probs, + batch.class_heatmap, + weights=class_weights, + valid_mask=valid_mask, + beta=conf.classification.focal.beta, + alpha=conf.classification.focal.alpha, + ) + + total = ( + detection_loss * conf.detection.weight + + size_loss * conf.size.weight + + classification_loss * conf.classification.weight + ) + + return Losses( + detection=detection_loss, + size=size_loss, + classification=classification_loss, + total=total, + ) diff --git a/batdetect2/train/modules.py b/batdetect2/train/modules.py new file mode 100644 index 0000000..5ae5631 --- /dev/null +++ b/batdetect2/train/modules.py @@ -0,0 +1,98 @@ +from typing import Optional + +import pytorch_lightning as L +import torch +from pydantic import Field +from torch import optim + +from batdetect2.configs import BaseConfig +from batdetect2.models import ( + BBoxHead, + ClassifierHead, + ModelConfig, + get_backbone, +) +from batdetect2.models.typing import ModelOutput +from batdetect2.preprocess import PreprocessingConfig +from batdetect2.train.dataset import TrainExample +from batdetect2.train.losses import LossConfig, compute_loss +from batdetect2.train.targets import TargetConfig + + +class OptimizerConfig(BaseConfig): + learning_rate: float = 1e-3 + t_max: int = 100 + + +class TrainingConfig(BaseConfig): + loss: LossConfig = Field(default_factory=LossConfig) + optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) + + +class ModuleConfig(BaseConfig): + train: TrainingConfig = Field(default_factory=TrainingConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) + optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) + backbone: ModelConfig = Field(default_factory=ModelConfig) + preprocessing: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + + +class DetectorModel(L.LightningModule): + config: ModuleConfig + + def __init__( + self, + config: Optional[ModuleConfig] = None, + ): + super().__init__() + self.config = config or ModuleConfig() + self.save_hyperparameters() + + self.backbone = get_backbone( + input_height=self.config.preprocessing.spectrogram.size.height, + config=self.config.backbone, + ) + + self.classifier = ClassifierHead( + num_classes=len(self.config.targets.classes), + in_channels=self.backbone.out_channels, + ) + + self.bbox = BBoxHead(in_channels=self.backbone.out_channels) + + conf = self.training_config.loss.classification + self.class_weights = ( + torch.tensor(conf.class_weights) if conf.class_weights else None + ) + + def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore + features = self.backbone(spec) + detection_probs, classification_probs = self.classifier(features) + size_preds = self.bbox(features) + return ModelOutput( + detection_probs=detection_probs, + size_preds=size_preds, + class_probs=classification_probs, + features=features, + ) + + def training_step(self, batch: TrainExample): + outputs = self.forward(batch.spec) + losses = compute_loss( + batch, + outputs, + conf=self.config.train.loss, + class_weights=self.class_weights, + ) + return losses.total + + def configure_optimizers(self): + conf = self.config.train.optimizer + optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=conf.t_max, + ) + return [optimizer], [scheduler] diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 309d49d..fbcad4c 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -20,8 +20,9 @@ from batdetect2.preprocess import ( from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps from batdetect2.train.targets import ( TargetConfig, - build_class_mapper, + build_encoder, build_sound_event_filter, + get_class_names, ) PathLike = Union[Path, str, os.PathLike] @@ -64,11 +65,15 @@ def generate_train_example( selected_events = [ event for event in clip_annotation.sound_events if filter_fn(event) ] - class_mapper = build_class_mapper(config.target.classes) + + class_names = get_class_names(config.target.classes) + encoder = build_encoder(config.target.classes) + detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( selected_events, spectrogram, - class_mapper, + class_names, + encoder, target_sigma=config.heatmaps.sigma, position=config.heatmaps.position, time_scale=config.heatmaps.time_scale, diff --git a/batdetect2/train/targets.py b/batdetect2/train/targets.py index dc23259..c92b724 100644 --- a/batdetect2/train/targets.py +++ b/batdetect2/train/targets.py @@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Set from pydantic import Field from soundevent import data -from soundevent.types import ClassMapper from batdetect2.configs import BaseConfig from batdetect2.terms import TagInfo, get_tag_from_info @@ -12,11 +11,25 @@ from batdetect2.terms import TagInfo, get_tag_from_info class TargetConfig(BaseConfig): """Configuration for target generation.""" - classes: List[TagInfo] = Field(default_factory=list) - generic_class: Optional[TagInfo] = None + classes: List[TagInfo] = Field( + default_factory=lambda: [ + TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST + ] + ) + generic_class: Optional[TagInfo] = Field( + default_factory=lambda: TagInfo(key="class", value="Bat") + ) - include: Optional[List[TagInfo]] = None - exclude: Optional[List[TagInfo]] = None + include: Optional[List[TagInfo]] = Field( + default_factory=lambda: [TagInfo(key="event", value="Echolocation")] + ) + exclude: Optional[List[TagInfo]] = Field( + default_factory=lambda: [ + TagInfo(key="class", value=""), + TagInfo(key="class", value=" "), + TagInfo(key="class", value="Unknown"), + ] + ) def build_sound_event_filter( @@ -36,13 +49,54 @@ def build_sound_event_filter( ) -def build_class_mapper(classes: List[TagInfo]) -> ClassMapper: - target_tags = [get_tag_from_info(tag) for tag in classes] - labels = [tag.label if tag.label else tag.value for tag in classes] - return GenericMapper( - classes=target_tags, - labels=labels, - ) +def get_tag_label(tag_info: TagInfo) -> str: + return tag_info.label if tag_info.label else tag_info.value + + +def get_class_names(classes: List[TagInfo]) -> List[str]: + return sorted({get_tag_label(tag) for tag in classes}) + + +def build_encoder( + classes: List[TagInfo], +) -> Callable[[data.SoundEventAnnotation], Optional[str]]: + target_tags = set([get_tag_from_info(tag) for tag in classes]) + + tag_mapping = { + tag: get_tag_label(tag_info) + for tag, tag_info in zip(target_tags, classes) + } + + def encoder( + sound_event_annotation: data.SoundEventAnnotation, + ) -> Optional[str]: + tags = set(sound_event_annotation.tags) + + intersection = tags & target_tags + + if not intersection: + return None + + first = intersection.pop() + return tag_mapping[first] + + return encoder + + +def build_decoder( + classes: List[TagInfo], +) -> Callable[[str], List[data.Tag]]: + target_tags = set([get_tag_from_info(tag) for tag in classes]) + tag_mapping = { + get_tag_label(tag_info): tag + for tag, tag_info in zip(target_tags, classes) + } + + def decoder(label: str) -> List[data.Tag]: + tag = tag_mapping.get(label) + return [tag] if tag else [] + + return decoder def filter_sound_event( @@ -61,39 +115,22 @@ def filter_sound_event( return True -class GenericMapper(ClassMapper): - """Generic class mapper configuration.""" - - def __init__( - self, - classes: List[data.Tag], - labels: List[str], - ): - if not len(classes) == len(labels): - raise ValueError("Number of targets and class labels must match.") - - self.targets = set(classes) - self.class_labels = list(dict.fromkeys(labels)) - - self._mapping = {tag: label for tag, label in zip(classes, labels)} - self._inverse_mapping = { - label: tag for tag, label in zip(classes, labels) - } - - def encode( - self, - sound_event_annotation: data.SoundEventAnnotation, - ) -> Optional[str]: - tags = set(sound_event_annotation.tags) - - intersection = tags & self.targets - if not intersection: - return None - - tag = intersection.pop() - return self._mapping[tag] - - def decode(self, label: str) -> List[data.Tag]: - if label not in self._inverse_mapping: - return [] - return [self._inverse_mapping[label]] +DEFAULT_SPECIES_LIST = [ + "Barbastellus barbastellus", + "Eptesicus serotinus", + "Myotis alcathoe", + "Myotis bechsteinii", + "Myotis brandtii", + "Myotis daubentonii", + "Myotis mystacinus", + "Myotis nattereri", + "Nyctalus leisleri", + "Nyctalus noctula", + "Pipistrellus nathusii", + "Pipistrellus pipistrellus", + "Pipistrellus pygmaeus", + "Plecotus auritus", + "Plecotus austriacus", + "Rhinolophus ferrumequinum", + "Rhinolophus hipposideros", +] diff --git a/tests/conftest.py b/tests/conftest.py index 2c7e5a7..f34e8fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,7 +91,7 @@ def recording_factory(wav_factory: Callable[..., Path]): recording_id: Optional[uuid.UUID] = None, duration: float = 1, channels: int = 1, - samplerate: int = 44100, + samplerate: int = 256_000, time_expansion: float = 1, ) -> data.Recording: path = path or wav_factory( diff --git a/tests/test_migration/test_preprocessing.py b/tests/test_migration/test_preprocessing.py index d9ee89a..70fd828 100644 --- a/tests/test_migration/test_preprocessing.py +++ b/tests/test_migration/test_preprocessing.py @@ -80,11 +80,12 @@ def test_spectrogram_generation_hasnt_changed( max_freq = 120_000 fft_overlap = 0.75 - scale = None if spec_scale == "log": - scale = "log" + scale = preprocess.LogScaleConfig() elif spec_scale == "pcen": - scale = preprocess.PcenConfig() + scale = preprocess.PcenScaleConfig() + else: + scale = preprocess.AmplitudeScaleConfig() config = preprocess.SpectrogramConfig( fft=preprocess.FFTConfig( diff --git a/tests/test_train/__init__.py b/tests/test_train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py new file mode 100644 index 0000000..bc2a1a1 --- /dev/null +++ b/tests/test_train/test_augmentations.py @@ -0,0 +1,70 @@ +from collections.abc import Callable + +import xarray as xr +from soundevent import data + +from batdetect2.train.augmentations import ( + add_echo, + mix_examples, + select_random_subclip, +) +from batdetect2.train.preprocess import ( + TrainPreprocessingConfig, + generate_train_example, +) + + +def test_mix_examples( + recording_factory: Callable[..., data.Recording], +): + recording1 = recording_factory() + recording2 = recording_factory() + + clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) + + clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8) + + clip_annotation_1 = data.ClipAnnotation(clip=clip1) + + clip_annotation_2 = data.ClipAnnotation(clip=clip2) + + config = TrainPreprocessingConfig() + + example1 = generate_train_example(clip_annotation_1, config) + example2 = generate_train_example(clip_annotation_2, config) + + mixed = mix_examples(example1, example2, config=config.preprocessing) + + assert mixed["spectrogram"].shape == example1["spectrogram"].shape + assert mixed["detection"].shape == example1["detection"].shape + assert mixed["size"].shape == example1["size"].shape + assert mixed["class"].shape == example1["class"].shape + + +def test_add_echo( + recording_factory: Callable[..., data.Recording], +): + recording1 = recording_factory() + clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) + clip_annotation_1 = data.ClipAnnotation(clip=clip1) + config = TrainPreprocessingConfig() + original = generate_train_example(clip_annotation_1, config) + with_echo = add_echo(original, config=config.preprocessing) + + assert with_echo["spectrogram"].shape == original["spectrogram"].shape + xr.testing.assert_identical(with_echo["size"], original["size"]) + xr.testing.assert_identical(with_echo["class"], original["class"]) + xr.testing.assert_identical(with_echo["detection"], original["detection"]) + + +def test_selected_random_subclip_has_the_correct_width( + recording_factory: Callable[..., data.Recording], +): + recording1 = recording_factory() + clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) + clip_annotation_1 = data.ClipAnnotation(clip=clip1) + config = TrainPreprocessingConfig() + original = generate_train_example(clip_annotation_1, config) + subclip = select_random_subclip(original, width=100) + + assert subclip["spectrogram"].shape[1] == 100 diff --git a/tests/test_data/test_labels.py b/tests/test_train/test_labels.py similarity index 96% rename from tests/test_data/test_labels.py rename to tests/test_train/test_labels.py index d2d7488..787027d 100644 --- a/tests/test_data/test_labels.py +++ b/tests/test_train/test_labels.py @@ -5,7 +5,7 @@ import xarray as xr from soundevent import data from soundevent.types import ClassMapper -from batdetect2.data.labels import generate_heatmaps +from batdetect2.train.labels import generate_heatmaps recording = data.Recording( samplerate=256_000, @@ -60,7 +60,7 @@ def test_generated_heatmaps_have_correct_dimensions(): class_mapper = Mapper() detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( - clip_annotation, + clip_annotation.sound_events, spec, class_mapper, ) @@ -109,7 +109,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(): class_mapper = Mapper() detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( - clip_annotation, + clip_annotation.sound_events, spec, class_mapper, )