Reworking model creation

This commit is contained in:
mbsantiago 2024-11-19 19:34:54 +00:00
parent 36c90a600f
commit 9cf159efff
24 changed files with 967 additions and 1484 deletions

View File

@ -1,10 +1,13 @@
from batdetect2.preprocess import ( from batdetect2.preprocess import (
AmplitudeScaleConfig,
AudioConfig, AudioConfig,
FFTConfig, FFTConfig,
FrequencyConfig, FrequencyConfig,
PcenConfig, LogScaleConfig,
PcenScaleConfig,
PreprocessingConfig, PreprocessingConfig,
ResampleConfig, ResampleConfig,
Scales,
SpecSizeConfig, SpecSizeConfig,
SpectrogramConfig, SpectrogramConfig,
) )
@ -17,12 +20,12 @@ from batdetect2.train.preprocess import (
) )
def get_spectrogram_scale(scale: str): def get_spectrogram_scale(scale: str) -> Scales:
if scale == "pcen": if scale == "pcen":
return PcenConfig() return PcenScaleConfig()
if scale == "log": if scale == "log":
return "log" return LogScaleConfig()
return None return AmplitudeScaleConfig()
def get_preprocessing_config(params: dict) -> PreprocessingConfig: def get_preprocessing_config(params: dict) -> PreprocessingConfig:

View File

@ -1,11 +1,54 @@
from batdetect2.models.feature_extractors import ( from enum import Enum
from batdetect2.configs import BaseConfig
from batdetect2.models.backbones import (
Net2DFast, Net2DFast,
Net2DFastNoAttn, Net2DFastNoAttn,
Net2DFastNoCoordConv, Net2DFastNoCoordConv,
) )
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.typing import BackboneModel
__all__ = [ __all__ = [
"get_backbone",
"Net2DFast", "Net2DFast",
"Net2DFastNoAttn", "Net2DFastNoAttn",
"Net2DFastNoCoordConv", "Net2DFastNoCoordConv",
"ModelType",
"BBoxHead",
"ClassifierHead",
] ]
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
class ModelConfig(BaseConfig):
name: ModelType = ModelType.Net2DFast
num_features: int = 128
def get_backbone(
config: ModelConfig,
input_height: int = 128,
) -> BackboneModel:
if config.name == ModelType.Net2DFast:
return Net2DFast(
input_height=input_height,
num_features=config.num_features,
)
elif config.name == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
num_features=config.num_features,
input_height=input_height,
)
elif config.name == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
num_features=config.num_features,
input_height=input_height,
)
else:
raise ValueError(f"Unknown model type: {config.name}")

View File

