mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
WIP
This commit is contained in:
parent
f7d6516550
commit
48e009fa9d
@ -1,7 +1,7 @@
|
||||
from batdetect2.preprocess import (
|
||||
AmplitudeScaleConfig,
|
||||
AudioConfig,
|
||||
FFTConfig,
|
||||
STFTConfig,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
PcenScaleConfig,
|
||||
@ -40,7 +40,7 @@ def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||
duration=None,
|
||||
),
|
||||
spectrogram=SpectrogramConfig(
|
||||
fft=FFTConfig(
|
||||
stft=STFTConfig(
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
window_fn="hann",
|
||||
|
@ -1,10 +1,12 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.backbones import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
Net2DFastNoCoordConv,
|
||||
Net2DPlain,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
@ -24,31 +26,57 @@ class ModelType(str, Enum):
|
||||
Net2DFast = "Net2DFast"
|
||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||
Net2DPlain = "Net2DPlain"
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
name: ModelType = ModelType.Net2DFast
|
||||
num_features: int = 32
|
||||
input_height: int = 128
|
||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||
bottleneck_channels: int = 256
|
||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def get_backbone(
|
||||
config: ModelConfig,
|
||||
input_height: int = 128,
|
||||
config: Optional[ModelConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or ModelConfig()
|
||||
|
||||
if config.name == ModelType.Net2DFast:
|
||||
return Net2DFast(
|
||||
input_height=input_height,
|
||||
num_features=config.num_features,
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
elif config.name == ModelType.Net2DFastNoAttn:
|
||||
|
||||
if config.name == ModelType.Net2DFastNoAttn:
|
||||
return Net2DFastNoAttn(
|
||||
num_features=config.num_features,
|
||||
input_height=input_height,
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
elif config.name == ModelType.Net2DFastNoCoordConv:
|
||||
|
||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||
return Net2DFastNoCoordConv(
|
||||
num_features=config.num_features,
|
||||
input_height=input_height,
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {config.name}")
|
||||
|
||||
if config.name == ModelType.Net2DPlain:
|
||||
return Net2DPlain(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {config.name}")
|
||||
|
@ -1,16 +1,16 @@
|
||||
from typing import Tuple
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlockDownCoordF,
|
||||
ConvBlockDownStandard,
|
||||
ConvBlockUpF,
|
||||
ConvBlockUpStandard,
|
||||
ConvBlock,
|
||||
Decoder,
|
||||
DownscalingLayer,
|
||||
Encoder,
|
||||
SelfAttention,
|
||||
UpscalingLayer,
|
||||
VerticalConv,
|
||||
)
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
|
||||
@ -21,303 +21,164 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class Net2DFast(BackboneModel):
|
||||
class Net2DPlain(BackboneModel):
|
||||
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
|
||||
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||
bottleneck_channels: int = 256,
|
||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||
out_channels: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
# encoder
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
self.input_height,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
self.input_height = input_height
|
||||
self.encoder_channels = tuple(encoder_channels)
|
||||
self.decoder_channels = tuple(decoder_channels)
|
||||
self.out_channels = out_channels
|
||||
|
||||
if len(encoder_channels) != len(decoder_channels):
|
||||
raise ValueError(
|
||||
f"Mismatched encoder and decoder channel lists. "
|
||||
f"The encoder has {len(encoder_channels)} channels "
|
||||
f"(implying {len(encoder_channels) - 1} layers), "
|
||||
f"while the decoder has {len(decoder_channels)} channels "
|
||||
f"(implying {len(decoder_channels) - 1} layers). "
|
||||
f"These lengths must be equal."
|
||||
)
|
||||
|
||||
self.divide_factor = 2 ** (len(encoder_channels) - 1)
|
||||
if self.input_height % self.divide_factor != 0:
|
||||
raise ValueError(
|
||||
f"Input height ({self.input_height}) must be divisible by "
|
||||
f"the divide factor ({self.divide_factor}). "
|
||||
f"This ensures proper upscaling after downscaling to recover "
|
||||
f"the original input height."
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
channels=encoder_channels,
|
||||
input_height=self.input_height,
|
||||
layer_type=self.downscaling_layer_type,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
self.input_height // 2,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
|
||||
self.conv_same_1 = ConvBlock(
|
||||
in_channels=encoder_channels[-1],
|
||||
out_channels=bottleneck_channels,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
self.input_height // 4,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
# bottleneck
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
||||
|
||||
# decoder
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
self.conv_vert = VerticalConv(
|
||||
in_channels=bottleneck_channels,
|
||||
out_channels=bottleneck_channels,
|
||||
input_height=self.input_height // (2**self.encoder.depth),
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
self.decoder = Decoder(
|
||||
channels=decoder_channels,
|
||||
input_height=self.input_height,
|
||||
layer_type=self.upscaling_layer_type,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
|
||||
self.out_channels = self.num_features // 4
|
||||
|
||||
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = (32 - h % 32) % 32
|
||||
w_pad = (32 - w % 32) % 32
|
||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||
self.conv_same_2 = ConvBlock(
|
||||
in_channels=decoder_channels[-1],
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
# encoder
|
||||
spec, h_pad, w_pad = self.pad_adjust(spec)
|
||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
# encoder
|
||||
residuals = self.encoder(spec)
|
||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||
|
||||
# bottleneck
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
x = self.conv_vert(residuals[-1])
|
||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||
|
||||
# decoder
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
x = self.decoder(x, residuals=residuals)
|
||||
|
||||
# Restore original size
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
return self.conv_same_2(x)
|
||||
|
||||
|
||||
class Net2DFastNoAttn(BackboneModel):
|
||||
class Net2DFast(Net2DPlain):
|
||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||
upscaling_layer_type = "ConvBlockUpF"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||
bottleneck_channels: int = 256,
|
||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||
out_channels: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
self.input_height,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
self.input_height // 2,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
self.input_height // 4,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
super().__init__(
|
||||
input_height=input_height,
|
||||
encoder_channels=encoder_channels,
|
||||
bottleneck_channels=bottleneck_channels,
|
||||
decoder_channels=decoder_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
self.out_channels = self.num_features // 4
|
||||
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
# encoder
|
||||
residuals = self.encoder(spec)
|
||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(BackboneModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownStandard(
|
||||
1,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownStandard(
|
||||
self.num_features // 4,
|
||||
self.num_features // 2,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownStandard(
|
||||
self.num_features // 2,
|
||||
self.num_features,
|
||||
kernel_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_features,
|
||||
self.num_features * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_features * 2,
|
||||
self.num_features * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
||||
|
||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpStandard(
|
||||
self.num_features * 2,
|
||||
self.num_features // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpStandard(
|
||||
self.num_features // 2,
|
||||
self.num_features // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpStandard(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_features // 4,
|
||||
self.num_features // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||
self.out_channels = self.num_features // 4
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
# bottleneck
|
||||
x = self.conv_vert(residuals[-1])
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
# decoder
|
||||
x = self.decoder(x, residuals=residuals)
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
# Restore original size
|
||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
return self.conv_same_2(x)
|
||||
|
||||
|
||||
class Net2DFastNoAttn(Net2DPlain):
|
||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||
upscaling_layer_type = "ConvBlockUpF"
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(Net2DFast):
|
||||
downscaling_layer_type = "ConvBlockDownStandard"
|
||||
upscaling_layer_type = "ConvBlockUpStandard"
|
||||
|
||||
|
||||
def pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = -h % factor
|
||||
w_pad = -w % factor
|
||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||
|
||||
|
||||
def restore_pad(
|
||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||
) -> torch.Tensor:
|
||||
# Restore original size
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
|
||||
return x
|
||||
|
@ -4,29 +4,35 @@ All these classes are subclasses of `torch.nn.Module` and can be used to build
|
||||
complex neural network architectures.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SelfAttention",
|
||||
"ConvBlock",
|
||||
"ConvBlockDownCoordF",
|
||||
"ConvBlockDownStandard",
|
||||
"ConvBlockUpF",
|
||||
"ConvBlockUpStandard",
|
||||
"SelfAttention",
|
||||
"VerticalConv",
|
||||
"DownscalingLayer",
|
||||
"UpscalingLayer",
|
||||
]
|
||||
|
||||
|
||||
class SelfAttentionConfig(BaseConfig):
|
||||
temperature: float = 1.0
|
||||
input_channels: int = 128
|
||||
attention_channels: int = 128
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Self-Attention module.
|
||||
|
||||
@ -76,13 +82,64 @@ class SelfAttention(nn.Module):
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlockDownCoordFConfig(BaseConfig):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
kernel_size: int = 3
|
||||
pad_size: int = 1
|
||||
stride: int = 1
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu_(self.conv_bn(self.conv(x)))
|
||||
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
"""Convolutional layer over full height.
|
||||
|
||||
This layer applies a convolution that captures information across the
|
||||
entire height of the input image. It uses a kernel with the same height as
|
||||
the input, effectively condensing the vertical information into a single
|
||||
output row.
|
||||
|
||||
More specifically:
|
||||
|
||||
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
|
||||
input channels, H is the image height, and W is the image width.
|
||||
* **Kernel:** (C', H, 1) where C' is the number of output channels.
|
||||
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
|
||||
convolution integrates information from all rows of the input.
|
||||
|
||||
This process effectively extracts features that span the full height of
|
||||
the input image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(input_height, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu_(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
@ -124,14 +181,6 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlockDownStandardConfig(BaseConfig):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
kernel_size: int = 3
|
||||
pad_size: int = 1
|
||||
stride: int = 1
|
||||
|
||||
|
||||
class ConvBlockDownStandard(nn.Module):
|
||||
"""Convolutional Block with Downsampling.
|
||||
|
||||
@ -158,18 +207,10 @@ class ConvBlockDownStandard(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
x = F.relu(self.conv_bn(x), inplace=True)
|
||||
return x
|
||||
return F.relu(self.conv_bn(x), inplace=True)
|
||||
|
||||
|
||||
class ConvBlockUpFConfig(BaseConfig):
|
||||
inp_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
kernel_size: int = 3
|
||||
pad_size: int = 1
|
||||
up_mode: str = "bilinear"
|
||||
up_scale: Tuple[int, int] = (2, 2)
|
||||
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
|
||||
|
||||
|
||||
class ConvBlockUpF(nn.Module):
|
||||
@ -224,15 +265,6 @@ class ConvBlockUpF(nn.Module):
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlockUpStandardConfig(BaseConfig):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
kernel_size: int = 3
|
||||
pad_size: int = 1
|
||||
up_mode: str = "bilinear"
|
||||
up_scale: Tuple[int, int] = (2, 2)
|
||||
|
||||
|
||||
class ConvBlockUpStandard(nn.Module):
|
||||
"""Convolutional Block with Upsampling.
|
||||
|
||||
@ -272,3 +304,143 @@ class ConvBlockUpStandard(nn.Module):
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
|
||||
|
||||
|
||||
def build_downscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: DownscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockDownStandard":
|
||||
return ConvBlockDownStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockDownCoordF":
|
||||
return ConvBlockDownCoordF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid downscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
||||
)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (1, 32, 62, 128),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockDownStandard", "ConvBlockDownCoordF"
|
||||
] = "ConvBlockDownStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_downscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height // (2**layer_num),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
outputs = []
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def build_upscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: UpscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockUpStandard":
|
||||
return ConvBlockUpStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockUpF":
|
||||
return ConvBlockUpF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid upscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
|
||||
)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (256, 62, 32, 32),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockUpStandard", "ConvBlockUpF"
|
||||
] = "ConvBlockUpStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
self.depth = len(self.channels) - 1
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_upscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height
|
||||
// (2 ** (self.depth - layer_num)),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if len(residuals) != len(self.layers):
|
||||
raise ValueError(
|
||||
f"Incorrect number of residuals provided. "
|
||||
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||
f"but got {len(residuals)}."
|
||||
)
|
||||
|
||||
for layer, res in zip(self.layers, residuals[::-1]):
|
||||
x = layer(x + res)
|
||||
|
||||
return x
|
||||
|
17
batdetect2/models/decoder.py
Normal file
17
batdetect2/models/decoder.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
|
17
batdetect2/models/encoder.py
Normal file
17
batdetect2/models/encoder.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -45,15 +45,27 @@ class BackboneModel(ABC, nn.Module):
|
||||
input_height: int
|
||||
"""Height of the input spectrogram."""
|
||||
|
||||
num_features: int
|
||||
"""Dimension of the feature tensor."""
|
||||
encoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the encoder. The length of this tuple determines the number of
|
||||
encoder layers."""
|
||||
|
||||
decoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the decoder. The length of this tuple determines the number of
|
||||
decoder layers."""
|
||||
|
||||
bottleneck_channels: int
|
||||
"""Number of channels in the bottleneck layer, which connects the
|
||||
encoder and decoder."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of output channels of the feature extractor."""
|
||||
"""Number of channels in the final output feature map produced by the
|
||||
backbone model."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the encoder model."""
|
||||
"""Forward pass of the model."""
|
||||
|
||||
|
||||
class DetectionModel(ABC, nn.Module):
|
||||
|
@ -13,7 +13,7 @@ from batdetect2.preprocess.audio import (
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
AmplitudeScaleConfig,
|
||||
FFTConfig,
|
||||
STFTConfig,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
PcenScaleConfig,
|
||||
@ -27,7 +27,7 @@ __all__ = [
|
||||
"AudioConfig",
|
||||
"ResampleConfig",
|
||||
"SpectrogramConfig",
|
||||
"FFTConfig",
|
||||
"STFTConfig",
|
||||
"FrequencyConfig",
|
||||
"PcenScaleConfig",
|
||||
"LogScaleConfig",
|
||||
|
61
batdetect2/preprocess/arrays.py
Normal file
61
batdetect2/preprocess/arrays.py
Normal file
@ -0,0 +1,61 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def extend_width(
|
||||
array: np.ndarray,
|
||||
extra: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
|
||||
return np.pad(
|
||||
array,
|
||||
pad,
|
||||
mode="constant",
|
||||
constant_values=value,
|
||||
)
|
||||
|
||||
|
||||
def make_width_divisible(
|
||||
array: np.ndarray,
|
||||
factor: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
width = array.shape[axis]
|
||||
|
||||
if width % factor == 0:
|
||||
return array
|
||||
|
||||
extra = (-width) % factor
|
||||
return extend_width(array, extra, axis=axis, value=value)
|
||||
|
||||
|
||||
def adjust_width(
|
||||
array: np.ndarray,
|
||||
width: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> np.ndarray:
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
current_width = array.shape[axis]
|
||||
|
||||
if current_width == width:
|
||||
return array
|
||||
|
||||
if current_width < width:
|
||||
return extend_width(
|
||||
array,
|
||||
extra=width - current_width,
|
||||
axis=axis,
|
||||
value=value,
|
||||
)
|
||||
|
||||
slices = [
|
||||
slice(None, None) if index != axis else slice(None, width)
|
||||
for index in range(dims)
|
||||
]
|
||||
return array[tuple(slices)]
|
@ -12,7 +12,7 @@ from soundevent.arrays import operations as ops
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
|
||||
class FFTConfig(BaseConfig):
|
||||
class STFTConfig(BaseConfig):
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
@ -24,9 +24,15 @@ class FrequencyConfig(BaseConfig):
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseConfig):
|
||||
height: int = 256
|
||||
height: int = 128
|
||||
"""Height of the spectrogram in pixels. This value determines the
|
||||
number of frequency bands and corresponds to the vertical dimension
|
||||
of the spectrogram."""
|
||||
|
||||
resize_factor: Optional[float] = 0.5
|
||||
divide_factor: Optional[int] = 32
|
||||
"""Factor by which to resize the spectrogram along the time axis.
|
||||
A value of 0.5 reduces the temporal dimension by half, while a
|
||||
value of 2.0 doubles it. If None, no resizing is performed."""
|
||||
|
||||
|
||||
class LogScaleConfig(BaseConfig):
|
||||
@ -50,13 +56,13 @@ Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
fft: FFTConfig = Field(default_factory=FFTConfig)
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
scale: Scales = Field(
|
||||
default_factory=PcenScaleConfig,
|
||||
discriminator="name",
|
||||
)
|
||||
size: SpecSizeConfig = Field(default_factory=SpecSizeConfig)
|
||||
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
denoise: bool = True
|
||||
max_scale: bool = False
|
||||
|
||||
@ -68,22 +74,11 @@ def compute_spectrogram(
|
||||
) -> xr.DataArray:
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
if config.size.divide_factor:
|
||||
# Need to pad the audio to make sure the spectrogram has a
|
||||
# width compatible with the divide factor
|
||||
resize_factor = config.size.resize_factor or 1
|
||||
wav = pad_audio(
|
||||
wav,
|
||||
window_duration=config.fft.window_duration,
|
||||
window_overlap=config.fft.window_overlap,
|
||||
divide_factor=int(config.size.divide_factor / resize_factor),
|
||||
)
|
||||
|
||||
spec = stft(
|
||||
wav,
|
||||
window_duration=config.fft.window_duration,
|
||||
window_overlap=config.fft.window_overlap,
|
||||
window_fn=config.fft.window_fn,
|
||||
window_duration=config.stft.window_duration,
|
||||
window_overlap=config.stft.window_overlap,
|
||||
window_fn=config.stft.window_fn,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@ -98,11 +93,12 @@ def compute_spectrogram(
|
||||
if config.denoise:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
spec = resize_spectrogram(
|
||||
spec,
|
||||
height=config.size.height,
|
||||
resize_factor=config.size.resize_factor,
|
||||
)
|
||||
if config.size:
|
||||
spec = resize_spectrogram(
|
||||
spec,
|
||||
height=config.size.height,
|
||||
resize_factor=config.size.resize_factor,
|
||||
)
|
||||
|
||||
if config.max_scale:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
@ -257,7 +253,7 @@ def resize_spectrogram(
|
||||
return ops.resize(
|
||||
spec,
|
||||
time=int(resize_factor * current_width),
|
||||
frequency=int(resize_factor * height),
|
||||
frequency=height,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
@ -285,43 +281,6 @@ def adjust_spectrogram_width(
|
||||
return resized
|
||||
|
||||
|
||||
def pad_audio(
|
||||
wave: xr.DataArray,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
divide_factor: int = 32,
|
||||
) -> xr.DataArray:
|
||||
current_duration = arrays.get_dim_width(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
samplerate = int(1 / step)
|
||||
|
||||
estimated_spec_width = duration_to_spec_width(
|
||||
current_duration,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
)
|
||||
|
||||
if estimated_spec_width % divide_factor == 0:
|
||||
return wave
|
||||
|
||||
target_spec_width = int(
|
||||
np.ceil(estimated_spec_width / divide_factor) * divide_factor
|
||||
)
|
||||
target_samples = spec_width_to_samples(
|
||||
target_spec_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
)
|
||||
return ops.adjust_dim_width(
|
||||
wave,
|
||||
dim="time",
|
||||
width=target_samples,
|
||||
position="start",
|
||||
)
|
||||
|
||||
|
||||
def duration_to_spec_width(
|
||||
duration: float,
|
||||
samplerate: int,
|
||||
@ -357,6 +316,8 @@ def get_spectrogram_resolution(
|
||||
|
||||
spec_height = config.size.height
|
||||
resize_factor = config.size.resize_factor or 1
|
||||
freq_bin_width = (max_freq - min_freq) / (spec_height * resize_factor)
|
||||
hop_duration = config.fft.window_duration * (1 - config.fft.window_overlap)
|
||||
freq_bin_width = (max_freq - min_freq) / spec_height
|
||||
hop_duration = config.stft.window_duration * (
|
||||
1 - config.stft.window_overlap
|
||||
)
|
||||
return freq_bin_width, hop_duration / resize_factor
|
||||
|
76
batdetect2/preprocess/tensors.py
Normal file
76
batdetect2/preprocess/tensors.py
Normal file
@ -0,0 +1,76 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def extend_width(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
extra: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
pad = [
|
||||
[0, 0] if index != axis else [0, extra]
|
||||
for index in range(axis, dims)[::-1]
|
||||
]
|
||||
return F.pad(
|
||||
array,
|
||||
[x for y in pad for x in y],
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
def make_width_divisible(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
factor: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
width = array.shape[axis]
|
||||
|
||||
if width % factor == 0:
|
||||
return array
|
||||
|
||||
extra = (-width) % factor
|
||||
return extend_width(array, extra, axis=axis, value=value)
|
||||
|
||||
|
||||
def adjust_width(
|
||||
array: Union[np.ndarray, torch.Tensor],
|
||||
width: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(array, torch.Tensor):
|
||||
array = torch.Tensor(array)
|
||||
|
||||
dims = len(array.shape)
|
||||
axis = axis % dims
|
||||
current_width = array.shape[axis]
|
||||
|
||||
if current_width == width:
|
||||
return array
|
||||
|
||||
if current_width < width:
|
||||
return extend_width(
|
||||
array,
|
||||
extra=width - current_width,
|
||||
axis=axis,
|
||||
value=value,
|
||||
)
|
||||
|
||||
slices = [
|
||||
slice(None, None) if index != axis else slice(None, width)
|
||||
for index in range(dims)
|
||||
]
|
||||
return array[tuple(slices)]
|
@ -4,10 +4,10 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import arrays
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||
from batdetect2.preprocess.arrays import adjust_width
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
@ -17,47 +17,12 @@ class AugmentationConfig(BaseConfig):
|
||||
probability: float = 0.2
|
||||
|
||||
|
||||
class SubclipConfig(BaseConfig):
|
||||
enable: bool = True
|
||||
duration: Optional[float] = None
|
||||
width: Optional[int] = 512
|
||||
|
||||
|
||||
def adjust_dataset_width(
|
||||
example: xr.Dataset,
|
||||
duration: Optional[float] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> xr.Dataset:
|
||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||
|
||||
if width is None:
|
||||
if duration is None:
|
||||
raise ValueError("Either duration or width must be provided")
|
||||
|
||||
width = int(np.floor(duration / step))
|
||||
|
||||
adjusted_arrays = {
|
||||
name: ops.adjust_dim_width(array, "time", width)
|
||||
for name, array in example.items()
|
||||
if name != "audio"
|
||||
}
|
||||
|
||||
ratio = width / example.spectrogram.sizes["time"]
|
||||
audio_width = int(example.audio.sizes["audio_time"] * ratio)
|
||||
adjusted_arrays["audio"] = ops.adjust_dim_width(
|
||||
example["audio"],
|
||||
"audio_time",
|
||||
audio_width,
|
||||
)
|
||||
|
||||
return xr.Dataset(data_vars=adjusted_arrays)
|
||||
|
||||
|
||||
def select_random_subclip(
|
||||
def select_subclip(
|
||||
example: xr.Dataset,
|
||||
start_time: Optional[float] = None,
|
||||
duration: Optional[float] = None,
|
||||
width: Optional[int] = None,
|
||||
random: bool = False,
|
||||
) -> xr.Dataset:
|
||||
"""Select a random subclip from a clip."""
|
||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||
@ -73,10 +38,13 @@ def select_random_subclip(
|
||||
duration = width * step
|
||||
|
||||
if start_time is None:
|
||||
start_time = np.random.uniform(start, max(stop - duration, start))
|
||||
if random:
|
||||
start_time = np.random.uniform(start, max(stop - duration, start))
|
||||
else:
|
||||
start_time = start
|
||||
|
||||
if start_time + duration > stop:
|
||||
example = adjust_dataset_width(example, width=width)
|
||||
return example
|
||||
|
||||
start_index = arrays.get_coord_index(
|
||||
example, # type: ignore
|
||||
@ -91,7 +59,7 @@ def select_random_subclip(
|
||||
|
||||
return example.sel(
|
||||
time=slice(start_time, end_time),
|
||||
audio_time=slice(start_time, end_time),
|
||||
audio_time=slice(start_time, end_time + step),
|
||||
)
|
||||
|
||||
|
||||
@ -115,40 +83,45 @@ def mix_examples(
|
||||
weight = np.random.uniform(min_weight, max_weight)
|
||||
|
||||
audio1 = example["audio"]
|
||||
|
||||
audio2 = ops.adjust_dim_width(
|
||||
other["audio"], "audio_time", len(audio1)
|
||||
).values
|
||||
|
||||
if len(audio2) > len(audio1):
|
||||
audio2 = audio2[: len(audio1)]
|
||||
audio2 = adjust_width(other["audio"].values, len(audio1))
|
||||
|
||||
combined = weight * audio1 + (1 - weight) * audio2
|
||||
|
||||
spec = compute_spectrogram(
|
||||
spectrogram = compute_spectrogram(
|
||||
combined.rename({"audio_time": "time"}),
|
||||
config=config.spectrogram,
|
||||
)
|
||||
).data
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
# spectrogram computed from the subclip's audio. This is due to a
|
||||
# simplification in the subclip process: It doesn't account for the
|
||||
# spectrogram parameters to precisely determine the corresponding audio
|
||||
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||
# as needed.
|
||||
previous_width = len(example["time"])
|
||||
spectrogram = adjust_width(spectrogram, previous_width)
|
||||
|
||||
detection_heatmap = xr.apply_ufunc(
|
||||
np.maximum,
|
||||
example["detection"],
|
||||
other["detection"].values,
|
||||
adjust_width(other["detection"].values, previous_width),
|
||||
)
|
||||
|
||||
class_heatmap = xr.apply_ufunc(
|
||||
np.maximum,
|
||||
example["class"],
|
||||
other["class"].values,
|
||||
adjust_width(other["class"].values, previous_width),
|
||||
)
|
||||
|
||||
size_heatmap = example["size"] + other["size"].values
|
||||
size_heatmap = example["size"] + adjust_width(
|
||||
other["size"].values, previous_width
|
||||
)
|
||||
|
||||
return xr.Dataset(
|
||||
{
|
||||
"audio": combined,
|
||||
"spectrogram": xr.DataArray(
|
||||
data=spec.data,
|
||||
data=spectrogram,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
),
|
||||
@ -192,12 +165,23 @@ def add_echo(
|
||||
spectrogram = compute_spectrogram(
|
||||
audio.rename({"audio_time": "time"}),
|
||||
config=config.spectrogram,
|
||||
).data
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
# spectrogram computed from the subclip's audio. This is due to a
|
||||
# simplification in the subclip process: It doesn't account for the
|
||||
# spectrogram parameters to precisely determine the corresponding audio
|
||||
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||
# as needed.
|
||||
spectrogram = adjust_width(
|
||||
spectrogram,
|
||||
example["spectrogram"].sizes["time"],
|
||||
)
|
||||
|
||||
return example.assign(
|
||||
audio=audio,
|
||||
spectrogram=xr.DataArray(
|
||||
data=spectrogram.data,
|
||||
data=spectrogram,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
),
|
||||
@ -359,7 +343,6 @@ def mask_frequency(
|
||||
|
||||
|
||||
class AugmentationsConfig(BaseConfig):
|
||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
||||
echo: EchoAugmentationConfig = Field(
|
||||
default_factory=EchoAugmentationConfig
|
||||
@ -391,23 +374,8 @@ def augment_example(
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
others: Optional[Callable[[], xr.Dataset]] = None,
|
||||
) -> xr.Dataset:
|
||||
if config.subclip.enable:
|
||||
example = select_random_subclip(
|
||||
example,
|
||||
duration=config.subclip.duration,
|
||||
width=config.subclip.width,
|
||||
)
|
||||
|
||||
if should_apply(config.mix) and (others is not None):
|
||||
other = others()
|
||||
|
||||
if config.subclip.enable:
|
||||
other = select_random_subclip(
|
||||
other,
|
||||
duration=config.subclip.duration,
|
||||
width=config.subclip.width,
|
||||
)
|
||||
|
||||
example = mix_examples(
|
||||
example,
|
||||
other,
|
||||
|
@ -5,10 +5,17 @@ from typing import NamedTuple, Optional, Sequence, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.train.augmentations import AugmentationsConfig, augment_example
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.tensors import adjust_width
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
augment_example,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.preprocess import PreprocessingConfig
|
||||
|
||||
__all__ = [
|
||||
@ -28,20 +35,36 @@ class TrainExample(NamedTuple):
|
||||
idx: torch.Tensor
|
||||
|
||||
|
||||
class SubclipConfig(BaseConfig):
|
||||
duration: Optional[float] = None
|
||||
width: int = 512
|
||||
random: bool = False
|
||||
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
augmentation: AugmentationsConfig = Field(
|
||||
default_factory=AugmentationsConfig
|
||||
)
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
config: DatasetConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filenames: Sequence[PathLike],
|
||||
augment: bool = False,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
augmentation_config: Optional[AugmentationsConfig] = None,
|
||||
subclip: bool = False,
|
||||
config: Optional[DatasetConfig] = None,
|
||||
):
|
||||
self.filenames = filenames
|
||||
self.augment = augment
|
||||
self.preprocessing_config = (
|
||||
preprocessing_config or PreprocessingConfig()
|
||||
)
|
||||
self.agumentation_config = augmentation_config or AugmentationsConfig()
|
||||
self.subclip = subclip
|
||||
self.config = config or DatasetConfig()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
@ -49,27 +72,27 @@ class LabeledDataset(Dataset):
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
dataset = self.get_dataset(idx)
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.config.subclip.duration,
|
||||
width=self.config.subclip.width,
|
||||
random=self.config.subclip.random,
|
||||
)
|
||||
|
||||
if self.augment:
|
||||
dataset = augment_example(
|
||||
dataset,
|
||||
self.agumentation_config,
|
||||
preprocessing_config=self.preprocessing_config,
|
||||
self.config.augmentation,
|
||||
preprocessing_config=self.config.preprocessing,
|
||||
others=self.get_random_example,
|
||||
)
|
||||
|
||||
return TrainExample(
|
||||
spec=torch.tensor(
|
||||
dataset["spectrogram"].values.astype(np.float32)
|
||||
).unsqueeze(0),
|
||||
detection_heatmap=torch.tensor(
|
||||
dataset["detection"].values.astype(np.float32)
|
||||
),
|
||||
class_heatmap=torch.tensor(
|
||||
dataset["class"].values.astype(np.float32)
|
||||
),
|
||||
size_heatmap=torch.tensor(
|
||||
dataset["size"].values.astype(np.float32)
|
||||
),
|
||||
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||
detection_heatmap=self.to_tensor(dataset["detection"]),
|
||||
class_heatmap=self.to_tensor(dataset["class"]),
|
||||
size_heatmap=self.to_tensor(dataset["size"]),
|
||||
idx=torch.tensor(idx),
|
||||
)
|
||||
|
||||
@ -78,20 +101,30 @@ class LabeledDataset(Dataset):
|
||||
cls,
|
||||
directory: PathLike,
|
||||
extension: str = ".nc",
|
||||
config: Optional[DatasetConfig] = None,
|
||||
augment: bool = False,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
augmentation_config: Optional[AugmentationsConfig] = None,
|
||||
subclip: bool = False,
|
||||
):
|
||||
return cls(
|
||||
get_files(directory, extension),
|
||||
config=config,
|
||||
augment=augment,
|
||||
preprocessing_config=preprocessing_config,
|
||||
augmentation_config=augmentation_config,
|
||||
subclip=subclip,
|
||||
)
|
||||
|
||||
def get_random_example(self) -> xr.Dataset:
|
||||
idx = np.random.randint(0, len(self))
|
||||
return self.get_dataset(idx)
|
||||
dataset = self.get_dataset(idx)
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.config.subclip.duration,
|
||||
width=self.config.subclip.width,
|
||||
random=self.config.subclip.random,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
def get_dataset(self, idx) -> xr.Dataset:
|
||||
return xr.open_dataset(self.filenames[idx])
|
||||
@ -101,6 +134,19 @@ class LabeledDataset(Dataset):
|
||||
self.get_dataset(idx).attrs["clip_annotation"]
|
||||
)
|
||||
|
||||
def to_tensor(
|
||||
self,
|
||||
array: xr.DataArray,
|
||||
dtype=np.float32,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.tensor(array.values.astype(dtype))
|
||||
|
||||
if not self.subclip:
|
||||
return tensor
|
||||
|
||||
width = self.config.subclip.width
|
||||
return adjust_width(tensor, width)
|
||||
|
||||
|
||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
@ -1,16 +1,56 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
auc,
|
||||
balanced_accuracy_score,
|
||||
roc_curve,
|
||||
)
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
|
||||
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
clip_prediction: data.ClipPrediction,
|
||||
) -> List[data.Match]:
|
||||
annotated_sound_events = [
|
||||
sound_event_annotation
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_sound_events = [
|
||||
sound_event_prediction
|
||||
for sound_event_prediction in clip_prediction.sound_events
|
||||
if sound_event_prediction.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
annotated_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in annotated_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
matches = []
|
||||
for id1, id2, affinity in match_geometries(
|
||||
annotated_geometries,
|
||||
predicted_geometries,
|
||||
):
|
||||
target = annotated_sound_events[id1] if id1 is not None else None
|
||||
source = predicted_sound_events[id2] if id2 is not None else None
|
||||
matches.append(
|
||||
data.Match(source=source, target=target, affinity=affinity)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def compute_error_auc(op_str, gt, pred, prob):
|
||||
|
||||
# classification error
|
||||
pred_int = (pred > prob).astype(np.int)
|
||||
pred_int = (pred > prob).astype(np.int32)
|
||||
class_acc = (pred_int == gt).mean() * 100.0
|
||||
|
||||
# ROC - area under curve
|
||||
@ -25,7 +65,6 @@ def compute_error_auc(op_str, gt, pred, prob):
|
||||
|
||||
|
||||
def calc_average_precision(recall, precision):
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
@ -91,7 +130,6 @@ def compute_pre_rec(
|
||||
pred_class = []
|
||||
file_ids = []
|
||||
for pid, pp in enumerate(preds):
|
||||
|
||||
# filter predicted calls that are too near the start or end of the file
|
||||
file_dur = gts[pid]["duration"]
|
||||
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
||||
@ -141,7 +179,6 @@ def compute_pre_rec(
|
||||
gt_generic_class = []
|
||||
num_positives = 0
|
||||
for gg in gts:
|
||||
|
||||
# filter ground truth calls that are too near the start or end of the file
|
||||
file_dur = gg["duration"]
|
||||
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
||||
@ -205,7 +242,6 @@ def compute_pre_rec(
|
||||
|
||||
# valid detection that has not already been assigned
|
||||
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
||||
|
||||
count_as_true_pos = True
|
||||
if eval_mode == "top_class" and (
|
||||
gt_class[gt_id][det_ind] != pred_class[ind]
|
||||
|
@ -3,7 +3,9 @@ from typing import Optional
|
||||
import pytorch_lightning as L
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch import optim
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models import (
|
||||
@ -13,11 +15,20 @@ from batdetect2.models import (
|
||||
get_backbone,
|
||||
)
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.post_process import PostprocessConfig
|
||||
from batdetect2.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.train.losses import LossConfig, compute_loss
|
||||
from batdetect2.train.targets import TargetConfig
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
build_encoder,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
@ -55,10 +66,8 @@ class DetectorModel(L.LightningModule):
|
||||
self.save_hyperparameters()
|
||||
|
||||
size = self.config.preprocessing.spectrogram.size
|
||||
self.backbone = get_backbone(
|
||||
input_height=int(size.height * (size.resize_factor or 1)),
|
||||
config=self.config.backbone,
|
||||
)
|
||||
assert size is not None
|
||||
self.backbone = get_backbone(self.config.backbone)
|
||||
|
||||
self.classifier = ClassifierHead(
|
||||
num_classes=len(self.config.targets.classes),
|
||||
@ -74,6 +83,13 @@ class DetectorModel(L.LightningModule):
|
||||
|
||||
self.validation_predictions = []
|
||||
|
||||
self.class_names = get_class_names(self.config.targets.classes)
|
||||
self.encoder = build_encoder(
|
||||
self.config.targets.classes,
|
||||
replacement_rules=self.config.targets.replace,
|
||||
)
|
||||
self.decoder = build_decoder(self.config.targets.classes)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.backbone(spec)
|
||||
detection_probs, classification_probs = self.classifier(features)
|
||||
@ -117,14 +133,29 @@ class DetectorModel(L.LightningModule):
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
print(dataloaders)
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||
|
||||
clip_prediction = postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip_annotation.clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
||||
self.validation_predictions.extend(
|
||||
match_predictions_and_annotations(clip_annotation, clip_prediction)
|
||||
)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
print(len(self.validation_predictions))
|
||||
self.validation_predictions.clear()
|
||||
|
||||
def configure_optimizers(self):
|
||||
conf = self.config.train.optimizer
|
||||
optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer,
|
||||
T_max=conf.t_max,
|
||||
)
|
||||
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
@ -101,7 +101,11 @@ def generate_train_example(
|
||||
return dataset.assign_attrs(
|
||||
title=f"Training example for {clip_annotation.uuid}",
|
||||
config=config.model_dump_json(),
|
||||
clip_annotation=clip_annotation.model_dump_json(),
|
||||
clip_annotation=clip_annotation.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude_defaults=True,
|
||||
exclude_unset=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from batdetect2.configs import load_config
|
||||
from batdetect2.data import DatasetsConfig, load_datasets
|
||||
|
||||
|
||||
def test_can_load_dataset_configs():
|
||||
root = Path(__file__).parent.parent
|
||||
path = root / "conf.yaml"
|
||||
config = load_config(path, schema=DatasetsConfig, field="datasets")
|
||||
load_datasets(config)
|
@ -4,7 +4,7 @@ from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.compat import load_annotation_project
|
||||
from batdetect2.compat.data import load_annotation_project_from_dir
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
@ -12,7 +12,7 @@ ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
def test_load_example_annotation_project():
|
||||
path = ROOT_DIR / "example_data" / "anns"
|
||||
audio_dir = ROOT_DIR / "example_data" / "audio"
|
||||
project = load_annotation_project(path, audio_dir=audio_dir)
|
||||
project = load_annotation_project_from_dir(path, audio_dir=audio_dir)
|
||||
assert isinstance(project, data.AnnotationProject)
|
||||
assert project.name == str(path)
|
||||
assert len(project.clip_annotations) == 3
|
||||
|
@ -88,7 +88,7 @@ def test_spectrogram_generation_hasnt_changed(
|
||||
scale = preprocess.AmplitudeScaleConfig()
|
||||
|
||||
config = preprocess.SpectrogramConfig(
|
||||
fft=preprocess.FFTConfig(
|
||||
stft=preprocess.STFTConfig(
|
||||
window_overlap=fft_overlap,
|
||||
window_duration=fft_win_length,
|
||||
),
|
||||
|
0
tests/test_models/__init__.py
Normal file
0
tests/test_models/__init__.py
Normal file
34
tests/test_models/test_inputs.py
Normal file
34
tests/test_models/test_inputs.py
Normal file
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from batdetect2.models import ModelConfig, ModelType, get_backbone
|
||||
|
||||
|
||||
@given(
|
||||
input_width=st.integers(min_value=50, max_value=1500),
|
||||
input_height=st.integers(min_value=1, max_value=16),
|
||||
model_type=st.sampled_from(ModelType),
|
||||
)
|
||||
def test_model_can_process_spectrograms_of_any_width(
|
||||
input_width,
|
||||
input_height,
|
||||
model_type,
|
||||
):
|
||||
# Input height must be divisible by 8
|
||||
input_height = 8 * input_height
|
||||
|
||||
input = torch.rand([1, 1, input_height, input_width])
|
||||
|
||||
model = get_backbone(
|
||||
config=ModelConfig(
|
||||
name=model_type, # type: ignore
|
||||
input_height=input_height,
|
||||
),
|
||||
)
|
||||
|
||||
output = model(input)
|
||||
assert output.shape[0] == 1
|
||||
assert output.shape[1] == model.out_channels
|
||||
assert output.shape[2] == input_height
|
||||
assert output.shape[3] == input_width
|
23
tests/test_preprocessing/test_arrays.py
Normal file
23
tests/test_preprocessing/test_arrays.py
Normal file
@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
|
||||
from batdetect2.preprocess.arrays import adjust_width, extend_width
|
||||
|
||||
|
||||
def test_extend_width():
|
||||
array = np.random.random([1, 1, 128, 100])
|
||||
|
||||
extended = extend_width(array, 100)
|
||||
|
||||
assert extended.shape == (1, 1, 128, 200)
|
||||
|
||||
|
||||
def test_can_adjust_short_width():
|
||||
array = np.random.random([1, 1, 128, 100])
|
||||
extended = adjust_width(array, 512)
|
||||
assert extended.shape == (1, 1, 128, 512)
|
||||
|
||||
|
||||
def test_can_adjust_long_width():
|
||||
array = np.random.random([1, 1, 128, 512])
|
||||
extended = adjust_width(array, 256)
|
||||
assert extended.shape == (1, 1, 128, 256)
|
@ -8,14 +8,13 @@ from soundevent import arrays
|
||||
|
||||
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
FFTConfig,
|
||||
STFTConfig,
|
||||
FrequencyConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
compute_spectrogram,
|
||||
duration_to_spec_width,
|
||||
get_spectrogram_resolution,
|
||||
pad_audio,
|
||||
spec_width_to_samples,
|
||||
stft,
|
||||
)
|
||||
@ -68,45 +67,6 @@ def test_can_estimate_correctly_spectrogram_width_from_duration(
|
||||
)
|
||||
|
||||
|
||||
@settings(
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
deadline=400,
|
||||
)
|
||||
@given(
|
||||
duration=st.floats(min_value=0.1, max_value=1.0),
|
||||
window_duration=st.floats(min_value=0.001, max_value=0.01),
|
||||
window_overlap=st.floats(min_value=0.2, max_value=0.9),
|
||||
samplerate=st.integers(min_value=256_000, max_value=512_000),
|
||||
divide_factor=st.integers(min_value=16, max_value=64),
|
||||
)
|
||||
def test_can_pad_audio_to_adjust_spectrogram_width(
|
||||
duration: float,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
samplerate: int,
|
||||
divide_factor: int,
|
||||
wav_factory: Callable[..., Path],
|
||||
):
|
||||
path = wav_factory(duration=duration, samplerate=samplerate)
|
||||
|
||||
audio = load_file_audio(
|
||||
path,
|
||||
# NOTE: Dont resample nor adjust duration to test if the width
|
||||
# estimation works on all scenarios
|
||||
config=AudioConfig(resample=None, duration=None),
|
||||
)
|
||||
|
||||
audio = pad_audio(
|
||||
audio,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
divide_factor=divide_factor,
|
||||
)
|
||||
|
||||
spectrogram = stft(audio, window_duration, window_overlap)
|
||||
assert spectrogram.sizes["time"] % divide_factor == 0
|
||||
|
||||
|
||||
def test_can_estimate_spectrogram_resolution(
|
||||
wav_factory: Callable[..., Path],
|
||||
):
|
||||
@ -120,7 +80,7 @@ def test_can_estimate_spectrogram_resolution(
|
||||
)
|
||||
|
||||
config = SpectrogramConfig(
|
||||
fft=FFTConfig(),
|
||||
stft=STFTConfig(),
|
||||
size=SpecSizeConfig(height=256, resize_factor=0.5),
|
||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||
)
|
||||
|
42
tests/test_preprocessing/test_tensors.py
Normal file
42
tests/test_preprocessing/test_tensors.py
Normal file
@ -0,0 +1,42 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from batdetect2.preprocess.tensors import adjust_width, make_width_divisible
|
||||
|
||||
|
||||
def test_width_is_divisible_after_adjustment():
|
||||
tensor = torch.rand([1, 1, 128, 374])
|
||||
adjusted = make_width_divisible(tensor, 32)
|
||||
assert adjusted.shape[-1] % 32 == 0
|
||||
assert adjusted.shape == (1, 1, 128, 384)
|
||||
|
||||
|
||||
def test_non_last_axis_is_divisible_after_adjustment():
|
||||
tensor = torch.rand([1, 1, 77, 124])
|
||||
adjusted = make_width_divisible(tensor, 32, axis=-2)
|
||||
assert adjusted.shape[-2] % 32 == 0
|
||||
assert adjusted.shape == (1, 1, 96, 124)
|
||||
|
||||
|
||||
def test_make_width_divisible_can_handle_numpy_array():
|
||||
array = np.random.random([1, 1, 128, 374])
|
||||
adjusted = make_width_divisible(array, 32)
|
||||
assert adjusted.shape[-1] % 32 == 0
|
||||
assert adjusted.shape == (1, 1, 128, 384)
|
||||
assert isinstance(adjusted, torch.Tensor)
|
||||
|
||||
|
||||
def test_adjust_last_axis_width_by_default():
|
||||
tensor = torch.rand([1, 1, 128, 374])
|
||||
adjusted = adjust_width(tensor, 512)
|
||||
assert adjusted.shape == (1, 1, 128, 512)
|
||||
assert (tensor == adjusted[:, :, :, :374]).all()
|
||||
assert (adjusted[:, :, :, 374:] == 0).all()
|
||||
|
||||
|
||||
def test_can_adjust_second_to_last_axis():
|
||||
tensor = torch.rand([1, 1, 89, 512])
|
||||
adjusted = adjust_width(tensor, 128, axis=-2)
|
||||
assert adjusted.shape == (1, 1, 128, 512)
|
||||
assert (tensor == adjusted[:, :, :89, :]).all()
|
||||
assert (adjusted[:, :, 89:, :] == 0).all()
|
@ -1,14 +1,14 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.train.augmentations import (
|
||||
add_echo,
|
||||
adjust_dataset_width,
|
||||
mix_examples,
|
||||
select_random_subclip,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.preprocess import (
|
||||
TrainPreprocessingConfig,
|
||||
@ -23,11 +23,9 @@ def test_mix_examples(
|
||||
recording2 = recording_factory()
|
||||
|
||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||
|
||||
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
||||
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
|
||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||
|
||||
config = TrainPreprocessingConfig()
|
||||
@ -43,6 +41,36 @@ def test_mix_examples(
|
||||
assert mixed["class"].shape == example1["class"].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
||||
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
||||
def test_mix_examples_of_different_durations(
|
||||
recording_factory: Callable[..., data.Recording],
|
||||
duration1: float,
|
||||
duration2: float,
|
||||
):
|
||||
recording1 = recording_factory()
|
||||
recording2 = recording_factory()
|
||||
|
||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
|
||||
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
|
||||
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||
|
||||
config = TrainPreprocessingConfig()
|
||||
|
||||
example1 = generate_train_example(clip_annotation_1, config)
|
||||
example2 = generate_train_example(clip_annotation_2, config)
|
||||
|
||||
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
||||
|
||||
# Check the spectrogram has the expected duration
|
||||
step = arrays.get_dim_step(mixed["spectrogram"], "time")
|
||||
start, stop = arrays.get_dim_range(mixed["spectrogram"], "time")
|
||||
assert start == 0
|
||||
assert np.isclose(stop + step, duration1, atol=2 * step)
|
||||
|
||||
|
||||
def test_add_echo(
|
||||
recording_factory: Callable[..., data.Recording],
|
||||
):
|
||||
@ -67,70 +95,23 @@ def test_selected_random_subclip_has_the_correct_width(
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
config = TrainPreprocessingConfig()
|
||||
original = generate_train_example(clip_annotation_1, config)
|
||||
subclip = select_random_subclip(original, width=100)
|
||||
subclip = select_subclip(original, width=100)
|
||||
|
||||
assert subclip["spectrogram"].shape[1] == 100
|
||||
|
||||
|
||||
def test_adjust_dataset_width():
|
||||
height = 128
|
||||
width = 512
|
||||
samplerate = 48_000
|
||||
def test_add_echo_after_subclip(
|
||||
recording_factory: Callable[..., data.Recording],
|
||||
):
|
||||
recording1 = recording_factory(duration=2)
|
||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
config = TrainPreprocessingConfig()
|
||||
original = generate_train_example(clip_annotation_1, config)
|
||||
|
||||
times = np.linspace(0, 1, width)
|
||||
assert original.sizes["time"] > 512
|
||||
|
||||
audio_times = np.linspace(0, 1, samplerate)
|
||||
frequency = np.linspace(0, 24_000, height)
|
||||
subclip = select_subclip(original, width=512)
|
||||
with_echo = add_echo(subclip)
|
||||
|
||||
width_subset = 356
|
||||
audio_width_subset = int(samplerate * width_subset / width)
|
||||
|
||||
times_subset = times[:width_subset]
|
||||
audio_times_subset = audio_times[:audio_width_subset]
|
||||
dimensions = ["width", "height"]
|
||||
class_names = [f"species_{i}" for i in range(17)]
|
||||
|
||||
spectrogram = np.random.random([height, width_subset])
|
||||
sizes = np.random.random([len(dimensions), height, width_subset])
|
||||
classes = np.random.random([len(class_names), height, width_subset])
|
||||
audio = np.random.random([int(samplerate * width_subset / width)])
|
||||
|
||||
dataset = xr.Dataset(
|
||||
data_vars={
|
||||
"audio": (("audio_time",), audio),
|
||||
"spectrogram": (("frequency", "time"), spectrogram),
|
||||
"sizes": (("dimension", "frequency", "time"), sizes),
|
||||
"classes": (("class", "frequency", "time"), classes),
|
||||
},
|
||||
coords={
|
||||
"audio_time": audio_times_subset,
|
||||
"time": times_subset,
|
||||
"frequency": frequency,
|
||||
"dimension": dimensions,
|
||||
"class": class_names,
|
||||
},
|
||||
)
|
||||
|
||||
adjusted = adjust_dataset_width(dataset, width=width)
|
||||
|
||||
# Spectrogram was adjusted correctly
|
||||
assert np.isclose(adjusted["spectrogram"].time, times).all()
|
||||
assert (adjusted["spectrogram"].frequency == frequency).all()
|
||||
|
||||
# Sizes was adjusted correctly
|
||||
assert np.isclose(adjusted["sizes"].time, times).all()
|
||||
assert (adjusted["sizes"].frequency == frequency).all()
|
||||
assert list(adjusted["sizes"].dimension.values) == dimensions
|
||||
|
||||
# Sizes was adjusted correctly
|
||||
assert np.isclose(adjusted["classes"].time, times).all()
|
||||
assert (adjusted["sizes"].frequency == frequency).all()
|
||||
assert list(adjusted["classes"]["class"].values) == class_names
|
||||
|
||||
# Audio time was adjusted corretly
|
||||
assert np.isclose(
|
||||
len(adjusted["audio"].audio_time), len(audio_times), atol=2
|
||||
)
|
||||
assert np.isclose(
|
||||
adjusted["audio"].audio_time[-1], audio_times[-1], atol=1e-3
|
||||
)
|
||||
assert with_echo.sizes["time"] == 512
|
||||
|
@ -3,7 +3,6 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
|
||||
@ -23,16 +22,6 @@ clip = data.Clip(
|
||||
)
|
||||
|
||||
|
||||
class Mapper(ClassMapper):
|
||||
class_labels = ["bat", "cat"]
|
||||
|
||||
def encode(self, sound_event_annotation: data.SoundEventAnnotation) -> str:
|
||||
return "bat"
|
||||
|
||||
def decode(self, label: str) -> list:
|
||||
return [data.Tag(term=data.term_from_key("species"), value="bat")]
|
||||
|
||||
|
||||
def test_generated_heatmaps_have_correct_dimensions():
|
||||
spec = xr.DataArray(
|
||||
data=np.random.rand(100, 100),
|
||||
@ -57,12 +46,11 @@ def test_generated_heatmaps_have_correct_dimensions():
|
||||
],
|
||||
)
|
||||
|
||||
class_mapper = Mapper()
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation.sound_events,
|
||||
spec,
|
||||
class_mapper,
|
||||
class_names=["bat", "cat"],
|
||||
encoder=lambda _: "bat",
|
||||
)
|
||||
|
||||
assert isinstance(detection_heatmap, xr.DataArray)
|
||||
@ -107,11 +95,13 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
|
||||
],
|
||||
)
|
||||
|
||||
class_mapper = Mapper()
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation.sound_events,
|
||||
spec,
|
||||
class_mapper,
|
||||
class_names=["bat", "cat"],
|
||||
encoder=lambda _: "bat",
|
||||
time_scale=1,
|
||||
frequency_scale=1,
|
||||
)
|
||||
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
|
||||
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
|
||||
|
Loading…
Reference in New Issue
Block a user