diff --git a/batdetect2/compat/params.py b/batdetect2/compat/params.py index 65ce3e9..0910d09 100644 --- a/batdetect2/compat/params.py +++ b/batdetect2/compat/params.py @@ -1,7 +1,7 @@ from batdetect2.preprocess import ( AmplitudeScaleConfig, AudioConfig, - FFTConfig, + STFTConfig, FrequencyConfig, LogScaleConfig, PcenScaleConfig, @@ -40,7 +40,7 @@ def get_preprocessing_config(params: dict) -> PreprocessingConfig: duration=None, ), spectrogram=SpectrogramConfig( - fft=FFTConfig( + stft=STFTConfig( window_duration=params["fft_win_length"], window_overlap=params["fft_overlap"], window_fn="hann", diff --git a/batdetect2/models/__init__.py b/batdetect2/models/__init__.py index 5883cd8..71caad1 100644 --- a/batdetect2/models/__init__.py +++ b/batdetect2/models/__init__.py @@ -1,10 +1,12 @@ from enum import Enum +from typing import Optional, Tuple from batdetect2.configs import BaseConfig from batdetect2.models.backbones import ( Net2DFast, Net2DFastNoAttn, Net2DFastNoCoordConv, + Net2DPlain, ) from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.typing import BackboneModel @@ -24,31 +26,57 @@ class ModelType(str, Enum): Net2DFast = "Net2DFast" Net2DFastNoAttn = "Net2DFastNoAttn" Net2DFastNoCoordConv = "Net2DFastNoCoordConv" + Net2DPlain = "Net2DPlain" class ModelConfig(BaseConfig): name: ModelType = ModelType.Net2DFast - num_features: int = 32 + input_height: int = 128 + encoder_channels: Tuple[int, ...] = (1, 32, 64, 128) + bottleneck_channels: int = 256 + decoder_channels: Tuple[int, ...] = (256, 64, 32, 32) + out_channels: int = 32 def get_backbone( - config: ModelConfig, - input_height: int = 128, + config: Optional[ModelConfig] = None, ) -> BackboneModel: + config = config or ModelConfig() + if config.name == ModelType.Net2DFast: return Net2DFast( - input_height=input_height, - num_features=config.num_features, + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, ) - elif config.name == ModelType.Net2DFastNoAttn: + + if config.name == ModelType.Net2DFastNoAttn: return Net2DFastNoAttn( - num_features=config.num_features, - input_height=input_height, + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, ) - elif config.name == ModelType.Net2DFastNoCoordConv: + + if config.name == ModelType.Net2DFastNoCoordConv: return Net2DFastNoCoordConv( - num_features=config.num_features, - input_height=input_height, + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, ) - else: - raise ValueError(f"Unknown model type: {config.name}") + + if config.name == ModelType.Net2DPlain: + return Net2DPlain( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + raise ValueError(f"Unknown model type: {config.name}") diff --git a/batdetect2/models/backbones.py b/batdetect2/models/backbones.py index 0aa266a..5083016 100644 --- a/batdetect2/models/backbones.py +++ b/batdetect2/models/backbones.py @@ -1,16 +1,16 @@ -from typing import Tuple +from typing import Sequence, Tuple import torch -import torch.fft import torch.nn.functional as F -from torch import nn from batdetect2.models.blocks import ( - ConvBlockDownCoordF, - ConvBlockDownStandard, - ConvBlockUpF, - ConvBlockUpStandard, + ConvBlock, + Decoder, + DownscalingLayer, + Encoder, SelfAttention, + UpscalingLayer, + VerticalConv, ) from batdetect2.models.typing import BackboneModel @@ -21,303 +21,164 @@ __all__ = [ ] -class Net2DFast(BackboneModel): +class Net2DPlain(BackboneModel): + downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard" + upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard" + def __init__( self, - num_features: int, input_height: int = 128, + encoder_channels: Sequence[int] = (1, 32, 64, 128), + bottleneck_channels: int = 256, + decoder_channels: Sequence[int] = (256, 64, 32, 32), + out_channels: int = 32, ): super().__init__() - self.num_features = num_features - self.input_height = input_height - self.bottleneck_height = self.input_height // 32 - # encoder - self.conv_dn_0 = ConvBlockDownCoordF( - 1, - self.num_features // 4, - self.input_height, - kernel_size=3, - pad_size=1, - stride=1, + self.input_height = input_height + self.encoder_channels = tuple(encoder_channels) + self.decoder_channels = tuple(decoder_channels) + self.out_channels = out_channels + + if len(encoder_channels) != len(decoder_channels): + raise ValueError( + f"Mismatched encoder and decoder channel lists. " + f"The encoder has {len(encoder_channels)} channels " + f"(implying {len(encoder_channels) - 1} layers), " + f"while the decoder has {len(decoder_channels)} channels " + f"(implying {len(decoder_channels) - 1} layers). " + f"These lengths must be equal." + ) + + self.divide_factor = 2 ** (len(encoder_channels) - 1) + if self.input_height % self.divide_factor != 0: + raise ValueError( + f"Input height ({self.input_height}) must be divisible by " + f"the divide factor ({self.divide_factor}). " + f"This ensures proper upscaling after downscaling to recover " + f"the original input height." + ) + + self.encoder = Encoder( + channels=encoder_channels, + input_height=self.input_height, + layer_type=self.downscaling_layer_type, ) - self.conv_dn_1 = ConvBlockDownCoordF( - self.num_features // 4, - self.num_features // 2, - self.input_height // 2, - kernel_size=3, - pad_size=1, - stride=1, + + self.conv_same_1 = ConvBlock( + in_channels=encoder_channels[-1], + out_channels=bottleneck_channels, ) - self.conv_dn_2 = ConvBlockDownCoordF( - self.num_features // 2, - self.num_features, - self.input_height // 4, - kernel_size=3, - pad_size=1, - stride=1, - ) - self.conv_dn_3 = nn.Conv2d( - self.num_features, - self.num_features * 2, - 3, - padding=1, - ) - self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2) # bottleneck - self.conv_1d = nn.Conv2d( - self.num_features * 2, - self.num_features * 2, - (self.input_height // 8, 1), - padding=0, - ) - self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2) - self.att = SelfAttention(self.num_features * 2, self.num_features * 2) - - # decoder - self.conv_up_2 = ConvBlockUpF( - self.num_features * 2, - self.num_features // 2, - self.input_height // 8, - ) - self.conv_up_3 = ConvBlockUpF( - self.num_features // 2, - self.num_features // 4, - self.input_height // 4, - ) - self.conv_up_4 = ConvBlockUpF( - self.num_features // 4, - self.num_features // 4, - self.input_height // 2, + self.conv_vert = VerticalConv( + in_channels=bottleneck_channels, + out_channels=bottleneck_channels, + input_height=self.input_height // (2**self.encoder.depth), ) - self.conv_op = nn.Conv2d( - self.num_features // 4, - self.num_features // 4, - kernel_size=3, - padding=1, + self.decoder = Decoder( + channels=decoder_channels, + input_height=self.input_height, + layer_type=self.upscaling_layer_type, ) - self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) - self.out_channels = self.num_features // 4 - - def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]: - h, w = spec.shape[2:] - h_pad = (32 - h % 32) % 32 - w_pad = (32 - w % 32) % 32 - return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad + self.conv_same_2 = ConvBlock( + in_channels=decoder_channels[-1], + out_channels=out_channels, + ) def forward(self, spec: torch.Tensor) -> torch.Tensor: - # encoder - spec, h_pad, w_pad = self.pad_adjust(spec) + spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor) - x1 = self.conv_dn_0(spec) - x2 = self.conv_dn_1(x1) - x3 = self.conv_dn_2(x2) - x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) + # encoder + residuals = self.encoder(spec) + residuals[-1] = self.conv_same_1(residuals[-1]) # bottleneck - x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) - x = self.att(x) - x = x.repeat([1, 1, self.bottleneck_height * 4, 1]) + x = self.conv_vert(residuals[-1]) + x = x.repeat([1, 1, residuals[-1].shape[-2], 1]) # decoder - x = self.conv_up_2(x + x3) - x = self.conv_up_3(x + x2) - x = self.conv_up_4(x + x1) + x = self.decoder(x, residuals=residuals) # Restore original size - if h_pad > 0: - x = x[:, :, :-h_pad, :] + x = restore_pad(x, h_pad=h_pad, w_pad=w_pad) - if w_pad > 0: - x = x[:, :, :, :-w_pad] - - return F.relu_(self.conv_op_bn(self.conv_op(x))) + return self.conv_same_2(x) -class Net2DFastNoAttn(BackboneModel): +class Net2DFast(Net2DPlain): + downscaling_layer_type = "ConvBlockDownCoordF" + upscaling_layer_type = "ConvBlockUpF" + def __init__( self, - num_features: int, input_height: int = 128, + encoder_channels: Sequence[int] = (1, 32, 64, 128), + bottleneck_channels: int = 256, + decoder_channels: Sequence[int] = (256, 64, 32, 32), + out_channels: int = 32, ): - super().__init__() - - self.num_features = num_features - self.input_height = input_height - self.bottleneck_height = self.input_height // 32 - - self.conv_dn_0 = ConvBlockDownCoordF( - 1, - self.num_features // 4, - self.input_height, - kernel_size=3, - pad_size=1, - stride=1, - ) - self.conv_dn_1 = ConvBlockDownCoordF( - self.num_features // 4, - self.num_features // 2, - self.input_height // 2, - kernel_size=3, - pad_size=1, - stride=1, - ) - self.conv_dn_2 = ConvBlockDownCoordF( - self.num_features // 2, - self.num_features, - self.input_height // 4, - kernel_size=3, - pad_size=1, - stride=1, - ) - self.conv_dn_3 = nn.Conv2d( - self.num_features, - self.num_features * 2, - 3, - padding=1, - ) - self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2) - - self.conv_1d = nn.Conv2d( - self.num_features * 2, - self.num_features * 2, - (self.input_height // 8, 1), - padding=0, - ) - self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2) - - self.conv_up_2 = ConvBlockUpF( - self.num_features * 2, - self.num_features // 2, - self.input_height // 8, - ) - self.conv_up_3 = ConvBlockUpF( - self.num_features // 2, - self.num_features // 4, - self.input_height // 4, - ) - self.conv_up_4 = ConvBlockUpF( - self.num_features // 4, - self.num_features // 4, - self.input_height // 2, + super().__init__( + input_height=input_height, + encoder_channels=encoder_channels, + bottleneck_channels=bottleneck_channels, + decoder_channels=decoder_channels, + out_channels=out_channels, ) - self.conv_op = nn.Conv2d( - self.num_features // 4, - self.num_features // 4, - kernel_size=3, - padding=1, - ) - self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) - self.out_channels = self.num_features // 4 + self.att = SelfAttention(bottleneck_channels, bottleneck_channels) def forward(self, spec: torch.Tensor) -> torch.Tensor: - x1 = self.conv_dn_0(spec) - x2 = self.conv_dn_1(x1) - x3 = self.conv_dn_2(x2) - x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) + spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor) - x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) - x = x.repeat([1, 1, self.bottleneck_height * 4, 1]) + # encoder + residuals = self.encoder(spec) + residuals[-1] = self.conv_same_1(residuals[-1]) - x = self.conv_up_2(x + x3) - x = self.conv_up_3(x + x2) - x = self.conv_up_4(x + x1) - - return F.relu_(self.conv_op_bn(self.conv_op(x))) - - -class Net2DFastNoCoordConv(BackboneModel): - def __init__( - self, - num_features: int, - input_height: int = 128, - ): - super().__init__() - - self.num_features = num_features - self.input_height = input_height - self.bottleneck_height = self.input_height // 32 - - self.conv_dn_0 = ConvBlockDownStandard( - 1, - self.num_features // 4, - kernel_size=3, - pad_size=1, - stride=1, - ) - self.conv_dn_1 = ConvBlockDownStandard( - self.num_features // 4, - self.num_features // 2, - kernel_size=3, - pad_size=1, - stride=1, - ) - self.conv_dn_2 = ConvBlockDownStandard( - self.num_features // 2, - self.num_features, - kernel_size=3, - pad_size=1, - stride=1, - ) - self.conv_dn_3 = nn.Conv2d( - self.num_features, - self.num_features * 2, - 3, - padding=1, - ) - self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2) - - self.conv_1d = nn.Conv2d( - self.num_features * 2, - self.num_features * 2, - (self.input_height // 8, 1), - padding=0, - ) - self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2) - - self.att = SelfAttention(self.num_features * 2, self.num_features * 2) - - self.conv_up_2 = ConvBlockUpStandard( - self.num_features * 2, - self.num_features // 2, - self.input_height // 8, - ) - self.conv_up_3 = ConvBlockUpStandard( - self.num_features // 2, - self.num_features // 4, - self.input_height // 4, - ) - self.conv_up_4 = ConvBlockUpStandard( - self.num_features // 4, - self.num_features // 4, - self.input_height // 2, - ) - - self.conv_op = nn.Conv2d( - self.num_features // 4, - self.num_features // 4, - kernel_size=3, - padding=1, - ) - self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4) - self.out_channels = self.num_features // 4 - - def forward(self, spec: torch.Tensor) -> torch.Tensor: - x1 = self.conv_dn_0(spec) - x2 = self.conv_dn_1(x1) - x3 = self.conv_dn_2(x2) - x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) - - x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) + # bottleneck + x = self.conv_vert(residuals[-1]) x = self.att(x) - x = x.repeat([1, 1, self.bottleneck_height * 4, 1]) + x = x.repeat([1, 1, residuals[-1].shape[-2], 1]) - x = self.conv_up_2(x + x3) - x = self.conv_up_3(x + x2) - x = self.conv_up_4(x + x1) + # decoder + x = self.decoder(x, residuals=residuals) - return F.relu_(self.conv_op_bn(self.conv_op(x))) + # Restore original size + x = restore_pad(x, h_pad=h_pad, w_pad=w_pad) + + return self.conv_same_2(x) + + +class Net2DFastNoAttn(Net2DPlain): + downscaling_layer_type = "ConvBlockDownCoordF" + upscaling_layer_type = "ConvBlockUpF" + + +class Net2DFastNoCoordConv(Net2DFast): + downscaling_layer_type = "ConvBlockDownStandard" + upscaling_layer_type = "ConvBlockUpStandard" + + +def pad_adjust( + spec: torch.Tensor, + factor: int = 32, +) -> Tuple[torch.Tensor, int, int]: + h, w = spec.shape[2:] + h_pad = -h % factor + w_pad = -w % factor + return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad + + +def restore_pad( + x: torch.Tensor, h_pad: int = 0, w_pad: int = 0 +) -> torch.Tensor: + # Restore original size + if h_pad > 0: + x = x[:, :, :-h_pad, :] + + if w_pad > 0: + x = x[:, :, :, :-w_pad] + + return x diff --git a/batdetect2/models/blocks.py b/batdetect2/models/blocks.py index 080b0fa..1c8dda0 100644 --- a/batdetect2/models/blocks.py +++ b/batdetect2/models/blocks.py @@ -4,29 +4,35 @@ All these classes are subclasses of `torch.nn.Module` and can be used to build complex neural network architectures. """ -from typing import Tuple +import sys +from typing import Iterable, List, Literal, Sequence, Tuple import torch import torch.nn.functional as F from torch import nn -from batdetect2.configs import BaseConfig +if sys.version_info >= (3, 10): + from itertools import pairwise +else: + + def pairwise(iterable: Sequence) -> Iterable: + for x, y in zip(iterable[:-1], iterable[1:]): + yield x, y + __all__ = [ - "SelfAttention", + "ConvBlock", "ConvBlockDownCoordF", "ConvBlockDownStandard", "ConvBlockUpF", "ConvBlockUpStandard", + "SelfAttention", + "VerticalConv", + "DownscalingLayer", + "UpscalingLayer", ] -class SelfAttentionConfig(BaseConfig): - temperature: float = 1.0 - input_channels: int = 128 - attention_channels: int = 128 - - class SelfAttention(nn.Module): """Self-Attention module. @@ -76,13 +82,64 @@ class SelfAttention(nn.Module): return op -class ConvBlockDownCoordFConfig(BaseConfig): - in_channels: int - out_channels: int - input_height: int - kernel_size: int = 3 - pad_size: int = 1 - stride: int = 1 +class ConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + pad_size: int = 1, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=pad_size, + ) + self.conv_bn = nn.BatchNorm2d(out_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.relu_(self.conv_bn(self.conv(x))) + + +class VerticalConv(nn.Module): + """Convolutional layer over full height. + + This layer applies a convolution that captures information across the + entire height of the input image. It uses a kernel with the same height as + the input, effectively condensing the vertical information into a single + output row. + + More specifically: + + * **Input:** (B, C, H, W) where B is the batch size, C is the number of + input channels, H is the image height, and W is the image width. + * **Kernel:** (C', H, 1) where C' is the number of output channels. + * **Output:** (B, C', 1, W) - The height dimension is 1 because the + convolution integrates information from all rows of the input. + + This process effectively extracts features that span the full height of + the input image. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + input_height: int, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(input_height, 1), + padding=0, + ) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.relu_(self.bn(self.conv(x))) class ConvBlockDownCoordF(nn.Module): @@ -124,14 +181,6 @@ class ConvBlockDownCoordF(nn.Module): return x -class ConvBlockDownStandardConfig(BaseConfig): - in_channels: int - out_channels: int - kernel_size: int = 3 - pad_size: int = 1 - stride: int = 1 - - class ConvBlockDownStandard(nn.Module): """Convolutional Block with Downsampling. @@ -158,18 +207,10 @@ class ConvBlockDownStandard(nn.Module): def forward(self, x): x = F.max_pool2d(self.conv(x), 2, 2) - x = F.relu(self.conv_bn(x), inplace=True) - return x + return F.relu(self.conv_bn(x), inplace=True) -class ConvBlockUpFConfig(BaseConfig): - inp_channels: int - out_channels: int - input_height: int - kernel_size: int = 3 - pad_size: int = 1 - up_mode: str = "bilinear" - up_scale: Tuple[int, int] = (2, 2) +DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"] class ConvBlockUpF(nn.Module): @@ -224,15 +265,6 @@ class ConvBlockUpF(nn.Module): return op -class ConvBlockUpStandardConfig(BaseConfig): - in_channels: int - out_channels: int - kernel_size: int = 3 - pad_size: int = 1 - up_mode: str = "bilinear" - up_scale: Tuple[int, int] = (2, 2) - - class ConvBlockUpStandard(nn.Module): """Convolutional Block with Upsampling. @@ -272,3 +304,143 @@ class ConvBlockUpStandard(nn.Module): op = self.conv(op) op = F.relu(self.conv_bn(op), inplace=True) return op + + +UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"] + + +def build_downscaling_layer( + in_channels: int, + out_channels: int, + input_height: int, + layer_type: DownscalingLayer, +) -> nn.Module: + if layer_type == "ConvBlockDownStandard": + return ConvBlockDownStandard( + in_channels=in_channels, + out_channels=out_channels, + ) + + if layer_type == "ConvBlockDownCoordF": + return ConvBlockDownCoordF( + in_channels=in_channels, + out_channels=out_channels, + input_height=input_height, + ) + + raise ValueError( + f"Invalid downscaling layer type {layer_type}. " + f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard" + ) + + +class Encoder(nn.Module): + def __init__( + self, + channels: Sequence[int] = (1, 32, 62, 128), + input_height: int = 128, + layer_type: Literal[ + "ConvBlockDownStandard", "ConvBlockDownCoordF" + ] = "ConvBlockDownStandard", + ): + super().__init__() + + self.channels = channels + self.input_height = input_height + + self.layers = nn.ModuleList( + [ + build_downscaling_layer( + in_channels=in_channels, + out_channels=out_channels, + input_height=input_height // (2**layer_num), + layer_type=layer_type, + ) + for layer_num, (in_channels, out_channels) in enumerate( + pairwise(channels) + ) + ] + ) + self.depth = len(self.layers) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + outputs = [] + + for layer in self.layers: + x = layer(x) + outputs.append(x) + + return outputs + + +def build_upscaling_layer( + in_channels: int, + out_channels: int, + input_height: int, + layer_type: UpscalingLayer, +) -> nn.Module: + if layer_type == "ConvBlockUpStandard": + return ConvBlockUpStandard( + in_channels=in_channels, + out_channels=out_channels, + ) + + if layer_type == "ConvBlockUpF": + return ConvBlockUpF( + in_channels=in_channels, + out_channels=out_channels, + input_height=input_height, + ) + + raise ValueError( + f"Invalid upscaling layer type {layer_type}. " + f"Valid values: ConvBlockUpStandard, ConvBlockUpF" + ) + + +class Decoder(nn.Module): + def __init__( + self, + channels: Sequence[int] = (256, 62, 32, 32), + input_height: int = 128, + layer_type: Literal[ + "ConvBlockUpStandard", "ConvBlockUpF" + ] = "ConvBlockUpStandard", + ): + super().__init__() + + self.channels = channels + self.input_height = input_height + self.depth = len(self.channels) - 1 + + self.layers = nn.ModuleList( + [ + build_upscaling_layer( + in_channels=in_channels, + out_channels=out_channels, + input_height=input_height + // (2 ** (self.depth - layer_num)), + layer_type=layer_type, + ) + for layer_num, (in_channels, out_channels) in enumerate( + pairwise(channels) + ) + ] + ) + + def forward( + self, + x: torch.Tensor, + residuals: List[torch.Tensor], + ) -> torch.Tensor: + if len(residuals) != len(self.layers): + raise ValueError( + f"Incorrect number of residuals provided. " + f"Expected {len(self.layers)} (matching the number of layers), " + f"but got {len(residuals)}." + ) + + for layer, res in zip(self.layers, residuals[::-1]): + x = layer(x + res) + + return x diff --git a/batdetect2/models/decoder.py b/batdetect2/models/decoder.py new file mode 100644 index 0000000..17bbebf --- /dev/null +++ b/batdetect2/models/decoder.py @@ -0,0 +1,17 @@ +import sys +from typing import Iterable, List, Literal, Sequence + +import torch +from torch import nn + +from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard + +if sys.version_info >= (3, 10): + from itertools import pairwise +else: + + def pairwise(iterable: Sequence) -> Iterable: + for x, y in zip(iterable[:-1], iterable[1:]): + yield x, y + + diff --git a/batdetect2/models/encoder.py b/batdetect2/models/encoder.py new file mode 100644 index 0000000..3f56130 --- /dev/null +++ b/batdetect2/models/encoder.py @@ -0,0 +1,17 @@ +import sys +from typing import Iterable, List, Literal, Sequence + +import torch +from torch import nn + +from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard + +if sys.version_info >= (3, 10): + from itertools import pairwise +else: + + def pairwise(iterable: Sequence) -> Iterable: + for x, y in zip(iterable[:-1], iterable[1:]): + yield x, y + + diff --git a/batdetect2/models/typing.py b/batdetect2/models/typing.py index c39229d..a55d329 100644 --- a/batdetect2/models/typing.py +++ b/batdetect2/models/typing.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import NamedTuple +from typing import NamedTuple, Tuple import torch import torch.nn as nn @@ -45,15 +45,27 @@ class BackboneModel(ABC, nn.Module): input_height: int """Height of the input spectrogram.""" - num_features: int - """Dimension of the feature tensor.""" + encoder_channels: Tuple[int, ...] + """Tuple specifying the number of channels for each convolutional layer + in the encoder. The length of this tuple determines the number of + encoder layers.""" + + decoder_channels: Tuple[int, ...] + """Tuple specifying the number of channels for each convolutional layer + in the decoder. The length of this tuple determines the number of + decoder layers.""" + + bottleneck_channels: int + """Number of channels in the bottleneck layer, which connects the + encoder and decoder.""" out_channels: int - """Number of output channels of the feature extractor.""" + """Number of channels in the final output feature map produced by the + backbone model.""" @abstractmethod def forward(self, spec: torch.Tensor) -> torch.Tensor: - """Forward pass of the encoder model.""" + """Forward pass of the model.""" class DetectionModel(ABC, nn.Module): diff --git a/batdetect2/preprocess/__init__.py b/batdetect2/preprocess/__init__.py index b3b4142..316a11f 100644 --- a/batdetect2/preprocess/__init__.py +++ b/batdetect2/preprocess/__init__.py @@ -13,7 +13,7 @@ from batdetect2.preprocess.audio import ( ) from batdetect2.preprocess.spectrogram import ( AmplitudeScaleConfig, - FFTConfig, + STFTConfig, FrequencyConfig, LogScaleConfig, PcenScaleConfig, @@ -27,7 +27,7 @@ __all__ = [ "AudioConfig", "ResampleConfig", "SpectrogramConfig", - "FFTConfig", + "STFTConfig", "FrequencyConfig", "PcenScaleConfig", "LogScaleConfig", diff --git a/batdetect2/preprocess/arrays.py b/batdetect2/preprocess/arrays.py new file mode 100644 index 0000000..e88a516 --- /dev/null +++ b/batdetect2/preprocess/arrays.py @@ -0,0 +1,61 @@ +import numpy as np + + +def extend_width( + array: np.ndarray, + extra: int, + axis: int = -1, + value: float = 0, +) -> np.ndarray: + dims = len(array.shape) + axis = axis % dims + pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)] + return np.pad( + array, + pad, + mode="constant", + constant_values=value, + ) + + +def make_width_divisible( + array: np.ndarray, + factor: int, + axis: int = -1, + value: float = 0, +) -> np.ndarray: + width = array.shape[axis] + + if width % factor == 0: + return array + + extra = (-width) % factor + return extend_width(array, extra, axis=axis, value=value) + + +def adjust_width( + array: np.ndarray, + width: int, + axis: int = -1, + value: float = 0, +) -> np.ndarray: + dims = len(array.shape) + axis = axis % dims + current_width = array.shape[axis] + + if current_width == width: + return array + + if current_width < width: + return extend_width( + array, + extra=width - current_width, + axis=axis, + value=value, + ) + + slices = [ + slice(None, None) if index != axis else slice(None, width) + for index in range(dims) + ] + return array[tuple(slices)] diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py index fd0d72a..bf228f5 100644 --- a/batdetect2/preprocess/spectrogram.py +++ b/batdetect2/preprocess/spectrogram.py @@ -12,7 +12,7 @@ from soundevent.arrays import operations as ops from batdetect2.configs import BaseConfig -class FFTConfig(BaseConfig): +class STFTConfig(BaseConfig): window_duration: float = Field(default=0.002, gt=0) window_overlap: float = Field(default=0.75, ge=0, lt=1) window_fn: str = "hann" @@ -24,9 +24,15 @@ class FrequencyConfig(BaseConfig): class SpecSizeConfig(BaseConfig): - height: int = 256 + height: int = 128 + """Height of the spectrogram in pixels. This value determines the + number of frequency bands and corresponds to the vertical dimension + of the spectrogram.""" + resize_factor: Optional[float] = 0.5 - divide_factor: Optional[int] = 32 + """Factor by which to resize the spectrogram along the time axis. + A value of 0.5 reduces the temporal dimension by half, while a + value of 2.0 doubles it. If None, no resizing is performed.""" class LogScaleConfig(BaseConfig): @@ -50,13 +56,13 @@ Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig] class SpectrogramConfig(BaseConfig): - fft: FFTConfig = Field(default_factory=FFTConfig) + stft: STFTConfig = Field(default_factory=STFTConfig) frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) scale: Scales = Field( default_factory=PcenScaleConfig, discriminator="name", ) - size: SpecSizeConfig = Field(default_factory=SpecSizeConfig) + size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) denoise: bool = True max_scale: bool = False @@ -68,22 +74,11 @@ def compute_spectrogram( ) -> xr.DataArray: config = config or SpectrogramConfig() - if config.size.divide_factor: - # Need to pad the audio to make sure the spectrogram has a - # width compatible with the divide factor - resize_factor = config.size.resize_factor or 1 - wav = pad_audio( - wav, - window_duration=config.fft.window_duration, - window_overlap=config.fft.window_overlap, - divide_factor=int(config.size.divide_factor / resize_factor), - ) - spec = stft( wav, - window_duration=config.fft.window_duration, - window_overlap=config.fft.window_overlap, - window_fn=config.fft.window_fn, + window_duration=config.stft.window_duration, + window_overlap=config.stft.window_overlap, + window_fn=config.stft.window_fn, dtype=dtype, ) @@ -98,11 +93,12 @@ def compute_spectrogram( if config.denoise: spec = denoise_spectrogram(spec) - spec = resize_spectrogram( - spec, - height=config.size.height, - resize_factor=config.size.resize_factor, - ) + if config.size: + spec = resize_spectrogram( + spec, + height=config.size.height, + resize_factor=config.size.resize_factor, + ) if config.max_scale: spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) @@ -257,7 +253,7 @@ def resize_spectrogram( return ops.resize( spec, time=int(resize_factor * current_width), - frequency=int(resize_factor * height), + frequency=height, dtype=np.float32, ) @@ -285,43 +281,6 @@ def adjust_spectrogram_width( return resized -def pad_audio( - wave: xr.DataArray, - window_duration: float, - window_overlap: float, - divide_factor: int = 32, -) -> xr.DataArray: - current_duration = arrays.get_dim_width(wave, dim="time") - step = arrays.get_dim_step(wave, dim="time") - samplerate = int(1 / step) - - estimated_spec_width = duration_to_spec_width( - current_duration, - samplerate=samplerate, - window_duration=window_duration, - window_overlap=window_overlap, - ) - - if estimated_spec_width % divide_factor == 0: - return wave - - target_spec_width = int( - np.ceil(estimated_spec_width / divide_factor) * divide_factor - ) - target_samples = spec_width_to_samples( - target_spec_width, - samplerate=samplerate, - window_duration=window_duration, - window_overlap=window_overlap, - ) - return ops.adjust_dim_width( - wave, - dim="time", - width=target_samples, - position="start", - ) - - def duration_to_spec_width( duration: float, samplerate: int, @@ -357,6 +316,8 @@ def get_spectrogram_resolution( spec_height = config.size.height resize_factor = config.size.resize_factor or 1 - freq_bin_width = (max_freq - min_freq) / (spec_height * resize_factor) - hop_duration = config.fft.window_duration * (1 - config.fft.window_overlap) + freq_bin_width = (max_freq - min_freq) / spec_height + hop_duration = config.stft.window_duration * ( + 1 - config.stft.window_overlap + ) return freq_bin_width, hop_duration / resize_factor diff --git a/batdetect2/preprocess/tensors.py b/batdetect2/preprocess/tensors.py new file mode 100644 index 0000000..8d2c77b --- /dev/null +++ b/batdetect2/preprocess/tensors.py @@ -0,0 +1,76 @@ +from typing import Union + +import numpy as np +import torch +from torch.nn import functional as F + + +def extend_width( + array: Union[np.ndarray, torch.Tensor], + extra: int, + axis: int = -1, + value: float = 0, +) -> torch.Tensor: + if not isinstance(array, torch.Tensor): + array = torch.Tensor(array) + + dims = len(array.shape) + axis = axis % dims + pad = [ + [0, 0] if index != axis else [0, extra] + for index in range(axis, dims)[::-1] + ] + return F.pad( + array, + [x for y in pad for x in y], + value=value, + ) + + +def make_width_divisible( + array: Union[np.ndarray, torch.Tensor], + factor: int, + axis: int = -1, + value: float = 0, +) -> torch.Tensor: + if not isinstance(array, torch.Tensor): + array = torch.Tensor(array) + + width = array.shape[axis] + + if width % factor == 0: + return array + + extra = (-width) % factor + return extend_width(array, extra, axis=axis, value=value) + + +def adjust_width( + array: Union[np.ndarray, torch.Tensor], + width: int, + axis: int = -1, + value: float = 0, +) -> torch.Tensor: + if not isinstance(array, torch.Tensor): + array = torch.Tensor(array) + + dims = len(array.shape) + axis = axis % dims + current_width = array.shape[axis] + + if current_width == width: + return array + + if current_width < width: + return extend_width( + array, + extra=width - current_width, + axis=axis, + value=value, + ) + + slices = [ + slice(None, None) if index != axis else slice(None, width) + for index in range(dims) + ] + return array[tuple(slices)] diff --git a/batdetect2/train/augmentations.py b/batdetect2/train/augmentations.py index deb8720..317e07f 100644 --- a/batdetect2/train/augmentations.py +++ b/batdetect2/train/augmentations.py @@ -4,10 +4,10 @@ import numpy as np import xarray as xr from pydantic import Field from soundevent import arrays -from soundevent.arrays import operations as ops from batdetect2.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram +from batdetect2.preprocess.arrays import adjust_width Augmentation = Callable[[xr.Dataset], xr.Dataset] @@ -17,47 +17,12 @@ class AugmentationConfig(BaseConfig): probability: float = 0.2 -class SubclipConfig(BaseConfig): - enable: bool = True - duration: Optional[float] = None - width: Optional[int] = 512 - - -def adjust_dataset_width( - example: xr.Dataset, - duration: Optional[float] = None, - width: Optional[int] = None, -) -> xr.Dataset: - step = arrays.get_dim_step(example, "time") # type: ignore - - if width is None: - if duration is None: - raise ValueError("Either duration or width must be provided") - - width = int(np.floor(duration / step)) - - adjusted_arrays = { - name: ops.adjust_dim_width(array, "time", width) - for name, array in example.items() - if name != "audio" - } - - ratio = width / example.spectrogram.sizes["time"] - audio_width = int(example.audio.sizes["audio_time"] * ratio) - adjusted_arrays["audio"] = ops.adjust_dim_width( - example["audio"], - "audio_time", - audio_width, - ) - - return xr.Dataset(data_vars=adjusted_arrays) - - -def select_random_subclip( +def select_subclip( example: xr.Dataset, start_time: Optional[float] = None, duration: Optional[float] = None, width: Optional[int] = None, + random: bool = False, ) -> xr.Dataset: """Select a random subclip from a clip.""" step = arrays.get_dim_step(example, "time") # type: ignore @@ -73,10 +38,13 @@ def select_random_subclip( duration = width * step if start_time is None: - start_time = np.random.uniform(start, max(stop - duration, start)) + if random: + start_time = np.random.uniform(start, max(stop - duration, start)) + else: + start_time = start if start_time + duration > stop: - example = adjust_dataset_width(example, width=width) + return example start_index = arrays.get_coord_index( example, # type: ignore @@ -91,7 +59,7 @@ def select_random_subclip( return example.sel( time=slice(start_time, end_time), - audio_time=slice(start_time, end_time), + audio_time=slice(start_time, end_time + step), ) @@ -115,40 +83,45 @@ def mix_examples( weight = np.random.uniform(min_weight, max_weight) audio1 = example["audio"] - - audio2 = ops.adjust_dim_width( - other["audio"], "audio_time", len(audio1) - ).values - - if len(audio2) > len(audio1): - audio2 = audio2[: len(audio1)] + audio2 = adjust_width(other["audio"].values, len(audio1)) combined = weight * audio1 + (1 - weight) * audio2 - spec = compute_spectrogram( + spectrogram = compute_spectrogram( combined.rename({"audio_time": "time"}), config=config.spectrogram, - ) + ).data + + # NOTE: The subclip's spectrogram might be slightly longer than the + # spectrogram computed from the subclip's audio. This is due to a + # simplification in the subclip process: It doesn't account for the + # spectrogram parameters to precisely determine the corresponding audio + # samples. To work around this, we pad the computed spectrogram with zeros + # as needed. + previous_width = len(example["time"]) + spectrogram = adjust_width(spectrogram, previous_width) detection_heatmap = xr.apply_ufunc( np.maximum, example["detection"], - other["detection"].values, + adjust_width(other["detection"].values, previous_width), ) class_heatmap = xr.apply_ufunc( np.maximum, example["class"], - other["class"].values, + adjust_width(other["class"].values, previous_width), ) - size_heatmap = example["size"] + other["size"].values + size_heatmap = example["size"] + adjust_width( + other["size"].values, previous_width + ) return xr.Dataset( { "audio": combined, "spectrogram": xr.DataArray( - data=spec.data, + data=spectrogram, dims=example["spectrogram"].dims, coords=example["spectrogram"].coords, ), @@ -192,12 +165,23 @@ def add_echo( spectrogram = compute_spectrogram( audio.rename({"audio_time": "time"}), config=config.spectrogram, + ).data + + # NOTE: The subclip's spectrogram might be slightly longer than the + # spectrogram computed from the subclip's audio. This is due to a + # simplification in the subclip process: It doesn't account for the + # spectrogram parameters to precisely determine the corresponding audio + # samples. To work around this, we pad the computed spectrogram with zeros + # as needed. + spectrogram = adjust_width( + spectrogram, + example["spectrogram"].sizes["time"], ) return example.assign( audio=audio, spectrogram=xr.DataArray( - data=spectrogram.data, + data=spectrogram, dims=example["spectrogram"].dims, coords=example["spectrogram"].coords, ), @@ -359,7 +343,6 @@ def mask_frequency( class AugmentationsConfig(BaseConfig): - subclip: SubclipConfig = Field(default_factory=SubclipConfig) mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig) echo: EchoAugmentationConfig = Field( default_factory=EchoAugmentationConfig @@ -391,23 +374,8 @@ def augment_example( preprocessing_config: Optional[PreprocessingConfig] = None, others: Optional[Callable[[], xr.Dataset]] = None, ) -> xr.Dataset: - if config.subclip.enable: - example = select_random_subclip( - example, - duration=config.subclip.duration, - width=config.subclip.width, - ) - if should_apply(config.mix) and (others is not None): other = others() - - if config.subclip.enable: - other = select_random_subclip( - other, - duration=config.subclip.duration, - width=config.subclip.width, - ) - example = mix_examples( example, other, diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index 3fdcf24..b444f42 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -5,10 +5,17 @@ from typing import NamedTuple, Optional, Sequence, Union import numpy as np import torch import xarray as xr +from pydantic import Field from soundevent import data from torch.utils.data import Dataset -from batdetect2.train.augmentations import AugmentationsConfig, augment_example +from batdetect2.configs import BaseConfig +from batdetect2.preprocess.tensors import adjust_width +from batdetect2.train.augmentations import ( + AugmentationsConfig, + augment_example, + select_subclip, +) from batdetect2.train.preprocess import PreprocessingConfig __all__ = [ @@ -28,20 +35,36 @@ class TrainExample(NamedTuple): idx: torch.Tensor +class SubclipConfig(BaseConfig): + duration: Optional[float] = None + width: int = 512 + random: bool = False + + +class DatasetConfig(BaseConfig): + subclip: SubclipConfig = Field(default_factory=SubclipConfig) + preprocessing: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + augmentation: AugmentationsConfig = Field( + default_factory=AugmentationsConfig + ) + + class LabeledDataset(Dataset): + config: DatasetConfig + def __init__( self, filenames: Sequence[PathLike], augment: bool = False, - preprocessing_config: Optional[PreprocessingConfig] = None, - augmentation_config: Optional[AugmentationsConfig] = None, + subclip: bool = False, + config: Optional[DatasetConfig] = None, ): self.filenames = filenames self.augment = augment - self.preprocessing_config = ( - preprocessing_config or PreprocessingConfig() - ) - self.agumentation_config = augmentation_config or AugmentationsConfig() + self.subclip = subclip + self.config = config or DatasetConfig() def __len__(self): return len(self.filenames) @@ -49,27 +72,27 @@ class LabeledDataset(Dataset): def __getitem__(self, idx) -> TrainExample: dataset = self.get_dataset(idx) + if self.subclip: + dataset = select_subclip( + dataset, + duration=self.config.subclip.duration, + width=self.config.subclip.width, + random=self.config.subclip.random, + ) + if self.augment: dataset = augment_example( dataset, - self.agumentation_config, - preprocessing_config=self.preprocessing_config, + self.config.augmentation, + preprocessing_config=self.config.preprocessing, others=self.get_random_example, ) return TrainExample( - spec=torch.tensor( - dataset["spectrogram"].values.astype(np.float32) - ).unsqueeze(0), - detection_heatmap=torch.tensor( - dataset["detection"].values.astype(np.float32) - ), - class_heatmap=torch.tensor( - dataset["class"].values.astype(np.float32) - ), - size_heatmap=torch.tensor( - dataset["size"].values.astype(np.float32) - ), + spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0), + detection_heatmap=self.to_tensor(dataset["detection"]), + class_heatmap=self.to_tensor(dataset["class"]), + size_heatmap=self.to_tensor(dataset["size"]), idx=torch.tensor(idx), ) @@ -78,20 +101,30 @@ class LabeledDataset(Dataset): cls, directory: PathLike, extension: str = ".nc", + config: Optional[DatasetConfig] = None, augment: bool = False, - preprocessing_config: Optional[PreprocessingConfig] = None, - augmentation_config: Optional[AugmentationsConfig] = None, + subclip: bool = False, ): return cls( get_files(directory, extension), + config=config, augment=augment, - preprocessing_config=preprocessing_config, - augmentation_config=augmentation_config, + subclip=subclip, ) def get_random_example(self) -> xr.Dataset: idx = np.random.randint(0, len(self)) - return self.get_dataset(idx) + dataset = self.get_dataset(idx) + + if self.subclip: + dataset = select_subclip( + dataset, + duration=self.config.subclip.duration, + width=self.config.subclip.width, + random=self.config.subclip.random, + ) + + return dataset def get_dataset(self, idx) -> xr.Dataset: return xr.open_dataset(self.filenames[idx]) @@ -101,6 +134,19 @@ class LabeledDataset(Dataset): self.get_dataset(idx).attrs["clip_annotation"] ) + def to_tensor( + self, + array: xr.DataArray, + dtype=np.float32, + ) -> torch.Tensor: + tensor = torch.tensor(array.values.astype(dtype)) + + if not self.subclip: + return tensor + + width = self.config.subclip.width + return adjust_width(tensor, width) + def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]: return list(Path(directory).glob(f"*{extension}")) diff --git a/batdetect2/train/evaluate.py b/batdetect2/train/evaluate.py index a926fbb..cf3d4d1 100755 --- a/batdetect2/train/evaluate.py +++ b/batdetect2/train/evaluate.py @@ -1,16 +1,56 @@ +from typing import List + import numpy as np -from sklearn.metrics import ( - accuracy_score, - auc, - balanced_accuracy_score, - roc_curve, -) +from sklearn.metrics import auc, roc_curve +from soundevent import data +from soundevent.evaluation import match_geometries + + +def match_predictions_and_annotations( + clip_annotation: data.ClipAnnotation, + clip_prediction: data.ClipPrediction, +) -> List[data.Match]: + annotated_sound_events = [ + sound_event_annotation + for sound_event_annotation in clip_annotation.sound_events + if sound_event_annotation.sound_event.geometry is not None + ] + + predicted_sound_events = [ + sound_event_prediction + for sound_event_prediction in clip_prediction.sound_events + if sound_event_prediction.sound_event.geometry is not None + ] + + annotated_geometries: List[data.Geometry] = [ + sound_event.sound_event.geometry + for sound_event in annotated_sound_events + if sound_event.sound_event.geometry is not None + ] + + predicted_geometries: List[data.Geometry] = [ + sound_event.sound_event.geometry + for sound_event in predicted_sound_events + if sound_event.sound_event.geometry is not None + ] + + matches = [] + for id1, id2, affinity in match_geometries( + annotated_geometries, + predicted_geometries, + ): + target = annotated_sound_events[id1] if id1 is not None else None + source = predicted_sound_events[id2] if id2 is not None else None + matches.append( + data.Match(source=source, target=target, affinity=affinity) + ) + + return matches def compute_error_auc(op_str, gt, pred, prob): - # classification error - pred_int = (pred > prob).astype(np.int) + pred_int = (pred > prob).astype(np.int32) class_acc = (pred_int == gt).mean() * 100.0 # ROC - area under curve @@ -25,7 +65,6 @@ def compute_error_auc(op_str, gt, pred, prob): def calc_average_precision(recall, precision): - precision[np.isnan(precision)] = 0 recall[np.isnan(recall)] = 0 @@ -91,7 +130,6 @@ def compute_pre_rec( pred_class = [] file_ids = [] for pid, pp in enumerate(preds): - # filter predicted calls that are too near the start or end of the file file_dur = gts[pid]["duration"] valid_inds = (pp["start_times"] >= ignore_start_end) & ( @@ -141,7 +179,6 @@ def compute_pre_rec( gt_generic_class = [] num_positives = 0 for gg in gts: - # filter ground truth calls that are too near the start or end of the file file_dur = gg["duration"] valid_inds = (gg["start_times"] >= ignore_start_end) & ( @@ -205,7 +242,6 @@ def compute_pre_rec( # valid detection that has not already been assigned if valid_det and (gt_assigned[gt_id][det_ind] == 0): - count_as_true_pos = True if eval_mode == "top_class" and ( gt_class[gt_id][det_ind] != pred_class[ind] diff --git a/batdetect2/train/modules.py b/batdetect2/train/modules.py index be3ef1d..ff44e84 100644 --- a/batdetect2/train/modules.py +++ b/batdetect2/train/modules.py @@ -3,7 +3,9 @@ from typing import Optional import pytorch_lightning as L import torch from pydantic import Field -from torch import optim +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.utils.data import DataLoader from batdetect2.configs import BaseConfig from batdetect2.models import ( @@ -13,11 +15,20 @@ from batdetect2.models import ( get_backbone, ) from batdetect2.models.typing import ModelOutput -from batdetect2.post_process import PostprocessConfig +from batdetect2.post_process import ( + PostprocessConfig, + postprocess_model_outputs, +) from batdetect2.preprocess import PreprocessingConfig -from batdetect2.train.dataset import TrainExample +from batdetect2.train.dataset import LabeledDataset, TrainExample +from batdetect2.train.evaluate import match_predictions_and_annotations from batdetect2.train.losses import LossConfig, compute_loss -from batdetect2.train.targets import TargetConfig +from batdetect2.train.targets import ( + TargetConfig, + build_decoder, + build_encoder, + get_class_names, +) class OptimizerConfig(BaseConfig): @@ -55,10 +66,8 @@ class DetectorModel(L.LightningModule): self.save_hyperparameters() size = self.config.preprocessing.spectrogram.size - self.backbone = get_backbone( - input_height=int(size.height * (size.resize_factor or 1)), - config=self.config.backbone, - ) + assert size is not None + self.backbone = get_backbone(self.config.backbone) self.classifier = ClassifierHead( num_classes=len(self.config.targets.classes), @@ -74,6 +83,13 @@ class DetectorModel(L.LightningModule): self.validation_predictions = [] + self.class_names = get_class_names(self.config.targets.classes) + self.encoder = build_encoder( + self.config.targets.classes, + replacement_rules=self.config.targets.replace, + ) + self.decoder = build_decoder(self.config.targets.classes) + def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore features = self.backbone(spec) detection_probs, classification_probs = self.classifier(features) @@ -117,14 +133,29 @@ class DetectorModel(L.LightningModule): self.log("val/loss/classification", losses.total, logger=True) dataloaders = self.trainer.val_dataloaders - print(dataloaders) + assert isinstance(dataloaders, DataLoader) + dataset = dataloaders.dataset + assert isinstance(dataset, LabeledDataset) + clip_annotation = dataset.get_clip_annotation(batch_idx) + clip_prediction = postprocess_model_outputs( + outputs, + clips=[clip_annotation.clip], + classes=self.class_names, + decoder=self.decoder, + config=self.config.postprocessing, + )[0] + + self.validation_predictions.extend( + match_predictions_and_annotations(clip_annotation, clip_prediction) + ) + + def on_validation_epoch_end(self) -> None: + print(len(self.validation_predictions)) + self.validation_predictions.clear() def configure_optimizers(self): conf = self.config.train.optimizer - optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate) - scheduler = optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=conf.t_max, - ) + optimizer = Adam(self.parameters(), lr=conf.learning_rate) + scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max) return [optimizer], [scheduler] diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 2817e6d..416d805 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -101,7 +101,11 @@ def generate_train_example( return dataset.assign_attrs( title=f"Training example for {clip_annotation.uuid}", config=config.model_dump_json(), - clip_annotation=clip_annotation.model_dump_json(), + clip_annotation=clip_annotation.model_dump_json( + exclude_none=True, + exclude_defaults=True, + exclude_unset=True, + ), ) diff --git a/tests/test_configs.py b/tests/test_configs.py deleted file mode 100644 index f0f84ef..0000000 --- a/tests/test_configs.py +++ /dev/null @@ -1,11 +0,0 @@ -from pathlib import Path - -from batdetect2.configs import load_config -from batdetect2.data import DatasetsConfig, load_datasets - - -def test_can_load_dataset_configs(): - root = Path(__file__).parent.parent - path = root / "conf.yaml" - config = load_config(path, schema=DatasetsConfig, field="datasets") - load_datasets(config) diff --git a/tests/test_data/test_batdetect.py b/tests/test_data/test_batdetect.py index 1ead0b8..e99505c 100644 --- a/tests/test_data/test_batdetect.py +++ b/tests/test_data/test_batdetect.py @@ -4,7 +4,7 @@ from pathlib import Path from soundevent import data -from batdetect2.data.compat import load_annotation_project +from batdetect2.compat.data import load_annotation_project_from_dir ROOT_DIR = Path(__file__).parent.parent.parent @@ -12,7 +12,7 @@ ROOT_DIR = Path(__file__).parent.parent.parent def test_load_example_annotation_project(): path = ROOT_DIR / "example_data" / "anns" audio_dir = ROOT_DIR / "example_data" / "audio" - project = load_annotation_project(path, audio_dir=audio_dir) + project = load_annotation_project_from_dir(path, audio_dir=audio_dir) assert isinstance(project, data.AnnotationProject) assert project.name == str(path) assert len(project.clip_annotations) == 3 diff --git a/tests/test_migration/test_preprocessing.py b/tests/test_migration/test_preprocessing.py index 70fd828..8480215 100644 --- a/tests/test_migration/test_preprocessing.py +++ b/tests/test_migration/test_preprocessing.py @@ -88,7 +88,7 @@ def test_spectrogram_generation_hasnt_changed( scale = preprocess.AmplitudeScaleConfig() config = preprocess.SpectrogramConfig( - fft=preprocess.FFTConfig( + stft=preprocess.STFTConfig( window_overlap=fft_overlap, window_duration=fft_win_length, ), diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models/test_inputs.py b/tests/test_models/test_inputs.py new file mode 100644 index 0000000..cac3548 --- /dev/null +++ b/tests/test_models/test_inputs.py @@ -0,0 +1,34 @@ +import torch +from hypothesis import given +from hypothesis import strategies as st + +from batdetect2.models import ModelConfig, ModelType, get_backbone + + +@given( + input_width=st.integers(min_value=50, max_value=1500), + input_height=st.integers(min_value=1, max_value=16), + model_type=st.sampled_from(ModelType), +) +def test_model_can_process_spectrograms_of_any_width( + input_width, + input_height, + model_type, +): + # Input height must be divisible by 8 + input_height = 8 * input_height + + input = torch.rand([1, 1, input_height, input_width]) + + model = get_backbone( + config=ModelConfig( + name=model_type, # type: ignore + input_height=input_height, + ), + ) + + output = model(input) + assert output.shape[0] == 1 + assert output.shape[1] == model.out_channels + assert output.shape[2] == input_height + assert output.shape[3] == input_width diff --git a/tests/test_preprocessing/test_arrays.py b/tests/test_preprocessing/test_arrays.py new file mode 100644 index 0000000..8ea13da --- /dev/null +++ b/tests/test_preprocessing/test_arrays.py @@ -0,0 +1,23 @@ +import numpy as np + +from batdetect2.preprocess.arrays import adjust_width, extend_width + + +def test_extend_width(): + array = np.random.random([1, 1, 128, 100]) + + extended = extend_width(array, 100) + + assert extended.shape == (1, 1, 128, 200) + + +def test_can_adjust_short_width(): + array = np.random.random([1, 1, 128, 100]) + extended = adjust_width(array, 512) + assert extended.shape == (1, 1, 128, 512) + + +def test_can_adjust_long_width(): + array = np.random.random([1, 1, 128, 512]) + extended = adjust_width(array, 256) + assert extended.shape == (1, 1, 128, 256) diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py index 5eb235f..2dc906c 100644 --- a/tests/test_preprocessing/test_spectrogram.py +++ b/tests/test_preprocessing/test_spectrogram.py @@ -8,14 +8,13 @@ from soundevent import arrays from batdetect2.preprocess.audio import AudioConfig, load_file_audio from batdetect2.preprocess.spectrogram import ( - FFTConfig, + STFTConfig, FrequencyConfig, SpecSizeConfig, SpectrogramConfig, compute_spectrogram, duration_to_spec_width, get_spectrogram_resolution, - pad_audio, spec_width_to_samples, stft, ) @@ -68,45 +67,6 @@ def test_can_estimate_correctly_spectrogram_width_from_duration( ) -@settings( - suppress_health_check=[HealthCheck.function_scoped_fixture], - deadline=400, -) -@given( - duration=st.floats(min_value=0.1, max_value=1.0), - window_duration=st.floats(min_value=0.001, max_value=0.01), - window_overlap=st.floats(min_value=0.2, max_value=0.9), - samplerate=st.integers(min_value=256_000, max_value=512_000), - divide_factor=st.integers(min_value=16, max_value=64), -) -def test_can_pad_audio_to_adjust_spectrogram_width( - duration: float, - window_duration: float, - window_overlap: float, - samplerate: int, - divide_factor: int, - wav_factory: Callable[..., Path], -): - path = wav_factory(duration=duration, samplerate=samplerate) - - audio = load_file_audio( - path, - # NOTE: Dont resample nor adjust duration to test if the width - # estimation works on all scenarios - config=AudioConfig(resample=None, duration=None), - ) - - audio = pad_audio( - audio, - window_duration=window_duration, - window_overlap=window_overlap, - divide_factor=divide_factor, - ) - - spectrogram = stft(audio, window_duration, window_overlap) - assert spectrogram.sizes["time"] % divide_factor == 0 - - def test_can_estimate_spectrogram_resolution( wav_factory: Callable[..., Path], ): @@ -120,7 +80,7 @@ def test_can_estimate_spectrogram_resolution( ) config = SpectrogramConfig( - fft=FFTConfig(), + stft=STFTConfig(), size=SpecSizeConfig(height=256, resize_factor=0.5), frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000), ) diff --git a/tests/test_preprocessing/test_tensors.py b/tests/test_preprocessing/test_tensors.py new file mode 100644 index 0000000..2115e50 --- /dev/null +++ b/tests/test_preprocessing/test_tensors.py @@ -0,0 +1,42 @@ +import numpy as np +import torch + +from batdetect2.preprocess.tensors import adjust_width, make_width_divisible + + +def test_width_is_divisible_after_adjustment(): + tensor = torch.rand([1, 1, 128, 374]) + adjusted = make_width_divisible(tensor, 32) + assert adjusted.shape[-1] % 32 == 0 + assert adjusted.shape == (1, 1, 128, 384) + + +def test_non_last_axis_is_divisible_after_adjustment(): + tensor = torch.rand([1, 1, 77, 124]) + adjusted = make_width_divisible(tensor, 32, axis=-2) + assert adjusted.shape[-2] % 32 == 0 + assert adjusted.shape == (1, 1, 96, 124) + + +def test_make_width_divisible_can_handle_numpy_array(): + array = np.random.random([1, 1, 128, 374]) + adjusted = make_width_divisible(array, 32) + assert adjusted.shape[-1] % 32 == 0 + assert adjusted.shape == (1, 1, 128, 384) + assert isinstance(adjusted, torch.Tensor) + + +def test_adjust_last_axis_width_by_default(): + tensor = torch.rand([1, 1, 128, 374]) + adjusted = adjust_width(tensor, 512) + assert adjusted.shape == (1, 1, 128, 512) + assert (tensor == adjusted[:, :, :, :374]).all() + assert (adjusted[:, :, :, 374:] == 0).all() + + +def test_can_adjust_second_to_last_axis(): + tensor = torch.rand([1, 1, 89, 512]) + adjusted = adjust_width(tensor, 128, axis=-2) + assert adjusted.shape == (1, 1, 128, 512) + assert (tensor == adjusted[:, :, :89, :]).all() + assert (adjusted[:, :, 89:, :] == 0).all() diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index cc92621..6430d26 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -1,14 +1,14 @@ from collections.abc import Callable import numpy as np +import pytest import xarray as xr -from soundevent import data +from soundevent import arrays, data from batdetect2.train.augmentations import ( add_echo, - adjust_dataset_width, mix_examples, - select_random_subclip, + select_subclip, ) from batdetect2.train.preprocess import ( TrainPreprocessingConfig, @@ -23,11 +23,9 @@ def test_mix_examples( recording2 = recording_factory() clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) - clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8) clip_annotation_1 = data.ClipAnnotation(clip=clip1) - clip_annotation_2 = data.ClipAnnotation(clip=clip2) config = TrainPreprocessingConfig() @@ -43,6 +41,36 @@ def test_mix_examples( assert mixed["class"].shape == example1["class"].shape +@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7]) +@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7]) +def test_mix_examples_of_different_durations( + recording_factory: Callable[..., data.Recording], + duration1: float, + duration2: float, +): + recording1 = recording_factory() + recording2 = recording_factory() + + clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1) + clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2) + + clip_annotation_1 = data.ClipAnnotation(clip=clip1) + clip_annotation_2 = data.ClipAnnotation(clip=clip2) + + config = TrainPreprocessingConfig() + + example1 = generate_train_example(clip_annotation_1, config) + example2 = generate_train_example(clip_annotation_2, config) + + mixed = mix_examples(example1, example2, config=config.preprocessing) + + # Check the spectrogram has the expected duration + step = arrays.get_dim_step(mixed["spectrogram"], "time") + start, stop = arrays.get_dim_range(mixed["spectrogram"], "time") + assert start == 0 + assert np.isclose(stop + step, duration1, atol=2 * step) + + def test_add_echo( recording_factory: Callable[..., data.Recording], ): @@ -67,70 +95,23 @@ def test_selected_random_subclip_has_the_correct_width( clip_annotation_1 = data.ClipAnnotation(clip=clip1) config = TrainPreprocessingConfig() original = generate_train_example(clip_annotation_1, config) - subclip = select_random_subclip(original, width=100) + subclip = select_subclip(original, width=100) assert subclip["spectrogram"].shape[1] == 100 -def test_adjust_dataset_width(): - height = 128 - width = 512 - samplerate = 48_000 +def test_add_echo_after_subclip( + recording_factory: Callable[..., data.Recording], +): + recording1 = recording_factory(duration=2) + clip1 = data.Clip(recording=recording1, start_time=0, end_time=1) + clip_annotation_1 = data.ClipAnnotation(clip=clip1) + config = TrainPreprocessingConfig() + original = generate_train_example(clip_annotation_1, config) - times = np.linspace(0, 1, width) + assert original.sizes["time"] > 512 - audio_times = np.linspace(0, 1, samplerate) - frequency = np.linspace(0, 24_000, height) + subclip = select_subclip(original, width=512) + with_echo = add_echo(subclip) - width_subset = 356 - audio_width_subset = int(samplerate * width_subset / width) - - times_subset = times[:width_subset] - audio_times_subset = audio_times[:audio_width_subset] - dimensions = ["width", "height"] - class_names = [f"species_{i}" for i in range(17)] - - spectrogram = np.random.random([height, width_subset]) - sizes = np.random.random([len(dimensions), height, width_subset]) - classes = np.random.random([len(class_names), height, width_subset]) - audio = np.random.random([int(samplerate * width_subset / width)]) - - dataset = xr.Dataset( - data_vars={ - "audio": (("audio_time",), audio), - "spectrogram": (("frequency", "time"), spectrogram), - "sizes": (("dimension", "frequency", "time"), sizes), - "classes": (("class", "frequency", "time"), classes), - }, - coords={ - "audio_time": audio_times_subset, - "time": times_subset, - "frequency": frequency, - "dimension": dimensions, - "class": class_names, - }, - ) - - adjusted = adjust_dataset_width(dataset, width=width) - - # Spectrogram was adjusted correctly - assert np.isclose(adjusted["spectrogram"].time, times).all() - assert (adjusted["spectrogram"].frequency == frequency).all() - - # Sizes was adjusted correctly - assert np.isclose(adjusted["sizes"].time, times).all() - assert (adjusted["sizes"].frequency == frequency).all() - assert list(adjusted["sizes"].dimension.values) == dimensions - - # Sizes was adjusted correctly - assert np.isclose(adjusted["classes"].time, times).all() - assert (adjusted["sizes"].frequency == frequency).all() - assert list(adjusted["classes"]["class"].values) == class_names - - # Audio time was adjusted corretly - assert np.isclose( - len(adjusted["audio"].audio_time), len(audio_times), atol=2 - ) - assert np.isclose( - adjusted["audio"].audio_time[-1], audio_times[-1], atol=1e-3 - ) + assert with_echo.sizes["time"] == 512 diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 787027d..e86aae7 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -3,7 +3,6 @@ from pathlib import Path import numpy as np import xarray as xr from soundevent import data -from soundevent.types import ClassMapper from batdetect2.train.labels import generate_heatmaps @@ -23,16 +22,6 @@ clip = data.Clip( ) -class Mapper(ClassMapper): - class_labels = ["bat", "cat"] - - def encode(self, sound_event_annotation: data.SoundEventAnnotation) -> str: - return "bat" - - def decode(self, label: str) -> list: - return [data.Tag(term=data.term_from_key("species"), value="bat")] - - def test_generated_heatmaps_have_correct_dimensions(): spec = xr.DataArray( data=np.random.rand(100, 100), @@ -57,12 +46,11 @@ def test_generated_heatmaps_have_correct_dimensions(): ], ) - class_mapper = Mapper() - detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( clip_annotation.sound_events, spec, - class_mapper, + class_names=["bat", "cat"], + encoder=lambda _: "bat", ) assert isinstance(detection_heatmap, xr.DataArray) @@ -107,11 +95,13 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(): ], ) - class_mapper = Mapper() detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( clip_annotation.sound_events, spec, - class_mapper, + class_names=["bat", "cat"], + encoder=lambda _: "bat", + time_scale=1, + frequency_scale=1, ) assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10 assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10