mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
WIP
This commit is contained in:
parent
f7d6516550
commit
48e009fa9d
@ -1,7 +1,7 @@
|
|||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import (
|
||||||
AmplitudeScaleConfig,
|
AmplitudeScaleConfig,
|
||||||
AudioConfig,
|
AudioConfig,
|
||||||
FFTConfig,
|
STFTConfig,
|
||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
LogScaleConfig,
|
LogScaleConfig,
|
||||||
PcenScaleConfig,
|
PcenScaleConfig,
|
||||||
@ -40,7 +40,7 @@ def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
|||||||
duration=None,
|
duration=None,
|
||||||
),
|
),
|
||||||
spectrogram=SpectrogramConfig(
|
spectrogram=SpectrogramConfig(
|
||||||
fft=FFTConfig(
|
stft=STFTConfig(
|
||||||
window_duration=params["fft_win_length"],
|
window_duration=params["fft_win_length"],
|
||||||
window_overlap=params["fft_overlap"],
|
window_overlap=params["fft_overlap"],
|
||||||
window_fn="hann",
|
window_fn="hann",
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Net2DFast,
|
Net2DFast,
|
||||||
Net2DFastNoAttn,
|
Net2DFastNoAttn,
|
||||||
Net2DFastNoCoordConv,
|
Net2DFastNoCoordConv,
|
||||||
|
Net2DPlain,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.models.typing import BackboneModel
|
from batdetect2.models.typing import BackboneModel
|
||||||
@ -24,31 +26,57 @@ class ModelType(str, Enum):
|
|||||||
Net2DFast = "Net2DFast"
|
Net2DFast = "Net2DFast"
|
||||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||||
|
Net2DPlain = "Net2DPlain"
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
class ModelConfig(BaseConfig):
|
||||||
name: ModelType = ModelType.Net2DFast
|
name: ModelType = ModelType.Net2DFast
|
||||||
num_features: int = 32
|
input_height: int = 128
|
||||||
|
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||||
|
bottleneck_channels: int = 256
|
||||||
|
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||||
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
def get_backbone(
|
def get_backbone(
|
||||||
config: ModelConfig,
|
config: Optional[ModelConfig] = None,
|
||||||
input_height: int = 128,
|
|
||||||
) -> BackboneModel:
|
) -> BackboneModel:
|
||||||
|
config = config or ModelConfig()
|
||||||
|
|
||||||
if config.name == ModelType.Net2DFast:
|
if config.name == ModelType.Net2DFast:
|
||||||
return Net2DFast(
|
return Net2DFast(
|
||||||
input_height=input_height,
|
input_height=config.input_height,
|
||||||
num_features=config.num_features,
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
)
|
)
|
||||||
elif config.name == ModelType.Net2DFastNoAttn:
|
|
||||||
|
if config.name == ModelType.Net2DFastNoAttn:
|
||||||
return Net2DFastNoAttn(
|
return Net2DFastNoAttn(
|
||||||
num_features=config.num_features,
|
input_height=config.input_height,
|
||||||
input_height=input_height,
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
)
|
)
|
||||||
elif config.name == ModelType.Net2DFastNoCoordConv:
|
|
||||||
|
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||||
return Net2DFastNoCoordConv(
|
return Net2DFastNoCoordConv(
|
||||||
num_features=config.num_features,
|
input_height=config.input_height,
|
||||||
input_height=input_height,
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model type: {config.name}")
|
if config.name == ModelType.Net2DPlain:
|
||||||
|
return Net2DPlain(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown model type: {config.name}")
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
from typing import Tuple
|
from typing import Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fft
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvBlockDownCoordF,
|
ConvBlock,
|
||||||
ConvBlockDownStandard,
|
Decoder,
|
||||||
ConvBlockUpF,
|
DownscalingLayer,
|
||||||
ConvBlockUpStandard,
|
Encoder,
|
||||||
SelfAttention,
|
SelfAttention,
|
||||||
|
UpscalingLayer,
|
||||||
|
VerticalConv,
|
||||||
)
|
)
|
||||||
from batdetect2.models.typing import BackboneModel
|
from batdetect2.models.typing import BackboneModel
|
||||||
|
|
||||||
@ -21,303 +21,164 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Net2DFast(BackboneModel):
|
class Net2DPlain(BackboneModel):
|
||||||
|
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
|
||||||
|
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
|
||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
|
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||||
|
bottleneck_channels: int = 256,
|
||||||
|
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||||
|
out_channels: int = 32,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_features = num_features
|
|
||||||
self.input_height = input_height
|
|
||||||
self.bottleneck_height = self.input_height // 32
|
|
||||||
|
|
||||||
# encoder
|
self.input_height = input_height
|
||||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
self.encoder_channels = tuple(encoder_channels)
|
||||||
1,
|
self.decoder_channels = tuple(decoder_channels)
|
||||||
self.num_features // 4,
|
self.out_channels = out_channels
|
||||||
self.input_height,
|
|
||||||
kernel_size=3,
|
if len(encoder_channels) != len(decoder_channels):
|
||||||
pad_size=1,
|
raise ValueError(
|
||||||
stride=1,
|
f"Mismatched encoder and decoder channel lists. "
|
||||||
|
f"The encoder has {len(encoder_channels)} channels "
|
||||||
|
f"(implying {len(encoder_channels) - 1} layers), "
|
||||||
|
f"while the decoder has {len(decoder_channels)} channels "
|
||||||
|
f"(implying {len(decoder_channels) - 1} layers). "
|
||||||
|
f"These lengths must be equal."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.divide_factor = 2 ** (len(encoder_channels) - 1)
|
||||||
|
if self.input_height % self.divide_factor != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input height ({self.input_height}) must be divisible by "
|
||||||
|
f"the divide factor ({self.divide_factor}). "
|
||||||
|
f"This ensures proper upscaling after downscaling to recover "
|
||||||
|
f"the original input height."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
channels=encoder_channels,
|
||||||
|
input_height=self.input_height,
|
||||||
|
layer_type=self.downscaling_layer_type,
|
||||||
)
|
)
|
||||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
|
||||||
self.num_features // 4,
|
self.conv_same_1 = ConvBlock(
|
||||||
self.num_features // 2,
|
in_channels=encoder_channels[-1],
|
||||||
self.input_height // 2,
|
out_channels=bottleneck_channels,
|
||||||
kernel_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
)
|
||||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features,
|
|
||||||
self.input_height // 4,
|
|
||||||
kernel_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
|
||||||
self.num_features,
|
|
||||||
self.num_features * 2,
|
|
||||||
3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
# bottleneck
|
# bottleneck
|
||||||
self.conv_1d = nn.Conv2d(
|
self.conv_vert = VerticalConv(
|
||||||
self.num_features * 2,
|
in_channels=bottleneck_channels,
|
||||||
self.num_features * 2,
|
out_channels=bottleneck_channels,
|
||||||
(self.input_height // 8, 1),
|
input_height=self.input_height // (2**self.encoder.depth),
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self.conv_up_2 = ConvBlockUpF(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 8,
|
|
||||||
)
|
|
||||||
self.conv_up_3 = ConvBlockUpF(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 4,
|
|
||||||
)
|
|
||||||
self.conv_up_4 = ConvBlockUpF(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 2,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
self.decoder = Decoder(
|
||||||
self.num_features // 4,
|
channels=decoder_channels,
|
||||||
self.num_features // 4,
|
input_height=self.input_height,
|
||||||
kernel_size=3,
|
layer_type=self.upscaling_layer_type,
|
||||||
padding=1,
|
|
||||||
)
|
)
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
|
||||||
|
|
||||||
self.out_channels = self.num_features // 4
|
self.conv_same_2 = ConvBlock(
|
||||||
|
in_channels=decoder_channels[-1],
|
||||||
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
out_channels=out_channels,
|
||||||
h, w = spec.shape[2:]
|
)
|
||||||
h_pad = (32 - h % 32) % 32
|
|
||||||
w_pad = (32 - w % 32) % 32
|
|
||||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
# encoder
|
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||||
spec, h_pad, w_pad = self.pad_adjust(spec)
|
|
||||||
|
|
||||||
x1 = self.conv_dn_0(spec)
|
# encoder
|
||||||
x2 = self.conv_dn_1(x1)
|
residuals = self.encoder(spec)
|
||||||
x3 = self.conv_dn_2(x2)
|
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
|
||||||
|
|
||||||
# bottleneck
|
# bottleneck
|
||||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
x = self.conv_vert(residuals[-1])
|
||||||
x = self.att(x)
|
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
x = self.conv_up_2(x + x3)
|
x = self.decoder(x, residuals=residuals)
|
||||||
x = self.conv_up_3(x + x2)
|
|
||||||
x = self.conv_up_4(x + x1)
|
|
||||||
|
|
||||||
# Restore original size
|
# Restore original size
|
||||||
if h_pad > 0:
|
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||||
x = x[:, :, :-h_pad, :]
|
|
||||||
|
|
||||||
if w_pad > 0:
|
return self.conv_same_2(x)
|
||||||
x = x[:, :, :, :-w_pad]
|
|
||||||
|
|
||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoAttn(BackboneModel):
|
class Net2DFast(Net2DPlain):
|
||||||
|
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||||
|
upscaling_layer_type = "ConvBlockUpF"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
|
||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
|
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||||
|
bottleneck_channels: int = 256,
|
||||||
|
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||||
|
out_channels: int = 32,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
input_height=input_height,
|
||||||
self.num_features = num_features
|
encoder_channels=encoder_channels,
|
||||||
self.input_height = input_height
|
bottleneck_channels=bottleneck_channels,
|
||||||
self.bottleneck_height = self.input_height // 32
|
decoder_channels=decoder_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
|
||||||
1,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height,
|
|
||||||
kernel_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 2,
|
|
||||||
kernel_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features,
|
|
||||||
self.input_height // 4,
|
|
||||||
kernel_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
|
||||||
self.num_features,
|
|
||||||
self.num_features * 2,
|
|
||||||
3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
self.conv_1d = nn.Conv2d(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features * 2,
|
|
||||||
(self.input_height // 8, 1),
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
self.conv_up_2 = ConvBlockUpF(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 8,
|
|
||||||
)
|
|
||||||
self.conv_up_3 = ConvBlockUpF(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 4,
|
|
||||||
)
|
|
||||||
self.conv_up_4 = ConvBlockUpF(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 2,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
|
||||||
self.out_channels = self.num_features // 4
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = self.conv_dn_0(spec)
|
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||||
x2 = self.conv_dn_1(x1)
|
|
||||||
x3 = self.conv_dn_2(x2)
|
|
||||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
|
||||||
|
|
||||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
# encoder
|
||||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
residuals = self.encoder(spec)
|
||||||
|
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||||
|
|
||||||
x = self.conv_up_2(x + x3)
|
# bottleneck
|
||||||
x = self.conv_up_3(x + x2)
|
x = self.conv_vert(residuals[-1])
|
||||||
x = self.conv_up_4(x + x1)
|
|
||||||
|
|
||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoCoordConv(BackboneModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_features: int,
|
|
||||||
input_height: int = 128,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.num_features = num_features
|
|
||||||
self.input_height = input_height
|
|
||||||
self.bottleneck_height = self.input_height // 32
|
|
||||||
|
|
||||||
self.conv_dn_0 = ConvBlockDownStandard(
|
|
||||||
1,
|
|
||||||
self.num_features // 4,
|
|
||||||
kernel_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_1 = ConvBlockDownStandard(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 2,
|
|
||||||
kernel_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_2 = ConvBlockDownStandard(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features,
|
|
||||||
kernel_size=3,
|
|
||||||
pad_size=1,
|
|
||||||
stride=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3 = nn.Conv2d(
|
|
||||||
self.num_features,
|
|
||||||
self.num_features * 2,
|
|
||||||
3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
self.conv_1d = nn.Conv2d(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features * 2,
|
|
||||||
(self.input_height // 8, 1),
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_features * 2)
|
|
||||||
|
|
||||||
self.att = SelfAttention(self.num_features * 2, self.num_features * 2)
|
|
||||||
|
|
||||||
self.conv_up_2 = ConvBlockUpStandard(
|
|
||||||
self.num_features * 2,
|
|
||||||
self.num_features // 2,
|
|
||||||
self.input_height // 8,
|
|
||||||
)
|
|
||||||
self.conv_up_3 = ConvBlockUpStandard(
|
|
||||||
self.num_features // 2,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 4,
|
|
||||||
)
|
|
||||||
self.conv_up_4 = ConvBlockUpStandard(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
self.input_height // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_op = nn.Conv2d(
|
|
||||||
self.num_features // 4,
|
|
||||||
self.num_features // 4,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
)
|
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
|
||||||
self.out_channels = self.num_features // 4
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
x1 = self.conv_dn_0(spec)
|
|
||||||
x2 = self.conv_dn_1(x1)
|
|
||||||
x3 = self.conv_dn_2(x2)
|
|
||||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
|
||||||
|
|
||||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
|
||||||
x = self.att(x)
|
x = self.att(x)
|
||||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||||
|
|
||||||
x = self.conv_up_2(x + x3)
|
# decoder
|
||||||
x = self.conv_up_3(x + x2)
|
x = self.decoder(x, residuals=residuals)
|
||||||
x = self.conv_up_4(x + x1)
|
|
||||||
|
|
||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
# Restore original size
|
||||||
|
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||||
|
|
||||||
|
return self.conv_same_2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DFastNoAttn(Net2DPlain):
|
||||||
|
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||||
|
upscaling_layer_type = "ConvBlockUpF"
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DFastNoCoordConv(Net2DFast):
|
||||||
|
downscaling_layer_type = "ConvBlockDownStandard"
|
||||||
|
upscaling_layer_type = "ConvBlockUpStandard"
|
||||||
|
|
||||||
|
|
||||||
|
def pad_adjust(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
factor: int = 32,
|
||||||
|
) -> Tuple[torch.Tensor, int, int]:
|
||||||
|
h, w = spec.shape[2:]
|
||||||
|
h_pad = -h % factor
|
||||||
|
w_pad = -w % factor
|
||||||
|
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||||
|
|
||||||
|
|
||||||
|
def restore_pad(
|
||||||
|
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Restore original size
|
||||||
|
if h_pad > 0:
|
||||||
|
x = x[:, :, :-h_pad, :]
|
||||||
|
|
||||||
|
if w_pad > 0:
|
||||||
|
x = x[:, :, :, :-w_pad]
|
||||||
|
|
||||||
|
return x
|
||||||
|
@ -4,29 +4,35 @@ All these classes are subclasses of `torch.nn.Module` and can be used to build
|
|||||||
complex neural network architectures.
|
complex neural network architectures.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
import sys
|
||||||
|
from typing import Iterable, List, Literal, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
if sys.version_info >= (3, 10):
|
||||||
|
from itertools import pairwise
|
||||||
|
else:
|
||||||
|
|
||||||
|
def pairwise(iterable: Sequence) -> Iterable:
|
||||||
|
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SelfAttention",
|
"ConvBlock",
|
||||||
"ConvBlockDownCoordF",
|
"ConvBlockDownCoordF",
|
||||||
"ConvBlockDownStandard",
|
"ConvBlockDownStandard",
|
||||||
"ConvBlockUpF",
|
"ConvBlockUpF",
|
||||||
"ConvBlockUpStandard",
|
"ConvBlockUpStandard",
|
||||||
|
"SelfAttention",
|
||||||
|
"VerticalConv",
|
||||||
|
"DownscalingLayer",
|
||||||
|
"UpscalingLayer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionConfig(BaseConfig):
|
|
||||||
temperature: float = 1.0
|
|
||||||
input_channels: int = 128
|
|
||||||
attention_channels: int = 128
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""Self-Attention module.
|
"""Self-Attention module.
|
||||||
|
|
||||||
@ -76,13 +82,64 @@ class SelfAttention(nn.Module):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownCoordFConfig(BaseConfig):
|
class ConvBlock(nn.Module):
|
||||||
in_channels: int
|
def __init__(
|
||||||
out_channels: int
|
self,
|
||||||
input_height: int
|
in_channels: int,
|
||||||
kernel_size: int = 3
|
out_channels: int,
|
||||||
pad_size: int = 1
|
kernel_size: int = 3,
|
||||||
stride: int = 1
|
pad_size: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=pad_size,
|
||||||
|
)
|
||||||
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.relu_(self.conv_bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class VerticalConv(nn.Module):
|
||||||
|
"""Convolutional layer over full height.
|
||||||
|
|
||||||
|
This layer applies a convolution that captures information across the
|
||||||
|
entire height of the input image. It uses a kernel with the same height as
|
||||||
|
the input, effectively condensing the vertical information into a single
|
||||||
|
output row.
|
||||||
|
|
||||||
|
More specifically:
|
||||||
|
|
||||||
|
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
|
||||||
|
input channels, H is the image height, and W is the image width.
|
||||||
|
* **Kernel:** (C', H, 1) where C' is the number of output channels.
|
||||||
|
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
|
||||||
|
convolution integrates information from all rows of the input.
|
||||||
|
|
||||||
|
This process effectively extracts features that span the full height of
|
||||||
|
the input image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=(input_height, 1),
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.relu_(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownCoordF(nn.Module):
|
class ConvBlockDownCoordF(nn.Module):
|
||||||
@ -124,14 +181,6 @@ class ConvBlockDownCoordF(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownStandardConfig(BaseConfig):
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
kernel_size: int = 3
|
|
||||||
pad_size: int = 1
|
|
||||||
stride: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownStandard(nn.Module):
|
class ConvBlockDownStandard(nn.Module):
|
||||||
"""Convolutional Block with Downsampling.
|
"""Convolutional Block with Downsampling.
|
||||||
|
|
||||||
@ -158,18 +207,10 @@ class ConvBlockDownStandard(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
x = F.relu(self.conv_bn(x), inplace=True)
|
return F.relu(self.conv_bn(x), inplace=True)
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockUpFConfig(BaseConfig):
|
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
|
||||||
inp_channels: int
|
|
||||||
out_channels: int
|
|
||||||
input_height: int
|
|
||||||
kernel_size: int = 3
|
|
||||||
pad_size: int = 1
|
|
||||||
up_mode: str = "bilinear"
|
|
||||||
up_scale: Tuple[int, int] = (2, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockUpF(nn.Module):
|
class ConvBlockUpF(nn.Module):
|
||||||
@ -224,15 +265,6 @@ class ConvBlockUpF(nn.Module):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockUpStandardConfig(BaseConfig):
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
kernel_size: int = 3
|
|
||||||
pad_size: int = 1
|
|
||||||
up_mode: str = "bilinear"
|
|
||||||
up_scale: Tuple[int, int] = (2, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockUpStandard(nn.Module):
|
class ConvBlockUpStandard(nn.Module):
|
||||||
"""Convolutional Block with Upsampling.
|
"""Convolutional Block with Upsampling.
|
||||||
|
|
||||||
@ -272,3 +304,143 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
op = self.conv(op)
|
op = self.conv(op)
|
||||||
op = F.relu(self.conv_bn(op), inplace=True)
|
op = F.relu(self.conv_bn(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
|
||||||
|
|
||||||
|
|
||||||
|
def build_downscaling_layer(
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
layer_type: DownscalingLayer,
|
||||||
|
) -> nn.Module:
|
||||||
|
if layer_type == "ConvBlockDownStandard":
|
||||||
|
return ConvBlockDownStandard(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_type == "ConvBlockDownCoordF":
|
||||||
|
return ConvBlockDownCoordF(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid downscaling layer type {layer_type}. "
|
||||||
|
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: Sequence[int] = (1, 32, 62, 128),
|
||||||
|
input_height: int = 128,
|
||||||
|
layer_type: Literal[
|
||||||
|
"ConvBlockDownStandard", "ConvBlockDownCoordF"
|
||||||
|
] = "ConvBlockDownStandard",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
self.input_height = input_height
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
build_downscaling_layer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height // (2**layer_num),
|
||||||
|
layer_type=layer_type,
|
||||||
|
)
|
||||||
|
for layer_num, (in_channels, out_channels) in enumerate(
|
||||||
|
pairwise(channels)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.depth = len(self.layers)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
outputs.append(x)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def build_upscaling_layer(
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
layer_type: UpscalingLayer,
|
||||||
|
) -> nn.Module:
|
||||||
|
if layer_type == "ConvBlockUpStandard":
|
||||||
|
return ConvBlockUpStandard(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_type == "ConvBlockUpF":
|
||||||
|
return ConvBlockUpF(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid upscaling layer type {layer_type}. "
|
||||||
|
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: Sequence[int] = (256, 62, 32, 32),
|
||||||
|
input_height: int = 128,
|
||||||
|
layer_type: Literal[
|
||||||
|
"ConvBlockUpStandard", "ConvBlockUpF"
|
||||||
|
] = "ConvBlockUpStandard",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
self.input_height = input_height
|
||||||
|
self.depth = len(self.channels) - 1
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
build_upscaling_layer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height
|
||||||
|
// (2 ** (self.depth - layer_num)),
|
||||||
|
layer_type=layer_type,
|
||||||
|
)
|
||||||
|
for layer_num, (in_channels, out_channels) in enumerate(
|
||||||
|
pairwise(channels)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residuals: List[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if len(residuals) != len(self.layers):
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect number of residuals provided. "
|
||||||
|
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||||
|
f"but got {len(residuals)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer, res in zip(self.layers, residuals[::-1]):
|
||||||
|
x = layer(x + res)
|
||||||
|
|
||||||
|
return x
|
||||||
|
17
batdetect2/models/decoder.py
Normal file
17
batdetect2/models/decoder.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import sys
|
||||||
|
from typing import Iterable, List, Literal, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from itertools import pairwise
|
||||||
|
else:
|
||||||
|
|
||||||
|
def pairwise(iterable: Sequence) -> Iterable:
|
||||||
|
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
|
17
batdetect2/models/encoder.py
Normal file
17
batdetect2/models/encoder.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import sys
|
||||||
|
from typing import Iterable, List, Literal, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from itertools import pairwise
|
||||||
|
else:
|
||||||
|
|
||||||
|
def pairwise(iterable: Sequence) -> Iterable:
|
||||||
|
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -45,15 +45,27 @@ class BackboneModel(ABC, nn.Module):
|
|||||||
input_height: int
|
input_height: int
|
||||||
"""Height of the input spectrogram."""
|
"""Height of the input spectrogram."""
|
||||||
|
|
||||||
num_features: int
|
encoder_channels: Tuple[int, ...]
|
||||||
"""Dimension of the feature tensor."""
|
"""Tuple specifying the number of channels for each convolutional layer
|
||||||
|
in the encoder. The length of this tuple determines the number of
|
||||||
|
encoder layers."""
|
||||||
|
|
||||||
|
decoder_channels: Tuple[int, ...]
|
||||||
|
"""Tuple specifying the number of channels for each convolutional layer
|
||||||
|
in the decoder. The length of this tuple determines the number of
|
||||||
|
decoder layers."""
|
||||||
|
|
||||||
|
bottleneck_channels: int
|
||||||
|
"""Number of channels in the bottleneck layer, which connects the
|
||||||
|
encoder and decoder."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
"""Number of output channels of the feature extractor."""
|
"""Number of channels in the final output feature map produced by the
|
||||||
|
backbone model."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
"""Forward pass of the encoder model."""
|
"""Forward pass of the model."""
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, nn.Module):
|
class DetectionModel(ABC, nn.Module):
|
||||||
|
@ -13,7 +13,7 @@ from batdetect2.preprocess.audio import (
|
|||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
AmplitudeScaleConfig,
|
AmplitudeScaleConfig,
|
||||||
FFTConfig,
|
STFTConfig,
|
||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
LogScaleConfig,
|
LogScaleConfig,
|
||||||
PcenScaleConfig,
|
PcenScaleConfig,
|
||||||
@ -27,7 +27,7 @@ __all__ = [
|
|||||||
"AudioConfig",
|
"AudioConfig",
|
||||||
"ResampleConfig",
|
"ResampleConfig",
|
||||||
"SpectrogramConfig",
|
"SpectrogramConfig",
|
||||||
"FFTConfig",
|
"STFTConfig",
|
||||||
"FrequencyConfig",
|
"FrequencyConfig",
|
||||||
"PcenScaleConfig",
|
"PcenScaleConfig",
|
||||||
"LogScaleConfig",
|
"LogScaleConfig",
|
||||||
|
61
batdetect2/preprocess/arrays.py
Normal file
61
batdetect2/preprocess/arrays.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def extend_width(
|
||||||
|
array: np.ndarray,
|
||||||
|
extra: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
dims = len(array.shape)
|
||||||
|
axis = axis % dims
|
||||||
|
pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)]
|
||||||
|
return np.pad(
|
||||||
|
array,
|
||||||
|
pad,
|
||||||
|
mode="constant",
|
||||||
|
constant_values=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_width_divisible(
|
||||||
|
array: np.ndarray,
|
||||||
|
factor: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
width = array.shape[axis]
|
||||||
|
|
||||||
|
if width % factor == 0:
|
||||||
|
return array
|
||||||
|
|
||||||
|
extra = (-width) % factor
|
||||||
|
return extend_width(array, extra, axis=axis, value=value)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_width(
|
||||||
|
array: np.ndarray,
|
||||||
|
width: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
dims = len(array.shape)
|
||||||
|
axis = axis % dims
|
||||||
|
current_width = array.shape[axis]
|
||||||
|
|
||||||
|
if current_width == width:
|
||||||
|
return array
|
||||||
|
|
||||||
|
if current_width < width:
|
||||||
|
return extend_width(
|
||||||
|
array,
|
||||||
|
extra=width - current_width,
|
||||||
|
axis=axis,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
slices = [
|
||||||
|
slice(None, None) if index != axis else slice(None, width)
|
||||||
|
for index in range(dims)
|
||||||
|
]
|
||||||
|
return array[tuple(slices)]
|
@ -12,7 +12,7 @@ from soundevent.arrays import operations as ops
|
|||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
class FFTConfig(BaseConfig):
|
class STFTConfig(BaseConfig):
|
||||||
window_duration: float = Field(default=0.002, gt=0)
|
window_duration: float = Field(default=0.002, gt=0)
|
||||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||||
window_fn: str = "hann"
|
window_fn: str = "hann"
|
||||||
@ -24,9 +24,15 @@ class FrequencyConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class SpecSizeConfig(BaseConfig):
|
class SpecSizeConfig(BaseConfig):
|
||||||
height: int = 256
|
height: int = 128
|
||||||
|
"""Height of the spectrogram in pixels. This value determines the
|
||||||
|
number of frequency bands and corresponds to the vertical dimension
|
||||||
|
of the spectrogram."""
|
||||||
|
|
||||||
resize_factor: Optional[float] = 0.5
|
resize_factor: Optional[float] = 0.5
|
||||||
divide_factor: Optional[int] = 32
|
"""Factor by which to resize the spectrogram along the time axis.
|
||||||
|
A value of 0.5 reduces the temporal dimension by half, while a
|
||||||
|
value of 2.0 doubles it. If None, no resizing is performed."""
|
||||||
|
|
||||||
|
|
||||||
class LogScaleConfig(BaseConfig):
|
class LogScaleConfig(BaseConfig):
|
||||||
@ -50,13 +56,13 @@ Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
|
|||||||
|
|
||||||
|
|
||||||
class SpectrogramConfig(BaseConfig):
|
class SpectrogramConfig(BaseConfig):
|
||||||
fft: FFTConfig = Field(default_factory=FFTConfig)
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
scale: Scales = Field(
|
scale: Scales = Field(
|
||||||
default_factory=PcenScaleConfig,
|
default_factory=PcenScaleConfig,
|
||||||
discriminator="name",
|
discriminator="name",
|
||||||
)
|
)
|
||||||
size: SpecSizeConfig = Field(default_factory=SpecSizeConfig)
|
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||||
denoise: bool = True
|
denoise: bool = True
|
||||||
max_scale: bool = False
|
max_scale: bool = False
|
||||||
|
|
||||||
@ -68,22 +74,11 @@ def compute_spectrogram(
|
|||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
config = config or SpectrogramConfig()
|
config = config or SpectrogramConfig()
|
||||||
|
|
||||||
if config.size.divide_factor:
|
|
||||||
# Need to pad the audio to make sure the spectrogram has a
|
|
||||||
# width compatible with the divide factor
|
|
||||||
resize_factor = config.size.resize_factor or 1
|
|
||||||
wav = pad_audio(
|
|
||||||
wav,
|
|
||||||
window_duration=config.fft.window_duration,
|
|
||||||
window_overlap=config.fft.window_overlap,
|
|
||||||
divide_factor=int(config.size.divide_factor / resize_factor),
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = stft(
|
spec = stft(
|
||||||
wav,
|
wav,
|
||||||
window_duration=config.fft.window_duration,
|
window_duration=config.stft.window_duration,
|
||||||
window_overlap=config.fft.window_overlap,
|
window_overlap=config.stft.window_overlap,
|
||||||
window_fn=config.fft.window_fn,
|
window_fn=config.stft.window_fn,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -98,11 +93,12 @@ def compute_spectrogram(
|
|||||||
if config.denoise:
|
if config.denoise:
|
||||||
spec = denoise_spectrogram(spec)
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
spec = resize_spectrogram(
|
if config.size:
|
||||||
spec,
|
spec = resize_spectrogram(
|
||||||
height=config.size.height,
|
spec,
|
||||||
resize_factor=config.size.resize_factor,
|
height=config.size.height,
|
||||||
)
|
resize_factor=config.size.resize_factor,
|
||||||
|
)
|
||||||
|
|
||||||
if config.max_scale:
|
if config.max_scale:
|
||||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||||
@ -257,7 +253,7 @@ def resize_spectrogram(
|
|||||||
return ops.resize(
|
return ops.resize(
|
||||||
spec,
|
spec,
|
||||||
time=int(resize_factor * current_width),
|
time=int(resize_factor * current_width),
|
||||||
frequency=int(resize_factor * height),
|
frequency=height,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -285,43 +281,6 @@ def adjust_spectrogram_width(
|
|||||||
return resized
|
return resized
|
||||||
|
|
||||||
|
|
||||||
def pad_audio(
|
|
||||||
wave: xr.DataArray,
|
|
||||||
window_duration: float,
|
|
||||||
window_overlap: float,
|
|
||||||
divide_factor: int = 32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
current_duration = arrays.get_dim_width(wave, dim="time")
|
|
||||||
step = arrays.get_dim_step(wave, dim="time")
|
|
||||||
samplerate = int(1 / step)
|
|
||||||
|
|
||||||
estimated_spec_width = duration_to_spec_width(
|
|
||||||
current_duration,
|
|
||||||
samplerate=samplerate,
|
|
||||||
window_duration=window_duration,
|
|
||||||
window_overlap=window_overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
if estimated_spec_width % divide_factor == 0:
|
|
||||||
return wave
|
|
||||||
|
|
||||||
target_spec_width = int(
|
|
||||||
np.ceil(estimated_spec_width / divide_factor) * divide_factor
|
|
||||||
)
|
|
||||||
target_samples = spec_width_to_samples(
|
|
||||||
target_spec_width,
|
|
||||||
samplerate=samplerate,
|
|
||||||
window_duration=window_duration,
|
|
||||||
window_overlap=window_overlap,
|
|
||||||
)
|
|
||||||
return ops.adjust_dim_width(
|
|
||||||
wave,
|
|
||||||
dim="time",
|
|
||||||
width=target_samples,
|
|
||||||
position="start",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def duration_to_spec_width(
|
def duration_to_spec_width(
|
||||||
duration: float,
|
duration: float,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
@ -357,6 +316,8 @@ def get_spectrogram_resolution(
|
|||||||
|
|
||||||
spec_height = config.size.height
|
spec_height = config.size.height
|
||||||
resize_factor = config.size.resize_factor or 1
|
resize_factor = config.size.resize_factor or 1
|
||||||
freq_bin_width = (max_freq - min_freq) / (spec_height * resize_factor)
|
freq_bin_width = (max_freq - min_freq) / spec_height
|
||||||
hop_duration = config.fft.window_duration * (1 - config.fft.window_overlap)
|
hop_duration = config.stft.window_duration * (
|
||||||
|
1 - config.stft.window_overlap
|
||||||
|
)
|
||||||
return freq_bin_width, hop_duration / resize_factor
|
return freq_bin_width, hop_duration / resize_factor
|
||||||
|
76
batdetect2/preprocess/tensors.py
Normal file
76
batdetect2/preprocess/tensors.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def extend_width(
|
||||||
|
array: Union[np.ndarray, torch.Tensor],
|
||||||
|
extra: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not isinstance(array, torch.Tensor):
|
||||||
|
array = torch.Tensor(array)
|
||||||
|
|
||||||
|
dims = len(array.shape)
|
||||||
|
axis = axis % dims
|
||||||
|
pad = [
|
||||||
|
[0, 0] if index != axis else [0, extra]
|
||||||
|
for index in range(axis, dims)[::-1]
|
||||||
|
]
|
||||||
|
return F.pad(
|
||||||
|
array,
|
||||||
|
[x for y in pad for x in y],
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_width_divisible(
|
||||||
|
array: Union[np.ndarray, torch.Tensor],
|
||||||
|
factor: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not isinstance(array, torch.Tensor):
|
||||||
|
array = torch.Tensor(array)
|
||||||
|
|
||||||
|
width = array.shape[axis]
|
||||||
|
|
||||||
|
if width % factor == 0:
|
||||||
|
return array
|
||||||
|
|
||||||
|
extra = (-width) % factor
|
||||||
|
return extend_width(array, extra, axis=axis, value=value)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_width(
|
||||||
|
array: Union[np.ndarray, torch.Tensor],
|
||||||
|
width: int,
|
||||||
|
axis: int = -1,
|
||||||
|
value: float = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not isinstance(array, torch.Tensor):
|
||||||
|
array = torch.Tensor(array)
|
||||||
|
|
||||||
|
dims = len(array.shape)
|
||||||
|
axis = axis % dims
|
||||||
|
current_width = array.shape[axis]
|
||||||
|
|
||||||
|
if current_width == width:
|
||||||
|
return array
|
||||||
|
|
||||||
|
if current_width < width:
|
||||||
|
return extend_width(
|
||||||
|
array,
|
||||||
|
extra=width - current_width,
|
||||||
|
axis=axis,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
|
||||||
|
slices = [
|
||||||
|
slice(None, None) if index != axis else slice(None, width)
|
||||||
|
for index in range(dims)
|
||||||
|
]
|
||||||
|
return array[tuple(slices)]
|
@ -4,10 +4,10 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import arrays
|
from soundevent import arrays
|
||||||
from soundevent.arrays import operations as ops
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||||
|
from batdetect2.preprocess.arrays import adjust_width
|
||||||
|
|
||||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||||
|
|
||||||
@ -17,47 +17,12 @@ class AugmentationConfig(BaseConfig):
|
|||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
|
|
||||||
|
|
||||||
class SubclipConfig(BaseConfig):
|
def select_subclip(
|
||||||
enable: bool = True
|
|
||||||
duration: Optional[float] = None
|
|
||||||
width: Optional[int] = 512
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_dataset_width(
|
|
||||||
example: xr.Dataset,
|
|
||||||
duration: Optional[float] = None,
|
|
||||||
width: Optional[int] = None,
|
|
||||||
) -> xr.Dataset:
|
|
||||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
|
||||||
|
|
||||||
if width is None:
|
|
||||||
if duration is None:
|
|
||||||
raise ValueError("Either duration or width must be provided")
|
|
||||||
|
|
||||||
width = int(np.floor(duration / step))
|
|
||||||
|
|
||||||
adjusted_arrays = {
|
|
||||||
name: ops.adjust_dim_width(array, "time", width)
|
|
||||||
for name, array in example.items()
|
|
||||||
if name != "audio"
|
|
||||||
}
|
|
||||||
|
|
||||||
ratio = width / example.spectrogram.sizes["time"]
|
|
||||||
audio_width = int(example.audio.sizes["audio_time"] * ratio)
|
|
||||||
adjusted_arrays["audio"] = ops.adjust_dim_width(
|
|
||||||
example["audio"],
|
|
||||||
"audio_time",
|
|
||||||
audio_width,
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.Dataset(data_vars=adjusted_arrays)
|
|
||||||
|
|
||||||
|
|
||||||
def select_random_subclip(
|
|
||||||
example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
start_time: Optional[float] = None,
|
start_time: Optional[float] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
|
random: bool = False,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Select a random subclip from a clip."""
|
"""Select a random subclip from a clip."""
|
||||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||||
@ -73,10 +38,13 @@ def select_random_subclip(
|
|||||||
duration = width * step
|
duration = width * step
|
||||||
|
|
||||||
if start_time is None:
|
if start_time is None:
|
||||||
start_time = np.random.uniform(start, max(stop - duration, start))
|
if random:
|
||||||
|
start_time = np.random.uniform(start, max(stop - duration, start))
|
||||||
|
else:
|
||||||
|
start_time = start
|
||||||
|
|
||||||
if start_time + duration > stop:
|
if start_time + duration > stop:
|
||||||
example = adjust_dataset_width(example, width=width)
|
return example
|
||||||
|
|
||||||
start_index = arrays.get_coord_index(
|
start_index = arrays.get_coord_index(
|
||||||
example, # type: ignore
|
example, # type: ignore
|
||||||
@ -91,7 +59,7 @@ def select_random_subclip(
|
|||||||
|
|
||||||
return example.sel(
|
return example.sel(
|
||||||
time=slice(start_time, end_time),
|
time=slice(start_time, end_time),
|
||||||
audio_time=slice(start_time, end_time),
|
audio_time=slice(start_time, end_time + step),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -115,40 +83,45 @@ def mix_examples(
|
|||||||
weight = np.random.uniform(min_weight, max_weight)
|
weight = np.random.uniform(min_weight, max_weight)
|
||||||
|
|
||||||
audio1 = example["audio"]
|
audio1 = example["audio"]
|
||||||
|
audio2 = adjust_width(other["audio"].values, len(audio1))
|
||||||
audio2 = ops.adjust_dim_width(
|
|
||||||
other["audio"], "audio_time", len(audio1)
|
|
||||||
).values
|
|
||||||
|
|
||||||
if len(audio2) > len(audio1):
|
|
||||||
audio2 = audio2[: len(audio1)]
|
|
||||||
|
|
||||||
combined = weight * audio1 + (1 - weight) * audio2
|
combined = weight * audio1 + (1 - weight) * audio2
|
||||||
|
|
||||||
spec = compute_spectrogram(
|
spectrogram = compute_spectrogram(
|
||||||
combined.rename({"audio_time": "time"}),
|
combined.rename({"audio_time": "time"}),
|
||||||
config=config.spectrogram,
|
config=config.spectrogram,
|
||||||
)
|
).data
|
||||||
|
|
||||||
|
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||||
|
# spectrogram computed from the subclip's audio. This is due to a
|
||||||
|
# simplification in the subclip process: It doesn't account for the
|
||||||
|
# spectrogram parameters to precisely determine the corresponding audio
|
||||||
|
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||||
|
# as needed.
|
||||||
|
previous_width = len(example["time"])
|
||||||
|
spectrogram = adjust_width(spectrogram, previous_width)
|
||||||
|
|
||||||
detection_heatmap = xr.apply_ufunc(
|
detection_heatmap = xr.apply_ufunc(
|
||||||
np.maximum,
|
np.maximum,
|
||||||
example["detection"],
|
example["detection"],
|
||||||
other["detection"].values,
|
adjust_width(other["detection"].values, previous_width),
|
||||||
)
|
)
|
||||||
|
|
||||||
class_heatmap = xr.apply_ufunc(
|
class_heatmap = xr.apply_ufunc(
|
||||||
np.maximum,
|
np.maximum,
|
||||||
example["class"],
|
example["class"],
|
||||||
other["class"].values,
|
adjust_width(other["class"].values, previous_width),
|
||||||
)
|
)
|
||||||
|
|
||||||
size_heatmap = example["size"] + other["size"].values
|
size_heatmap = example["size"] + adjust_width(
|
||||||
|
other["size"].values, previous_width
|
||||||
|
)
|
||||||
|
|
||||||
return xr.Dataset(
|
return xr.Dataset(
|
||||||
{
|
{
|
||||||
"audio": combined,
|
"audio": combined,
|
||||||
"spectrogram": xr.DataArray(
|
"spectrogram": xr.DataArray(
|
||||||
data=spec.data,
|
data=spectrogram,
|
||||||
dims=example["spectrogram"].dims,
|
dims=example["spectrogram"].dims,
|
||||||
coords=example["spectrogram"].coords,
|
coords=example["spectrogram"].coords,
|
||||||
),
|
),
|
||||||
@ -192,12 +165,23 @@ def add_echo(
|
|||||||
spectrogram = compute_spectrogram(
|
spectrogram = compute_spectrogram(
|
||||||
audio.rename({"audio_time": "time"}),
|
audio.rename({"audio_time": "time"}),
|
||||||
config=config.spectrogram,
|
config=config.spectrogram,
|
||||||
|
).data
|
||||||
|
|
||||||
|
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||||
|
# spectrogram computed from the subclip's audio. This is due to a
|
||||||
|
# simplification in the subclip process: It doesn't account for the
|
||||||
|
# spectrogram parameters to precisely determine the corresponding audio
|
||||||
|
# samples. To work around this, we pad the computed spectrogram with zeros
|
||||||
|
# as needed.
|
||||||
|
spectrogram = adjust_width(
|
||||||
|
spectrogram,
|
||||||
|
example["spectrogram"].sizes["time"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return example.assign(
|
return example.assign(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
spectrogram=xr.DataArray(
|
spectrogram=xr.DataArray(
|
||||||
data=spectrogram.data,
|
data=spectrogram,
|
||||||
dims=example["spectrogram"].dims,
|
dims=example["spectrogram"].dims,
|
||||||
coords=example["spectrogram"].coords,
|
coords=example["spectrogram"].coords,
|
||||||
),
|
),
|
||||||
@ -359,7 +343,6 @@ def mask_frequency(
|
|||||||
|
|
||||||
|
|
||||||
class AugmentationsConfig(BaseConfig):
|
class AugmentationsConfig(BaseConfig):
|
||||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
|
||||||
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
||||||
echo: EchoAugmentationConfig = Field(
|
echo: EchoAugmentationConfig = Field(
|
||||||
default_factory=EchoAugmentationConfig
|
default_factory=EchoAugmentationConfig
|
||||||
@ -391,23 +374,8 @@ def augment_example(
|
|||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
others: Optional[Callable[[], xr.Dataset]] = None,
|
others: Optional[Callable[[], xr.Dataset]] = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
if config.subclip.enable:
|
|
||||||
example = select_random_subclip(
|
|
||||||
example,
|
|
||||||
duration=config.subclip.duration,
|
|
||||||
width=config.subclip.width,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_apply(config.mix) and (others is not None):
|
if should_apply(config.mix) and (others is not None):
|
||||||
other = others()
|
other = others()
|
||||||
|
|
||||||
if config.subclip.enable:
|
|
||||||
other = select_random_subclip(
|
|
||||||
other,
|
|
||||||
duration=config.subclip.duration,
|
|
||||||
width=config.subclip.width,
|
|
||||||
)
|
|
||||||
|
|
||||||
example = mix_examples(
|
example = mix_examples(
|
||||||
example,
|
example,
|
||||||
other,
|
other,
|
||||||
|
@ -5,10 +5,17 @@ from typing import NamedTuple, Optional, Sequence, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from batdetect2.train.augmentations import AugmentationsConfig, augment_example
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.preprocess.tensors import adjust_width
|
||||||
|
from batdetect2.train.augmentations import (
|
||||||
|
AugmentationsConfig,
|
||||||
|
augment_example,
|
||||||
|
select_subclip,
|
||||||
|
)
|
||||||
from batdetect2.train.preprocess import PreprocessingConfig
|
from batdetect2.train.preprocess import PreprocessingConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -28,20 +35,36 @@ class TrainExample(NamedTuple):
|
|||||||
idx: torch.Tensor
|
idx: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SubclipConfig(BaseConfig):
|
||||||
|
duration: Optional[float] = None
|
||||||
|
width: int = 512
|
||||||
|
random: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetConfig(BaseConfig):
|
||||||
|
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
augmentation: AugmentationsConfig = Field(
|
||||||
|
default_factory=AugmentationsConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset(Dataset):
|
class LabeledDataset(Dataset):
|
||||||
|
config: DatasetConfig
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filenames: Sequence[PathLike],
|
filenames: Sequence[PathLike],
|
||||||
augment: bool = False,
|
augment: bool = False,
|
||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
subclip: bool = False,
|
||||||
augmentation_config: Optional[AugmentationsConfig] = None,
|
config: Optional[DatasetConfig] = None,
|
||||||
):
|
):
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
self.preprocessing_config = (
|
self.subclip = subclip
|
||||||
preprocessing_config or PreprocessingConfig()
|
self.config = config or DatasetConfig()
|
||||||
)
|
|
||||||
self.agumentation_config = augmentation_config or AugmentationsConfig()
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.filenames)
|
return len(self.filenames)
|
||||||
@ -49,27 +72,27 @@ class LabeledDataset(Dataset):
|
|||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
dataset = self.get_dataset(idx)
|
dataset = self.get_dataset(idx)
|
||||||
|
|
||||||
|
if self.subclip:
|
||||||
|
dataset = select_subclip(
|
||||||
|
dataset,
|
||||||
|
duration=self.config.subclip.duration,
|
||||||
|
width=self.config.subclip.width,
|
||||||
|
random=self.config.subclip.random,
|
||||||
|
)
|
||||||
|
|
||||||
if self.augment:
|
if self.augment:
|
||||||
dataset = augment_example(
|
dataset = augment_example(
|
||||||
dataset,
|
dataset,
|
||||||
self.agumentation_config,
|
self.config.augmentation,
|
||||||
preprocessing_config=self.preprocessing_config,
|
preprocessing_config=self.config.preprocessing,
|
||||||
others=self.get_random_example,
|
others=self.get_random_example,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TrainExample(
|
return TrainExample(
|
||||||
spec=torch.tensor(
|
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||||
dataset["spectrogram"].values.astype(np.float32)
|
detection_heatmap=self.to_tensor(dataset["detection"]),
|
||||||
).unsqueeze(0),
|
class_heatmap=self.to_tensor(dataset["class"]),
|
||||||
detection_heatmap=torch.tensor(
|
size_heatmap=self.to_tensor(dataset["size"]),
|
||||||
dataset["detection"].values.astype(np.float32)
|
|
||||||
),
|
|
||||||
class_heatmap=torch.tensor(
|
|
||||||
dataset["class"].values.astype(np.float32)
|
|
||||||
),
|
|
||||||
size_heatmap=torch.tensor(
|
|
||||||
dataset["size"].values.astype(np.float32)
|
|
||||||
),
|
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,20 +101,30 @@ class LabeledDataset(Dataset):
|
|||||||
cls,
|
cls,
|
||||||
directory: PathLike,
|
directory: PathLike,
|
||||||
extension: str = ".nc",
|
extension: str = ".nc",
|
||||||
|
config: Optional[DatasetConfig] = None,
|
||||||
augment: bool = False,
|
augment: bool = False,
|
||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
subclip: bool = False,
|
||||||
augmentation_config: Optional[AugmentationsConfig] = None,
|
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
get_files(directory, extension),
|
get_files(directory, extension),
|
||||||
|
config=config,
|
||||||
augment=augment,
|
augment=augment,
|
||||||
preprocessing_config=preprocessing_config,
|
subclip=subclip,
|
||||||
augmentation_config=augmentation_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_random_example(self) -> xr.Dataset:
|
def get_random_example(self) -> xr.Dataset:
|
||||||
idx = np.random.randint(0, len(self))
|
idx = np.random.randint(0, len(self))
|
||||||
return self.get_dataset(idx)
|
dataset = self.get_dataset(idx)
|
||||||
|
|
||||||
|
if self.subclip:
|
||||||
|
dataset = select_subclip(
|
||||||
|
dataset,
|
||||||
|
duration=self.config.subclip.duration,
|
||||||
|
width=self.config.subclip.width,
|
||||||
|
random=self.config.subclip.random,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
def get_dataset(self, idx) -> xr.Dataset:
|
def get_dataset(self, idx) -> xr.Dataset:
|
||||||
return xr.open_dataset(self.filenames[idx])
|
return xr.open_dataset(self.filenames[idx])
|
||||||
@ -101,6 +134,19 @@ class LabeledDataset(Dataset):
|
|||||||
self.get_dataset(idx).attrs["clip_annotation"]
|
self.get_dataset(idx).attrs["clip_annotation"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_tensor(
|
||||||
|
self,
|
||||||
|
array: xr.DataArray,
|
||||||
|
dtype=np.float32,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tensor = torch.tensor(array.values.astype(dtype))
|
||||||
|
|
||||||
|
if not self.subclip:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
width = self.config.subclip.width
|
||||||
|
return adjust_width(tensor, width)
|
||||||
|
|
||||||
|
|
||||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
@ -1,16 +1,56 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.metrics import (
|
from sklearn.metrics import auc, roc_curve
|
||||||
accuracy_score,
|
from soundevent import data
|
||||||
auc,
|
from soundevent.evaluation import match_geometries
|
||||||
balanced_accuracy_score,
|
|
||||||
roc_curve,
|
|
||||||
)
|
def match_predictions_and_annotations(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
clip_prediction: data.ClipPrediction,
|
||||||
|
) -> List[data.Match]:
|
||||||
|
annotated_sound_events = [
|
||||||
|
sound_event_annotation
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_sound_events = [
|
||||||
|
sound_event_prediction
|
||||||
|
for sound_event_prediction in clip_prediction.sound_events
|
||||||
|
if sound_event_prediction.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
annotated_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in annotated_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in predicted_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for id1, id2, affinity in match_geometries(
|
||||||
|
annotated_geometries,
|
||||||
|
predicted_geometries,
|
||||||
|
):
|
||||||
|
target = annotated_sound_events[id1] if id1 is not None else None
|
||||||
|
source = predicted_sound_events[id2] if id2 is not None else None
|
||||||
|
matches.append(
|
||||||
|
data.Match(source=source, target=target, affinity=affinity)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
def compute_error_auc(op_str, gt, pred, prob):
|
def compute_error_auc(op_str, gt, pred, prob):
|
||||||
|
|
||||||
# classification error
|
# classification error
|
||||||
pred_int = (pred > prob).astype(np.int)
|
pred_int = (pred > prob).astype(np.int32)
|
||||||
class_acc = (pred_int == gt).mean() * 100.0
|
class_acc = (pred_int == gt).mean() * 100.0
|
||||||
|
|
||||||
# ROC - area under curve
|
# ROC - area under curve
|
||||||
@ -25,7 +65,6 @@ def compute_error_auc(op_str, gt, pred, prob):
|
|||||||
|
|
||||||
|
|
||||||
def calc_average_precision(recall, precision):
|
def calc_average_precision(recall, precision):
|
||||||
|
|
||||||
precision[np.isnan(precision)] = 0
|
precision[np.isnan(precision)] = 0
|
||||||
recall[np.isnan(recall)] = 0
|
recall[np.isnan(recall)] = 0
|
||||||
|
|
||||||
@ -91,7 +130,6 @@ def compute_pre_rec(
|
|||||||
pred_class = []
|
pred_class = []
|
||||||
file_ids = []
|
file_ids = []
|
||||||
for pid, pp in enumerate(preds):
|
for pid, pp in enumerate(preds):
|
||||||
|
|
||||||
# filter predicted calls that are too near the start or end of the file
|
# filter predicted calls that are too near the start or end of the file
|
||||||
file_dur = gts[pid]["duration"]
|
file_dur = gts[pid]["duration"]
|
||||||
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
||||||
@ -141,7 +179,6 @@ def compute_pre_rec(
|
|||||||
gt_generic_class = []
|
gt_generic_class = []
|
||||||
num_positives = 0
|
num_positives = 0
|
||||||
for gg in gts:
|
for gg in gts:
|
||||||
|
|
||||||
# filter ground truth calls that are too near the start or end of the file
|
# filter ground truth calls that are too near the start or end of the file
|
||||||
file_dur = gg["duration"]
|
file_dur = gg["duration"]
|
||||||
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
||||||
@ -205,7 +242,6 @@ def compute_pre_rec(
|
|||||||
|
|
||||||
# valid detection that has not already been assigned
|
# valid detection that has not already been assigned
|
||||||
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
||||||
|
|
||||||
count_as_true_pos = True
|
count_as_true_pos = True
|
||||||
if eval_mode == "top_class" and (
|
if eval_mode == "top_class" and (
|
||||||
gt_class[gt_id][det_ind] != pred_class[ind]
|
gt_class[gt_id][det_ind] != pred_class[ind]
|
||||||
|
@ -3,7 +3,9 @@ from typing import Optional
|
|||||||
import pytorch_lightning as L
|
import pytorch_lightning as L
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import optim
|
from torch.optim.adam import Adam
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models import (
|
from batdetect2.models import (
|
||||||
@ -13,11 +15,20 @@ from batdetect2.models import (
|
|||||||
get_backbone,
|
get_backbone,
|
||||||
)
|
)
|
||||||
from batdetect2.models.typing import ModelOutput
|
from batdetect2.models.typing import ModelOutput
|
||||||
from batdetect2.post_process import PostprocessConfig
|
from batdetect2.post_process import (
|
||||||
|
PostprocessConfig,
|
||||||
|
postprocess_model_outputs,
|
||||||
|
)
|
||||||
from batdetect2.preprocess import PreprocessingConfig
|
from batdetect2.preprocess import PreprocessingConfig
|
||||||
from batdetect2.train.dataset import TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
|
from batdetect2.train.evaluate import match_predictions_and_annotations
|
||||||
from batdetect2.train.losses import LossConfig, compute_loss
|
from batdetect2.train.losses import LossConfig, compute_loss
|
||||||
from batdetect2.train.targets import TargetConfig
|
from batdetect2.train.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
build_decoder,
|
||||||
|
build_encoder,
|
||||||
|
get_class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(BaseConfig):
|
class OptimizerConfig(BaseConfig):
|
||||||
@ -55,10 +66,8 @@ class DetectorModel(L.LightningModule):
|
|||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
size = self.config.preprocessing.spectrogram.size
|
size = self.config.preprocessing.spectrogram.size
|
||||||
self.backbone = get_backbone(
|
assert size is not None
|
||||||
input_height=int(size.height * (size.resize_factor or 1)),
|
self.backbone = get_backbone(self.config.backbone)
|
||||||
config=self.config.backbone,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.classifier = ClassifierHead(
|
self.classifier = ClassifierHead(
|
||||||
num_classes=len(self.config.targets.classes),
|
num_classes=len(self.config.targets.classes),
|
||||||
@ -74,6 +83,13 @@ class DetectorModel(L.LightningModule):
|
|||||||
|
|
||||||
self.validation_predictions = []
|
self.validation_predictions = []
|
||||||
|
|
||||||
|
self.class_names = get_class_names(self.config.targets.classes)
|
||||||
|
self.encoder = build_encoder(
|
||||||
|
self.config.targets.classes,
|
||||||
|
replacement_rules=self.config.targets.replace,
|
||||||
|
)
|
||||||
|
self.decoder = build_decoder(self.config.targets.classes)
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||||
features = self.backbone(spec)
|
features = self.backbone(spec)
|
||||||
detection_probs, classification_probs = self.classifier(features)
|
detection_probs, classification_probs = self.classifier(features)
|
||||||
@ -117,14 +133,29 @@ class DetectorModel(L.LightningModule):
|
|||||||
self.log("val/loss/classification", losses.total, logger=True)
|
self.log("val/loss/classification", losses.total, logger=True)
|
||||||
|
|
||||||
dataloaders = self.trainer.val_dataloaders
|
dataloaders = self.trainer.val_dataloaders
|
||||||
print(dataloaders)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
dataset = dataloaders.dataset
|
||||||
|
assert isinstance(dataset, LabeledDataset)
|
||||||
|
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||||
|
|
||||||
|
clip_prediction = postprocess_model_outputs(
|
||||||
|
outputs,
|
||||||
|
clips=[clip_annotation.clip],
|
||||||
|
classes=self.class_names,
|
||||||
|
decoder=self.decoder,
|
||||||
|
config=self.config.postprocessing,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
self.validation_predictions.extend(
|
||||||
|
match_predictions_and_annotations(clip_annotation, clip_prediction)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self) -> None:
|
||||||
|
print(len(self.validation_predictions))
|
||||||
|
self.validation_predictions.clear()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
conf = self.config.train.optimizer
|
conf = self.config.train.optimizer
|
||||||
optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate)
|
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||||
optimizer,
|
|
||||||
T_max=conf.t_max,
|
|
||||||
)
|
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
@ -101,7 +101,11 @@ def generate_train_example(
|
|||||||
return dataset.assign_attrs(
|
return dataset.assign_attrs(
|
||||||
title=f"Training example for {clip_annotation.uuid}",
|
title=f"Training example for {clip_annotation.uuid}",
|
||||||
config=config.model_dump_json(),
|
config=config.model_dump_json(),
|
||||||
clip_annotation=clip_annotation.model_dump_json(),
|
clip_annotation=clip_annotation.model_dump_json(
|
||||||
|
exclude_none=True,
|
||||||
|
exclude_defaults=True,
|
||||||
|
exclude_unset=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from batdetect2.configs import load_config
|
|
||||||
from batdetect2.data import DatasetsConfig, load_datasets
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_load_dataset_configs():
|
|
||||||
root = Path(__file__).parent.parent
|
|
||||||
path = root / "conf.yaml"
|
|
||||||
config = load_config(path, schema=DatasetsConfig, field="datasets")
|
|
||||||
load_datasets(config)
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.compat import load_annotation_project
|
from batdetect2.compat.data import load_annotation_project_from_dir
|
||||||
|
|
||||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ ROOT_DIR = Path(__file__).parent.parent.parent
|
|||||||
def test_load_example_annotation_project():
|
def test_load_example_annotation_project():
|
||||||
path = ROOT_DIR / "example_data" / "anns"
|
path = ROOT_DIR / "example_data" / "anns"
|
||||||
audio_dir = ROOT_DIR / "example_data" / "audio"
|
audio_dir = ROOT_DIR / "example_data" / "audio"
|
||||||
project = load_annotation_project(path, audio_dir=audio_dir)
|
project = load_annotation_project_from_dir(path, audio_dir=audio_dir)
|
||||||
assert isinstance(project, data.AnnotationProject)
|
assert isinstance(project, data.AnnotationProject)
|
||||||
assert project.name == str(path)
|
assert project.name == str(path)
|
||||||
assert len(project.clip_annotations) == 3
|
assert len(project.clip_annotations) == 3
|
||||||
|
@ -88,7 +88,7 @@ def test_spectrogram_generation_hasnt_changed(
|
|||||||
scale = preprocess.AmplitudeScaleConfig()
|
scale = preprocess.AmplitudeScaleConfig()
|
||||||
|
|
||||||
config = preprocess.SpectrogramConfig(
|
config = preprocess.SpectrogramConfig(
|
||||||
fft=preprocess.FFTConfig(
|
stft=preprocess.STFTConfig(
|
||||||
window_overlap=fft_overlap,
|
window_overlap=fft_overlap,
|
||||||
window_duration=fft_win_length,
|
window_duration=fft_win_length,
|
||||||
),
|
),
|
||||||
|
0
tests/test_models/__init__.py
Normal file
0
tests/test_models/__init__.py
Normal file
34
tests/test_models/test_inputs.py
Normal file
34
tests/test_models/test_inputs.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
from hypothesis import given
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from batdetect2.models import ModelConfig, ModelType, get_backbone
|
||||||
|
|
||||||
|
|
||||||
|
@given(
|
||||||
|
input_width=st.integers(min_value=50, max_value=1500),
|
||||||
|
input_height=st.integers(min_value=1, max_value=16),
|
||||||
|
model_type=st.sampled_from(ModelType),
|
||||||
|
)
|
||||||
|
def test_model_can_process_spectrograms_of_any_width(
|
||||||
|
input_width,
|
||||||
|
input_height,
|
||||||
|
model_type,
|
||||||
|
):
|
||||||
|
# Input height must be divisible by 8
|
||||||
|
input_height = 8 * input_height
|
||||||
|
|
||||||
|
input = torch.rand([1, 1, input_height, input_width])
|
||||||
|
|
||||||
|
model = get_backbone(
|
||||||
|
config=ModelConfig(
|
||||||
|
name=model_type, # type: ignore
|
||||||
|
input_height=input_height,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
output = model(input)
|
||||||
|
assert output.shape[0] == 1
|
||||||
|
assert output.shape[1] == model.out_channels
|
||||||
|
assert output.shape[2] == input_height
|
||||||
|
assert output.shape[3] == input_width
|
23
tests/test_preprocessing/test_arrays.py
Normal file
23
tests/test_preprocessing/test_arrays.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from batdetect2.preprocess.arrays import adjust_width, extend_width
|
||||||
|
|
||||||
|
|
||||||
|
def test_extend_width():
|
||||||
|
array = np.random.random([1, 1, 128, 100])
|
||||||
|
|
||||||
|
extended = extend_width(array, 100)
|
||||||
|
|
||||||
|
assert extended.shape == (1, 1, 128, 200)
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_adjust_short_width():
|
||||||
|
array = np.random.random([1, 1, 128, 100])
|
||||||
|
extended = adjust_width(array, 512)
|
||||||
|
assert extended.shape == (1, 1, 128, 512)
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_adjust_long_width():
|
||||||
|
array = np.random.random([1, 1, 128, 512])
|
||||||
|
extended = adjust_width(array, 256)
|
||||||
|
assert extended.shape == (1, 1, 128, 256)
|
@ -8,14 +8,13 @@ from soundevent import arrays
|
|||||||
|
|
||||||
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
FFTConfig,
|
STFTConfig,
|
||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
SpecSizeConfig,
|
SpecSizeConfig,
|
||||||
SpectrogramConfig,
|
SpectrogramConfig,
|
||||||
compute_spectrogram,
|
compute_spectrogram,
|
||||||
duration_to_spec_width,
|
duration_to_spec_width,
|
||||||
get_spectrogram_resolution,
|
get_spectrogram_resolution,
|
||||||
pad_audio,
|
|
||||||
spec_width_to_samples,
|
spec_width_to_samples,
|
||||||
stft,
|
stft,
|
||||||
)
|
)
|
||||||
@ -68,45 +67,6 @@ def test_can_estimate_correctly_spectrogram_width_from_duration(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@settings(
|
|
||||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
|
||||||
deadline=400,
|
|
||||||
)
|
|
||||||
@given(
|
|
||||||
duration=st.floats(min_value=0.1, max_value=1.0),
|
|
||||||
window_duration=st.floats(min_value=0.001, max_value=0.01),
|
|
||||||
window_overlap=st.floats(min_value=0.2, max_value=0.9),
|
|
||||||
samplerate=st.integers(min_value=256_000, max_value=512_000),
|
|
||||||
divide_factor=st.integers(min_value=16, max_value=64),
|
|
||||||
)
|
|
||||||
def test_can_pad_audio_to_adjust_spectrogram_width(
|
|
||||||
duration: float,
|
|
||||||
window_duration: float,
|
|
||||||
window_overlap: float,
|
|
||||||
samplerate: int,
|
|
||||||
divide_factor: int,
|
|
||||||
wav_factory: Callable[..., Path],
|
|
||||||
):
|
|
||||||
path = wav_factory(duration=duration, samplerate=samplerate)
|
|
||||||
|
|
||||||
audio = load_file_audio(
|
|
||||||
path,
|
|
||||||
# NOTE: Dont resample nor adjust duration to test if the width
|
|
||||||
# estimation works on all scenarios
|
|
||||||
config=AudioConfig(resample=None, duration=None),
|
|
||||||
)
|
|
||||||
|
|
||||||
audio = pad_audio(
|
|
||||||
audio,
|
|
||||||
window_duration=window_duration,
|
|
||||||
window_overlap=window_overlap,
|
|
||||||
divide_factor=divide_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
spectrogram = stft(audio, window_duration, window_overlap)
|
|
||||||
assert spectrogram.sizes["time"] % divide_factor == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_estimate_spectrogram_resolution(
|
def test_can_estimate_spectrogram_resolution(
|
||||||
wav_factory: Callable[..., Path],
|
wav_factory: Callable[..., Path],
|
||||||
):
|
):
|
||||||
@ -120,7 +80,7 @@ def test_can_estimate_spectrogram_resolution(
|
|||||||
)
|
)
|
||||||
|
|
||||||
config = SpectrogramConfig(
|
config = SpectrogramConfig(
|
||||||
fft=FFTConfig(),
|
stft=STFTConfig(),
|
||||||
size=SpecSizeConfig(height=256, resize_factor=0.5),
|
size=SpecSizeConfig(height=256, resize_factor=0.5),
|
||||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||||
)
|
)
|
||||||
|
42
tests/test_preprocessing/test_tensors.py
Normal file
42
tests/test_preprocessing/test_tensors.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.preprocess.tensors import adjust_width, make_width_divisible
|
||||||
|
|
||||||
|
|
||||||
|
def test_width_is_divisible_after_adjustment():
|
||||||
|
tensor = torch.rand([1, 1, 128, 374])
|
||||||
|
adjusted = make_width_divisible(tensor, 32)
|
||||||
|
assert adjusted.shape[-1] % 32 == 0
|
||||||
|
assert adjusted.shape == (1, 1, 128, 384)
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_last_axis_is_divisible_after_adjustment():
|
||||||
|
tensor = torch.rand([1, 1, 77, 124])
|
||||||
|
adjusted = make_width_divisible(tensor, 32, axis=-2)
|
||||||
|
assert adjusted.shape[-2] % 32 == 0
|
||||||
|
assert adjusted.shape == (1, 1, 96, 124)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_width_divisible_can_handle_numpy_array():
|
||||||
|
array = np.random.random([1, 1, 128, 374])
|
||||||
|
adjusted = make_width_divisible(array, 32)
|
||||||
|
assert adjusted.shape[-1] % 32 == 0
|
||||||
|
assert adjusted.shape == (1, 1, 128, 384)
|
||||||
|
assert isinstance(adjusted, torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_adjust_last_axis_width_by_default():
|
||||||
|
tensor = torch.rand([1, 1, 128, 374])
|
||||||
|
adjusted = adjust_width(tensor, 512)
|
||||||
|
assert adjusted.shape == (1, 1, 128, 512)
|
||||||
|
assert (tensor == adjusted[:, :, :, :374]).all()
|
||||||
|
assert (adjusted[:, :, :, 374:] == 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_adjust_second_to_last_axis():
|
||||||
|
tensor = torch.rand([1, 1, 89, 512])
|
||||||
|
adjusted = adjust_width(tensor, 128, axis=-2)
|
||||||
|
assert adjusted.shape == (1, 1, 128, 512)
|
||||||
|
assert (tensor == adjusted[:, :, :89, :]).all()
|
||||||
|
assert (adjusted[:, :, 89:, :] == 0).all()
|
@ -1,14 +1,14 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
add_echo,
|
add_echo,
|
||||||
adjust_dataset_width,
|
|
||||||
mix_examples,
|
mix_examples,
|
||||||
select_random_subclip,
|
select_subclip,
|
||||||
)
|
)
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import (
|
||||||
TrainPreprocessingConfig,
|
TrainPreprocessingConfig,
|
||||||
@ -23,11 +23,9 @@ def test_mix_examples(
|
|||||||
recording2 = recording_factory()
|
recording2 = recording_factory()
|
||||||
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
|
|
||||||
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
||||||
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
|
||||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||||
|
|
||||||
config = TrainPreprocessingConfig()
|
config = TrainPreprocessingConfig()
|
||||||
@ -43,6 +41,36 @@ def test_mix_examples(
|
|||||||
assert mixed["class"].shape == example1["class"].shape
|
assert mixed["class"].shape == example1["class"].shape
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
||||||
|
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
||||||
|
def test_mix_examples_of_different_durations(
|
||||||
|
recording_factory: Callable[..., data.Recording],
|
||||||
|
duration1: float,
|
||||||
|
duration2: float,
|
||||||
|
):
|
||||||
|
recording1 = recording_factory()
|
||||||
|
recording2 = recording_factory()
|
||||||
|
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
|
||||||
|
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
|
||||||
|
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||||
|
|
||||||
|
config = TrainPreprocessingConfig()
|
||||||
|
|
||||||
|
example1 = generate_train_example(clip_annotation_1, config)
|
||||||
|
example2 = generate_train_example(clip_annotation_2, config)
|
||||||
|
|
||||||
|
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
||||||
|
|
||||||
|
# Check the spectrogram has the expected duration
|
||||||
|
step = arrays.get_dim_step(mixed["spectrogram"], "time")
|
||||||
|
start, stop = arrays.get_dim_range(mixed["spectrogram"], "time")
|
||||||
|
assert start == 0
|
||||||
|
assert np.isclose(stop + step, duration1, atol=2 * step)
|
||||||
|
|
||||||
|
|
||||||
def test_add_echo(
|
def test_add_echo(
|
||||||
recording_factory: Callable[..., data.Recording],
|
recording_factory: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
@ -67,70 +95,23 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
config = TrainPreprocessingConfig()
|
config = TrainPreprocessingConfig()
|
||||||
original = generate_train_example(clip_annotation_1, config)
|
original = generate_train_example(clip_annotation_1, config)
|
||||||
subclip = select_random_subclip(original, width=100)
|
subclip = select_subclip(original, width=100)
|
||||||
|
|
||||||
assert subclip["spectrogram"].shape[1] == 100
|
assert subclip["spectrogram"].shape[1] == 100
|
||||||
|
|
||||||
|
|
||||||
def test_adjust_dataset_width():
|
def test_add_echo_after_subclip(
|
||||||
height = 128
|
recording_factory: Callable[..., data.Recording],
|
||||||
width = 512
|
):
|
||||||
samplerate = 48_000
|
recording1 = recording_factory(duration=2)
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
config = TrainPreprocessingConfig()
|
||||||
|
original = generate_train_example(clip_annotation_1, config)
|
||||||
|
|
||||||
times = np.linspace(0, 1, width)
|
assert original.sizes["time"] > 512
|
||||||
|
|
||||||
audio_times = np.linspace(0, 1, samplerate)
|
subclip = select_subclip(original, width=512)
|
||||||
frequency = np.linspace(0, 24_000, height)
|
with_echo = add_echo(subclip)
|
||||||
|
|
||||||
width_subset = 356
|
assert with_echo.sizes["time"] == 512
|
||||||
audio_width_subset = int(samplerate * width_subset / width)
|
|
||||||
|
|
||||||
times_subset = times[:width_subset]
|
|
||||||
audio_times_subset = audio_times[:audio_width_subset]
|
|
||||||
dimensions = ["width", "height"]
|
|
||||||
class_names = [f"species_{i}" for i in range(17)]
|
|
||||||
|
|
||||||
spectrogram = np.random.random([height, width_subset])
|
|
||||||
sizes = np.random.random([len(dimensions), height, width_subset])
|
|
||||||
classes = np.random.random([len(class_names), height, width_subset])
|
|
||||||
audio = np.random.random([int(samplerate * width_subset / width)])
|
|
||||||
|
|
||||||
dataset = xr.Dataset(
|
|
||||||
data_vars={
|
|
||||||
"audio": (("audio_time",), audio),
|
|
||||||
"spectrogram": (("frequency", "time"), spectrogram),
|
|
||||||
"sizes": (("dimension", "frequency", "time"), sizes),
|
|
||||||
"classes": (("class", "frequency", "time"), classes),
|
|
||||||
},
|
|
||||||
coords={
|
|
||||||
"audio_time": audio_times_subset,
|
|
||||||
"time": times_subset,
|
|
||||||
"frequency": frequency,
|
|
||||||
"dimension": dimensions,
|
|
||||||
"class": class_names,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
adjusted = adjust_dataset_width(dataset, width=width)
|
|
||||||
|
|
||||||
# Spectrogram was adjusted correctly
|
|
||||||
assert np.isclose(adjusted["spectrogram"].time, times).all()
|
|
||||||
assert (adjusted["spectrogram"].frequency == frequency).all()
|
|
||||||
|
|
||||||
# Sizes was adjusted correctly
|
|
||||||
assert np.isclose(adjusted["sizes"].time, times).all()
|
|
||||||
assert (adjusted["sizes"].frequency == frequency).all()
|
|
||||||
assert list(adjusted["sizes"].dimension.values) == dimensions
|
|
||||||
|
|
||||||
# Sizes was adjusted correctly
|
|
||||||
assert np.isclose(adjusted["classes"].time, times).all()
|
|
||||||
assert (adjusted["sizes"].frequency == frequency).all()
|
|
||||||
assert list(adjusted["classes"]["class"].values) == class_names
|
|
||||||
|
|
||||||
# Audio time was adjusted corretly
|
|
||||||
assert np.isclose(
|
|
||||||
len(adjusted["audio"].audio_time), len(audio_times), atol=2
|
|
||||||
)
|
|
||||||
assert np.isclose(
|
|
||||||
adjusted["audio"].audio_time[-1], audio_times[-1], atol=1e-3
|
|
||||||
)
|
|
||||||
|
@ -3,7 +3,6 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.types import ClassMapper
|
|
||||||
|
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
|
||||||
@ -23,16 +22,6 @@ clip = data.Clip(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Mapper(ClassMapper):
|
|
||||||
class_labels = ["bat", "cat"]
|
|
||||||
|
|
||||||
def encode(self, sound_event_annotation: data.SoundEventAnnotation) -> str:
|
|
||||||
return "bat"
|
|
||||||
|
|
||||||
def decode(self, label: str) -> list:
|
|
||||||
return [data.Tag(term=data.term_from_key("species"), value="bat")]
|
|
||||||
|
|
||||||
|
|
||||||
def test_generated_heatmaps_have_correct_dimensions():
|
def test_generated_heatmaps_have_correct_dimensions():
|
||||||
spec = xr.DataArray(
|
spec = xr.DataArray(
|
||||||
data=np.random.rand(100, 100),
|
data=np.random.rand(100, 100),
|
||||||
@ -57,12 +46,11 @@ def test_generated_heatmaps_have_correct_dimensions():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
class_mapper = Mapper()
|
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation.sound_events,
|
clip_annotation.sound_events,
|
||||||
spec,
|
spec,
|
||||||
class_mapper,
|
class_names=["bat", "cat"],
|
||||||
|
encoder=lambda _: "bat",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(detection_heatmap, xr.DataArray)
|
assert isinstance(detection_heatmap, xr.DataArray)
|
||||||
@ -107,11 +95,13 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
class_mapper = Mapper()
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation.sound_events,
|
clip_annotation.sound_events,
|
||||||
spec,
|
spec,
|
||||||
class_mapper,
|
class_names=["bat", "cat"],
|
||||||
|
encoder=lambda _: "bat",
|
||||||
|
time_scale=1,
|
||||||
|
frequency_scale=1,
|
||||||
)
|
)
|
||||||
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
|
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
|
||||||
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
|
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
|
||||||
|
Loading…
Reference in New Issue
Block a user