@ -12,7 +12,7 @@ from batdetect2.models.blocks import (
ConvBlockUpStandard, ConvBlockUpStandard,
SelfAttention, SelfAttention,
) )
from batdetect2.models.typing import FeatureExtractorModel from batdetect2.models.typing import BackboneModel
__all__ = [ __all__ = [
"Net2DFast", "Net2DFast",
@ -21,7 +21,7 @@ __all__ = [
] ]
class Net2DFast(FeatureExtractorModel): class Net2DFast(BackboneModel):
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
@ -37,7 +37,7 @@ class Net2DFast(FeatureExtractorModel):
1, 1,
self.num_features // 4, self.num_features // 4,
self.input_height, self.input_height,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
@ -45,7 +45,7 @@ class Net2DFast(FeatureExtractorModel):
self.num_features // 4, self.num_features // 4,
self.num_features // 2, self.num_features // 2,
self.input_height // 2, self.input_height // 2,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
@ -53,7 +53,7 @@ class Net2DFast(FeatureExtractorModel):
self.num_features // 2, self.num_features // 2,
self.num_features, self.num_features,
self.input_height // 4, self.input_height // 4,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
@ -100,6 +100,8 @@ class Net2DFast(FeatureExtractorModel):
) )
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
self.out_channels = self.num_features // 4
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]: def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
h, w = spec.shape[2:] h, w = spec.shape[2:]
h_pad = (32 - h % 32) % 32 h_pad = (32 - h % 32) % 32
@ -135,7 +137,7 @@ class Net2DFast(FeatureExtractorModel):
return F.relu_(self.conv_op_bn(self.conv_op(x))) return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoAttn(FeatureExtractorModel): class Net2DFastNoAttn(BackboneModel):
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
@ -151,7 +153,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
1, 1,
self.num_features // 4, self.num_features // 4,
self.input_height, self.input_height,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
@ -159,7 +161,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
self.num_features // 4, self.num_features // 4,
self.num_features // 2, self.num_features // 2,
self.input_height // 2, self.input_height // 2,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
@ -167,7 +169,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
self.num_features // 2, self.num_features // 2,
self.num_features, self.num_features,
self.input_height // 4, self.input_height // 4,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
@ -210,6 +212,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
padding=1, padding=1,
) )
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
self.out_channels = self.num_features // 4
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec) x1 = self.conv_dn_0(spec)
@ -227,7 +230,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
return F.relu_(self.conv_op_bn(self.conv_op(x))) return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoCoordConv(FeatureExtractorModel): class Net2DFastNoCoordConv(BackboneModel):
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
@ -242,21 +245,21 @@ class Net2DFastNoCoordConv(FeatureExtractorModel):
self.conv_dn_0 = ConvBlockDownStandard( self.conv_dn_0 = ConvBlockDownStandard(
1, 1,
self.num_features // 4, self.num_features // 4,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
self.conv_dn_1 = ConvBlockDownStandard( self.conv_dn_1 = ConvBlockDownStandard(
self.num_features // 4, self.num_features // 4,
self.num_features // 2, self.num_features // 2,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
self.conv_dn_2 = ConvBlockDownStandard( self.conv_dn_2 = ConvBlockDownStandard(
self.num_features // 2, self.num_features // 2,
self.num_features, self.num_features,
k_size=3, kernel_size=3,
pad_size=1, pad_size=1,
stride=1, stride=1,
) )
@ -301,6 +304,7 @@ class Net2DFastNoCoordConv(FeatureExtractorModel):
padding=1, padding=1,
) )
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
self.out_channels = self.num_features // 4
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec) x1 = self.conv_dn_0(spec)

View File

@ -10,6 +10,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig
__all__ = [ __all__ = [
"SelfAttention", "SelfAttention",
"ConvBlockDownCoordF", "ConvBlockDownCoordF",
@ -19,22 +21,33 @@ __all__ = [
] ]
class SelfAttentionConfig(BaseConfig):
temperature: float = 1.0
input_channels: int = 128
attention_channels: int = 128
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
"""Self-Attention module. """Self-Attention module.
This module implements self-attention mechanism. This module implements self-attention mechanism.
""" """
def __init__(self, ip_dim: int, att_dim: int): def __init__(
self,
in_channels: int,
attention_channels: int,
temperature: float = 1.0,
):
super().__init__() super().__init__()
# Note, does not encode position information (absolute or realtive) # Note, does not encode position information (absolute or relative)
self.temperature = 1.0 self.temperature = temperature
self.att_dim = att_dim self.att_dim = attention_channels
self.key_fun = nn.Linear(ip_dim, att_dim) self.key_fun = nn.Linear(in_channels, attention_channels)
self.val_fun = nn.Linear(ip_dim, att_dim) self.value_fun = nn.Linear(in_channels, attention_channels)
self.que_fun = nn.Linear(ip_dim, att_dim) self.query_fun = nn.Linear(in_channels, attention_channels)
self.pro_fun = nn.Linear(att_dim, ip_dim) self.pro_fun = nn.Linear(attention_channels, in_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.squeeze(2).permute(0, 2, 1) x = x.squeeze(2).permute(0, 2, 1)
@ -43,11 +56,11 @@ class SelfAttention(nn.Module):
x, self.key_fun.weight.T x, self.key_fun.weight.T
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0) ) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
query = torch.matmul( query = torch.matmul(
x, self.que_fun.weight.T x, self.query_fun.weight.T
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0) ) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
value = torch.matmul( value = torch.matmul(
x, self.val_fun.weight.T x, self.value_fun.weight.T
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0) ) + self.value_fun.bias.unsqueeze(0).unsqueeze(0)
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / ( kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
self.temperature * self.att_dim self.temperature * self.att_dim
@ -63,6 +76,15 @@ class SelfAttention(nn.Module):
return op return op
class ConvBlockDownCoordFConfig(BaseConfig):
in_channels: int
out_channels: int
input_height: int
kernel_size: int = 3
pad_size: int = 1
stride: int = 1
class ConvBlockDownCoordF(nn.Module): class ConvBlockDownCoordF(nn.Module):
"""Convolutional Block with Downsampling and Coord Feature. """Convolutional Block with Downsampling and Coord Feature.
@ -72,27 +94,27 @@ class ConvBlockDownCoordF(nn.Module):
def __init__( def __init__(
self, self,
in_chn: int, in_channels: int,
out_chn: int, out_channels: int,
ip_height: int, input_height: int,
k_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
stride: int = 1, stride: int = 1,
): ):
super().__init__() super().__init__()
self.coords = nn.Parameter( self.coords = nn.Parameter(
torch.linspace(-1, 1, ip_height)[None, None, ..., None], torch.linspace(-1, 1, input_height)[None, None, ..., None],
requires_grad=False, requires_grad=False,
) )
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chn + 1, in_channels + 1,
out_chn, out_channels,
kernel_size=k_size, kernel_size=kernel_size,
padding=pad_size, padding=pad_size,
stride=stride, stride=stride,
) )
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3]) freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
@ -102,6 +124,14 @@ class ConvBlockDownCoordF(nn.Module):
return x return x
class ConvBlockDownStandardConfig(BaseConfig):
in_channels: int
out_channels: int
kernel_size: int = 3
pad_size: int = 1
stride: int = 1
class ConvBlockDownStandard(nn.Module): class ConvBlockDownStandard(nn.Module):
"""Convolutional Block with Downsampling. """Convolutional Block with Downsampling.
@ -110,21 +140,21 @@ class ConvBlockDownStandard(nn.Module):
def __init__( def __init__(
self, self,
in_chn, in_channels: int,
out_chn, out_channels: int,
k_size=3, kernel_size: int = 3,
pad_size=1, pad_size: int = 1,
stride=1, stride: int = 1,
): ):
super(ConvBlockDownStandard, self).__init__() super(ConvBlockDownStandard, self).__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chn, in_channels,
out_chn, out_channels,
kernel_size=k_size, kernel_size=kernel_size,
padding=pad_size, padding=pad_size,
stride=stride, stride=stride,
) )
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x): def forward(self, x):
x = F.max_pool2d(self.conv(x), 2, 2) x = F.max_pool2d(self.conv(x), 2, 2)
@ -132,6 +162,16 @@ class ConvBlockDownStandard(nn.Module):
return x return x
class ConvBlockUpFConfig(BaseConfig):
inp_channels: int
out_channels: int
input_height: int
kernel_size: int = 3
pad_size: int = 1
up_mode: str = "bilinear"
up_scale: Tuple[int, int] = (2, 2)
class ConvBlockUpF(nn.Module): class ConvBlockUpF(nn.Module):
"""Convolutional Block with Upsampling and Coord Feature. """Convolutional Block with Upsampling and Coord Feature.
@ -141,10 +181,10 @@ class ConvBlockUpF(nn.Module):
def __init__( def __init__(
self, self,
in_chn: int, in_channels: int,
out_chn: int, out_channels: int,
ip_height: int, input_height: int,
k_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
up_mode: str = "bilinear", up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2), up_scale: Tuple[int, int] = (2, 2),
@ -154,15 +194,18 @@ class ConvBlockUpF(nn.Module):
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
self.coords = nn.Parameter( self.coords = nn.Parameter(
torch.linspace(-1, 1, ip_height * up_scale[0])[ torch.linspace(-1, 1, input_height * up_scale[0])[
None, None, ..., None None, None, ..., None
], ],
requires_grad=False, requires_grad=False,
) )
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size in_channels + 1,
out_channels,
kernel_size=kernel_size,
padding=pad_size,
) )
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
op = F.interpolate( op = F.interpolate(
@ -181,6 +224,15 @@ class ConvBlockUpF(nn.Module):
return op return op
class ConvBlockUpStandardConfig(BaseConfig):
in_channels: int
out_channels: int
kernel_size: int = 3
pad_size: int = 1
up_mode: str = "bilinear"
up_scale: Tuple[int, int] = (2, 2)
class ConvBlockUpStandard(nn.Module): class ConvBlockUpStandard(nn.Module):
"""Convolutional Block with Upsampling. """Convolutional Block with Upsampling.
@ -189,9 +241,9 @@ class ConvBlockUpStandard(nn.Module):
def __init__( def __init__(
self, self,
in_chn: int, in_channels: int,
out_chn: int, out_channels: int,
k_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
up_mode: str = "bilinear", up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2), up_scale: Tuple[int, int] = (2, 2),
@ -200,9 +252,12 @@ class ConvBlockUpStandard(nn.Module):
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chn, out_chn, kernel_size=k_size, padding=pad_size in_channels,
out_channels,
kernel_size=kernel_size,
padding=pad_size,
) )
self.conv_bn = nn.BatchNorm2d(out_chn) self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
op = F.interpolate( op = F.interpolate(

View File

@ -1,142 +0,0 @@
from typing import Optional, Type
import pytorch_lightning as L
import torch
import xarray as xr
from soundevent import data
from torch import nn, optim
from batdetect2.data.labels import ClassMapper
from batdetect2.data.preprocessing import (
PreprocessingConfig,
preprocess,
)
from batdetect2.models.feature_extractors import Net2DFast
from batdetect2.models.post_process import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.models.typing import FeatureExtractorModel, ModelOutput
from batdetect2.train import losses
from batdetect2.train.dataset import TrainExample
class DetectorModel(L.LightningModule):
def __init__(
self,
class_mapper: ClassMapper,
feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast,
learning_rate: float = 1e-3,
input_height: int = 128,
num_features: int = 32,
preprocessing_config: Optional[PreprocessingConfig] = None,
postprocessing_config: Optional[PostprocessConfig] = None,
):
super().__init__()
preprocessing_config = preprocessing_config or PreprocessingConfig()
postprocessing_config = postprocessing_config or PostprocessConfig()
self.save_hyperparameters()
self.preprocessing_config = preprocessing_config
self.postprocessing_config = postprocessing_config
self.class_mapper = class_mapper
self.learning_rate = learning_rate
self.input_height = input_height
self.num_features = num_features
self.num_classes = class_mapper.num_classes
self.feature_extractor = feature_extractor_class(
input_height=input_height,
num_features=num_features,
)
self.classifier = nn.Conv2d(
self.feature_extractor.num_features // 4,
self.num_classes + 1,
kernel_size=1,
padding=0,
)
self.bbox = nn.Conv2d(
self.feature_extractor.num_features // 4,
2,
kernel_size=1,
padding=0,
)
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.feature_extractor(spec)
classification_logits = self.classifier(features)
classification_probs = torch.softmax(classification_logits, dim=1)
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
return ModelOutput(
detection_probs=detection_probs,
size_preds=self.bbox(features),
class_probs=classification_probs[:, :-1],
features=features,
)
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
return preprocess(
clip,
config=self.preprocessing_config,
)
def compute_clip_features(self, clip: data.Clip) -> torch.Tensor:
spectrogram = self.compute_spectrogram(clip)
return self.feature_extractor(
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
)
def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction:
spectrogram = self.compute_spectrogram(clip)
spec_tensor = (
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
)
outputs = self(spec_tensor)
return postprocess_model_outputs(
outputs,
[clip],
class_mapper=self.class_mapper,
config=self.postprocessing_config,
)[0]
def compute_loss(
self,
outputs: ModelOutput,
batch: TrainExample,
) -> torch.Tensor:
detection_loss = losses.focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
)
size_loss = losses.bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = losses.focal_loss(
outputs.class_probs,
batch.class_heatmap,
valid_mask=valid_mask,
)
return detection_loss + size_loss + classification_loss
def training_step( # type: ignore
self,
batch: TrainExample,
):
outputs = self.forward(batch.spec)
loss = self.compute_loss(outputs, batch)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
return [optimizer], [scheduler]

View File

@ -0,0 +1,51 @@
from typing import NamedTuple
import torch
from torch import nn
__all__ = ["ClassifierHead"]
class Output(NamedTuple):
detection: torch.Tensor
classification: torch.Tensor
class ClassifierHead(nn.Module):
def __init__(self, num_classes: int, in_channels: int):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.classifier = nn.Conv2d(
self.in_channels,
# Add one to account for the background class
self.num_classes + 1,
kernel_size=1,
padding=0,
)
def forward(self, features: torch.Tensor) -> Output:
logits = self.classifier(features)
probs = torch.softmax(logits, dim=1)
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
return Output(
detection=detection_probs,
classification=probs[:, :-1],
)
class BBoxHead(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.bbox = nn.Conv2d(
self.feature_extractor.out_channels,
2,
kernel_size=1,
padding=0,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
return self.bbox(features)

View File

@ -6,7 +6,7 @@ import torch.nn as nn
__all__ = [ __all__ = [
"ModelOutput", "ModelOutput",
"FeatureExtractorModel", "BackboneModel",
] ]
@ -41,13 +41,16 @@ class ModelOutput(NamedTuple):
"""Tensor with intermediate features.""" """Tensor with intermediate features."""
class FeatureExtractorModel(ABC, nn.Module): class BackboneModel(ABC, nn.Module):
input_height: int input_height: int
"""Height of the input spectrogram.""" """Height of the input spectrogram."""
num_features: int num_features: int
"""Dimension of the feature tensor.""" """Dimension of the feature tensor."""
out_channels: int
"""Number of output channels of the feature extractor."""
@abstractmethod @abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Forward pass of the encoder model.""" """Forward pass of the encoder model."""

View File

@ -8,7 +8,6 @@ from pydantic import BaseModel, Field
from soundevent import data from soundevent import data
from torch import nn from torch import nn
from batdetect2.data.labels import ClassMapper
from batdetect2.models.typing import ModelOutput from batdetect2.models.typing import ModelOutput
__all__ = [ __all__ = [
@ -37,7 +36,8 @@ TagFunction = Callable[[int], List[data.Tag]]
def postprocess_model_outputs( def postprocess_model_outputs(
outputs: ModelOutput, outputs: ModelOutput,
clips: List[data.Clip], clips: List[data.Clip],
class_mapper: ClassMapper, classes: List[str],
decoder: Callable[[str], List[data.Tag]],
config: PostprocessConfig, config: PostprocessConfig,
) -> List[data.ClipPrediction]: ) -> List[data.ClipPrediction]:
"""Postprocesses model outputs to generate clip predictions. """Postprocesses model outputs to generate clip predictions.
@ -108,7 +108,8 @@ def postprocess_model_outputs(
size_preds, size_preds,
class_probs, class_probs,
features, features,
class_mapper=class_mapper, classes=classes,
decoder=decoder,
min_freq=config.min_freq, min_freq=config.min_freq,
max_freq=config.max_freq, max_freq=config.max_freq,
detection_threshold=config.detection_threshold, detection_threshold=config.detection_threshold,
@ -132,7 +133,8 @@ def compute_sound_events_from_outputs(
size_preds: torch.Tensor, size_preds: torch.Tensor,
class_probs: torch.Tensor, class_probs: torch.Tensor,
features: torch.Tensor, features: torch.Tensor,
class_mapper: ClassMapper, classes: List[str],
decoder: Callable[[str], List[data.Tag]],
min_freq: int = 10000, min_freq: int = 10000,
max_freq: int = 120000, max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD, detection_threshold: float = DETECTION_THRESHOLD,
@ -181,7 +183,8 @@ def compute_sound_events_from_outputs(
predicted_tags: List[data.PredictedTag] = [] predicted_tags: List[data.PredictedTag] = []
for label_id, class_score in enumerate(class_prob): for label_id, class_score in enumerate(class_prob):
corresponding_tags = class_mapper.inverse_transform(label_id) class_name = classes[label_id]
corresponding_tags = decoder(class_name)
predicted_tags.extend( predicted_tags.extend(
[ [
data.PredictedTag( data.PredictedTag(

View File

@ -12,9 +12,12 @@ from batdetect2.preprocess.audio import (
load_clip_audio, load_clip_audio,
) )
from batdetect2.preprocess.spectrogram import ( from batdetect2.preprocess.spectrogram import (
AmplitudeScaleConfig,
FFTConfig, FFTConfig,
FrequencyConfig, FrequencyConfig,
PcenConfig, LogScaleConfig,
PcenScaleConfig,
Scales,
SpecSizeConfig, SpecSizeConfig,
SpectrogramConfig, SpectrogramConfig,
compute_spectrogram, compute_spectrogram,
@ -26,7 +29,10 @@ __all__ = [
"SpectrogramConfig", "SpectrogramConfig",
"FFTConfig", "FFTConfig",
"FrequencyConfig", "FrequencyConfig",
"PcenConfig", "PcenScaleConfig",
"LogScaleConfig",
"AmplitudeScaleConfig",
"Scales",
"SpecSizeConfig", "SpecSizeConfig",
"PreprocessingConfig", "PreprocessingConfig",
"preprocess_audio_clip", "preprocess_audio_clip",

View File

@ -23,7 +23,18 @@ class FrequencyConfig(BaseConfig):
min_freq: int = Field(default=10_000, gt=0) min_freq: int = Field(default=10_000, gt=0)
class PcenConfig(BaseConfig): class SpecSizeConfig(BaseConfig):
height: int = 256
resize_factor: Optional[float] = 0.5
divide_factor: Optional[int] = 32
class LogScaleConfig(BaseConfig):
name: Literal["log"] = "log"
class PcenScaleConfig(BaseConfig):
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4 time_constant: float = 0.4
hop_length: int = 512 hop_length: int = 512
gain: float = 0.98 gain: float = 0.98
@ -31,19 +42,21 @@ class PcenConfig(BaseConfig):
power: float = 0.5 power: float = 0.5
class SpecSizeConfig(BaseConfig): class AmplitudeScaleConfig(BaseConfig):
height: int = 256 name: Literal["amplitude"] = "amplitude"
resize_factor: Optional[float] = 0.5
divide_factor: Optional[int] = 32
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
class SpectrogramConfig(BaseConfig): class SpectrogramConfig(BaseConfig):
fft: FFTConfig = Field(default_factory=FFTConfig) fft: FFTConfig = Field(default_factory=FFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
scale: Union[Literal["log"], None, PcenConfig] = Field( scale: Scales = Field(
default_factory=PcenConfig default_factory=PcenScaleConfig,
discriminator="name",
) )
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) size: SpecSizeConfig = Field(default_factory=SpecSizeConfig)
denoise: bool = True denoise: bool = True
max_scale: bool = False max_scale: bool = False
@ -55,7 +68,7 @@ def compute_spectrogram(
) -> xr.DataArray: ) -> xr.DataArray:
config = config or SpectrogramConfig() config = config or SpectrogramConfig()
if config.size and config.size.divide_factor: if config.size.divide_factor:
# Need to pad the audio to make sure the spectrogram has a # Need to pad the audio to make sure the spectrogram has a
# width compatible with the divide factor # width compatible with the divide factor
wav = pad_audio( wav = pad_audio(
@ -84,12 +97,11 @@ def compute_spectrogram(
if config.denoise: if config.denoise:
spec = denoise_spectrogram(spec) spec = denoise_spectrogram(spec)
if config.size: spec = resize_spectrogram(
spec = resize_spectrogram( spec,
spec, height=config.size.height,
height=config.size.height, resize_factor=config.size.resize_factor,
resize_factor=config.size.resize_factor, )
)
if config.max_scale: if config.max_scale:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
@ -180,13 +192,13 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
def scale_spectrogram( def scale_spectrogram(
spec: xr.DataArray, spec: xr.DataArray,
scale: Union[Literal["log"], None, PcenConfig], scale: Scales,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32,
) -> xr.DataArray: ) -> xr.DataArray:
if scale == "log": if scale.name == "log":
return scale_log(spec, dtype=dtype) return scale_log(spec, dtype=dtype)
if isinstance(scale, PcenConfig): if scale.name == "pcen":
return scale_pcen( return scale_pcen(
spec, spec,
time_constant=scale.time_constant, time_constant=scale.time_constant,

View File

@ -1,941 +0,0 @@
"""Functions and dataloaders for training and testing the model."""
import copy
from typing import List, Optional, Tuple
import librosa
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
import batdetect2.utils.audio_utils as au
from batdetect2.types import (
Annotation,
AudioLoaderAnnotationGroup,
AudioLoaderParameters,
FileAnnotation,
)
def generate_gt_heatmaps(
spec_op_shape: Tuple[int, int],
sampling_rate: float,
ann: AudioLoaderAnnotationGroup,
class_names: List[str],
fft_win_length: float,
fft_overlap: float,
max_freq: float,
min_freq: float,
resize_factor: float,
target_sigma: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
"""Generate ground truth heatmaps from annotations.
Parameters
----------
spec_op_shape : Tuple[int, int]
Shape of the input spectrogram.
sampling_rate : int
Sampling rate of the input audio in Hz.
ann : AnnotationGroup
Dictionary containing the annotation information.
params : HeatmapParameters
Parameters controlling the generation of the heatmaps.
Returns
-------
y_2d_det : np.ndarray
2D heatmap of the presence of an event.
y_2d_size : np.ndarray
2D heatmap of the size of the bounding box associated to event.
y_2d_classes : np.ndarray
3D array containing the ground-truth class probabilities for each
pixel.
ann_aug : AnnotationGroup
A dictionary containing the annotation information of the
annotations that are within the input spectrogram, augmented with
the x and y indices of their pixel location in the input spectrogram.
"""
# spec may be resized on input into the network
num_classes = len(class_names)
op_height = spec_op_shape[0]
op_width = spec_op_shape[1]
freq_per_bin = (max_freq - min_freq) / op_height
# start and end times
x_pos_start = au.time_to_x_coords(
ann["start_times"],
sampling_rate,
fft_win_length,
fft_overlap,
)
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
x_pos_end = au.time_to_x_coords(
ann["end_times"],
sampling_rate,
fft_win_length,
fft_overlap,
)
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
# location on y axis i.e. frequency
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
y_pos_low = (op_height - y_pos_low).astype(np.int32)
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
y_pos_high = (op_height - y_pos_high).astype(np.int32)
bb_widths = x_pos_end - x_pos_start
bb_heights = y_pos_low - y_pos_high
# Only include annotations that are within the input spectrogram
valid_inds = np.where(
(x_pos_start >= 0)
& (x_pos_start < op_width)
& (y_pos_low >= 0)
& (y_pos_low < (op_height - 1))
)[0]
ann_aug: AudioLoaderAnnotationGroup = {
**ann,
"start_times": ann["start_times"][valid_inds],
"end_times": ann["end_times"][valid_inds],
"high_freqs": ann["high_freqs"][valid_inds],
"low_freqs": ann["low_freqs"][valid_inds],
"class_ids": ann["class_ids"][valid_inds],
"individual_ids": ann["individual_ids"][valid_inds],
"x_inds": x_pos_start[valid_inds],
"y_inds": y_pos_low[valid_inds],
}
# if the number of calls is only 1, then it is unique
# TODO would be better if we found these unique calls at the merging stage
if len(ann_aug["individual_ids"]) == 1:
ann_aug["individual_ids"][0] = 0
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
# num classes and "background" class
y_2d_classes: np.ndarray = np.zeros(
(num_classes + 1, op_height, op_width), dtype=np.float32
)
# create 2D ground truth heatmaps
for ii in valid_inds:
draw_gaussian(
y_2d_det[0, :],
(x_pos_start[ii], y_pos_low[ii]),
target_sigma,
)
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
cls_id = ann["class_ids"][ii]
if cls_id > -1:
draw_gaussian(
y_2d_classes[cls_id, :],
(x_pos_start[ii], y_pos_low[ii]),
target_sigma,
)
# be careful as this will have a 1.0 places where we have event but
# dont know gt class this will be masked in training anyway
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
def draw_gaussian(
heatmap: np.ndarray,
center: Tuple[int, int],
sigmax: float,
sigmay: Optional[float] = None,
) -> bool:
"""Draw a 2D gaussian into the heatmap.
If the gaussian center is outside the heatmap, then the gaussian is not
drawn.
Parameters
----------
heatmap : np.ndarray
The heatmap to draw into. Should be of shape (height, width).
center : Tuple[int, int]
The center of the gaussian in (x, y) format.
sigmax : float
The standard deviation of the gaussian in the x direction.
sigmay : Optional[float], optional
The standard deviation of the gaussian in the y direction. If None,
then sigmay = sigmax, by default None.
Returns
-------
bool
True if the gaussian was drawn, False if it was not (because
the center was outside the heatmap).
"""
# center is (x, y)
# this edits the heatmap inplace
if sigmay is None:
sigmay = sigmax
tmp_size = np.maximum(sigmax, sigmay) * 3
mu_x = int(center[0] + 0.5)
mu_y = int(center[1] + 0.5)
w, h = heatmap.shape[0], heatmap.shape[1]
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
return False
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
g = np.exp(
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
)
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
img_x = max(0, ul[0]), min(br[0], h)
img_y = max(0, ul[1]), min(br[1], w)
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
)
return True
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
"""Pad array with -1s."""
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
def warp_spec_aug(
spec: torch.Tensor,
ann: AudioLoaderAnnotationGroup,
stretch_squeeze_delta: float,
) -> torch.Tensor:
"""Warp spectrogram by randomly stretching and squeezing.
Parameters
----------
spec: torch.Tensor
Spectrogram to warp.
ann: AnnotationGroup
Annotation group for the spectrogram. Must be provided to sync
the start and stop times with the spectrogram after warping.
stretch_squeeze_delta: float
Maximum amount to stretch or squeeze the spectrogram.
Returns
-------
torch.Tensor
Warped spectrogram.
Notes
-----
This function modifies the annotation group in place.
"""
# Augment spectrogram by randomly stretch and squeezing
# NOTE this also changes the start and stop time in place
delta = stretch_squeeze_delta
op_size = (spec.shape[1], spec.shape[2])
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
resize_amt = int(spec.shape[2] * resize_fract_r)
if resize_amt >= spec.shape[2]:
spec_r = torch.cat(
(
spec,
torch.zeros(
(1, spec.shape[1], resize_amt - spec.shape[2]),
dtype=spec.dtype,
),
),
dim=2,
)
else:
spec_r = spec[:, :, :resize_amt]
# Resize the spectrogram
spec = F.interpolate(
spec_r.unsqueeze(0),
size=op_size,
mode="bilinear",
align_corners=False,
).squeeze(0)
# Update the start and stop times
ann["start_times"] *= 1.0 / resize_fract_r
ann["end_times"] *= 1.0 / resize_fract_r
return spec
def mask_time_aug(
spec: torch.Tensor,
mask_max_time_perc: float,
) -> torch.Tensor:
"""Mask out random blocks of time.
Will randomly mask out a block of time in the spectrogram. The block
will be between 0.0 and `mask_max_time_perc` of the total time.
A random number of blocks will be masked out between 1 and 3.
Parameters
----------
spec: torch.Tensor
Spectrogram to mask.
mask_max_time_perc: float
Maximum percentage of time to mask out.
Returns
-------
torch.Tensor
Spectrogram with masked out time blocks.
Notes
-----
This function is based on the implementation in::
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition
"""
fm = torchaudio.transforms.TimeMasking(
int(spec.shape[1] * mask_max_time_perc)
)
for _ in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def mask_freq_aug(
spec: torch.Tensor,
mask_max_freq_perc: float,
) -> torch.Tensor:
"""Mask out random blocks of frequency.
Will randomly mask out a block of frequency in the spectrogram. The block
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
A random number of blocks will be masked out between 1 and 3.
Parameters
----------
spec: torch.Tensor
Spectrogram to mask.
mask_max_freq_perc: float
Maximum percentage of frequency to mask out.
Returns
-------
torch.Tensor
Spectrogram with masked out frequency blocks.
Notes
-----
This function is based on the implementation in::
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition
"""
fm = torchaudio.transforms.FrequencyMasking(
int(spec.shape[1] * mask_max_freq_perc)
)
for _ in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def scale_vol_aug(
spec: torch.Tensor,
spec_amp_scaling: float,
) -> torch.Tensor:
"""Scale the volume of the spectrogram.
Parameters
----------
spec: torch.Tensor
Spectrogram to scale.
spec_amp_scaling: float
Maximum scaling factor.
Returns
-------
torch.Tensor
"""
return spec * np.random.random() * spec_amp_scaling
def echo_aug(
audio: np.ndarray,
sampling_rate: float,
echo_max_delay: float,
) -> np.ndarray:
"""Add echo to audio.
Parameters
----------
audio: np.ndarray
Audio to add echo to.
sampling_rate: float
Sampling rate of the audio.
echo_max_delay: float
Maximum delay of the echo in seconds.
Returns
-------
np.ndarray
Audio with echo added.
"""
sample_offset = (
int(echo_max_delay * np.random.random() * sampling_rate) + 1
)
# NOTE: This seems to be wrong, as the echo should be added to the
# end of the audio, not the beginning.
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
return audio
def resample_aug(
audio: np.ndarray,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
resize_factor: float,
spec_divide_factor: float,
spec_train_width: int,
aug_sampling_rates: List[int],
) -> Tuple[np.ndarray, float, float]:
"""Resample audio augmentation.
Will resample the audio to a random sampling rate from the list of
sampling rates in `aug_sampling_rates`.
Parameters
----------
audio: np.ndarray
Audio to resample.
sampling_rate: float
Original sampling rate of the audio.
fft_win_length: float
Length of the FFT window in seconds.
fft_overlap: float
Amount of overlap between FFT windows.
resize_factor: float
Factor to resize the spectrogram by.
spec_divide_factor: float
Factor to divide the spectrogram by.
spec_train_width: int
Width of the spectrogram.
aug_sampling_rates: List[int]
List of sampling rates to resample to.
Returns
-------
audio : np.ndarray
Resampled audio.
sampling_rate : float
New sampling rate.
duration : float
Duration of the audio in seconds.
"""
sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(aug_sampling_rates)
audio = librosa.resample(
audio,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
audio = au.pad_audio(
audio,
sampling_rate,
fft_win_length,
fft_overlap,
resize_factor,
spec_divide_factor,
spec_train_width,
)
duration = audio.shape[0] / float(sampling_rate)
return audio, sampling_rate, duration
def resample_audio(
num_samples: int,
sampling_rate: float,
audio2: np.ndarray,
sampling_rate2: float,
) -> Tuple[np.ndarray, float]:
"""Resample audio.
Parameters
----------
num_samples: int
Expected number of samples for the output audio.
sampling_rate: float
Original sampling rate of the audio.
audio2: np.ndarray
Audio to resample.
sampling_rate2: float
Target sampling rate of the audio.
Returns
-------
audio2 : np.ndarray
Resampled audio.
sampling_rate2 : float
New sampling rate.
"""
# resample to target sampling rate
if sampling_rate != sampling_rate2:
audio2 = librosa.resample(
audio2,
orig_sr=sampling_rate2,
target_sr=sampling_rate,
res_type="polyphase",
)
sampling_rate2 = sampling_rate
# pad or trim to the correct length
if audio2.shape[0] < num_samples:
audio2 = np.hstack(
(
audio2,
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
)
)
elif audio2.shape[0] > num_samples:
audio2 = audio2[:num_samples]
return audio2, sampling_rate2
def combine_audio_aug(
audio: np.ndarray,
sampling_rate: float,
ann: AudioLoaderAnnotationGroup,
audio2: np.ndarray,
sampling_rate2: float,
ann2: AudioLoaderAnnotationGroup,
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
"""Combine two audio files.
Will combine two audio files by resampling them to the same sampling rate
and then combining them with a random weight. The annotations will be
combined by taking the union of the two sets of annotations.
Parameters
----------
audio: np.ndarray
First Audio to combine.
sampling_rate: int
Sampling rate of the first audio.
ann: AnnotationGroup
Annotations for the first audio.
audio2: np.ndarray
Second Audio to combine.
sampling_rate2: int
Sampling rate of the second audio.
ann2: AnnotationGroup
Annotations for the second audio.
Returns
-------
audio : np.ndarray
Combined audio.
ann : AnnotationGroup
Combined annotations.
"""
# resample so they are the same
audio2, sampling_rate2 = resample_audio(
audio.shape[0],
sampling_rate,
audio2,
sampling_rate2,
)
# # set mean and std to be the same
# audio2 = (audio2 - audio2.mean())
# audio2 = (audio2/audio2.std())*audio.std()
# audio2 = audio2 + audio.mean()
if (
ann.get("annotated", False)
and (ann2.get("annotated", False))
and (sampling_rate2 == sampling_rate)
and (audio.shape[0] == audio2.shape[0])
):
comb_weight = 0.3 + np.random.random() * 0.4
audio = comb_weight * audio + (1 - comb_weight) * audio2
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
for kk in ann.keys():
# when combining calls from different files, assume they come
# from different individuals
if kk == "individual_ids":
if (ann[kk] > -1).sum() > 0:
ann2[kk][ann2[kk] > -1] += (
np.max(ann[kk][ann[kk] > -1]) + 1
)
if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
return audio, ann
def _prepare_annotation(
annotation: Annotation,
class_names: List[str],
) -> Annotation:
try:
class_id = class_names.index(annotation["class"])
except ValueError:
class_id = -1
ann: Annotation = {
**annotation,
"class_id": class_id,
}
if "individual" in ann:
ann["individual"] = int(ann["individual"]) # type: ignore
return ann
def _prepare_file_annotation(
annotation: FileAnnotation,
class_names: List[str],
classes_to_ignore: List[str],
) -> AudioLoaderAnnotationGroup:
annotations = [
_prepare_annotation(ann, class_names)
for ann in annotation["annotation"]
if ann["class"] not in classes_to_ignore
]
try:
class_id_file = class_names.index(annotation["class_name"])
except ValueError:
class_id_file = -1
ret: AudioLoaderAnnotationGroup = {
"id": annotation["id"],
"annotated": annotation["annotated"],
"duration": annotation["duration"],
"issues": annotation["issues"],
"time_exp": annotation["time_exp"],
"class_name": annotation["class_name"],
"notes": annotation["notes"],
"annotation": annotations,
"start_times": np.array([ann["start_time"] for ann in annotations]),
"end_times": np.array([ann["end_time"] for ann in annotations]),
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
"class_ids": np.array(
[ann.get("class_id", -1) for ann in annotations]
),
"individual_ids": np.array([ann["individual"] for ann in annotations]),
"class_id_file": class_id_file,
}
return ret
class AudioLoader(torch.utils.data.Dataset):
"""Main AudioLoader for training and testing."""
def __init__(
self,
data_anns_ip: List[FileAnnotation],
params: AudioLoaderParameters,
dataset_name: Optional[str] = None,
is_train: bool = False,
return_spec_for_viz: bool = False,
):
self.is_train = is_train
self.params = params
self.return_spec_for_viz = return_spec_for_viz
self.data_anns: List[AudioLoaderAnnotationGroup] = [
_prepare_file_annotation(
ann,
params["class_names"],
params["classes_to_ignore"],
)
for ann in data_anns_ip
]
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
self.max_num_anns = 2 * np.max(
ann_cnt
) # x2 because we may be combining files during training
print("\n")
if dataset_name is not None:
print("Dataset : " + dataset_name)
if self.is_train:
print("Split type : train")
else:
print("Split type : test")
print("Num files : " + str(len(self.data_anns)))
print("Num calls : " + str(np.sum(ann_cnt)))
def get_file_and_anns(
self,
index: Optional[int] = None,
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
"""Get an audio file and its annotations.
Parameters
----------
index : int, optional
Index of the file to be loaded. If None, a random file is chosen.
Returns
-------
audio_raw : np.ndarray
Loaded audio file.
sampling_rate : float
Sampling rate of the audio file.
duration : float
Duration of the audio file in seconds.
ann : AnnotationGroup
AnnotationGroup object containing the annotations for the audio file.
"""
# if no file specified, choose random one
if index is None:
index = np.random.randint(0, len(self.data_anns))
audio_file = self.data_anns[index]["file_path"]
sampling_rate, audio_raw = au.load_audio(
audio_file,
self.data_anns[index]["time_exp"],
self.params["target_samp_rate"],
self.params["scale_raw_audio"],
)
# copy annotation
ann = copy.deepcopy(self.data_anns[index])
# ann["annotated"] = self.data_anns[index]["annotated"]
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
# keys = [
# "start_times",
# "end_times",
# "high_freqs",
# "low_freqs",
# "class_ids",
# "individual_ids",
# ]
# for kk in keys:
# ann[kk] = self.data_anns[index][kk].copy()
# if train then grab a random crop
if self.is_train:
nfft = int(self.params["fft_win_length"] * sampling_rate)
noverlap = int(self.params["fft_overlap"] * nfft)
length_samples = (
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
)
if audio_raw.shape[0] - length_samples > 0:
sample_crop = np.random.randint(
audio_raw.shape[0] - length_samples
)
else:
sample_crop = 0
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
ann["start_times"] = ann["start_times"] - sample_crop / float(
sampling_rate
)
ann["end_times"] = ann["end_times"] - sample_crop / float(
sampling_rate
)
# pad audio
if self.is_train:
op_spec_target_size = self.params["spec_train_width"]
else:
op_spec_target_size = None
audio_raw = au.pad_audio(
audio_raw,
sampling_rate,
self.params["fft_win_length"],
self.params["fft_overlap"],
self.params["resize_factor"],
self.params["spec_divide_factor"],
op_spec_target_size,
)
duration = audio_raw.shape[0] / float(sampling_rate)
# sort based on time
inds = np.argsort(ann["start_times"])
for kk in ann.keys():
if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = ann[kk][inds]
return audio_raw, sampling_rate, duration, ann
def __getitem__(self, index):
"""Get an item from the dataset."""
# load audio file
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
# augment on raw audio
if self.is_train and self.params["augment_at_train"]:
# augment - combine with random audio file
if (
self.params["augment_at_train_combine"]
and np.random.random() < self.params["aug_prob"]
):
(
audio2,
sampling_rate2,
_,
ann2,
) = self.get_file_and_anns()
audio, ann = combine_audio_aug(
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
)
# simulate echo by adding delayed copy of the file
if np.random.random() < self.params["aug_prob"]:
audio = echo_aug(
audio,
sampling_rate,
echo_max_delay=self.params["echo_max_delay"],
)
# resample the audio
# if np.random.random() < self.params["aug_prob"]:
# audio, sampling_rate, duration = resample_aug(
# audio, sampling_rate, self.params
# )
# create spectrogram
spec, _ = au.generate_spectrogram(
audio,
sampling_rate,
params=dict(
fft_win_length=self.params["fft_win_length"],
fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
min_freq=self.params["min_freq"],
spec_scale=self.params["spec_scale"],
denoise_spec_avg=self.params["denoise_spec_avg"],
max_scale_spec=self.params["max_scale_spec"],
),
)
rsf = self.params["resize_factor"]
spec_op_shape = (
int(self.params["spec_height"] * rsf),
int(spec.shape[1] * rsf),
)
# resize the spec
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
spec = F.interpolate(
spec,
size=spec_op_shape,
mode="bilinear",
align_corners=False,
).squeeze(0)
# augment spectrogram
if self.is_train and self.params["augment_at_train"]:
if np.random.random() < self.params["aug_prob"]:
spec = scale_vol_aug(
spec,
spec_amp_scaling=self.params["spec_amp_scaling"],
)
if np.random.random() < self.params["aug_prob"]:
spec = warp_spec_aug(
spec,
ann,
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
)
if np.random.random() < self.params["aug_prob"]:
spec = mask_time_aug(
spec,
mask_max_time_perc=self.params["mask_max_time_perc"],
)
if np.random.random() < self.params["aug_prob"]:
spec = mask_freq_aug(
spec,
mask_max_freq_perc=self.params["mask_max_freq_perc"],
)
outputs = {}
outputs["spec"] = spec
if self.return_spec_for_viz:
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
0
)
# create ground truth heatmaps
(
outputs["y_2d_det"],
outputs["y_2d_size"],
outputs["y_2d_classes"],
ann_aug,
) = generate_gt_heatmaps(
spec_op_shape,
sampling_rate,
ann,
class_names=self.params["class_names"],
fft_win_length=self.params["fft_win_length"],
fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
min_freq=self.params["min_freq"],
resize_factor=self.params["resize_factor"],
target_sigma=self.params["target_sigma"],
)
# hack to get around requirement that all vectors are the same length
# in the output batch
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
outputs["is_valid"] = pad_aray(
np.ones(len(ann_aug["individual_ids"])), pad_size
)
keys = [
"class_ids",
"individual_ids",
"x_inds",
"y_inds",
"start_times",
"end_times",
"low_freqs",
"high_freqs",
]
for kk in keys:
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
# convert to pytorch
for kk in outputs.keys():
if type(outputs[kk]) != torch.Tensor:
outputs[kk] = torch.from_numpy(outputs[kk])
# scalars
outputs["class_id_file"] = ann["class_id_file"]
outputs["annotated"] = ann["annotated"]
outputs["duration"] = duration
outputs["sampling_rate"] = sampling_rate
outputs["file_id"] = index
return outputs
def __len__(self):
"""Denotes the total number of samples."""
return len(self.data_anns)

View File

@ -1,133 +1,161 @@
from functools import wraps from typing import Callable, Optional, Union
from typing import Callable, List, Optional
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from pydantic import Field
from soundevent import arrays
from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
Augmentation = Callable[[xr.Dataset], xr.Dataset] Augmentation = Callable[[xr.Dataset], xr.Dataset]
AUGMENTATION_PROBABILITY = 0.2 class AugmentationConfig(BaseConfig):
MAX_DELAY = 0.005 enable: bool = True
STRETCH_SQUEEZE_DELTA = 0.04 probability: float = 0.2
MASK_MAX_TIME_PERC: float = 0.05
MASK_MAX_FREQ_PERC: float = 0.10
def maybe_apply( class SubclipConfig(BaseConfig):
augmentation: Callable, enable: bool = True
prob: float = AUGMENTATION_PROBABILITY, duration: Optional[float] = None
) -> Callable:
"""Apply an augmentation with a given probability."""
@wraps(augmentation)
def _augmentation(x):
if np.random.rand() > prob:
return x
return augmentation(x)
return _augmentation
def select_random_subclip( def select_random_subclip(
train_example: xr.Dataset, example: xr.Dataset,
start_time: Optional[float] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
proportion: float = 0.9, width: Optional[int] = None,
) -> xr.Dataset: ) -> xr.Dataset:
"""Select a random subclip from a clip.""" """Select a random subclip from a clip."""
step = arrays.get_dim_step(example, "time") # type: ignore
time_coords = train_example.coords["time"] if width is None:
if duration is None:
raise ValueError("Either duration or width must be provided")
start_time = time_coords.attrs.get("min", time_coords.min()) width = int(np.floor(duration / step))
end_time = time_coords.attrs.get("max", time_coords.max())
if duration is None: if duration is None:
duration = (end_time - start_time) * proportion duration = width * step
start_time = np.random.uniform(start_time, end_time - duration) if start_time is None:
return train_example.sel(time=slice(start_time, start_time + duration)) start, stop = arrays.get_dim_range(example, "time") # type: ignore
start_time = np.random.uniform(start, stop - duration)
start_index = arrays.get_coord_index(
example, # type: ignore
"time",
start_time,
)
end_index = start_index + width - 1
start_time = example.time.values[start_index]
end_time = example.time.values[end_index]
return example.sel(
time=slice(start_time, end_time),
audio_time=slice(start_time, end_time),
)
def combine_audio( class MixAugmentationConfig(AugmentationConfig):
audio1: xr.DataArray, min_weight: float = 0.3
audio2: xr.DataArray, max_weight: float = 0.7
alpha: Optional[float] = None,
min_alpha: float = 0.3,
max_alpha: float = 0.7, def mix_examples(
) -> xr.DataArray: example: xr.Dataset,
other: xr.Dataset,
weight: Optional[float] = None,
min_weight: float = 0.3,
max_weight: float = 0.7,
config: Optional[PreprocessingConfig] = None,
) -> xr.Dataset:
"""Combine two audio clips.""" """Combine two audio clips."""
config = config or PreprocessingConfig()
if alpha is None: if weight is None:
alpha = np.random.uniform(min_alpha, max_alpha) weight = np.random.uniform(min_weight, max_weight)
return alpha * audio1 + (1 - alpha) * audio2.data audio2 = other["audio"].values
audio1 = ops.adjust_dim_width(example["audio"], "audio_time", len(audio2))
combined = weight * audio1 + (1 - weight) * audio2
spec = compute_spectrogram(
combined.rename({"audio_time": "time"}),
config=config.spectrogram,
)
detection_heatmap = xr.apply_ufunc(
np.maximum,
example["detection"],
other["detection"].values,
)
class_heatmap = xr.apply_ufunc(
np.maximum,
example["class"],
other["class"].values,
)
size_heatmap = example["size"] + other["size"].values
return xr.Dataset(
{
"audio": combined,
"spectrogram": spec,
"detection": detection_heatmap,
"class": class_heatmap,
"size": size_heatmap,
}
)
# def random_mix( class EchoAugmentationConfig(AugmentationConfig):
# audio: xr.DataArray, max_delay: float = 0.005
# clip: data.ClipAnnotation, min_weight: float = 0.0
# provider: Optional[ClipProvider] = None, max_weight: float = 1.0
# alpha: Optional[float] = None,
# min_alpha: float = 0.3,
# max_alpha: float = 0.7,
# join_annotations: bool = True,
# ) -> Tuple[xr.DataArray, data.ClipAnnotation]:
# """Mix two audio clips."""
# if provider is None:
# raise ValueError("No audio provider given.")
#
# try:
# other_audio, other_clip = provider(clip)
# except (StopIteration, ValueError):
# raise ValueError("No more audio sources available.")
#
# new_audio = combine_audio(
# audio,
# other_audio,
# alpha=alpha,
# min_alpha=min_alpha,
# max_alpha=max_alpha,
# )
#
# if join_annotations:
# clip = clip.model_copy(
# update=dict(
# sound_events=clip.sound_events + other_clip.sound_events,
# )
# )
#
# return new_audio, clip
def add_echo( def add_echo(
train_example: xr.Dataset, example: xr.Dataset,
delay: Optional[float] = None, delay: Optional[float] = None,
alpha: Optional[float] = None, weight: Optional[float] = None,
min_alpha: float = 0.0, min_weight: float = 0.1,
max_alpha: float = 1.0, max_weight: float = 1.0,
max_delay: float = MAX_DELAY, max_delay: float = 0.005,
config: Optional[PreprocessingConfig] = None,
) -> xr.Dataset: ) -> xr.Dataset:
"""Add a delay to the audio.""" """Add a delay to the audio."""
config = config or PreprocessingConfig()
if delay is None: if delay is None:
delay = np.random.uniform(0, max_delay) delay = np.random.uniform(0, max_delay)
if alpha is None: if weight is None:
alpha = np.random.uniform(min_alpha, max_alpha) weight = np.random.uniform(min_weight, max_weight)
spec = train_example["spectrogram"] audio = example["audio"]
step = arrays.get_dim_step(audio, "audio_time")
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
audio = audio + weight * audio_delay
time_coords = spec.coords["time"] spectrogram = compute_spectrogram(
start_time = time_coords.attrs["min"] audio.rename({"audio_time": "time"}),
end_time = time_coords.attrs["max"] config=config.spectrogram,
step = (end_time - start_time) / time_coords.size )
spec_delay = spec.shift(time=int(delay / step), fill_value=0) return example.assign(audio=audio, spectrogram=spectrogram)
return train_example.assign(spectrogram=spec + alpha * spec_delay)
class VolumeAugmentationConfig(AugmentationConfig):
min_scaling: float = 0.0
max_scaling: float = 2.0
def scale_volume( def scale_volume(
train_example: xr.Dataset, example: xr.Dataset,
factor: Optional[float] = None, factor: Optional[float] = None,
max_scaling: float = 2, max_scaling: float = 2,
min_scaling: float = 0, min_scaling: float = 0,
@ -136,106 +164,235 @@ def scale_volume(
if factor is None: if factor is None:
factor = np.random.uniform(min_scaling, max_scaling) factor = np.random.uniform(min_scaling, max_scaling)
return train_example.assign( return example.assign(spectrogram=example["spectrogram"] * factor)
spectrogram=train_example["spectrogram"] * factor
)
class WarpAugmentationConfig(AugmentationConfig):
delta: float = 0.04
def warp_spectrogram( def warp_spectrogram(
train_example: xr.Dataset, example: xr.Dataset,
factor: Optional[float] = None, factor: Optional[float] = None,
delta: float = STRETCH_SQUEEZE_DELTA, delta: float = 0.04,
) -> xr.Dataset: ) -> xr.Dataset:
"""Warp a spectrogram.""" """Warp a spectrogram."""
if factor is None: if factor is None:
factor = np.random.uniform(1 - delta, 1 + delta) factor = np.random.uniform(1 - delta, 1 + delta)
time_coords = train_example.coords["time"] start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
start_time = time_coords.attrs["min"]
end_time = time_coords.attrs["max"]
duration = end_time - start_time duration = end_time - start_time
new_time = np.linspace( new_time = np.linspace(
start_time, start_time,
start_time + duration * factor, start_time + duration * factor,
train_example.time.size, example.time.size,
) )
return train_example.interp(time=new_time) spectrogram = (
example["spectrogram"]
.interp(
coords={"time": new_time},
method="linear",
kwargs=dict(
fill_value=0,
),
)
.clip(min=0)
)
detection = example["detection"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
classification = example["class"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
size = example["size"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
return example.assign(
{
"time": new_time,
"spectrogram": spectrogram,
"detection": detection,
"class": classification,
"size": size,
}
)
def mask_axis( def mask_axis(
train_example: xr.Dataset, array: xr.DataArray,
dim: str, dim: str,
start: float, start: float,
end: float, end: float,
mask_all: bool = False, mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
mask_value: float = 0, ) -> xr.DataArray:
) -> xr.Dataset: if dim not in array.dims:
if dim not in train_example.dims:
raise ValueError(f"Axis {dim} not found in array") raise ValueError(f"Axis {dim} not found in array")
coord = train_example.coords[dim] coord = array.coords[dim]
condition = (coord < start) | (coord > end) condition = (coord < start) | (coord > end)
if mask_all: if callable(mask_value):
return train_example.where(condition, other=mask_value) mask_value = mask_value(array)
return train_example.assign( return array.where(condition, other=mask_value)
spectrogram=train_example.spectrogram.where(
condition, other=mask_value
) class TimeMaskAugmentationConfig(AugmentationConfig):
) max_perc: float = 0.05
max_masks: int = 3
def mask_time( def mask_time(
train_example: xr.Dataset, example: xr.Dataset,
max_time_mask: float = MASK_MAX_TIME_PERC, max_perc: float = 0.05,
max_num_masks: int = 3, max_mask: int = 3,
) -> xr.Dataset: ) -> xr.Dataset:
"""Mask a random section of the time axis.""" """Mask a random section of the time axis."""
num_masks = np.random.randint(1, max_mask + 1)
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
num_masks = np.random.randint(1, max_num_masks + 1) spectrogram = example["spectrogram"]
time_coord = train_example.coords["time"]
start_time = time_coord.attrs.get("min", time_coord.min())
end_time = time_coord.attrs.get("max", time_coord.max())
for _ in range(num_masks): for _ in range(num_masks):
mask_size = np.random.uniform(0, max_time_mask) mask_size = np.random.uniform(0, max_perc) * (end_time - start_time)
start = np.random.uniform(start_time, end_time - mask_size) start = np.random.uniform(start_time, end_time - mask_size)
end = start + mask_size end = start + mask_size
train_example = mask_axis(train_example, "time", start, end) spectrogram = mask_axis(spectrogram, "time", start, end)
return train_example return example.assign(spectrogram=spectrogram)
class FrequencyMaskAugmentationConfig(AugmentationConfig):
max_perc: float = 0.10
max_masks: int = 3
def mask_frequency( def mask_frequency(
train_example: xr.Dataset, example: xr.Dataset,
max_freq_mask: float = MASK_MAX_FREQ_PERC, max_perc: float = 0.10,
max_num_masks: int = 3, max_masks: int = 3,
) -> xr.Dataset: ) -> xr.Dataset:
"""Mask a random section of the frequency axis.""" """Mask a random section of the frequency axis."""
num_masks = np.random.randint(1, max_masks + 1)
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore
num_masks = np.random.randint(1, max_num_masks + 1) spectrogram = example["spectrogram"]
freq_coord = train_example.coords["frequency"]
min_freq = float(freq_coord.min())
max_freq = float(freq_coord.max())
for _ in range(num_masks): for _ in range(num_masks):
mask_size = np.random.uniform(0, max_freq_mask) mask_size = np.random.uniform(0, max_perc) * (max_freq - min_freq)
start = np.random.uniform(min_freq, max_freq - mask_size) start = np.random.uniform(min_freq, max_freq - mask_size)
end = start + mask_size end = start + mask_size
train_example = mask_axis(train_example, "frequency", start, end) spectrogram = mask_axis(spectrogram, "frequency", start, end)
return train_example return example.assign(spectrogram=spectrogram)
AUGMENTATIONS: List[Augmentation] = [ class AugmentationsConfig(BaseConfig):
select_random_subclip, subclip: SubclipConfig = Field(default_factory=SubclipConfig)
add_echo, mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
scale_volume, echo: EchoAugmentationConfig = Field(
mask_time, default_factory=EchoAugmentationConfig
mask_frequency, )
] volume: VolumeAugmentationConfig = Field(
default_factory=VolumeAugmentationConfig
)
warp: WarpAugmentationConfig = Field(
default_factory=WarpAugmentationConfig
)
time_mask: TimeMaskAugmentationConfig = Field(
default_factory=TimeMaskAugmentationConfig
)
frequency_mask: FrequencyMaskAugmentationConfig = Field(
default_factory=FrequencyMaskAugmentationConfig
)
def should_apply(config: AugmentationConfig) -> bool:
if not config.enable:
return False
return np.random.uniform() < config.probability
def augment_example(
example: xr.Dataset,
config: AugmentationsConfig,
preprocessing_config: Optional[PreprocessingConfig] = None,
others: Optional[Callable[[], xr.Dataset]] = None,
) -> xr.Dataset:
if config.subclip.enable:
example = select_random_subclip(
example,
duration=config.subclip.duration,
)
if should_apply(config.mix) and others is not None:
other = others()
if config.subclip.enable:
other = select_random_subclip(
other,
duration=config.subclip.duration,
)
example = mix_examples(
example,
other,
min_weight=config.mix.min_weight,
max_weight=config.mix.max_weight,
config=preprocessing_config,
)
if should_apply(config.echo):
example = add_echo(
example,
max_delay=config.echo.max_delay,
min_weight=config.echo.min_weight,
max_weight=config.echo.max_weight,
config=preprocessing_config,
)
if should_apply(config.volume):
example = scale_volume(
example,
max_scaling=config.volume.max_scaling,
min_scaling=config.volume.min_scaling,
)
if should_apply(config.warp):
example = warp_spectrogram(
example,
delta=config.warp.delta,
)
if should_apply(config.time_mask):
example = mask_time(
example,
max_perc=config.time_mask.max_perc,
max_mask=config.time_mask.max_masks,
)
if should_apply(config.frequency_mask):
example = mask_frequency(
example,
max_perc=config.frequency_mask.max_perc,
max_masks=config.frequency_mask.max_masks,
)
return example

View File

@ -1,12 +1,14 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union from typing import NamedTuple, Optional, Sequence, Union
import numpy as np
import torch import torch
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.train.augmentations import AugmentationsConfig, augment_example
from batdetect2.train.preprocess import PreprocessingConfig from batdetect2.train.preprocess import PreprocessingConfig
__all__ = [ __all__ = [
@ -34,21 +36,44 @@ class LabeledDataset(Dataset):
def __init__( def __init__(
self, self,
filenames: Sequence[PathLike], filenames: Sequence[PathLike],
transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None, augment: bool = False,
preprocessing_config: Optional[PreprocessingConfig] = None,
augmentation_config: Optional[AugmentationsConfig] = None,
): ):
self.filenames = filenames self.filenames = filenames
self.transform = transform self.augment = augment
self.preprocessing_config = (
preprocessing_config or PreprocessingConfig()
)
self.agumentation_config = augmentation_config or AugmentationsConfig()
def __len__(self): def __len__(self):
return len(self.filenames) return len(self.filenames)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
data = self.load(idx) dataset = self.get_dataset(idx)
if self.augment:
dataset = augment_example(
dataset,
self.agumentation_config,
preprocessing_config=self.preprocessing_config,
others=self.get_random_example,
)
return TrainExample( return TrainExample(
spec=data["spectrogram"], spec=torch.tensor(
detection_heatmap=data["detection"], dataset["spectrogram"].values.astype(np.float32)
class_heatmap=data["class"], ).unsqueeze(0),
size_heatmap=data["size"], detection_heatmap=torch.tensor(
dataset["detection"].values.astype(np.float32)
),
class_heatmap=torch.tensor(
dataset["class"].values.astype(np.float32)
),
size_heatmap=torch.tensor(
dataset["size"].values.astype(np.float32)
),
idx=torch.tensor(idx), idx=torch.tensor(idx),
) )
@ -56,44 +81,14 @@ class LabeledDataset(Dataset):
def from_directory(cls, directory: PathLike, extension: str = ".nc"): def from_directory(cls, directory: PathLike, extension: str = ".nc"):
return cls(get_files(directory, extension)) return cls(get_files(directory, extension))
def load(self, idx) -> Dict[str, torch.Tensor]: def get_random_example(self) -> xr.Dataset:
dataset = self.get_dataset(idx) idx = np.random.randint(0, len(self))
return { return self.get_dataset(idx)
"spectrogram": torch.tensor(
dataset["spectrogram"].values
).unsqueeze(0),
"detection": torch.tensor(dataset["detection"].values),
"class": torch.tensor(dataset["class"].values),
"size": torch.tensor(dataset["size"].values),
}
def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset: def get_dataset(self, idx) -> xr.Dataset:
if self.transform is not None:
return self.transform(dataset)
return dataset
def get_dataset(self, idx):
return xr.open_dataset(self.filenames[idx]) return xr.open_dataset(self.filenames[idx])
def get_spectrogram(self, idx): def get_clip_annotation(self, idx) -> data.ClipAnnotation:
return xr.open_dataset(self.filenames[idx])["spectrogram"] return data.ClipAnnotation.model_validate_json(
self.get_dataset(idx).attrs["clip_annotation"]
def get_detection_mask(self, idx): )
return xr.open_dataset(self.filenames[idx])["detection"]
def get_class_mask(self, idx):
return xr.open_dataset(self.filenames[idx])["class"]
def get_size_mask(self, idx):
return xr.open_dataset(self.filenames[idx])["size"]
def get_clip_annotation(self, idx):
filename = self.filenames[idx]
dataset = xr.open_dataset(filename)
clip_annotation = dataset.attrs["clip_annotation"]
return data.ClipAnnotation.model_validate_json(clip_annotation)
def get_preprocessing_configuration(self, idx):
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"]
return PreprocessingConfig.model_validate_json(config)

View File

@ -1,18 +1,14 @@
from typing import Sequence, Tuple from typing import Callable, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from scipy.ndimage import gaussian_filter from scipy.ndimage import gaussian_filter
from soundevent import arrays, data, geometry from soundevent import arrays, data, geometry
from soundevent.geometry.operations import Positions from soundevent.geometry.operations import Positions
from soundevent.types import ClassMapper
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
__all__ = [ __all__ = ["generate_heatmaps"]
"ClassMapper",
"generate_heatmaps",
]
class HeatmapsConfig(BaseConfig): class HeatmapsConfig(BaseConfig):
@ -25,7 +21,8 @@ class HeatmapsConfig(BaseConfig):
def generate_heatmaps( def generate_heatmaps(
sound_events: Sequence[data.SoundEventAnnotation], sound_events: Sequence[data.SoundEventAnnotation],
spec: xr.DataArray, spec: xr.DataArray,
class_mapper: ClassMapper, class_names: List[str],
encoder: Callable[[data.SoundEventAnnotation], Optional[str]],
target_sigma: float = 3.0, target_sigma: float = 3.0,
position: Positions = "bottom-left", position: Positions = "bottom-left",
time_scale: float = 1000.0, time_scale: float = 1000.0,
@ -42,10 +39,10 @@ def generate_heatmaps(
# Initialize heatmaps # Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype) detection_heatmap = xr.zeros_like(spec, dtype=dtype)
class_heatmap = xr.DataArray( class_heatmap = xr.DataArray(
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype), data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
dims=["category", *spec.dims], dims=["category", *spec.dims],
coords={ coords={
"category": [*class_mapper.class_labels], "category": [*class_names],
**spec.coords, **spec.coords,
}, },
) )
@ -92,7 +89,7 @@ def generate_heatmaps(
) )
# Get the class name of the sound event # Get the class name of the sound event
class_name = class_mapper.encode(sound_event_annotation) class_name = encoder(sound_event_annotation)
if class_name is None: if class_name is None:
# If the label is None skip the sound event # If the label is None skip the sound event

View File

@ -1,56 +0,0 @@
import pytorch_lightning as L
from torch import Tensor, optim
from batdetect2.models.typing import DetectionModel, ModelOutput
from batdetect2.train import losses
from batdetect2.train.dataset import TrainExample
__all__ = [
"LitDetectorModel",
]
class LitDetectorModel(L.LightningModule):
model: DetectionModel
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
super().__init__()
self.model = model
self.learning_rate = learning_rate
def compute_loss(
self,
outputs: ModelOutput,
batch: TrainExample,
) -> Tensor:
detection_loss = losses.focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
)
size_loss = losses.bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = losses.focal_loss(
outputs.class_probs,
batch.class_heatmap,
valid_mask=valid_mask,
)
return detection_loss + size_loss + classification_loss
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
outputs: ModelOutput = self.model(batch.spec)
loss = self.compute_loss(outputs, batch)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
return [optimizer], [scheduler]

View File

@ -1,7 +1,16 @@
from typing import Optional from typing import NamedTuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.models.typing import ModelOutput
from batdetect2.train.dataset import TrainExample
class SizeLossConfig(BaseConfig):
weight: float = 0.1
def bbox_size_loss( def bbox_size_loss(
@ -17,6 +26,11 @@ def bbox_size_loss(
) )
class FocalLossConfig(BaseConfig):
beta: float = 4
alpha: float = 2
def focal_loss( def focal_loss(
pred: torch.Tensor, pred: torch.Tensor,
gt: torch.Tensor, gt: torch.Tensor,
@ -44,7 +58,7 @@ def focal_loss(
) )
if weights is not None: if weights is not None:
pos_loss = pos_loss * weights pos_loss = pos_loss * torch.tensor(weights)
# neg_loss = neg_loss*weights # neg_loss = neg_loss*weights
if valid_mask is not None: if valid_mask is not None:
@ -75,3 +89,71 @@ def mse_loss(
else: else:
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum() op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
return op return op
class DetectionLossConfig(BaseConfig):
weight: float = 1.0
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
class ClassificationLossConfig(BaseConfig):
weight: float = 2.0
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
class_weights: Optional[list[float]] = None
class LossConfig(BaseConfig):
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
classification: ClassificationLossConfig = Field(
default_factory=ClassificationLossConfig
)
class Losses(NamedTuple):
detection: torch.Tensor
size: torch.Tensor
classification: torch.Tensor
total: torch.Tensor
def compute_loss(
batch: TrainExample,
outputs: ModelOutput,
conf: LossConfig,
class_weights: Optional[torch.Tensor] = None,
) -> Losses:
detection_loss = focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
beta=conf.detection.focal.beta,
alpha=conf.detection.focal.alpha,
)
size_loss = bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = focal_loss(
outputs.class_probs,
batch.class_heatmap,
weights=class_weights,
valid_mask=valid_mask,
beta=conf.classification.focal.beta,
alpha=conf.classification.focal.alpha,
)
total = (
detection_loss * conf.detection.weight
+ size_loss * conf.size.weight
+ classification_loss * conf.classification.weight
)
return Losses(
detection=detection_loss,
size=size_loss,
classification=classification_loss,
total=total,
)

View File

@ -0,0 +1,98 @@
from typing import Optional
import pytorch_lightning as L
import torch
from pydantic import Field
from torch import optim
from batdetect2.configs import BaseConfig
from batdetect2.models import (
BBoxHead,
ClassifierHead,
ModelConfig,
get_backbone,
)
from batdetect2.models.typing import ModelOutput
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.train.dataset import TrainExample
from batdetect2.train.losses import LossConfig, compute_loss
from batdetect2.train.targets import TargetConfig
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainingConfig(BaseConfig):
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
class ModuleConfig(BaseConfig):
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
backbone: ModelConfig = Field(default_factory=ModelConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class DetectorModel(L.LightningModule):
config: ModuleConfig
def __init__(
self,
config: Optional[ModuleConfig] = None,
):
super().__init__()
self.config = config or ModuleConfig()
self.save_hyperparameters()
self.backbone = get_backbone(
input_height=self.config.preprocessing.spectrogram.size.height,
config=self.config.backbone,
)
self.classifier = ClassifierHead(
num_classes=len(self.config.targets.classes),
in_channels=self.backbone.out_channels,
)
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
conf = self.training_config.loss.classification
self.class_weights = (
torch.tensor(conf.class_weights) if conf.class_weights else None
)
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.backbone(spec)
detection_probs, classification_probs = self.classifier(features)
size_preds = self.bbox(features)
return ModelOutput(
detection_probs=detection_probs,
size_preds=size_preds,
class_probs=classification_probs,
features=features,
)
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
losses = compute_loss(
batch,
outputs,
conf=self.config.train.loss,
class_weights=self.class_weights,
)
return losses.total
def configure_optimizers(self):
conf = self.config.train.optimizer
optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=conf.t_max,
)
return [optimizer], [scheduler]

View File

@ -20,8 +20,9 @@ from batdetect2.preprocess import (
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
from batdetect2.train.targets import ( from batdetect2.train.targets import (
TargetConfig, TargetConfig,
build_class_mapper, build_encoder,
build_sound_event_filter, build_sound_event_filter,
get_class_names,
) )
PathLike = Union[Path, str, os.PathLike] PathLike = Union[Path, str, os.PathLike]
@ -64,11 +65,15 @@ def generate_train_example(
selected_events = [ selected_events = [
event for event in clip_annotation.sound_events if filter_fn(event) event for event in clip_annotation.sound_events if filter_fn(event)
] ]
class_mapper = build_class_mapper(config.target.classes)
class_names = get_class_names(config.target.classes)
encoder = build_encoder(config.target.classes)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
selected_events, selected_events,
spectrogram, spectrogram,
class_mapper, class_names,
encoder,
target_sigma=config.heatmaps.sigma, target_sigma=config.heatmaps.sigma,
position=config.heatmaps.position, position=config.heatmaps.position,
time_scale=config.heatmaps.time_scale, time_scale=config.heatmaps.time_scale,

View File

@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Set
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.types import ClassMapper
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.terms import TagInfo, get_tag_from_info from batdetect2.terms import TagInfo, get_tag_from_info
@ -12,11 +11,25 @@ from batdetect2.terms import TagInfo, get_tag_from_info
class TargetConfig(BaseConfig): class TargetConfig(BaseConfig):
"""Configuration for target generation.""" """Configuration for target generation."""
classes: List[TagInfo] = Field(default_factory=list) classes: List[TagInfo] = Field(
generic_class: Optional[TagInfo] = None default_factory=lambda: [
TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST
]
)
generic_class: Optional[TagInfo] = Field(
default_factory=lambda: TagInfo(key="class", value="Bat")
)
include: Optional[List[TagInfo]] = None include: Optional[List[TagInfo]] = Field(
exclude: Optional[List[TagInfo]] = None default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
)
exclude: Optional[List[TagInfo]] = Field(
default_factory=lambda: [
TagInfo(key="class", value=""),
TagInfo(key="class", value=" "),
TagInfo(key="class", value="Unknown"),
]
)
def build_sound_event_filter( def build_sound_event_filter(
@ -36,13 +49,54 @@ def build_sound_event_filter(
) )
def build_class_mapper(classes: List[TagInfo]) -> ClassMapper: def get_tag_label(tag_info: TagInfo) -> str:
target_tags = [get_tag_from_info(tag) for tag in classes] return tag_info.label if tag_info.label else tag_info.value
labels = [tag.label if tag.label else tag.value for tag in classes]
return GenericMapper(
classes=target_tags, def get_class_names(classes: List[TagInfo]) -> List[str]:
labels=labels, return sorted({get_tag_label(tag) for tag in classes})
)
def build_encoder(
classes: List[TagInfo],
) -> Callable[[data.SoundEventAnnotation], Optional[str]]:
target_tags = set([get_tag_from_info(tag) for tag in classes])
tag_mapping = {
tag: get_tag_label(tag_info)
for tag, tag_info in zip(target_tags, classes)
}
def encoder(
sound_event_annotation: data.SoundEventAnnotation,
) -> Optional[str]:
tags = set(sound_event_annotation.tags)
intersection = tags & target_tags
if not intersection:
return None
first = intersection.pop()
return tag_mapping[first]
return encoder
def build_decoder(
classes: List[TagInfo],
) -> Callable[[str], List[data.Tag]]:
target_tags = set([get_tag_from_info(tag) for tag in classes])
tag_mapping = {
get_tag_label(tag_info): tag
for tag, tag_info in zip(target_tags, classes)
}
def decoder(label: str) -> List[data.Tag]:
tag = tag_mapping.get(label)
return [tag] if tag else []
return decoder
def filter_sound_event( def filter_sound_event(
@ -61,39 +115,22 @@ def filter_sound_event(
return True return True
class GenericMapper(ClassMapper): DEFAULT_SPECIES_LIST = [
"""Generic class mapper configuration.""" "Barbastellus barbastellus",
"Eptesicus serotinus",
def __init__( "Myotis alcathoe",
self, "Myotis bechsteinii",
classes: List[data.Tag], "Myotis brandtii",
labels: List[str], "Myotis daubentonii",
): "Myotis mystacinus",
if not len(classes) == len(labels): "Myotis nattereri",
raise ValueError("Number of targets and class labels must match.") "Nyctalus leisleri",
"Nyctalus noctula",
self.targets = set(classes) "Pipistrellus nathusii",
self.class_labels = list(dict.fromkeys(labels)) "Pipistrellus pipistrellus",
"Pipistrellus pygmaeus",
self._mapping = {tag: label for tag, label in zip(classes, labels)} "Plecotus auritus",
self._inverse_mapping = { "Plecotus austriacus",
label: tag for tag, label in zip(classes, labels) "Rhinolophus ferrumequinum",
} "Rhinolophus hipposideros",
]
def encode(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> Optional[str]:
tags = set(sound_event_annotation.tags)
intersection = tags & self.targets
if not intersection:
return None
tag = intersection.pop()
return self._mapping[tag]
def decode(self, label: str) -> List[data.Tag]:
if label not in self._inverse_mapping:
return []
return [self._inverse_mapping[label]]

View File

@ -91,7 +91,7 @@ def recording_factory(wav_factory: Callable[..., Path]):
recording_id: Optional[uuid.UUID] = None, recording_id: Optional[uuid.UUID] = None,
duration: float = 1, duration: float = 1,
channels: int = 1, channels: int = 1,
samplerate: int = 44100, samplerate: int = 256_000,
time_expansion: float = 1, time_expansion: float = 1,
) -> data.Recording: ) -> data.Recording:
path = path or wav_factory( path = path or wav_factory(

View File

@ -80,11 +80,12 @@ def test_spectrogram_generation_hasnt_changed(
max_freq = 120_000 max_freq = 120_000
fft_overlap = 0.75 fft_overlap = 0.75
scale = None
if spec_scale == "log": if spec_scale == "log":
scale = "log" scale = preprocess.LogScaleConfig()
elif spec_scale == "pcen": elif spec_scale == "pcen":
scale = preprocess.PcenConfig() scale = preprocess.PcenScaleConfig()
else:
scale = preprocess.AmplitudeScaleConfig()
config = preprocess.SpectrogramConfig( config = preprocess.SpectrogramConfig(
fft=preprocess.FFTConfig( fft=preprocess.FFTConfig(

View File

View File

@ -0,0 +1,70 @@
from collections.abc import Callable
import xarray as xr
from soundevent import data
from batdetect2.train.augmentations import (
add_echo,
mix_examples,
select_random_subclip,
)
from batdetect2.train.preprocess import (
TrainPreprocessingConfig,
generate_train_example,
)
def test_mix_examples(
recording_factory: Callable[..., data.Recording],
):
recording1 = recording_factory()
recording2 = recording_factory()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
config = TrainPreprocessingConfig()
example1 = generate_train_example(clip_annotation_1, config)
example2 = generate_train_example(clip_annotation_2, config)
mixed = mix_examples(example1, example2, config=config.preprocessing)
assert mixed["spectrogram"].shape == example1["spectrogram"].shape
assert mixed["detection"].shape == example1["detection"].shape
assert mixed["size"].shape == example1["size"].shape
assert mixed["class"].shape == example1["class"].shape
def test_add_echo(
recording_factory: Callable[..., data.Recording],
):
recording1 = recording_factory()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig()
original = generate_train_example(clip_annotation_1, config)
with_echo = add_echo(original, config=config.preprocessing)
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
xr.testing.assert_identical(with_echo["size"], original["size"])
xr.testing.assert_identical(with_echo["class"], original["class"])
xr.testing.assert_identical(with_echo["detection"], original["detection"])
def test_selected_random_subclip_has_the_correct_width(
recording_factory: Callable[..., data.Recording],
):
recording1 = recording_factory()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig()
original = generate_train_example(clip_annotation_1, config)
subclip = select_random_subclip(original, width=100)
assert subclip["spectrogram"].shape[1] == 100

View File

@ -5,7 +5,7 @@ import xarray as xr
from soundevent import data from soundevent import data
from soundevent.types import ClassMapper from soundevent.types import ClassMapper
from batdetect2.data.labels import generate_heatmaps from batdetect2.train.labels import generate_heatmaps
recording = data.Recording( recording = data.Recording(
samplerate=256_000, samplerate=256_000,
@ -60,7 +60,7 @@ def test_generated_heatmaps_have_correct_dimensions():
class_mapper = Mapper() class_mapper = Mapper()
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation, clip_annotation.sound_events,
spec, spec,
class_mapper, class_mapper,
) )
@ -109,7 +109,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
class_mapper = Mapper() class_mapper = Mapper()
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation, clip_annotation.sound_events,
spec, spec,
class_mapper, class_mapper,
) )