This commit is contained in:
mbsantiago 2025-01-28 19:35:57 +00:00
parent f7d6516550
commit 48e009fa9d
26 changed files with 967 additions and 658 deletions

View File

@ -1,7 +1,7 @@
from batdetect2.preprocess import (
AmplitudeScaleConfig,
AudioConfig,
FFTConfig,
STFTConfig,
FrequencyConfig,
LogScaleConfig,
PcenScaleConfig,
@ -40,7 +40,7 @@ def get_preprocessing_config(params: dict) -> PreprocessingConfig:
duration=None,
),
spectrogram=SpectrogramConfig(
fft=FFTConfig(
stft=STFTConfig(
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
window_fn="hann",

View File

@ -1,10 +1,12 @@
from enum import Enum
from typing import Optional, Tuple
from batdetect2.configs import BaseConfig
from batdetect2.models.backbones import (
Net2DFast,
Net2DFastNoAttn,
Net2DFastNoCoordConv,
Net2DPlain,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.typing import BackboneModel
@ -24,31 +26,57 @@ class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
class ModelConfig(BaseConfig):
name: ModelType = ModelType.Net2DFast
num_features: int = 32
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
out_channels: int = 32
def get_backbone(
config: ModelConfig,
input_height: int = 128,
config: Optional[ModelConfig] = None,
) -> BackboneModel:
config = config or ModelConfig()
if config.name == ModelType.Net2DFast:
return Net2DFast(
input_height=input_height,
num_features=config.num_features,
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
elif config.name == ModelType.Net2DFastNoAttn:
if config.name == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
num_features=config.num_features,
input_height=input_height,
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
elif config.name == ModelType.Net2DFastNoCoordConv:
if config.name == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
num_features=config.num_features,
input_height=input_height,
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
else:
raise ValueError(f"Unknown model type: {config.name}")
if config.name == ModelType.Net2DPlain:
return Net2DPlain(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
raise ValueError(f"Unknown model type: {config.name}")

View File

@ -1,16 +1,16 @@
from typing import Tuple
from typing import Sequence, Tuple
import torch
import torch.fft
import torch.nn.functional as F
from torch import nn
from batdetect2.models.blocks import (
ConvBlockDownCoordF,
ConvBlockDownStandard,
ConvBlockUpF,
ConvBlockUpStandard,
ConvBlock,
Decoder,
DownscalingLayer,
Encoder,
SelfAttention,
UpscalingLayer,
VerticalConv,
)
from batdetect2.models.typing import BackboneModel
@ -21,303 +21,164 @@ __all__ = [
]
class Net2DFast(BackboneModel):
class Net2DPlain(BackboneModel):
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
def __init__(
self,
num_features: int,
input_height: int = 128,
encoder_channels: Sequence[int] = (1, 32, 64, 128),
bottleneck_channels: int = 256,
decoder_channels: Sequence[int] = (256, 64, 32, 32),
out_channels: int = 32,
):
super().__init__()
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
# encoder
self.conv_dn_0 = ConvBlockDownCoordF(
1,
self.num_features // 4,
self.input_height,
kernel_size=3,
pad_size=1,
stride=1,
self.input_height = input_height
self.encoder_channels = tuple(encoder_channels)
self.decoder_channels = tuple(decoder_channels)
self.out_channels = out_channels
if len(encoder_channels) != len(decoder_channels):
raise ValueError(
f"Mismatched encoder and decoder channel lists. "
f"The encoder has {len(encoder_channels)} channels "
f"(implying {len(encoder_channels) - 1} layers), "
f"while the decoder has {len(decoder_channels)} channels "
f"(implying {len(decoder_channels) - 1} layers). "
f"These lengths must be equal."
)
self.divide_factor = 2 ** (len(encoder_channels) - 1)
if self.input_height % self.divide_factor != 0:
raise ValueError(
f"Input height ({self.input_height}) must be divisible by "
f"the divide factor ({self.divide_factor}). "
f"This ensures proper upscaling after downscaling to recover "
f"the original input height."
)
self.encoder = Encoder(
channels=encoder_channels,
input_height=self.input_height,
layer_type=self.downscaling_layer_type,
)
self.conv_dn_1 = ConvBlockDownCoordF(
self.num_features // 4,
self.num_features // 2,
self.input_height // 2,
kernel_size=3,
pad_size=1,
stride=1,
self.conv_same_1 = ConvBlock(
in_channels=encoder_channels[-1],
out_channels=bottleneck_channels,
)
self.conv_dn_2 = ConvBlockDownCoordF(
self.num_features // 2,
self.num_features,
self.input_height // 4,
kernel_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
# bottleneck
self.conv_1d = nn.Conv2d(
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
# decoder
self.conv_up_2 = ConvBlockUpF(
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpF(
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpF(
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
self.conv_vert = VerticalConv(
in_channels=bottleneck_channels,
out_channels=bottleneck_channels,
input_height=self.input_height // (2**self.encoder.depth),
)
self.conv_op = nn.Conv2d(
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
self.decoder = Decoder(
channels=decoder_channels,
input_height=self.input_height,
layer_type=self.upscaling_layer_type,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
self.out_channels = self.num_features // 4
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
h, w = spec.shape[2:]
h_pad = (32 - h % 32) % 32
w_pad = (32 - w % 32) % 32
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
self.conv_same_2 = ConvBlock(
in_channels=decoder_channels[-1],
out_channels=out_channels,
)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
# encoder
spec, h_pad, w_pad = self.pad_adjust(spec)
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
# encoder
residuals = self.encoder(spec)
residuals[-1] = self.conv_same_1(residuals[-1])
# bottleneck
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x)
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
x = self.conv_vert(residuals[-1])
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
# decoder
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
x = self.decoder(x, residuals=residuals)
# Restore original size
if h_pad > 0:
x = x[:, :, :-h_pad, :]
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
if w_pad > 0:
x = x[:, :, :, :-w_pad]
return F.relu_(self.conv_op_bn(self.conv_op(x)))
return self.conv_same_2(x)
class Net2DFastNoAttn(BackboneModel):
class Net2DFast(Net2DPlain):
downscaling_layer_type = "ConvBlockDownCoordF"
upscaling_layer_type = "ConvBlockUpF"
def __init__(
self,
num_features: int,
input_height: int = 128,
encoder_channels: Sequence[int] = (1, 32, 64, 128),
bottleneck_channels: int = 256,
decoder_channels: Sequence[int] = (256, 64, 32, 32),
out_channels: int = 32,
):
super().__init__()
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
self.conv_dn_0 = ConvBlockDownCoordF(
1,
self.num_features // 4,
self.input_height,
kernel_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
self.num_features // 4,
self.num_features // 2,
self.input_height // 2,
kernel_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
self.num_features // 2,
self.num_features,
self.input_height // 4,
kernel_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_1d = nn.Conv2d(
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_up_2 = ConvBlockUpF(
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpF(
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpF(
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
super().__init__(
input_height=input_height,
encoder_channels=encoder_channels,
bottleneck_channels=bottleneck_channels,
decoder_channels=decoder_channels,
out_channels=out_channels,
)
self.conv_op = nn.Conv2d(
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
self.out_channels = self.num_features // 4
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
# encoder
residuals = self.encoder(spec)
residuals[-1] = self.conv_same_1(residuals[-1])
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoCoordConv(BackboneModel):
def __init__(
self,
num_features: int,
input_height: int = 128,
):
super().__init__()
self.num_features = num_features
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
self.conv_dn_0 = ConvBlockDownStandard(
1,
self.num_features // 4,
kernel_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownStandard(
self.num_features // 4,
self.num_features // 2,
kernel_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownStandard(
self.num_features // 2,
self.num_features,
kernel_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_features,
self.num_features * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
self.conv_1d = nn.Conv2d(
self.num_features * 2,
self.num_features * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
self.conv_up_2 = ConvBlockUpStandard(
self.num_features * 2,
self.num_features // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpStandard(
self.num_features // 2,
self.num_features // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpStandard(
self.num_features // 4,
self.num_features // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_features // 4,
self.num_features // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
self.out_channels = self.num_features // 4
def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
# bottleneck
x = self.conv_vert(residuals[-1])
x = self.att(x)
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
# decoder
x = self.decoder(x, residuals=residuals)
return F.relu_(self.conv_op_bn(self.conv_op(x)))
# Restore original size
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
return self.conv_same_2(x)
class Net2DFastNoAttn(Net2DPlain):
downscaling_layer_type = "ConvBlockDownCoordF"
upscaling_layer_type = "ConvBlockUpF"
class Net2DFastNoCoordConv(Net2DFast):
downscaling_layer_type = "ConvBlockDownStandard"
upscaling_layer_type = "ConvBlockUpStandard"
def pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
h, w = spec.shape[2:]
h_pad = -h % factor
w_pad = -w % factor
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
def restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor:
# Restore original size
if h_pad > 0:
x = x[:, :, :-h_pad, :]
if w_pad > 0:
x = x[:, :, :, :-w_pad]
return x

View File

@ -4,29 +4,35 @@ All these classes are subclasses of `torch.nn.Module` and can be used to build
complex neural network architectures.
"""
from typing import Tuple
import sys
from typing import Iterable, List, Literal, Sequence, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from batdetect2.configs import BaseConfig
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y
__all__ = [
"SelfAttention",
"ConvBlock",
"ConvBlockDownCoordF",
"ConvBlockDownStandard",
"ConvBlockUpF",
"ConvBlockUpStandard",
"SelfAttention",
"VerticalConv",
"DownscalingLayer",
"UpscalingLayer",
]
class SelfAttentionConfig(BaseConfig):
temperature: float = 1.0
input_channels: int = 128
attention_channels: int = 128
class SelfAttention(nn.Module):
"""Self-Attention module.
@ -76,13 +82,64 @@ class SelfAttention(nn.Module):
return op
class ConvBlockDownCoordFConfig(BaseConfig):
in_channels: int
out_channels: int
input_height: int
kernel_size: int = 3
pad_size: int = 1
stride: int = 1
class ConvBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
pad_size: int = 1,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=pad_size,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu_(self.conv_bn(self.conv(x)))
class VerticalConv(nn.Module):
"""Convolutional layer over full height.
This layer applies a convolution that captures information across the
entire height of the input image. It uses a kernel with the same height as
the input, effectively condensing the vertical information into a single
output row.
More specifically:
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
input channels, H is the image height, and W is the image width.
* **Kernel:** (C', H, 1) where C' is the number of output channels.
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
convolution integrates information from all rows of the input.
This process effectively extracts features that span the full height of
the input image.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
input_height: int,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(input_height, 1),
padding=0,
)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.relu_(self.bn(self.conv(x)))
class ConvBlockDownCoordF(nn.Module):
@ -124,14 +181,6 @@ class ConvBlockDownCoordF(nn.Module):
return x
class ConvBlockDownStandardConfig(BaseConfig):
in_channels: int
out_channels: int
kernel_size: int = 3
pad_size: int = 1
stride: int = 1
class ConvBlockDownStandard(nn.Module):
"""Convolutional Block with Downsampling.
@ -158,18 +207,10 @@ class ConvBlockDownStandard(nn.Module):
def forward(self, x):
x = F.max_pool2d(self.conv(x), 2, 2)
x = F.relu(self.conv_bn(x), inplace=True)
return x
return F.relu(self.conv_bn(x), inplace=True)
class ConvBlockUpFConfig(BaseConfig):
inp_channels: int
out_channels: int
input_height: int
kernel_size: int = 3
pad_size: int = 1
up_mode: str = "bilinear"
up_scale: Tuple[int, int] = (2, 2)
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
class ConvBlockUpF(nn.Module):
@ -224,15 +265,6 @@ class ConvBlockUpF(nn.Module):
return op
class ConvBlockUpStandardConfig(BaseConfig):
in_channels: int
out_channels: int
kernel_size: int = 3
pad_size: int = 1
up_mode: str = "bilinear"
up_scale: Tuple[int, int] = (2, 2)
class ConvBlockUpStandard(nn.Module):
"""Convolutional Block with Upsampling.
@ -272,3 +304,143 @@ class ConvBlockUpStandard(nn.Module):
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
return op
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
def build_downscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: DownscalingLayer,
) -> nn.Module:
if layer_type == "ConvBlockDownStandard":
return ConvBlockDownStandard(
in_channels=in_channels,
out_channels=out_channels,
)
if layer_type == "ConvBlockDownCoordF":
return ConvBlockDownCoordF(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
raise ValueError(
f"Invalid downscaling layer type {layer_type}. "
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
)
class Encoder(nn.Module):
def __init__(
self,
channels: Sequence[int] = (1, 32, 62, 128),
input_height: int = 128,
layer_type: Literal[
"ConvBlockDownStandard", "ConvBlockDownCoordF"
] = "ConvBlockDownStandard",
):
super().__init__()
self.channels = channels
self.input_height = input_height
self.layers = nn.ModuleList(
[
build_downscaling_layer(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height // (2**layer_num),
layer_type=layer_type,
)
for layer_num, (in_channels, out_channels) in enumerate(
pairwise(channels)
)
]
)
self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outputs = []
for layer in self.layers:
x = layer(x)
outputs.append(x)
return outputs
def build_upscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: UpscalingLayer,
) -> nn.Module:
if layer_type == "ConvBlockUpStandard":
return ConvBlockUpStandard(
in_channels=in_channels,
out_channels=out_channels,
)
if layer_type == "ConvBlockUpF":
return ConvBlockUpF(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
raise ValueError(
f"Invalid upscaling layer type {layer_type}. "
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
)
class Decoder(nn.Module):
def __init__(
self,
channels: Sequence[int] = (256, 62, 32, 32),
input_height: int = 128,
layer_type: Literal[
"ConvBlockUpStandard", "ConvBlockUpF"
] = "ConvBlockUpStandard",
):
super().__init__()
self.channels = channels
self.input_height = input_height
self.depth = len(self.channels) - 1
self.layers = nn.ModuleList(
[
build_upscaling_layer(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height
// (2 ** (self.depth - layer_num)),
layer_type=layer_type,
)
for layer_num, (in_channels, out_channels) in enumerate(
pairwise(channels)
)
]
)
def forward(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
if len(residuals) != len(self.layers):
raise ValueError(
f"Incorrect number of residuals provided. "
f"Expected {len(self.layers)} (matching the number of layers), "
f"but got {len(residuals)}."
)
for layer, res in zip(self.layers, residuals[::-1]):
x = layer(x + res)
return x

View File

@ -0,0 +1,17 @@
import sys
from typing import Iterable, List, Literal, Sequence
import torch
from torch import nn
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y

View File

@ -0,0 +1,17 @@
import sys
from typing import Iterable, List, Literal, Sequence
import torch
from torch import nn
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import NamedTuple
from typing import NamedTuple, Tuple
import torch
import torch.nn as nn
@ -45,15 +45,27 @@ class BackboneModel(ABC, nn.Module):
input_height: int
"""Height of the input spectrogram."""
num_features: int
"""Dimension of the feature tensor."""
encoder_channels: Tuple[int, ...]
"""Tuple specifying the number of channels for each convolutional layer
in the encoder. The length of this tuple determines the number of
encoder layers."""
decoder_channels: Tuple[int, ...]
"""Tuple specifying the number of channels for each convolutional layer
in the decoder. The length of this tuple determines the number of
decoder layers."""
bottleneck_channels: int
"""Number of channels in the bottleneck layer, which connects the
encoder and decoder."""
out_channels: int
"""Number of output channels of the feature extractor."""
"""Number of channels in the final output feature map produced by the
backbone model."""
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Forward pass of the encoder model."""
"""Forward pass of the model."""
class DetectionModel(ABC, nn.Module):

View File

@ -13,7 +13,7 @@ from batdetect2.preprocess.audio import (
)
from batdetect2.preprocess.spectrogram import (
AmplitudeScaleConfig,
FFTConfig,
STFTConfig,
FrequencyConfig,
LogScaleConfig,
PcenScaleConfig,
@ -27,7 +27,7 @@ __all__ = [
"AudioConfig",
"ResampleConfig",
"SpectrogramConfig",
"FFTConfig",
"STFTConfig",
"FrequencyConfig",
"PcenScaleConfig",
"LogScaleConfig",

View File

@ -0,0 +1,61 @@
import numpy as np
def extend_width(
array: np.ndarray,
extra: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
dims = len(array.shape)
axis = axis % dims
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
return np.pad(
array,
pad,
mode="constant",
constant_values=value,
)
def make_width_divisible(
array: np.ndarray,
factor: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
width = array.shape[axis]
if width % factor == 0:
return array
extra = (-width) % factor
return extend_width(array, extra, axis=axis, value=value)
def adjust_width(
array: np.ndarray,
width: int,
axis: int = -1,
value: float = 0,
) -> np.ndarray:
dims = len(array.shape)
axis = axis % dims
current_width = array.shape[axis]
if current_width == width:
return array
if current_width < width:
return extend_width(
array,
extra=width - current_width,
axis=axis,
value=value,
)
slices = [
slice(None, None) if index != axis else slice(None, width)
for index in range(dims)
]
return array[tuple(slices)]

View File

@ -12,7 +12,7 @@ from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
class FFTConfig(BaseConfig):
class STFTConfig(BaseConfig):
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
@ -24,9 +24,15 @@ class FrequencyConfig(BaseConfig):
class SpecSizeConfig(BaseConfig):
height: int = 256
height: int = 128
"""Height of the spectrogram in pixels. This value determines the
number of frequency bands and corresponds to the vertical dimension
of the spectrogram."""
resize_factor: Optional[float] = 0.5
divide_factor: Optional[int] = 32
"""Factor by which to resize the spectrogram along the time axis.
A value of 0.5 reduces the temporal dimension by half, while a
value of 2.0 doubles it. If None, no resizing is performed."""
class LogScaleConfig(BaseConfig):
@ -50,13 +56,13 @@ Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
class SpectrogramConfig(BaseConfig):
fft: FFTConfig = Field(default_factory=FFTConfig)
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
scale: Scales = Field(
default_factory=PcenScaleConfig,
discriminator="name",
)
size: SpecSizeConfig = Field(default_factory=SpecSizeConfig)
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
denoise: bool = True
max_scale: bool = False
@ -68,22 +74,11 @@ def compute_spectrogram(
) -> xr.DataArray:
config = config or SpectrogramConfig()
if config.size.divide_factor:
# Need to pad the audio to make sure the spectrogram has a
# width compatible with the divide factor
resize_factor = config.size.resize_factor or 1
wav = pad_audio(
wav,
window_duration=config.fft.window_duration,
window_overlap=config.fft.window_overlap,
divide_factor=int(config.size.divide_factor / resize_factor),
)
spec = stft(
wav,
window_duration=config.fft.window_duration,
window_overlap=config.fft.window_overlap,
window_fn=config.fft.window_fn,
window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn,
dtype=dtype,
)
@ -98,11 +93,12 @@ def compute_spectrogram(
if config.denoise:
spec = denoise_spectrogram(spec)
spec = resize_spectrogram(
spec,
height=config.size.height,
resize_factor=config.size.resize_factor,
)
if config.size:
spec = resize_spectrogram(
spec,
height=config.size.height,
resize_factor=config.size.resize_factor,
)
if config.max_scale:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
@ -257,7 +253,7 @@ def resize_spectrogram(
return ops.resize(
spec,
time=int(resize_factor * current_width),
frequency=int(resize_factor * height),
frequency=height,
dtype=np.float32,
)
@ -285,43 +281,6 @@ def adjust_spectrogram_width(
return resized
def pad_audio(
wave: xr.DataArray,
window_duration: float,
window_overlap: float,
divide_factor: int = 32,
) -> xr.DataArray:
current_duration = arrays.get_dim_width(wave, dim="time")
step = arrays.get_dim_step(wave, dim="time")
samplerate = int(1 / step)
estimated_spec_width = duration_to_spec_width(
current_duration,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
)
if estimated_spec_width % divide_factor == 0:
return wave
target_spec_width = int(
np.ceil(estimated_spec_width / divide_factor) * divide_factor
)
target_samples = spec_width_to_samples(
target_spec_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
)
return ops.adjust_dim_width(
wave,
dim="time",
width=target_samples,
position="start",
)
def duration_to_spec_width(
duration: float,
samplerate: int,
@ -357,6 +316,8 @@ def get_spectrogram_resolution(
spec_height = config.size.height
resize_factor = config.size.resize_factor or 1
freq_bin_width = (max_freq - min_freq) / (spec_height * resize_factor)
hop_duration = config.fft.window_duration * (1 - config.fft.window_overlap)
freq_bin_width = (max_freq - min_freq) / spec_height
hop_duration = config.stft.window_duration * (
1 - config.stft.window_overlap
)
return freq_bin_width, hop_duration / resize_factor

View File

@ -0,0 +1,76 @@
from typing import Union
import numpy as np
import torch
from torch.nn import functional as F
def extend_width(
array: Union[np.ndarray, torch.Tensor],
extra: int,
axis: int = -1,
value: float = 0,
) -> torch.Tensor:
if not isinstance(array, torch.Tensor):
array = torch.Tensor(array)
dims = len(array.shape)
axis = axis % dims
pad = [
[0, 0] if index != axis else [0, extra]
for index in range(axis, dims)[::-1]
]
return F.pad(
array,
[x for y in pad for x in y],
value=value,
)
def make_width_divisible(
array: Union[np.ndarray, torch.Tensor],
factor: int,
axis: int = -1,
value: float = 0,
) -> torch.Tensor:
if not isinstance(array, torch.Tensor):
array = torch.Tensor(array)
width = array.shape[axis]
if width % factor == 0:
return array
extra = (-width) % factor
return extend_width(array, extra, axis=axis, value=value)
def adjust_width(
array: Union[np.ndarray, torch.Tensor],
width: int,
axis: int = -1,
value: float = 0,
) -> torch.Tensor:
if not isinstance(array, torch.Tensor):
array = torch.Tensor(array)
dims = len(array.shape)
axis = axis % dims
current_width = array.shape[axis]
if current_width == width:
return array
if current_width < width:
return extend_width(
array,
extra=width - current_width,
axis=axis,
value=value,
)
slices = [
slice(None, None) if index != axis else slice(None, width)
for index in range(dims)
]
return array[tuple(slices)]

View File

@ -4,10 +4,10 @@ import numpy as np
import xarray as xr
from pydantic import Field
from soundevent import arrays
from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
from batdetect2.preprocess.arrays import adjust_width
Augmentation = Callable[[xr.Dataset], xr.Dataset]
@ -17,47 +17,12 @@ class AugmentationConfig(BaseConfig):
probability: float = 0.2
class SubclipConfig(BaseConfig):
enable: bool = True
duration: Optional[float] = None
width: Optional[int] = 512
def adjust_dataset_width(
example: xr.Dataset,
duration: Optional[float] = None,
width: Optional[int] = None,
) -> xr.Dataset:
step = arrays.get_dim_step(example, "time") # type: ignore
if width is None:
if duration is None:
raise ValueError("Either duration or width must be provided")
width = int(np.floor(duration / step))
adjusted_arrays = {
name: ops.adjust_dim_width(array, "time", width)
for name, array in example.items()
if name != "audio"
}
ratio = width / example.spectrogram.sizes["time"]
audio_width = int(example.audio.sizes["audio_time"] * ratio)
adjusted_arrays["audio"] = ops.adjust_dim_width(
example["audio"],
"audio_time",
audio_width,
)
return xr.Dataset(data_vars=adjusted_arrays)
def select_random_subclip(
def select_subclip(
example: xr.Dataset,
start_time: Optional[float] = None,
duration: Optional[float] = None,
width: Optional[int] = None,
random: bool = False,
) -> xr.Dataset:
"""Select a random subclip from a clip."""
step = arrays.get_dim_step(example, "time") # type: ignore
@ -73,10 +38,13 @@ def select_random_subclip(
duration = width * step
if start_time is None:
start_time = np.random.uniform(start, max(stop - duration, start))
if random:
start_time = np.random.uniform(start, max(stop - duration, start))
else:
start_time = start
if start_time + duration > stop:
example = adjust_dataset_width(example, width=width)
return example
start_index = arrays.get_coord_index(
example, # type: ignore
@ -91,7 +59,7 @@ def select_random_subclip(
return example.sel(
time=slice(start_time, end_time),
audio_time=slice(start_time, end_time),
audio_time=slice(start_time, end_time + step),
)
@ -115,40 +83,45 @@ def mix_examples(
weight = np.random.uniform(min_weight, max_weight)
audio1 = example["audio"]
audio2 = ops.adjust_dim_width(
other["audio"], "audio_time", len(audio1)
).values
if len(audio2) > len(audio1):
audio2 = audio2[: len(audio1)]
audio2 = adjust_width(other["audio"].values, len(audio1))
combined = weight * audio1 + (1 - weight) * audio2
spec = compute_spectrogram(
spectrogram = compute_spectrogram(
combined.rename({"audio_time": "time"}),
config=config.spectrogram,
)
).data
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
# simplification in the subclip process: It doesn't account for the
# spectrogram parameters to precisely determine the corresponding audio
# samples. To work around this, we pad the computed spectrogram with zeros
# as needed.
previous_width = len(example["time"])
spectrogram = adjust_width(spectrogram, previous_width)
detection_heatmap = xr.apply_ufunc(
np.maximum,
example["detection"],
other["detection"].values,
adjust_width(other["detection"].values, previous_width),
)
class_heatmap = xr.apply_ufunc(
np.maximum,
example["class"],
other["class"].values,
adjust_width(other["class"].values, previous_width),
)
size_heatmap = example["size"] + other["size"].values
size_heatmap = example["size"] + adjust_width(
other["size"].values, previous_width
)
return xr.Dataset(
{
"audio": combined,
"spectrogram": xr.DataArray(
data=spec.data,
data=spectrogram,
dims=example["spectrogram"].dims,
coords=example["spectrogram"].coords,
),
@ -192,12 +165,23 @@ def add_echo(
spectrogram = compute_spectrogram(
audio.rename({"audio_time": "time"}),
config=config.spectrogram,
).data
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
# simplification in the subclip process: It doesn't account for the
# spectrogram parameters to precisely determine the corresponding audio
# samples. To work around this, we pad the computed spectrogram with zeros
# as needed.
spectrogram = adjust_width(
spectrogram,
example["spectrogram"].sizes["time"],
)
return example.assign(
audio=audio,
spectrogram=xr.DataArray(
data=spectrogram.data,
data=spectrogram,
dims=example["spectrogram"].dims,
coords=example["spectrogram"].coords,
),
@ -359,7 +343,6 @@ def mask_frequency(
class AugmentationsConfig(BaseConfig):
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
echo: EchoAugmentationConfig = Field(
default_factory=EchoAugmentationConfig
@ -391,23 +374,8 @@ def augment_example(
preprocessing_config: Optional[PreprocessingConfig] = None,
others: Optional[Callable[[], xr.Dataset]] = None,
) -> xr.Dataset:
if config.subclip.enable:
example = select_random_subclip(
example,
duration=config.subclip.duration,
width=config.subclip.width,
)
if should_apply(config.mix) and (others is not None):
other = others()
if config.subclip.enable:
other = select_random_subclip(
other,
duration=config.subclip.duration,
width=config.subclip.width,
)
example = mix_examples(
example,
other,

View File

@ -5,10 +5,17 @@ from typing import NamedTuple, Optional, Sequence, Union
import numpy as np
import torch
import xarray as xr
from pydantic import Field
from soundevent import data
from torch.utils.data import Dataset
from batdetect2.train.augmentations import AugmentationsConfig, augment_example
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.tensors import adjust_width
from batdetect2.train.augmentations import (
AugmentationsConfig,
augment_example,
select_subclip,
)
from batdetect2.train.preprocess import PreprocessingConfig
__all__ = [
@ -28,20 +35,36 @@ class TrainExample(NamedTuple):
idx: torch.Tensor
class SubclipConfig(BaseConfig):
duration: Optional[float] = None
width: int = 512
random: bool = False
class DatasetConfig(BaseConfig):
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
augmentation: AugmentationsConfig = Field(
default_factory=AugmentationsConfig
)
class LabeledDataset(Dataset):
config: DatasetConfig
def __init__(
self,
filenames: Sequence[PathLike],
augment: bool = False,
preprocessing_config: Optional[PreprocessingConfig] = None,
augmentation_config: Optional[AugmentationsConfig] = None,
subclip: bool = False,
config: Optional[DatasetConfig] = None,
):
self.filenames = filenames
self.augment = augment
self.preprocessing_config = (
preprocessing_config or PreprocessingConfig()
)
self.agumentation_config = augmentation_config or AugmentationsConfig()
self.subclip = subclip
self.config = config or DatasetConfig()
def __len__(self):
return len(self.filenames)
@ -49,27 +72,27 @@ class LabeledDataset(Dataset):
def __getitem__(self, idx) -> TrainExample:
dataset = self.get_dataset(idx)
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.config.subclip.duration,
width=self.config.subclip.width,
random=self.config.subclip.random,
)
if self.augment:
dataset = augment_example(
dataset,
self.agumentation_config,
preprocessing_config=self.preprocessing_config,
self.config.augmentation,
preprocessing_config=self.config.preprocessing,
others=self.get_random_example,
)
return TrainExample(
spec=torch.tensor(
dataset["spectrogram"].values.astype(np.float32)
).unsqueeze(0),
detection_heatmap=torch.tensor(
dataset["detection"].values.astype(np.float32)
),
class_heatmap=torch.tensor(
dataset["class"].values.astype(np.float32)
),
size_heatmap=torch.tensor(
dataset["size"].values.astype(np.float32)
),
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
detection_heatmap=self.to_tensor(dataset["detection"]),
class_heatmap=self.to_tensor(dataset["class"]),
size_heatmap=self.to_tensor(dataset["size"]),
idx=torch.tensor(idx),
)
@ -78,20 +101,30 @@ class LabeledDataset(Dataset):
cls,
directory: PathLike,
extension: str = ".nc",
config: Optional[DatasetConfig] = None,
augment: bool = False,
preprocessing_config: Optional[PreprocessingConfig] = None,
augmentation_config: Optional[AugmentationsConfig] = None,
subclip: bool = False,
):
return cls(
get_files(directory, extension),
config=config,
augment=augment,
preprocessing_config=preprocessing_config,
augmentation_config=augmentation_config,
subclip=subclip,
)
def get_random_example(self) -> xr.Dataset:
idx = np.random.randint(0, len(self))
return self.get_dataset(idx)
dataset = self.get_dataset(idx)
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.config.subclip.duration,
width=self.config.subclip.width,
random=self.config.subclip.random,
)
return dataset
def get_dataset(self, idx) -> xr.Dataset:
return xr.open_dataset(self.filenames[idx])
@ -101,6 +134,19 @@ class LabeledDataset(Dataset):
self.get_dataset(idx).attrs["clip_annotation"]
)
def to_tensor(
self,
array: xr.DataArray,
dtype=np.float32,
) -> torch.Tensor:
tensor = torch.tensor(array.values.astype(dtype))
if not self.subclip:
return tensor
width = self.config.subclip.width
return adjust_width(tensor, width)
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
return list(Path(directory).glob(f"*{extension}"))

View File

@ -1,16 +1,56 @@
from typing import List
import numpy as np
from sklearn.metrics import (
accuracy_score,
auc,
balanced_accuracy_score,
roc_curve,
)
from sklearn.metrics import auc, roc_curve
from soundevent import data
from soundevent.evaluation import match_geometries
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
) -> List[data.Match]:
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None
matches.append(
data.Match(source=source, target=target, affinity=affinity)
)
return matches
def compute_error_auc(op_str, gt, pred, prob):
# classification error
pred_int = (pred > prob).astype(np.int)
pred_int = (pred > prob).astype(np.int32)
class_acc = (pred_int == gt).mean() * 100.0
# ROC - area under curve
@ -25,7 +65,6 @@ def compute_error_auc(op_str, gt, pred, prob):
def calc_average_precision(recall, precision):
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
@ -91,7 +130,6 @@ def compute_pre_rec(
pred_class = []
file_ids = []
for pid, pp in enumerate(preds):
# filter predicted calls that are too near the start or end of the file
file_dur = gts[pid]["duration"]
valid_inds = (pp["start_times"] >= ignore_start_end) & (
@ -141,7 +179,6 @@ def compute_pre_rec(
gt_generic_class = []
num_positives = 0
for gg in gts:
# filter ground truth calls that are too near the start or end of the file
file_dur = gg["duration"]
valid_inds = (gg["start_times"] >= ignore_start_end) & (
@ -205,7 +242,6 @@ def compute_pre_rec(
# valid detection that has not already been assigned
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
count_as_true_pos = True
if eval_mode == "top_class" and (
gt_class[gt_id][det_ind] != pred_class[ind]

View File

@ -3,7 +3,9 @@ from typing import Optional
import pytorch_lightning as L
import torch
from pydantic import Field
from torch import optim
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig
from batdetect2.models import (
@ -13,11 +15,20 @@ from batdetect2.models import (
get_backbone,
)
from batdetect2.models.typing import ModelOutput
from batdetect2.post_process import PostprocessConfig
from batdetect2.post_process import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.train.dataset import TrainExample
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.evaluate import match_predictions_and_annotations
from batdetect2.train.losses import LossConfig, compute_loss
from batdetect2.train.targets import TargetConfig
from batdetect2.train.targets import (
TargetConfig,
build_decoder,
build_encoder,
get_class_names,
)
class OptimizerConfig(BaseConfig):
@ -55,10 +66,8 @@ class DetectorModel(L.LightningModule):
self.save_hyperparameters()
size = self.config.preprocessing.spectrogram.size
self.backbone = get_backbone(
input_height=int(size.height * (size.resize_factor or 1)),
config=self.config.backbone,
)
assert size is not None
self.backbone = get_backbone(self.config.backbone)
self.classifier = ClassifierHead(
num_classes=len(self.config.targets.classes),
@ -74,6 +83,13 @@ class DetectorModel(L.LightningModule):
self.validation_predictions = []
self.class_names = get_class_names(self.config.targets.classes)
self.encoder = build_encoder(
self.config.targets.classes,
replacement_rules=self.config.targets.replace,
)
self.decoder = build_decoder(self.config.targets.classes)
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.backbone(spec)
detection_probs, classification_probs = self.classifier(features)
@ -117,14 +133,29 @@ class DetectorModel(L.LightningModule):
self.log("val/loss/classification", losses.total, logger=True)
dataloaders = self.trainer.val_dataloaders
print(dataloaders)
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx)
clip_prediction = postprocess_model_outputs(
outputs,
clips=[clip_annotation.clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]
self.validation_predictions.extend(
match_predictions_and_annotations(clip_annotation, clip_prediction)
)
def on_validation_epoch_end(self) -> None:
print(len(self.validation_predictions))
self.validation_predictions.clear()
def configure_optimizers(self):
conf = self.config.train.optimizer
optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=conf.t_max,
)
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
return [optimizer], [scheduler]

View File

@ -101,7 +101,11 @@ def generate_train_example(
return dataset.assign_attrs(
title=f"Training example for {clip_annotation.uuid}",
config=config.model_dump_json(),
clip_annotation=clip_annotation.model_dump_json(),
clip_annotation=clip_annotation.model_dump_json(
exclude_none=True,
exclude_defaults=True,
exclude_unset=True,
),
)

View File

@ -1,11 +0,0 @@
from pathlib import Path
from batdetect2.configs import load_config
from batdetect2.data import DatasetsConfig, load_datasets
def test_can_load_dataset_configs():
root = Path(__file__).parent.parent
path = root / "conf.yaml"
config = load_config(path, schema=DatasetsConfig, field="datasets")
load_datasets(config)

View File

@ -4,7 +4,7 @@ from pathlib import Path
from soundevent import data
from batdetect2.data.compat import load_annotation_project
from batdetect2.compat.data import load_annotation_project_from_dir
ROOT_DIR = Path(__file__).parent.parent.parent
@ -12,7 +12,7 @@ ROOT_DIR = Path(__file__).parent.parent.parent
def test_load_example_annotation_project():
path = ROOT_DIR / "example_data" / "anns"
audio_dir = ROOT_DIR / "example_data" / "audio"
project = load_annotation_project(path, audio_dir=audio_dir)
project = load_annotation_project_from_dir(path, audio_dir=audio_dir)
assert isinstance(project, data.AnnotationProject)
assert project.name == str(path)
assert len(project.clip_annotations) == 3

View File

@ -88,7 +88,7 @@ def test_spectrogram_generation_hasnt_changed(
scale = preprocess.AmplitudeScaleConfig()
config = preprocess.SpectrogramConfig(
fft=preprocess.FFTConfig(
stft=preprocess.STFTConfig(
window_overlap=fft_overlap,
window_duration=fft_win_length,
),

View File

View File

@ -0,0 +1,34 @@
import torch
from hypothesis import given
from hypothesis import strategies as st
from batdetect2.models import ModelConfig, ModelType, get_backbone
@given(
input_width=st.integers(min_value=50, max_value=1500),
input_height=st.integers(min_value=1, max_value=16),
model_type=st.sampled_from(ModelType),
)
def test_model_can_process_spectrograms_of_any_width(
input_width,
input_height,
model_type,
):
# Input height must be divisible by 8
input_height = 8 * input_height
input = torch.rand([1, 1, input_height, input_width])
model = get_backbone(
config=ModelConfig(
name=model_type, # type: ignore
input_height=input_height,
),
)
output = model(input)
assert output.shape[0] == 1
assert output.shape[1] == model.out_channels
assert output.shape[2] == input_height
assert output.shape[3] == input_width

View File

@ -0,0 +1,23 @@
import numpy as np
from batdetect2.preprocess.arrays import adjust_width, extend_width
def test_extend_width():
array = np.random.random([1, 1, 128, 100])
extended = extend_width(array, 100)
assert extended.shape == (1, 1, 128, 200)
def test_can_adjust_short_width():
array = np.random.random([1, 1, 128, 100])
extended = adjust_width(array, 512)
assert extended.shape == (1, 1, 128, 512)
def test_can_adjust_long_width():
array = np.random.random([1, 1, 128, 512])
extended = adjust_width(array, 256)
assert extended.shape == (1, 1, 128, 256)

View File

@ -8,14 +8,13 @@ from soundevent import arrays
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
from batdetect2.preprocess.spectrogram import (
FFTConfig,
STFTConfig,
FrequencyConfig,
SpecSizeConfig,
SpectrogramConfig,
compute_spectrogram,
duration_to_spec_width,
get_spectrogram_resolution,
pad_audio,
spec_width_to_samples,
stft,
)
@ -68,45 +67,6 @@ def test_can_estimate_correctly_spectrogram_width_from_duration(
)
@settings(
suppress_health_check=[HealthCheck.function_scoped_fixture],
deadline=400,
)
@given(
duration=st.floats(min_value=0.1, max_value=1.0),
window_duration=st.floats(min_value=0.001, max_value=0.01),
window_overlap=st.floats(min_value=0.2, max_value=0.9),
samplerate=st.integers(min_value=256_000, max_value=512_000),
divide_factor=st.integers(min_value=16, max_value=64),
)
def test_can_pad_audio_to_adjust_spectrogram_width(
duration: float,
window_duration: float,
window_overlap: float,
samplerate: int,
divide_factor: int,
wav_factory: Callable[..., Path],
):
path = wav_factory(duration=duration, samplerate=samplerate)
audio = load_file_audio(
path,
# NOTE: Dont resample nor adjust duration to test if the width
# estimation works on all scenarios
config=AudioConfig(resample=None, duration=None),
)
audio = pad_audio(
audio,
window_duration=window_duration,
window_overlap=window_overlap,
divide_factor=divide_factor,
)
spectrogram = stft(audio, window_duration, window_overlap)
assert spectrogram.sizes["time"] % divide_factor == 0
def test_can_estimate_spectrogram_resolution(
wav_factory: Callable[..., Path],
):
@ -120,7 +80,7 @@ def test_can_estimate_spectrogram_resolution(
)
config = SpectrogramConfig(
fft=FFTConfig(),
stft=STFTConfig(),
size=SpecSizeConfig(height=256, resize_factor=0.5),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
)

View File

@ -0,0 +1,42 @@
import numpy as np
import torch
from batdetect2.preprocess.tensors import adjust_width, make_width_divisible
def test_width_is_divisible_after_adjustment():
tensor = torch.rand([1, 1, 128, 374])
adjusted = make_width_divisible(tensor, 32)
assert adjusted.shape[-1] % 32 == 0
assert adjusted.shape == (1, 1, 128, 384)
def test_non_last_axis_is_divisible_after_adjustment():
tensor = torch.rand([1, 1, 77, 124])
adjusted = make_width_divisible(tensor, 32, axis=-2)
assert adjusted.shape[-2] % 32 == 0
assert adjusted.shape == (1, 1, 96, 124)
def test_make_width_divisible_can_handle_numpy_array():
array = np.random.random([1, 1, 128, 374])
adjusted = make_width_divisible(array, 32)
assert adjusted.shape[-1] % 32 == 0
assert adjusted.shape == (1, 1, 128, 384)
assert isinstance(adjusted, torch.Tensor)
def test_adjust_last_axis_width_by_default():
tensor = torch.rand([1, 1, 128, 374])
adjusted = adjust_width(tensor, 512)
assert adjusted.shape == (1, 1, 128, 512)
assert (tensor == adjusted[:, :, :, :374]).all()
assert (adjusted[:, :, :, 374:] == 0).all()
def test_can_adjust_second_to_last_axis():
tensor = torch.rand([1, 1, 89, 512])
adjusted = adjust_width(tensor, 128, axis=-2)
assert adjusted.shape == (1, 1, 128, 512)
assert (tensor == adjusted[:, :, :89, :]).all()
assert (adjusted[:, :, 89:, :] == 0).all()

View File

@ -1,14 +1,14 @@
from collections.abc import Callable
import numpy as np
import pytest
import xarray as xr
from soundevent import data
from soundevent import arrays, data
from batdetect2.train.augmentations import (
add_echo,
adjust_dataset_width,
mix_examples,
select_random_subclip,
select_subclip,
)
from batdetect2.train.preprocess import (
TrainPreprocessingConfig,
@ -23,11 +23,9 @@ def test_mix_examples(
recording2 = recording_factory()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
config = TrainPreprocessingConfig()
@ -43,6 +41,36 @@ def test_mix_examples(
assert mixed["class"].shape == example1["class"].shape
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
def test_mix_examples_of_different_durations(
recording_factory: Callable[..., data.Recording],
duration1: float,
duration2: float,
):
recording1 = recording_factory()
recording2 = recording_factory()
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
config = TrainPreprocessingConfig()
example1 = generate_train_example(clip_annotation_1, config)
example2 = generate_train_example(clip_annotation_2, config)
mixed = mix_examples(example1, example2, config=config.preprocessing)
# Check the spectrogram has the expected duration
step = arrays.get_dim_step(mixed["spectrogram"], "time")
start, stop = arrays.get_dim_range(mixed["spectrogram"], "time")
assert start == 0
assert np.isclose(stop + step, duration1, atol=2 * step)
def test_add_echo(
recording_factory: Callable[..., data.Recording],
):
@ -67,70 +95,23 @@ def test_selected_random_subclip_has_the_correct_width(
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig()
original = generate_train_example(clip_annotation_1, config)
subclip = select_random_subclip(original, width=100)
subclip = select_subclip(original, width=100)
assert subclip["spectrogram"].shape[1] == 100
def test_adjust_dataset_width():
height = 128
width = 512
samplerate = 48_000
def test_add_echo_after_subclip(
recording_factory: Callable[..., data.Recording],
):
recording1 = recording_factory(duration=2)
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig()
original = generate_train_example(clip_annotation_1, config)
times = np.linspace(0, 1, width)
assert original.sizes["time"] > 512
audio_times = np.linspace(0, 1, samplerate)
frequency = np.linspace(0, 24_000, height)
subclip = select_subclip(original, width=512)
with_echo = add_echo(subclip)
width_subset = 356
audio_width_subset = int(samplerate * width_subset / width)
times_subset = times[:width_subset]
audio_times_subset = audio_times[:audio_width_subset]
dimensions = ["width", "height"]
class_names = [f"species_{i}" for i in range(17)]
spectrogram = np.random.random([height, width_subset])
sizes = np.random.random([len(dimensions), height, width_subset])
classes = np.random.random([len(class_names), height, width_subset])
audio = np.random.random([int(samplerate * width_subset / width)])
dataset = xr.Dataset(
data_vars={
"audio": (("audio_time",), audio),
"spectrogram": (("frequency", "time"), spectrogram),
"sizes": (("dimension", "frequency", "time"), sizes),
"classes": (("class", "frequency", "time"), classes),
},
coords={
"audio_time": audio_times_subset,
"time": times_subset,
"frequency": frequency,
"dimension": dimensions,
"class": class_names,
},
)
adjusted = adjust_dataset_width(dataset, width=width)
# Spectrogram was adjusted correctly
assert np.isclose(adjusted["spectrogram"].time, times).all()
assert (adjusted["spectrogram"].frequency == frequency).all()
# Sizes was adjusted correctly
assert np.isclose(adjusted["sizes"].time, times).all()
assert (adjusted["sizes"].frequency == frequency).all()
assert list(adjusted["sizes"].dimension.values) == dimensions
# Sizes was adjusted correctly
assert np.isclose(adjusted["classes"].time, times).all()
assert (adjusted["sizes"].frequency == frequency).all()
assert list(adjusted["classes"]["class"].values) == class_names
# Audio time was adjusted corretly
assert np.isclose(
len(adjusted["audio"].audio_time), len(audio_times), atol=2
)
assert np.isclose(
adjusted["audio"].audio_time[-1], audio_times[-1], atol=1e-3
)
assert with_echo.sizes["time"] == 512

View File

@ -3,7 +3,6 @@ from pathlib import Path
import numpy as np
import xarray as xr
from soundevent import data
from soundevent.types import ClassMapper
from batdetect2.train.labels import generate_heatmaps
@ -23,16 +22,6 @@ clip = data.Clip(
)
class Mapper(ClassMapper):
class_labels = ["bat", "cat"]
def encode(self, sound_event_annotation: data.SoundEventAnnotation) -> str:
return "bat"
def decode(self, label: str) -> list:
return [data.Tag(term=data.term_from_key("species"), value="bat")]
def test_generated_heatmaps_have_correct_dimensions():
spec = xr.DataArray(
data=np.random.rand(100, 100),
@ -57,12 +46,11 @@ def test_generated_heatmaps_have_correct_dimensions():
],
)
class_mapper = Mapper()
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation.sound_events,
spec,
class_mapper,
class_names=["bat", "cat"],
encoder=lambda _: "bat",
)
assert isinstance(detection_heatmap, xr.DataArray)
@ -107,11 +95,13 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
],
)
class_mapper = Mapper()
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation.sound_events,
spec,
class_mapper,
class_names=["bat", "cat"],
encoder=lambda _: "bat",
time_scale=1,
frequency_scale=1,
)
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10