mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Reworking model creation
This commit is contained in:
parent
36c90a600f
commit
9cf159efff
@ -1,10 +1,13 @@
|
|||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import (
|
||||||
|
AmplitudeScaleConfig,
|
||||||
AudioConfig,
|
AudioConfig,
|
||||||
FFTConfig,
|
FFTConfig,
|
||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
PcenConfig,
|
LogScaleConfig,
|
||||||
|
PcenScaleConfig,
|
||||||
PreprocessingConfig,
|
PreprocessingConfig,
|
||||||
ResampleConfig,
|
ResampleConfig,
|
||||||
|
Scales,
|
||||||
SpecSizeConfig,
|
SpecSizeConfig,
|
||||||
SpectrogramConfig,
|
SpectrogramConfig,
|
||||||
)
|
)
|
||||||
@ -17,12 +20,12 @@ from batdetect2.train.preprocess import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_spectrogram_scale(scale: str):
|
def get_spectrogram_scale(scale: str) -> Scales:
|
||||||
if scale == "pcen":
|
if scale == "pcen":
|
||||||
return PcenConfig()
|
return PcenScaleConfig()
|
||||||
if scale == "log":
|
if scale == "log":
|
||||||
return "log"
|
return LogScaleConfig()
|
||||||
return None
|
return AmplitudeScaleConfig()
|
||||||
|
|
||||||
|
|
||||||
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||||
|
@ -1,11 +1,54 @@
|
|||||||
from batdetect2.models.feature_extractors import (
|
from enum import Enum
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.models.backbones import (
|
||||||
Net2DFast,
|
Net2DFast,
|
||||||
Net2DFastNoAttn,
|
Net2DFastNoAttn,
|
||||||
Net2DFastNoCoordConv,
|
Net2DFastNoCoordConv,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
|
from batdetect2.models.typing import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"get_backbone",
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
"Net2DFastNoAttn",
|
"Net2DFastNoAttn",
|
||||||
"Net2DFastNoCoordConv",
|
"Net2DFastNoCoordConv",
|
||||||
|
"ModelType",
|
||||||
|
"BBoxHead",
|
||||||
|
"ClassifierHead",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
Net2DFast = "Net2DFast"
|
||||||
|
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||||
|
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseConfig):
|
||||||
|
name: ModelType = ModelType.Net2DFast
|
||||||
|
num_features: int = 128
|
||||||
|
|
||||||
|
|
||||||
|
def get_backbone(
|
||||||
|
config: ModelConfig,
|
||||||
|
input_height: int = 128,
|
||||||
|
) -> BackboneModel:
|
||||||
|
if config.name == ModelType.Net2DFast:
|
||||||
|
return Net2DFast(
|
||||||
|
input_height=input_height,
|
||||||
|
num_features=config.num_features,
|
||||||
|
)
|
||||||
|
elif config.name == ModelType.Net2DFastNoAttn:
|
||||||
|
return Net2DFastNoAttn(
|
||||||
|
num_features=config.num_features,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
elif config.name == ModelType.Net2DFastNoCoordConv:
|
||||||
|
return Net2DFastNoCoordConv(
|
||||||
|
num_features=config.num_features,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model type: {config.name}")
|
||||||
|
@ -12,7 +12,7 @@ from batdetect2.models.blocks import (
|
|||||||
ConvBlockUpStandard,
|
ConvBlockUpStandard,
|
||||||
SelfAttention,
|
SelfAttention,
|
||||||
)
|
)
|
||||||
from batdetect2.models.typing import FeatureExtractorModel
|
from batdetect2.models.typing import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
@ -21,7 +21,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Net2DFast(FeatureExtractorModel):
|
class Net2DFast(BackboneModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
@ -37,7 +37,7 @@ class Net2DFast(FeatureExtractorModel):
|
|||||||
1,
|
1,
|
||||||
self.num_features // 4,
|
self.num_features // 4,
|
||||||
self.input_height,
|
self.input_height,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
@ -45,7 +45,7 @@ class Net2DFast(FeatureExtractorModel):
|
|||||||
self.num_features // 4,
|
self.num_features // 4,
|
||||||
self.num_features // 2,
|
self.num_features // 2,
|
||||||
self.input_height // 2,
|
self.input_height // 2,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
@ -53,7 +53,7 @@ class Net2DFast(FeatureExtractorModel):
|
|||||||
self.num_features // 2,
|
self.num_features // 2,
|
||||||
self.num_features,
|
self.num_features,
|
||||||
self.input_height // 4,
|
self.input_height // 4,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
@ -100,6 +100,8 @@ class Net2DFast(FeatureExtractorModel):
|
|||||||
)
|
)
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||||
|
|
||||||
|
self.out_channels = self.num_features // 4
|
||||||
|
|
||||||
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||||
h, w = spec.shape[2:]
|
h, w = spec.shape[2:]
|
||||||
h_pad = (32 - h % 32) % 32
|
h_pad = (32 - h % 32) % 32
|
||||||
@ -135,7 +137,7 @@ class Net2DFast(FeatureExtractorModel):
|
|||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoAttn(FeatureExtractorModel):
|
class Net2DFastNoAttn(BackboneModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
@ -151,7 +153,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
|
|||||||
1,
|
1,
|
||||||
self.num_features // 4,
|
self.num_features // 4,
|
||||||
self.input_height,
|
self.input_height,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
@ -159,7 +161,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
|
|||||||
self.num_features // 4,
|
self.num_features // 4,
|
||||||
self.num_features // 2,
|
self.num_features // 2,
|
||||||
self.input_height // 2,
|
self.input_height // 2,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
@ -167,7 +169,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
|
|||||||
self.num_features // 2,
|
self.num_features // 2,
|
||||||
self.num_features,
|
self.num_features,
|
||||||
self.input_height // 4,
|
self.input_height // 4,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
@ -210,6 +212,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
|
|||||||
padding=1,
|
padding=1,
|
||||||
)
|
)
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||||
|
self.out_channels = self.num_features // 4
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = self.conv_dn_0(spec)
|
x1 = self.conv_dn_0(spec)
|
||||||
@ -227,7 +230,7 @@ class Net2DFastNoAttn(FeatureExtractorModel):
|
|||||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoCoordConv(FeatureExtractorModel):
|
class Net2DFastNoCoordConv(BackboneModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
@ -242,21 +245,21 @@ class Net2DFastNoCoordConv(FeatureExtractorModel):
|
|||||||
self.conv_dn_0 = ConvBlockDownStandard(
|
self.conv_dn_0 = ConvBlockDownStandard(
|
||||||
1,
|
1,
|
||||||
self.num_features // 4,
|
self.num_features // 4,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_1 = ConvBlockDownStandard(
|
self.conv_dn_1 = ConvBlockDownStandard(
|
||||||
self.num_features // 4,
|
self.num_features // 4,
|
||||||
self.num_features // 2,
|
self.num_features // 2,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_dn_2 = ConvBlockDownStandard(
|
self.conv_dn_2 = ConvBlockDownStandard(
|
||||||
self.num_features // 2,
|
self.num_features // 2,
|
||||||
self.num_features,
|
self.num_features,
|
||||||
k_size=3,
|
kernel_size=3,
|
||||||
pad_size=1,
|
pad_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
@ -301,6 +304,7 @@ class Net2DFastNoCoordConv(FeatureExtractorModel):
|
|||||||
padding=1,
|
padding=1,
|
||||||
)
|
)
|
||||||
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
self.conv_op_bn = nn.BatchNorm2d(self.num_features // 4)
|
||||||
|
self.out_channels = self.num_features // 4
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = self.conv_dn_0(spec)
|
x1 = self.conv_dn_0(spec)
|
@ -10,6 +10,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SelfAttention",
|
"SelfAttention",
|
||||||
"ConvBlockDownCoordF",
|
"ConvBlockDownCoordF",
|
||||||
@ -19,22 +21,33 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionConfig(BaseConfig):
|
||||||
|
temperature: float = 1.0
|
||||||
|
input_channels: int = 128
|
||||||
|
attention_channels: int = 128
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""Self-Attention module.
|
"""Self-Attention module.
|
||||||
|
|
||||||
This module implements self-attention mechanism.
|
This module implements self-attention mechanism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ip_dim: int, att_dim: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
attention_channels: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Note, does not encode position information (absolute or realtive)
|
# Note, does not encode position information (absolute or relative)
|
||||||
self.temperature = 1.0
|
self.temperature = temperature
|
||||||
self.att_dim = att_dim
|
self.att_dim = attention_channels
|
||||||
self.key_fun = nn.Linear(ip_dim, att_dim)
|
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.val_fun = nn.Linear(ip_dim, att_dim)
|
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.que_fun = nn.Linear(ip_dim, att_dim)
|
self.query_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.pro_fun = nn.Linear(att_dim, ip_dim)
|
self.pro_fun = nn.Linear(attention_channels, in_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = x.squeeze(2).permute(0, 2, 1)
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
@ -43,11 +56,11 @@ class SelfAttention(nn.Module):
|
|||||||
x, self.key_fun.weight.T
|
x, self.key_fun.weight.T
|
||||||
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
query = torch.matmul(
|
query = torch.matmul(
|
||||||
x, self.que_fun.weight.T
|
x, self.query_fun.weight.T
|
||||||
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
value = torch.matmul(
|
value = torch.matmul(
|
||||||
x, self.val_fun.weight.T
|
x, self.value_fun.weight.T
|
||||||
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
) + self.value_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||||
self.temperature * self.att_dim
|
self.temperature * self.att_dim
|
||||||
@ -63,6 +76,15 @@ class SelfAttention(nn.Module):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockDownCoordFConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
input_height: int
|
||||||
|
kernel_size: int = 3
|
||||||
|
pad_size: int = 1
|
||||||
|
stride: int = 1
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownCoordF(nn.Module):
|
class ConvBlockDownCoordF(nn.Module):
|
||||||
"""Convolutional Block with Downsampling and Coord Feature.
|
"""Convolutional Block with Downsampling and Coord Feature.
|
||||||
|
|
||||||
@ -72,27 +94,27 @@ class ConvBlockDownCoordF(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chn: int,
|
in_channels: int,
|
||||||
out_chn: int,
|
out_channels: int,
|
||||||
ip_height: int,
|
input_height: int,
|
||||||
k_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
stride: int = 1,
|
stride: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.coords = nn.Parameter(
|
self.coords = nn.Parameter(
|
||||||
torch.linspace(-1, 1, ip_height)[None, None, ..., None],
|
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_chn + 1,
|
in_channels + 1,
|
||||||
out_chn,
|
out_channels,
|
||||||
kernel_size=k_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||||
@ -102,6 +124,14 @@ class ConvBlockDownCoordF(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockDownStandardConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
kernel_size: int = 3
|
||||||
|
pad_size: int = 1
|
||||||
|
stride: int = 1
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownStandard(nn.Module):
|
class ConvBlockDownStandard(nn.Module):
|
||||||
"""Convolutional Block with Downsampling.
|
"""Convolutional Block with Downsampling.
|
||||||
|
|
||||||
@ -110,21 +140,21 @@ class ConvBlockDownStandard(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chn,
|
in_channels: int,
|
||||||
out_chn,
|
out_channels: int,
|
||||||
k_size=3,
|
kernel_size: int = 3,
|
||||||
pad_size=1,
|
pad_size: int = 1,
|
||||||
stride=1,
|
stride: int = 1,
|
||||||
):
|
):
|
||||||
super(ConvBlockDownStandard, self).__init__()
|
super(ConvBlockDownStandard, self).__init__()
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_chn,
|
in_channels,
|
||||||
out_chn,
|
out_channels,
|
||||||
kernel_size=k_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
@ -132,6 +162,16 @@ class ConvBlockDownStandard(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockUpFConfig(BaseConfig):
|
||||||
|
inp_channels: int
|
||||||
|
out_channels: int
|
||||||
|
input_height: int
|
||||||
|
kernel_size: int = 3
|
||||||
|
pad_size: int = 1
|
||||||
|
up_mode: str = "bilinear"
|
||||||
|
up_scale: Tuple[int, int] = (2, 2)
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockUpF(nn.Module):
|
class ConvBlockUpF(nn.Module):
|
||||||
"""Convolutional Block with Upsampling and Coord Feature.
|
"""Convolutional Block with Upsampling and Coord Feature.
|
||||||
|
|
||||||
@ -141,10 +181,10 @@ class ConvBlockUpF(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chn: int,
|
in_channels: int,
|
||||||
out_chn: int,
|
out_channels: int,
|
||||||
ip_height: int,
|
input_height: int,
|
||||||
k_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
@ -154,15 +194,18 @@ class ConvBlockUpF(nn.Module):
|
|||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
self.coords = nn.Parameter(
|
self.coords = nn.Parameter(
|
||||||
torch.linspace(-1, 1, ip_height * up_scale[0])[
|
torch.linspace(-1, 1, input_height * up_scale[0])[
|
||||||
None, None, ..., None
|
None, None, ..., None
|
||||||
],
|
],
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
|
in_channels + 1,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
op = F.interpolate(
|
op = F.interpolate(
|
||||||
@ -181,6 +224,15 @@ class ConvBlockUpF(nn.Module):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockUpStandardConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
kernel_size: int = 3
|
||||||
|
pad_size: int = 1
|
||||||
|
up_mode: str = "bilinear"
|
||||||
|
up_scale: Tuple[int, int] = (2, 2)
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockUpStandard(nn.Module):
|
class ConvBlockUpStandard(nn.Module):
|
||||||
"""Convolutional Block with Upsampling.
|
"""Convolutional Block with Upsampling.
|
||||||
|
|
||||||
@ -189,9 +241,9 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chn: int,
|
in_channels: int,
|
||||||
out_chn: int,
|
out_channels: int,
|
||||||
k_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
@ -200,9 +252,12 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_chn, out_chn, kernel_size=k_size, padding=pad_size
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
op = F.interpolate(
|
op = F.interpolate(
|
||||||
|
@ -1,142 +0,0 @@
|
|||||||
from typing import Optional, Type
|
|
||||||
|
|
||||||
import pytorch_lightning as L
|
|
||||||
import torch
|
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
|
||||||
from torch import nn, optim
|
|
||||||
|
|
||||||
from batdetect2.data.labels import ClassMapper
|
|
||||||
from batdetect2.data.preprocessing import (
|
|
||||||
PreprocessingConfig,
|
|
||||||
preprocess,
|
|
||||||
)
|
|
||||||
from batdetect2.models.feature_extractors import Net2DFast
|
|
||||||
from batdetect2.models.post_process import (
|
|
||||||
PostprocessConfig,
|
|
||||||
postprocess_model_outputs,
|
|
||||||
)
|
|
||||||
from batdetect2.models.typing import FeatureExtractorModel, ModelOutput
|
|
||||||
from batdetect2.train import losses
|
|
||||||
from batdetect2.train.dataset import TrainExample
|
|
||||||
|
|
||||||
|
|
||||||
class DetectorModel(L.LightningModule):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_mapper: ClassMapper,
|
|
||||||
feature_extractor_class: Type[FeatureExtractorModel] = Net2DFast,
|
|
||||||
learning_rate: float = 1e-3,
|
|
||||||
input_height: int = 128,
|
|
||||||
num_features: int = 32,
|
|
||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
|
||||||
postprocessing_config: Optional[PostprocessConfig] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
preprocessing_config = preprocessing_config or PreprocessingConfig()
|
|
||||||
postprocessing_config = postprocessing_config or PostprocessConfig()
|
|
||||||
|
|
||||||
self.save_hyperparameters()
|
|
||||||
|
|
||||||
self.preprocessing_config = preprocessing_config
|
|
||||||
self.postprocessing_config = postprocessing_config
|
|
||||||
self.class_mapper = class_mapper
|
|
||||||
self.learning_rate = learning_rate
|
|
||||||
self.input_height = input_height
|
|
||||||
self.num_features = num_features
|
|
||||||
self.num_classes = class_mapper.num_classes
|
|
||||||
|
|
||||||
self.feature_extractor = feature_extractor_class(
|
|
||||||
input_height=input_height,
|
|
||||||
num_features=num_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.classifier = nn.Conv2d(
|
|
||||||
self.feature_extractor.num_features // 4,
|
|
||||||
self.num_classes + 1,
|
|
||||||
kernel_size=1,
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bbox = nn.Conv2d(
|
|
||||||
self.feature_extractor.num_features // 4,
|
|
||||||
2,
|
|
||||||
kernel_size=1,
|
|
||||||
padding=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
|
||||||
features = self.feature_extractor(spec)
|
|
||||||
classification_logits = self.classifier(features)
|
|
||||||
classification_probs = torch.softmax(classification_logits, dim=1)
|
|
||||||
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
|
||||||
return ModelOutput(
|
|
||||||
detection_probs=detection_probs,
|
|
||||||
size_preds=self.bbox(features),
|
|
||||||
class_probs=classification_probs[:, :-1],
|
|
||||||
features=features,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
|
||||||
return preprocess(
|
|
||||||
clip,
|
|
||||||
config=self.preprocessing_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_clip_features(self, clip: data.Clip) -> torch.Tensor:
|
|
||||||
spectrogram = self.compute_spectrogram(clip)
|
|
||||||
return self.feature_extractor(
|
|
||||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_clip_predictions(self, clip: data.Clip) -> data.ClipPrediction:
|
|
||||||
spectrogram = self.compute_spectrogram(clip)
|
|
||||||
spec_tensor = (
|
|
||||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
|
||||||
)
|
|
||||||
outputs = self(spec_tensor)
|
|
||||||
return postprocess_model_outputs(
|
|
||||||
outputs,
|
|
||||||
[clip],
|
|
||||||
class_mapper=self.class_mapper,
|
|
||||||
config=self.postprocessing_config,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
outputs: ModelOutput,
|
|
||||||
batch: TrainExample,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
detection_loss = losses.focal_loss(
|
|
||||||
outputs.detection_probs,
|
|
||||||
batch.detection_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
size_loss = losses.bbox_size_loss(
|
|
||||||
outputs.size_preds,
|
|
||||||
batch.size_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
|
||||||
classification_loss = losses.focal_loss(
|
|
||||||
outputs.class_probs,
|
|
||||||
batch.class_heatmap,
|
|
||||||
valid_mask=valid_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
return detection_loss + size_loss + classification_loss
|
|
||||||
|
|
||||||
def training_step( # type: ignore
|
|
||||||
self,
|
|
||||||
batch: TrainExample,
|
|
||||||
):
|
|
||||||
outputs = self.forward(batch.spec)
|
|
||||||
loss = self.compute_loss(outputs, batch)
|
|
||||||
self.log("train_loss", loss)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
|
||||||
return [optimizer], [scheduler]
|
|
51
batdetect2/models/heads.py
Normal file
51
batdetect2/models/heads.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
__all__ = ["ClassifierHead"]
|
||||||
|
|
||||||
|
|
||||||
|
class Output(NamedTuple):
|
||||||
|
detection: torch.Tensor
|
||||||
|
classification: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierHead(nn.Module):
|
||||||
|
def __init__(self, num_classes: int, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.classifier = nn.Conv2d(
|
||||||
|
self.in_channels,
|
||||||
|
# Add one to account for the background class
|
||||||
|
self.num_classes + 1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor) -> Output:
|
||||||
|
logits = self.classifier(features)
|
||||||
|
probs = torch.softmax(logits, dim=1)
|
||||||
|
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
|
||||||
|
return Output(
|
||||||
|
detection=detection_probs,
|
||||||
|
classification=probs[:, :-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxHead(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.bbox = nn.Conv2d(
|
||||||
|
self.feature_extractor.out_channels,
|
||||||
|
2,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.bbox(features)
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"FeatureExtractorModel",
|
"BackboneModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -41,13 +41,16 @@ class ModelOutput(NamedTuple):
|
|||||||
"""Tensor with intermediate features."""
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractorModel(ABC, nn.Module):
|
class BackboneModel(ABC, nn.Module):
|
||||||
input_height: int
|
input_height: int
|
||||||
"""Height of the input spectrogram."""
|
"""Height of the input spectrogram."""
|
||||||
|
|
||||||
num_features: int
|
num_features: int
|
||||||
"""Dimension of the feature tensor."""
|
"""Dimension of the feature tensor."""
|
||||||
|
|
||||||
|
out_channels: int
|
||||||
|
"""Number of output channels of the feature extractor."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
"""Forward pass of the encoder model."""
|
"""Forward pass of the encoder model."""
|
||||||
|
@ -8,7 +8,6 @@ from pydantic import BaseModel, Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.data.labels import ClassMapper
|
|
||||||
from batdetect2.models.typing import ModelOutput
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -37,7 +36,8 @@ TagFunction = Callable[[int], List[data.Tag]]
|
|||||||
def postprocess_model_outputs(
|
def postprocess_model_outputs(
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
clips: List[data.Clip],
|
clips: List[data.Clip],
|
||||||
class_mapper: ClassMapper,
|
classes: List[str],
|
||||||
|
decoder: Callable[[str], List[data.Tag]],
|
||||||
config: PostprocessConfig,
|
config: PostprocessConfig,
|
||||||
) -> List[data.ClipPrediction]:
|
) -> List[data.ClipPrediction]:
|
||||||
"""Postprocesses model outputs to generate clip predictions.
|
"""Postprocesses model outputs to generate clip predictions.
|
||||||
@ -108,7 +108,8 @@ def postprocess_model_outputs(
|
|||||||
size_preds,
|
size_preds,
|
||||||
class_probs,
|
class_probs,
|
||||||
features,
|
features,
|
||||||
class_mapper=class_mapper,
|
classes=classes,
|
||||||
|
decoder=decoder,
|
||||||
min_freq=config.min_freq,
|
min_freq=config.min_freq,
|
||||||
max_freq=config.max_freq,
|
max_freq=config.max_freq,
|
||||||
detection_threshold=config.detection_threshold,
|
detection_threshold=config.detection_threshold,
|
||||||
@ -132,7 +133,8 @@ def compute_sound_events_from_outputs(
|
|||||||
size_preds: torch.Tensor,
|
size_preds: torch.Tensor,
|
||||||
class_probs: torch.Tensor,
|
class_probs: torch.Tensor,
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
class_mapper: ClassMapper,
|
classes: List[str],
|
||||||
|
decoder: Callable[[str], List[data.Tag]],
|
||||||
min_freq: int = 10000,
|
min_freq: int = 10000,
|
||||||
max_freq: int = 120000,
|
max_freq: int = 120000,
|
||||||
detection_threshold: float = DETECTION_THRESHOLD,
|
detection_threshold: float = DETECTION_THRESHOLD,
|
||||||
@ -181,7 +183,8 @@ def compute_sound_events_from_outputs(
|
|||||||
predicted_tags: List[data.PredictedTag] = []
|
predicted_tags: List[data.PredictedTag] = []
|
||||||
|
|
||||||
for label_id, class_score in enumerate(class_prob):
|
for label_id, class_score in enumerate(class_prob):
|
||||||
corresponding_tags = class_mapper.inverse_transform(label_id)
|
class_name = classes[label_id]
|
||||||
|
corresponding_tags = decoder(class_name)
|
||||||
predicted_tags.extend(
|
predicted_tags.extend(
|
||||||
[
|
[
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
@ -12,9 +12,12 @@ from batdetect2.preprocess.audio import (
|
|||||||
load_clip_audio,
|
load_clip_audio,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
AmplitudeScaleConfig,
|
||||||
FFTConfig,
|
FFTConfig,
|
||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
PcenConfig,
|
LogScaleConfig,
|
||||||
|
PcenScaleConfig,
|
||||||
|
Scales,
|
||||||
SpecSizeConfig,
|
SpecSizeConfig,
|
||||||
SpectrogramConfig,
|
SpectrogramConfig,
|
||||||
compute_spectrogram,
|
compute_spectrogram,
|
||||||
@ -26,7 +29,10 @@ __all__ = [
|
|||||||
"SpectrogramConfig",
|
"SpectrogramConfig",
|
||||||
"FFTConfig",
|
"FFTConfig",
|
||||||
"FrequencyConfig",
|
"FrequencyConfig",
|
||||||
"PcenConfig",
|
"PcenScaleConfig",
|
||||||
|
"LogScaleConfig",
|
||||||
|
"AmplitudeScaleConfig",
|
||||||
|
"Scales",
|
||||||
"SpecSizeConfig",
|
"SpecSizeConfig",
|
||||||
"PreprocessingConfig",
|
"PreprocessingConfig",
|
||||||
"preprocess_audio_clip",
|
"preprocess_audio_clip",
|
||||||
|
@ -23,7 +23,18 @@ class FrequencyConfig(BaseConfig):
|
|||||||
min_freq: int = Field(default=10_000, gt=0)
|
min_freq: int = Field(default=10_000, gt=0)
|
||||||
|
|
||||||
|
|
||||||
class PcenConfig(BaseConfig):
|
class SpecSizeConfig(BaseConfig):
|
||||||
|
height: int = 256
|
||||||
|
resize_factor: Optional[float] = 0.5
|
||||||
|
divide_factor: Optional[int] = 32
|
||||||
|
|
||||||
|
|
||||||
|
class LogScaleConfig(BaseConfig):
|
||||||
|
name: Literal["log"] = "log"
|
||||||
|
|
||||||
|
|
||||||
|
class PcenScaleConfig(BaseConfig):
|
||||||
|
name: Literal["pcen"] = "pcen"
|
||||||
time_constant: float = 0.4
|
time_constant: float = 0.4
|
||||||
hop_length: int = 512
|
hop_length: int = 512
|
||||||
gain: float = 0.98
|
gain: float = 0.98
|
||||||
@ -31,19 +42,21 @@ class PcenConfig(BaseConfig):
|
|||||||
power: float = 0.5
|
power: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
class SpecSizeConfig(BaseConfig):
|
class AmplitudeScaleConfig(BaseConfig):
|
||||||
height: int = 256
|
name: Literal["amplitude"] = "amplitude"
|
||||||
resize_factor: Optional[float] = 0.5
|
|
||||||
divide_factor: Optional[int] = 32
|
|
||||||
|
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramConfig(BaseConfig):
|
class SpectrogramConfig(BaseConfig):
|
||||||
fft: FFTConfig = Field(default_factory=FFTConfig)
|
fft: FFTConfig = Field(default_factory=FFTConfig)
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
scale: Union[Literal["log"], None, PcenConfig] = Field(
|
scale: Scales = Field(
|
||||||
default_factory=PcenConfig
|
default_factory=PcenScaleConfig,
|
||||||
|
discriminator="name",
|
||||||
)
|
)
|
||||||
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
size: SpecSizeConfig = Field(default_factory=SpecSizeConfig)
|
||||||
denoise: bool = True
|
denoise: bool = True
|
||||||
max_scale: bool = False
|
max_scale: bool = False
|
||||||
|
|
||||||
@ -55,7 +68,7 @@ def compute_spectrogram(
|
|||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
config = config or SpectrogramConfig()
|
config = config or SpectrogramConfig()
|
||||||
|
|
||||||
if config.size and config.size.divide_factor:
|
if config.size.divide_factor:
|
||||||
# Need to pad the audio to make sure the spectrogram has a
|
# Need to pad the audio to make sure the spectrogram has a
|
||||||
# width compatible with the divide factor
|
# width compatible with the divide factor
|
||||||
wav = pad_audio(
|
wav = pad_audio(
|
||||||
@ -84,7 +97,6 @@ def compute_spectrogram(
|
|||||||
if config.denoise:
|
if config.denoise:
|
||||||
spec = denoise_spectrogram(spec)
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
if config.size:
|
|
||||||
spec = resize_spectrogram(
|
spec = resize_spectrogram(
|
||||||
spec,
|
spec,
|
||||||
height=config.size.height,
|
height=config.size.height,
|
||||||
@ -180,13 +192,13 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
|||||||
|
|
||||||
def scale_spectrogram(
|
def scale_spectrogram(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
scale: Union[Literal["log"], None, PcenConfig],
|
scale: Scales,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
if scale == "log":
|
if scale.name == "log":
|
||||||
return scale_log(spec, dtype=dtype)
|
return scale_log(spec, dtype=dtype)
|
||||||
|
|
||||||
if isinstance(scale, PcenConfig):
|
if scale.name == "pcen":
|
||||||
return scale_pcen(
|
return scale_pcen(
|
||||||
spec,
|
spec,
|
||||||
time_constant=scale.time_constant,
|
time_constant=scale.time_constant,
|
||||||
|
@ -1,941 +0,0 @@
|
|||||||
"""Functions and dataloaders for training and testing the model."""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.data
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
import batdetect2.utils.audio_utils as au
|
|
||||||
from batdetect2.types import (
|
|
||||||
Annotation,
|
|
||||||
AudioLoaderAnnotationGroup,
|
|
||||||
AudioLoaderParameters,
|
|
||||||
FileAnnotation,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_gt_heatmaps(
|
|
||||||
spec_op_shape: Tuple[int, int],
|
|
||||||
sampling_rate: float,
|
|
||||||
ann: AudioLoaderAnnotationGroup,
|
|
||||||
class_names: List[str],
|
|
||||||
fft_win_length: float,
|
|
||||||
fft_overlap: float,
|
|
||||||
max_freq: float,
|
|
||||||
min_freq: float,
|
|
||||||
resize_factor: float,
|
|
||||||
target_sigma: float,
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
|
|
||||||
"""Generate ground truth heatmaps from annotations.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec_op_shape : Tuple[int, int]
|
|
||||||
Shape of the input spectrogram.
|
|
||||||
sampling_rate : int
|
|
||||||
Sampling rate of the input audio in Hz.
|
|
||||||
ann : AnnotationGroup
|
|
||||||
Dictionary containing the annotation information.
|
|
||||||
params : HeatmapParameters
|
|
||||||
Parameters controlling the generation of the heatmaps.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
y_2d_det : np.ndarray
|
|
||||||
2D heatmap of the presence of an event.
|
|
||||||
y_2d_size : np.ndarray
|
|
||||||
2D heatmap of the size of the bounding box associated to event.
|
|
||||||
y_2d_classes : np.ndarray
|
|
||||||
3D array containing the ground-truth class probabilities for each
|
|
||||||
pixel.
|
|
||||||
ann_aug : AnnotationGroup
|
|
||||||
A dictionary containing the annotation information of the
|
|
||||||
annotations that are within the input spectrogram, augmented with
|
|
||||||
the x and y indices of their pixel location in the input spectrogram.
|
|
||||||
"""
|
|
||||||
# spec may be resized on input into the network
|
|
||||||
num_classes = len(class_names)
|
|
||||||
op_height = spec_op_shape[0]
|
|
||||||
op_width = spec_op_shape[1]
|
|
||||||
freq_per_bin = (max_freq - min_freq) / op_height
|
|
||||||
|
|
||||||
# start and end times
|
|
||||||
x_pos_start = au.time_to_x_coords(
|
|
||||||
ann["start_times"],
|
|
||||||
sampling_rate,
|
|
||||||
fft_win_length,
|
|
||||||
fft_overlap,
|
|
||||||
)
|
|
||||||
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
|
|
||||||
x_pos_end = au.time_to_x_coords(
|
|
||||||
ann["end_times"],
|
|
||||||
sampling_rate,
|
|
||||||
fft_win_length,
|
|
||||||
fft_overlap,
|
|
||||||
)
|
|
||||||
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
|
|
||||||
|
|
||||||
# location on y axis i.e. frequency
|
|
||||||
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
|
|
||||||
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
|
||||||
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
|
|
||||||
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
|
||||||
bb_widths = x_pos_end - x_pos_start
|
|
||||||
bb_heights = y_pos_low - y_pos_high
|
|
||||||
|
|
||||||
# Only include annotations that are within the input spectrogram
|
|
||||||
valid_inds = np.where(
|
|
||||||
(x_pos_start >= 0)
|
|
||||||
& (x_pos_start < op_width)
|
|
||||||
& (y_pos_low >= 0)
|
|
||||||
& (y_pos_low < (op_height - 1))
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
ann_aug: AudioLoaderAnnotationGroup = {
|
|
||||||
**ann,
|
|
||||||
"start_times": ann["start_times"][valid_inds],
|
|
||||||
"end_times": ann["end_times"][valid_inds],
|
|
||||||
"high_freqs": ann["high_freqs"][valid_inds],
|
|
||||||
"low_freqs": ann["low_freqs"][valid_inds],
|
|
||||||
"class_ids": ann["class_ids"][valid_inds],
|
|
||||||
"individual_ids": ann["individual_ids"][valid_inds],
|
|
||||||
"x_inds": x_pos_start[valid_inds],
|
|
||||||
"y_inds": y_pos_low[valid_inds],
|
|
||||||
}
|
|
||||||
|
|
||||||
# if the number of calls is only 1, then it is unique
|
|
||||||
# TODO would be better if we found these unique calls at the merging stage
|
|
||||||
if len(ann_aug["individual_ids"]) == 1:
|
|
||||||
ann_aug["individual_ids"][0] = 0
|
|
||||||
|
|
||||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
|
||||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
|
||||||
|
|
||||||
# num classes and "background" class
|
|
||||||
y_2d_classes: np.ndarray = np.zeros(
|
|
||||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
# create 2D ground truth heatmaps
|
|
||||||
for ii in valid_inds:
|
|
||||||
draw_gaussian(
|
|
||||||
y_2d_det[0, :],
|
|
||||||
(x_pos_start[ii], y_pos_low[ii]),
|
|
||||||
target_sigma,
|
|
||||||
)
|
|
||||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
|
||||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
|
||||||
|
|
||||||
cls_id = ann["class_ids"][ii]
|
|
||||||
if cls_id > -1:
|
|
||||||
draw_gaussian(
|
|
||||||
y_2d_classes[cls_id, :],
|
|
||||||
(x_pos_start[ii], y_pos_low[ii]),
|
|
||||||
target_sigma,
|
|
||||||
)
|
|
||||||
|
|
||||||
# be careful as this will have a 1.0 places where we have event but
|
|
||||||
# dont know gt class this will be masked in training anyway
|
|
||||||
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
|
|
||||||
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
|
|
||||||
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
|
|
||||||
|
|
||||||
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
|
|
||||||
|
|
||||||
|
|
||||||
def draw_gaussian(
|
|
||||||
heatmap: np.ndarray,
|
|
||||||
center: Tuple[int, int],
|
|
||||||
sigmax: float,
|
|
||||||
sigmay: Optional[float] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""Draw a 2D gaussian into the heatmap.
|
|
||||||
|
|
||||||
If the gaussian center is outside the heatmap, then the gaussian is not
|
|
||||||
drawn.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
heatmap : np.ndarray
|
|
||||||
The heatmap to draw into. Should be of shape (height, width).
|
|
||||||
center : Tuple[int, int]
|
|
||||||
The center of the gaussian in (x, y) format.
|
|
||||||
sigmax : float
|
|
||||||
The standard deviation of the gaussian in the x direction.
|
|
||||||
sigmay : Optional[float], optional
|
|
||||||
The standard deviation of the gaussian in the y direction. If None,
|
|
||||||
then sigmay = sigmax, by default None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the gaussian was drawn, False if it was not (because
|
|
||||||
the center was outside the heatmap).
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
# center is (x, y)
|
|
||||||
# this edits the heatmap inplace
|
|
||||||
|
|
||||||
if sigmay is None:
|
|
||||||
sigmay = sigmax
|
|
||||||
tmp_size = np.maximum(sigmax, sigmay) * 3
|
|
||||||
mu_x = int(center[0] + 0.5)
|
|
||||||
mu_y = int(center[1] + 0.5)
|
|
||||||
w, h = heatmap.shape[0], heatmap.shape[1]
|
|
||||||
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
|
||||||
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
|
||||||
|
|
||||||
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
size = 2 * tmp_size + 1
|
|
||||||
x = np.arange(0, size, 1, np.float32)
|
|
||||||
y = x[:, np.newaxis]
|
|
||||||
x0 = y0 = size // 2
|
|
||||||
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
|
||||||
g = np.exp(
|
|
||||||
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
|
|
||||||
)
|
|
||||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
|
||||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
|
||||||
img_x = max(0, ul[0]), min(br[0], h)
|
|
||||||
img_y = max(0, ul[1]), min(br[1], w)
|
|
||||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
|
|
||||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
|
|
||||||
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
|
|
||||||
"""Pad array with -1s."""
|
|
||||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int32) * -1))
|
|
||||||
|
|
||||||
|
|
||||||
def warp_spec_aug(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
ann: AudioLoaderAnnotationGroup,
|
|
||||||
stretch_squeeze_delta: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Warp spectrogram by randomly stretching and squeezing.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec: torch.Tensor
|
|
||||||
Spectrogram to warp.
|
|
||||||
ann: AnnotationGroup
|
|
||||||
Annotation group for the spectrogram. Must be provided to sync
|
|
||||||
the start and stop times with the spectrogram after warping.
|
|
||||||
stretch_squeeze_delta: float
|
|
||||||
Maximum amount to stretch or squeeze the spectrogram.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Warped spectrogram.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
This function modifies the annotation group in place.
|
|
||||||
"""
|
|
||||||
# Augment spectrogram by randomly stretch and squeezing
|
|
||||||
# NOTE this also changes the start and stop time in place
|
|
||||||
|
|
||||||
delta = stretch_squeeze_delta
|
|
||||||
op_size = (spec.shape[1], spec.shape[2])
|
|
||||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
|
||||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
|
||||||
|
|
||||||
if resize_amt >= spec.shape[2]:
|
|
||||||
spec_r = torch.cat(
|
|
||||||
(
|
|
||||||
spec,
|
|
||||||
torch.zeros(
|
|
||||||
(1, spec.shape[1], resize_amt - spec.shape[2]),
|
|
||||||
dtype=spec.dtype,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
spec_r = spec[:, :, :resize_amt]
|
|
||||||
|
|
||||||
# Resize the spectrogram
|
|
||||||
spec = F.interpolate(
|
|
||||||
spec_r.unsqueeze(0),
|
|
||||||
size=op_size,
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
# Update the start and stop times
|
|
||||||
ann["start_times"] *= 1.0 / resize_fract_r
|
|
||||||
ann["end_times"] *= 1.0 / resize_fract_r
|
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def mask_time_aug(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
mask_max_time_perc: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Mask out random blocks of time.
|
|
||||||
|
|
||||||
Will randomly mask out a block of time in the spectrogram. The block
|
|
||||||
will be between 0.0 and `mask_max_time_perc` of the total time.
|
|
||||||
A random number of blocks will be masked out between 1 and 3.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec: torch.Tensor
|
|
||||||
Spectrogram to mask.
|
|
||||||
mask_max_time_perc: float
|
|
||||||
Maximum percentage of time to mask out.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Spectrogram with masked out time blocks.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
This function is based on the implementation in::
|
|
||||||
|
|
||||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
|
||||||
Recognition
|
|
||||||
"""
|
|
||||||
fm = torchaudio.transforms.TimeMasking(
|
|
||||||
int(spec.shape[1] * mask_max_time_perc)
|
|
||||||
)
|
|
||||||
for _ in range(np.random.randint(1, 4)):
|
|
||||||
spec = fm(spec)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def mask_freq_aug(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
mask_max_freq_perc: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Mask out random blocks of frequency.
|
|
||||||
|
|
||||||
Will randomly mask out a block of frequency in the spectrogram. The block
|
|
||||||
will be between 0.0 and `mask_max_freq_perc` of the total frequency.
|
|
||||||
A random number of blocks will be masked out between 1 and 3.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec: torch.Tensor
|
|
||||||
Spectrogram to mask.
|
|
||||||
mask_max_freq_perc: float
|
|
||||||
Maximum percentage of frequency to mask out.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Spectrogram with masked out frequency blocks.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
This function is based on the implementation in::
|
|
||||||
|
|
||||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
|
||||||
Recognition
|
|
||||||
"""
|
|
||||||
fm = torchaudio.transforms.FrequencyMasking(
|
|
||||||
int(spec.shape[1] * mask_max_freq_perc)
|
|
||||||
)
|
|
||||||
for _ in range(np.random.randint(1, 4)):
|
|
||||||
spec = fm(spec)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def scale_vol_aug(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
spec_amp_scaling: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Scale the volume of the spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec: torch.Tensor
|
|
||||||
Spectrogram to scale.
|
|
||||||
spec_amp_scaling: float
|
|
||||||
Maximum scaling factor.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
"""
|
|
||||||
return spec * np.random.random() * spec_amp_scaling
|
|
||||||
|
|
||||||
|
|
||||||
def echo_aug(
|
|
||||||
audio: np.ndarray,
|
|
||||||
sampling_rate: float,
|
|
||||||
echo_max_delay: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Add echo to audio.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
audio: np.ndarray
|
|
||||||
Audio to add echo to.
|
|
||||||
sampling_rate: float
|
|
||||||
Sampling rate of the audio.
|
|
||||||
echo_max_delay: float
|
|
||||||
Maximum delay of the echo in seconds.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
Audio with echo added.
|
|
||||||
"""
|
|
||||||
sample_offset = (
|
|
||||||
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
|
||||||
)
|
|
||||||
# NOTE: This seems to be wrong, as the echo should be added to the
|
|
||||||
# end of the audio, not the beginning.
|
|
||||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
def resample_aug(
|
|
||||||
audio: np.ndarray,
|
|
||||||
sampling_rate: float,
|
|
||||||
fft_win_length: float,
|
|
||||||
fft_overlap: float,
|
|
||||||
resize_factor: float,
|
|
||||||
spec_divide_factor: float,
|
|
||||||
spec_train_width: int,
|
|
||||||
aug_sampling_rates: List[int],
|
|
||||||
) -> Tuple[np.ndarray, float, float]:
|
|
||||||
"""Resample audio augmentation.
|
|
||||||
|
|
||||||
Will resample the audio to a random sampling rate from the list of
|
|
||||||
sampling rates in `aug_sampling_rates`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
audio: np.ndarray
|
|
||||||
Audio to resample.
|
|
||||||
sampling_rate: float
|
|
||||||
Original sampling rate of the audio.
|
|
||||||
fft_win_length: float
|
|
||||||
Length of the FFT window in seconds.
|
|
||||||
fft_overlap: float
|
|
||||||
Amount of overlap between FFT windows.
|
|
||||||
resize_factor: float
|
|
||||||
Factor to resize the spectrogram by.
|
|
||||||
spec_divide_factor: float
|
|
||||||
Factor to divide the spectrogram by.
|
|
||||||
spec_train_width: int
|
|
||||||
Width of the spectrogram.
|
|
||||||
aug_sampling_rates: List[int]
|
|
||||||
List of sampling rates to resample to.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
audio : np.ndarray
|
|
||||||
Resampled audio.
|
|
||||||
sampling_rate : float
|
|
||||||
New sampling rate.
|
|
||||||
duration : float
|
|
||||||
Duration of the audio in seconds.
|
|
||||||
"""
|
|
||||||
sampling_rate_old = sampling_rate
|
|
||||||
sampling_rate = np.random.choice(aug_sampling_rates)
|
|
||||||
audio = librosa.resample(
|
|
||||||
audio,
|
|
||||||
orig_sr=sampling_rate_old,
|
|
||||||
target_sr=sampling_rate,
|
|
||||||
res_type="polyphase",
|
|
||||||
)
|
|
||||||
|
|
||||||
audio = au.pad_audio(
|
|
||||||
audio,
|
|
||||||
sampling_rate,
|
|
||||||
fft_win_length,
|
|
||||||
fft_overlap,
|
|
||||||
resize_factor,
|
|
||||||
spec_divide_factor,
|
|
||||||
spec_train_width,
|
|
||||||
)
|
|
||||||
duration = audio.shape[0] / float(sampling_rate)
|
|
||||||
return audio, sampling_rate, duration
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
|
||||||
num_samples: int,
|
|
||||||
sampling_rate: float,
|
|
||||||
audio2: np.ndarray,
|
|
||||||
sampling_rate2: float,
|
|
||||||
) -> Tuple[np.ndarray, float]:
|
|
||||||
"""Resample audio.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
num_samples: int
|
|
||||||
Expected number of samples for the output audio.
|
|
||||||
sampling_rate: float
|
|
||||||
Original sampling rate of the audio.
|
|
||||||
audio2: np.ndarray
|
|
||||||
Audio to resample.
|
|
||||||
sampling_rate2: float
|
|
||||||
Target sampling rate of the audio.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
audio2 : np.ndarray
|
|
||||||
Resampled audio.
|
|
||||||
sampling_rate2 : float
|
|
||||||
New sampling rate.
|
|
||||||
"""
|
|
||||||
# resample to target sampling rate
|
|
||||||
if sampling_rate != sampling_rate2:
|
|
||||||
audio2 = librosa.resample(
|
|
||||||
audio2,
|
|
||||||
orig_sr=sampling_rate2,
|
|
||||||
target_sr=sampling_rate,
|
|
||||||
res_type="polyphase",
|
|
||||||
)
|
|
||||||
sampling_rate2 = sampling_rate
|
|
||||||
|
|
||||||
# pad or trim to the correct length
|
|
||||||
if audio2.shape[0] < num_samples:
|
|
||||||
audio2 = np.hstack(
|
|
||||||
(
|
|
||||||
audio2,
|
|
||||||
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif audio2.shape[0] > num_samples:
|
|
||||||
audio2 = audio2[:num_samples]
|
|
||||||
|
|
||||||
return audio2, sampling_rate2
|
|
||||||
|
|
||||||
|
|
||||||
def combine_audio_aug(
|
|
||||||
audio: np.ndarray,
|
|
||||||
sampling_rate: float,
|
|
||||||
ann: AudioLoaderAnnotationGroup,
|
|
||||||
audio2: np.ndarray,
|
|
||||||
sampling_rate2: float,
|
|
||||||
ann2: AudioLoaderAnnotationGroup,
|
|
||||||
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
|
|
||||||
"""Combine two audio files.
|
|
||||||
|
|
||||||
Will combine two audio files by resampling them to the same sampling rate
|
|
||||||
and then combining them with a random weight. The annotations will be
|
|
||||||
combined by taking the union of the two sets of annotations.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
audio: np.ndarray
|
|
||||||
First Audio to combine.
|
|
||||||
sampling_rate: int
|
|
||||||
Sampling rate of the first audio.
|
|
||||||
ann: AnnotationGroup
|
|
||||||
Annotations for the first audio.
|
|
||||||
audio2: np.ndarray
|
|
||||||
Second Audio to combine.
|
|
||||||
sampling_rate2: int
|
|
||||||
Sampling rate of the second audio.
|
|
||||||
ann2: AnnotationGroup
|
|
||||||
Annotations for the second audio.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
audio : np.ndarray
|
|
||||||
Combined audio.
|
|
||||||
ann : AnnotationGroup
|
|
||||||
Combined annotations.
|
|
||||||
"""
|
|
||||||
# resample so they are the same
|
|
||||||
audio2, sampling_rate2 = resample_audio(
|
|
||||||
audio.shape[0],
|
|
||||||
sampling_rate,
|
|
||||||
audio2,
|
|
||||||
sampling_rate2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# # set mean and std to be the same
|
|
||||||
# audio2 = (audio2 - audio2.mean())
|
|
||||||
# audio2 = (audio2/audio2.std())*audio.std()
|
|
||||||
# audio2 = audio2 + audio.mean()
|
|
||||||
|
|
||||||
if (
|
|
||||||
ann.get("annotated", False)
|
|
||||||
and (ann2.get("annotated", False))
|
|
||||||
and (sampling_rate2 == sampling_rate)
|
|
||||||
and (audio.shape[0] == audio2.shape[0])
|
|
||||||
):
|
|
||||||
comb_weight = 0.3 + np.random.random() * 0.4
|
|
||||||
audio = comb_weight * audio + (1 - comb_weight) * audio2
|
|
||||||
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
|
|
||||||
for kk in ann.keys():
|
|
||||||
# when combining calls from different files, assume they come
|
|
||||||
# from different individuals
|
|
||||||
if kk == "individual_ids":
|
|
||||||
if (ann[kk] > -1).sum() > 0:
|
|
||||||
ann2[kk][ann2[kk] > -1] += (
|
|
||||||
np.max(ann[kk][ann[kk] > -1]) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
if (kk != "class_id_file") and (kk != "annotated"):
|
|
||||||
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
|
||||||
|
|
||||||
return audio, ann
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_annotation(
|
|
||||||
annotation: Annotation,
|
|
||||||
class_names: List[str],
|
|
||||||
) -> Annotation:
|
|
||||||
try:
|
|
||||||
class_id = class_names.index(annotation["class"])
|
|
||||||
except ValueError:
|
|
||||||
class_id = -1
|
|
||||||
|
|
||||||
ann: Annotation = {
|
|
||||||
**annotation,
|
|
||||||
"class_id": class_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
if "individual" in ann:
|
|
||||||
ann["individual"] = int(ann["individual"]) # type: ignore
|
|
||||||
|
|
||||||
return ann
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_file_annotation(
|
|
||||||
annotation: FileAnnotation,
|
|
||||||
class_names: List[str],
|
|
||||||
classes_to_ignore: List[str],
|
|
||||||
) -> AudioLoaderAnnotationGroup:
|
|
||||||
annotations = [
|
|
||||||
_prepare_annotation(ann, class_names)
|
|
||||||
for ann in annotation["annotation"]
|
|
||||||
if ann["class"] not in classes_to_ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
class_id_file = class_names.index(annotation["class_name"])
|
|
||||||
except ValueError:
|
|
||||||
class_id_file = -1
|
|
||||||
|
|
||||||
ret: AudioLoaderAnnotationGroup = {
|
|
||||||
"id": annotation["id"],
|
|
||||||
"annotated": annotation["annotated"],
|
|
||||||
"duration": annotation["duration"],
|
|
||||||
"issues": annotation["issues"],
|
|
||||||
"time_exp": annotation["time_exp"],
|
|
||||||
"class_name": annotation["class_name"],
|
|
||||||
"notes": annotation["notes"],
|
|
||||||
"annotation": annotations,
|
|
||||||
"start_times": np.array([ann["start_time"] for ann in annotations]),
|
|
||||||
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
|
||||||
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
|
||||||
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
|
||||||
"class_ids": np.array(
|
|
||||||
[ann.get("class_id", -1) for ann in annotations]
|
|
||||||
),
|
|
||||||
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
|
||||||
"class_id_file": class_id_file,
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class AudioLoader(torch.utils.data.Dataset):
|
|
||||||
"""Main AudioLoader for training and testing."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data_anns_ip: List[FileAnnotation],
|
|
||||||
params: AudioLoaderParameters,
|
|
||||||
dataset_name: Optional[str] = None,
|
|
||||||
is_train: bool = False,
|
|
||||||
return_spec_for_viz: bool = False,
|
|
||||||
):
|
|
||||||
self.is_train = is_train
|
|
||||||
self.params = params
|
|
||||||
self.return_spec_for_viz = return_spec_for_viz
|
|
||||||
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
|
||||||
_prepare_file_annotation(
|
|
||||||
ann,
|
|
||||||
params["class_names"],
|
|
||||||
params["classes_to_ignore"],
|
|
||||||
)
|
|
||||||
for ann in data_anns_ip
|
|
||||||
]
|
|
||||||
|
|
||||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
|
||||||
self.max_num_anns = 2 * np.max(
|
|
||||||
ann_cnt
|
|
||||||
) # x2 because we may be combining files during training
|
|
||||||
|
|
||||||
print("\n")
|
|
||||||
if dataset_name is not None:
|
|
||||||
print("Dataset : " + dataset_name)
|
|
||||||
if self.is_train:
|
|
||||||
print("Split type : train")
|
|
||||||
else:
|
|
||||||
print("Split type : test")
|
|
||||||
print("Num files : " + str(len(self.data_anns)))
|
|
||||||
print("Num calls : " + str(np.sum(ann_cnt)))
|
|
||||||
|
|
||||||
def get_file_and_anns(
|
|
||||||
self,
|
|
||||||
index: Optional[int] = None,
|
|
||||||
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
|
|
||||||
"""Get an audio file and its annotations.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
index : int, optional
|
|
||||||
Index of the file to be loaded. If None, a random file is chosen.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
audio_raw : np.ndarray
|
|
||||||
Loaded audio file.
|
|
||||||
sampling_rate : float
|
|
||||||
Sampling rate of the audio file.
|
|
||||||
duration : float
|
|
||||||
Duration of the audio file in seconds.
|
|
||||||
ann : AnnotationGroup
|
|
||||||
AnnotationGroup object containing the annotations for the audio file.
|
|
||||||
"""
|
|
||||||
# if no file specified, choose random one
|
|
||||||
if index is None:
|
|
||||||
index = np.random.randint(0, len(self.data_anns))
|
|
||||||
|
|
||||||
audio_file = self.data_anns[index]["file_path"]
|
|
||||||
sampling_rate, audio_raw = au.load_audio(
|
|
||||||
audio_file,
|
|
||||||
self.data_anns[index]["time_exp"],
|
|
||||||
self.params["target_samp_rate"],
|
|
||||||
self.params["scale_raw_audio"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# copy annotation
|
|
||||||
ann = copy.deepcopy(self.data_anns[index])
|
|
||||||
# ann["annotated"] = self.data_anns[index]["annotated"]
|
|
||||||
# ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
|
||||||
# keys = [
|
|
||||||
# "start_times",
|
|
||||||
# "end_times",
|
|
||||||
# "high_freqs",
|
|
||||||
# "low_freqs",
|
|
||||||
# "class_ids",
|
|
||||||
# "individual_ids",
|
|
||||||
# ]
|
|
||||||
# for kk in keys:
|
|
||||||
# ann[kk] = self.data_anns[index][kk].copy()
|
|
||||||
|
|
||||||
# if train then grab a random crop
|
|
||||||
if self.is_train:
|
|
||||||
nfft = int(self.params["fft_win_length"] * sampling_rate)
|
|
||||||
noverlap = int(self.params["fft_overlap"] * nfft)
|
|
||||||
length_samples = (
|
|
||||||
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
|
|
||||||
)
|
|
||||||
|
|
||||||
if audio_raw.shape[0] - length_samples > 0:
|
|
||||||
sample_crop = np.random.randint(
|
|
||||||
audio_raw.shape[0] - length_samples
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample_crop = 0
|
|
||||||
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
|
|
||||||
ann["start_times"] = ann["start_times"] - sample_crop / float(
|
|
||||||
sampling_rate
|
|
||||||
)
|
|
||||||
ann["end_times"] = ann["end_times"] - sample_crop / float(
|
|
||||||
sampling_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
# pad audio
|
|
||||||
if self.is_train:
|
|
||||||
op_spec_target_size = self.params["spec_train_width"]
|
|
||||||
else:
|
|
||||||
op_spec_target_size = None
|
|
||||||
audio_raw = au.pad_audio(
|
|
||||||
audio_raw,
|
|
||||||
sampling_rate,
|
|
||||||
self.params["fft_win_length"],
|
|
||||||
self.params["fft_overlap"],
|
|
||||||
self.params["resize_factor"],
|
|
||||||
self.params["spec_divide_factor"],
|
|
||||||
op_spec_target_size,
|
|
||||||
)
|
|
||||||
duration = audio_raw.shape[0] / float(sampling_rate)
|
|
||||||
|
|
||||||
# sort based on time
|
|
||||||
inds = np.argsort(ann["start_times"])
|
|
||||||
for kk in ann.keys():
|
|
||||||
if (kk != "class_id_file") and (kk != "annotated"):
|
|
||||||
ann[kk] = ann[kk][inds]
|
|
||||||
|
|
||||||
return audio_raw, sampling_rate, duration, ann
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
"""Get an item from the dataset."""
|
|
||||||
# load audio file
|
|
||||||
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
|
||||||
|
|
||||||
# augment on raw audio
|
|
||||||
if self.is_train and self.params["augment_at_train"]:
|
|
||||||
# augment - combine with random audio file
|
|
||||||
if (
|
|
||||||
self.params["augment_at_train_combine"]
|
|
||||||
and np.random.random() < self.params["aug_prob"]
|
|
||||||
):
|
|
||||||
(
|
|
||||||
audio2,
|
|
||||||
sampling_rate2,
|
|
||||||
_,
|
|
||||||
ann2,
|
|
||||||
) = self.get_file_and_anns()
|
|
||||||
audio, ann = combine_audio_aug(
|
|
||||||
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
|
|
||||||
)
|
|
||||||
|
|
||||||
# simulate echo by adding delayed copy of the file
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
audio = echo_aug(
|
|
||||||
audio,
|
|
||||||
sampling_rate,
|
|
||||||
echo_max_delay=self.params["echo_max_delay"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# resample the audio
|
|
||||||
# if np.random.random() < self.params["aug_prob"]:
|
|
||||||
# audio, sampling_rate, duration = resample_aug(
|
|
||||||
# audio, sampling_rate, self.params
|
|
||||||
# )
|
|
||||||
|
|
||||||
# create spectrogram
|
|
||||||
spec, _ = au.generate_spectrogram(
|
|
||||||
audio,
|
|
||||||
sampling_rate,
|
|
||||||
params=dict(
|
|
||||||
fft_win_length=self.params["fft_win_length"],
|
|
||||||
fft_overlap=self.params["fft_overlap"],
|
|
||||||
max_freq=self.params["max_freq"],
|
|
||||||
min_freq=self.params["min_freq"],
|
|
||||||
spec_scale=self.params["spec_scale"],
|
|
||||||
denoise_spec_avg=self.params["denoise_spec_avg"],
|
|
||||||
max_scale_spec=self.params["max_scale_spec"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
rsf = self.params["resize_factor"]
|
|
||||||
spec_op_shape = (
|
|
||||||
int(self.params["spec_height"] * rsf),
|
|
||||||
int(spec.shape[1] * rsf),
|
|
||||||
)
|
|
||||||
|
|
||||||
# resize the spec
|
|
||||||
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
|
||||||
spec = F.interpolate(
|
|
||||||
spec,
|
|
||||||
size=spec_op_shape,
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
# augment spectrogram
|
|
||||||
if self.is_train and self.params["augment_at_train"]:
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
spec = scale_vol_aug(
|
|
||||||
spec,
|
|
||||||
spec_amp_scaling=self.params["spec_amp_scaling"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
spec = warp_spec_aug(
|
|
||||||
spec,
|
|
||||||
ann,
|
|
||||||
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
spec = mask_time_aug(
|
|
||||||
spec,
|
|
||||||
mask_max_time_perc=self.params["mask_max_time_perc"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
|
||||||
spec = mask_freq_aug(
|
|
||||||
spec,
|
|
||||||
mask_max_freq_perc=self.params["mask_max_freq_perc"],
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = {}
|
|
||||||
outputs["spec"] = spec
|
|
||||||
if self.return_spec_for_viz:
|
|
||||||
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
|
|
||||||
0
|
|
||||||
)
|
|
||||||
|
|
||||||
# create ground truth heatmaps
|
|
||||||
(
|
|
||||||
outputs["y_2d_det"],
|
|
||||||
outputs["y_2d_size"],
|
|
||||||
outputs["y_2d_classes"],
|
|
||||||
ann_aug,
|
|
||||||
) = generate_gt_heatmaps(
|
|
||||||
spec_op_shape,
|
|
||||||
sampling_rate,
|
|
||||||
ann,
|
|
||||||
class_names=self.params["class_names"],
|
|
||||||
fft_win_length=self.params["fft_win_length"],
|
|
||||||
fft_overlap=self.params["fft_overlap"],
|
|
||||||
max_freq=self.params["max_freq"],
|
|
||||||
min_freq=self.params["min_freq"],
|
|
||||||
resize_factor=self.params["resize_factor"],
|
|
||||||
target_sigma=self.params["target_sigma"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# hack to get around requirement that all vectors are the same length
|
|
||||||
# in the output batch
|
|
||||||
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
|
|
||||||
outputs["is_valid"] = pad_aray(
|
|
||||||
np.ones(len(ann_aug["individual_ids"])), pad_size
|
|
||||||
)
|
|
||||||
keys = [
|
|
||||||
"class_ids",
|
|
||||||
"individual_ids",
|
|
||||||
"x_inds",
|
|
||||||
"y_inds",
|
|
||||||
"start_times",
|
|
||||||
"end_times",
|
|
||||||
"low_freqs",
|
|
||||||
"high_freqs",
|
|
||||||
]
|
|
||||||
for kk in keys:
|
|
||||||
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
|
|
||||||
|
|
||||||
# convert to pytorch
|
|
||||||
for kk in outputs.keys():
|
|
||||||
if type(outputs[kk]) != torch.Tensor:
|
|
||||||
outputs[kk] = torch.from_numpy(outputs[kk])
|
|
||||||
|
|
||||||
# scalars
|
|
||||||
outputs["class_id_file"] = ann["class_id_file"]
|
|
||||||
outputs["annotated"] = ann["annotated"]
|
|
||||||
outputs["duration"] = duration
|
|
||||||
outputs["sampling_rate"] = sampling_rate
|
|
||||||
outputs["file_id"] = index
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
"""Denotes the total number of samples."""
|
|
||||||
return len(self.data_anns)
|
|
@ -1,133 +1,161 @@
|
|||||||
from functools import wraps
|
from typing import Callable, Optional, Union
|
||||||
from typing import Callable, List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import arrays
|
||||||
|
from soundevent.arrays import operations as ops
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||||
|
|
||||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||||
|
|
||||||
|
|
||||||
AUGMENTATION_PROBABILITY = 0.2
|
class AugmentationConfig(BaseConfig):
|
||||||
MAX_DELAY = 0.005
|
enable: bool = True
|
||||||
STRETCH_SQUEEZE_DELTA = 0.04
|
probability: float = 0.2
|
||||||
MASK_MAX_TIME_PERC: float = 0.05
|
|
||||||
MASK_MAX_FREQ_PERC: float = 0.10
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_apply(
|
class SubclipConfig(BaseConfig):
|
||||||
augmentation: Callable,
|
enable: bool = True
|
||||||
prob: float = AUGMENTATION_PROBABILITY,
|
duration: Optional[float] = None
|
||||||
) -> Callable:
|
|
||||||
"""Apply an augmentation with a given probability."""
|
|
||||||
|
|
||||||
@wraps(augmentation)
|
|
||||||
def _augmentation(x):
|
|
||||||
if np.random.rand() > prob:
|
|
||||||
return x
|
|
||||||
return augmentation(x)
|
|
||||||
|
|
||||||
return _augmentation
|
|
||||||
|
|
||||||
|
|
||||||
def select_random_subclip(
|
def select_random_subclip(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
|
start_time: Optional[float] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
proportion: float = 0.9,
|
width: Optional[int] = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Select a random subclip from a clip."""
|
"""Select a random subclip from a clip."""
|
||||||
|
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||||
|
|
||||||
time_coords = train_example.coords["time"]
|
if width is None:
|
||||||
|
if duration is None:
|
||||||
|
raise ValueError("Either duration or width must be provided")
|
||||||
|
|
||||||
start_time = time_coords.attrs.get("min", time_coords.min())
|
width = int(np.floor(duration / step))
|
||||||
end_time = time_coords.attrs.get("max", time_coords.max())
|
|
||||||
|
|
||||||
if duration is None:
|
if duration is None:
|
||||||
duration = (end_time - start_time) * proportion
|
duration = width * step
|
||||||
|
|
||||||
start_time = np.random.uniform(start_time, end_time - duration)
|
if start_time is None:
|
||||||
return train_example.sel(time=slice(start_time, start_time + duration))
|
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||||
|
start_time = np.random.uniform(start, stop - duration)
|
||||||
|
|
||||||
|
start_index = arrays.get_coord_index(
|
||||||
|
example, # type: ignore
|
||||||
|
"time",
|
||||||
|
start_time,
|
||||||
|
)
|
||||||
|
end_index = start_index + width - 1
|
||||||
|
start_time = example.time.values[start_index]
|
||||||
|
end_time = example.time.values[end_index]
|
||||||
|
|
||||||
|
return example.sel(
|
||||||
|
time=slice(start_time, end_time),
|
||||||
|
audio_time=slice(start_time, end_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def combine_audio(
|
class MixAugmentationConfig(AugmentationConfig):
|
||||||
audio1: xr.DataArray,
|
min_weight: float = 0.3
|
||||||
audio2: xr.DataArray,
|
max_weight: float = 0.7
|
||||||
alpha: Optional[float] = None,
|
|
||||||
min_alpha: float = 0.3,
|
|
||||||
max_alpha: float = 0.7,
|
def mix_examples(
|
||||||
) -> xr.DataArray:
|
example: xr.Dataset,
|
||||||
|
other: xr.Dataset,
|
||||||
|
weight: Optional[float] = None,
|
||||||
|
min_weight: float = 0.3,
|
||||||
|
max_weight: float = 0.7,
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
|
) -> xr.Dataset:
|
||||||
"""Combine two audio clips."""
|
"""Combine two audio clips."""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
|
||||||
if alpha is None:
|
if weight is None:
|
||||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
weight = np.random.uniform(min_weight, max_weight)
|
||||||
|
|
||||||
return alpha * audio1 + (1 - alpha) * audio2.data
|
audio2 = other["audio"].values
|
||||||
|
audio1 = ops.adjust_dim_width(example["audio"], "audio_time", len(audio2))
|
||||||
|
combined = weight * audio1 + (1 - weight) * audio2
|
||||||
|
|
||||||
|
spec = compute_spectrogram(
|
||||||
|
combined.rename({"audio_time": "time"}),
|
||||||
|
config=config.spectrogram,
|
||||||
|
)
|
||||||
|
|
||||||
|
detection_heatmap = xr.apply_ufunc(
|
||||||
|
np.maximum,
|
||||||
|
example["detection"],
|
||||||
|
other["detection"].values,
|
||||||
|
)
|
||||||
|
|
||||||
|
class_heatmap = xr.apply_ufunc(
|
||||||
|
np.maximum,
|
||||||
|
example["class"],
|
||||||
|
other["class"].values,
|
||||||
|
)
|
||||||
|
|
||||||
|
size_heatmap = example["size"] + other["size"].values
|
||||||
|
|
||||||
|
return xr.Dataset(
|
||||||
|
{
|
||||||
|
"audio": combined,
|
||||||
|
"spectrogram": spec,
|
||||||
|
"detection": detection_heatmap,
|
||||||
|
"class": class_heatmap,
|
||||||
|
"size": size_heatmap,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# def random_mix(
|
class EchoAugmentationConfig(AugmentationConfig):
|
||||||
# audio: xr.DataArray,
|
max_delay: float = 0.005
|
||||||
# clip: data.ClipAnnotation,
|
min_weight: float = 0.0
|
||||||
# provider: Optional[ClipProvider] = None,
|
max_weight: float = 1.0
|
||||||
# alpha: Optional[float] = None,
|
|
||||||
# min_alpha: float = 0.3,
|
|
||||||
# max_alpha: float = 0.7,
|
|
||||||
# join_annotations: bool = True,
|
|
||||||
# ) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
|
||||||
# """Mix two audio clips."""
|
|
||||||
# if provider is None:
|
|
||||||
# raise ValueError("No audio provider given.")
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# other_audio, other_clip = provider(clip)
|
|
||||||
# except (StopIteration, ValueError):
|
|
||||||
# raise ValueError("No more audio sources available.")
|
|
||||||
#
|
|
||||||
# new_audio = combine_audio(
|
|
||||||
# audio,
|
|
||||||
# other_audio,
|
|
||||||
# alpha=alpha,
|
|
||||||
# min_alpha=min_alpha,
|
|
||||||
# max_alpha=max_alpha,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# if join_annotations:
|
|
||||||
# clip = clip.model_copy(
|
|
||||||
# update=dict(
|
|
||||||
# sound_events=clip.sound_events + other_clip.sound_events,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# return new_audio, clip
|
|
||||||
|
|
||||||
|
|
||||||
def add_echo(
|
def add_echo(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
delay: Optional[float] = None,
|
delay: Optional[float] = None,
|
||||||
alpha: Optional[float] = None,
|
weight: Optional[float] = None,
|
||||||
min_alpha: float = 0.0,
|
min_weight: float = 0.1,
|
||||||
max_alpha: float = 1.0,
|
max_weight: float = 1.0,
|
||||||
max_delay: float = MAX_DELAY,
|
max_delay: float = 0.005,
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Add a delay to the audio."""
|
"""Add a delay to the audio."""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
|
||||||
if delay is None:
|
if delay is None:
|
||||||
delay = np.random.uniform(0, max_delay)
|
delay = np.random.uniform(0, max_delay)
|
||||||
|
|
||||||
if alpha is None:
|
if weight is None:
|
||||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
weight = np.random.uniform(min_weight, max_weight)
|
||||||
|
|
||||||
spec = train_example["spectrogram"]
|
audio = example["audio"]
|
||||||
|
step = arrays.get_dim_step(audio, "audio_time")
|
||||||
|
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
||||||
|
audio = audio + weight * audio_delay
|
||||||
|
|
||||||
time_coords = spec.coords["time"]
|
spectrogram = compute_spectrogram(
|
||||||
start_time = time_coords.attrs["min"]
|
audio.rename({"audio_time": "time"}),
|
||||||
end_time = time_coords.attrs["max"]
|
config=config.spectrogram,
|
||||||
step = (end_time - start_time) / time_coords.size
|
)
|
||||||
|
|
||||||
spec_delay = spec.shift(time=int(delay / step), fill_value=0)
|
return example.assign(audio=audio, spectrogram=spectrogram)
|
||||||
|
|
||||||
return train_example.assign(spectrogram=spec + alpha * spec_delay)
|
|
||||||
|
class VolumeAugmentationConfig(AugmentationConfig):
|
||||||
|
min_scaling: float = 0.0
|
||||||
|
max_scaling: float = 2.0
|
||||||
|
|
||||||
|
|
||||||
def scale_volume(
|
def scale_volume(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
factor: Optional[float] = None,
|
factor: Optional[float] = None,
|
||||||
max_scaling: float = 2,
|
max_scaling: float = 2,
|
||||||
min_scaling: float = 0,
|
min_scaling: float = 0,
|
||||||
@ -136,106 +164,235 @@ def scale_volume(
|
|||||||
if factor is None:
|
if factor is None:
|
||||||
factor = np.random.uniform(min_scaling, max_scaling)
|
factor = np.random.uniform(min_scaling, max_scaling)
|
||||||
|
|
||||||
return train_example.assign(
|
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||||
spectrogram=train_example["spectrogram"] * factor
|
|
||||||
)
|
|
||||||
|
class WarpAugmentationConfig(AugmentationConfig):
|
||||||
|
delta: float = 0.04
|
||||||
|
|
||||||
|
|
||||||
def warp_spectrogram(
|
def warp_spectrogram(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
factor: Optional[float] = None,
|
factor: Optional[float] = None,
|
||||||
delta: float = STRETCH_SQUEEZE_DELTA,
|
delta: float = 0.04,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Warp a spectrogram."""
|
"""Warp a spectrogram."""
|
||||||
if factor is None:
|
if factor is None:
|
||||||
factor = np.random.uniform(1 - delta, 1 + delta)
|
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||||
|
|
||||||
time_coords = train_example.coords["time"]
|
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||||
start_time = time_coords.attrs["min"]
|
|
||||||
end_time = time_coords.attrs["max"]
|
|
||||||
duration = end_time - start_time
|
duration = end_time - start_time
|
||||||
|
|
||||||
new_time = np.linspace(
|
new_time = np.linspace(
|
||||||
start_time,
|
start_time,
|
||||||
start_time + duration * factor,
|
start_time + duration * factor,
|
||||||
train_example.time.size,
|
example.time.size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_example.interp(time=new_time)
|
spectrogram = (
|
||||||
|
example["spectrogram"]
|
||||||
|
.interp(
|
||||||
|
coords={"time": new_time},
|
||||||
|
method="linear",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.clip(min=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
detection = example["detection"].interp(
|
||||||
|
time=new_time,
|
||||||
|
method="nearest",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
classification = example["class"].interp(
|
||||||
|
time=new_time,
|
||||||
|
method="nearest",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
size = example["size"].interp(
|
||||||
|
time=new_time,
|
||||||
|
method="nearest",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return example.assign(
|
||||||
|
{
|
||||||
|
"time": new_time,
|
||||||
|
"spectrogram": spectrogram,
|
||||||
|
"detection": detection,
|
||||||
|
"class": classification,
|
||||||
|
"size": size,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mask_axis(
|
def mask_axis(
|
||||||
train_example: xr.Dataset,
|
array: xr.DataArray,
|
||||||
dim: str,
|
dim: str,
|
||||||
start: float,
|
start: float,
|
||||||
end: float,
|
end: float,
|
||||||
mask_all: bool = False,
|
mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
|
||||||
mask_value: float = 0,
|
) -> xr.DataArray:
|
||||||
) -> xr.Dataset:
|
if dim not in array.dims:
|
||||||
if dim not in train_example.dims:
|
|
||||||
raise ValueError(f"Axis {dim} not found in array")
|
raise ValueError(f"Axis {dim} not found in array")
|
||||||
|
|
||||||
coord = train_example.coords[dim]
|
coord = array.coords[dim]
|
||||||
condition = (coord < start) | (coord > end)
|
condition = (coord < start) | (coord > end)
|
||||||
|
|
||||||
if mask_all:
|
if callable(mask_value):
|
||||||
return train_example.where(condition, other=mask_value)
|
mask_value = mask_value(array)
|
||||||
|
|
||||||
return train_example.assign(
|
return array.where(condition, other=mask_value)
|
||||||
spectrogram=train_example.spectrogram.where(
|
|
||||||
condition, other=mask_value
|
|
||||||
)
|
class TimeMaskAugmentationConfig(AugmentationConfig):
|
||||||
)
|
max_perc: float = 0.05
|
||||||
|
max_masks: int = 3
|
||||||
|
|
||||||
|
|
||||||
def mask_time(
|
def mask_time(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
max_time_mask: float = MASK_MAX_TIME_PERC,
|
max_perc: float = 0.05,
|
||||||
max_num_masks: int = 3,
|
max_mask: int = 3,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Mask a random section of the time axis."""
|
"""Mask a random section of the time axis."""
|
||||||
|
num_masks = np.random.randint(1, max_mask + 1)
|
||||||
|
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||||
|
|
||||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
spectrogram = example["spectrogram"]
|
||||||
|
|
||||||
time_coord = train_example.coords["time"]
|
|
||||||
start_time = time_coord.attrs.get("min", time_coord.min())
|
|
||||||
end_time = time_coord.attrs.get("max", time_coord.max())
|
|
||||||
|
|
||||||
for _ in range(num_masks):
|
for _ in range(num_masks):
|
||||||
mask_size = np.random.uniform(0, max_time_mask)
|
mask_size = np.random.uniform(0, max_perc) * (end_time - start_time)
|
||||||
start = np.random.uniform(start_time, end_time - mask_size)
|
start = np.random.uniform(start_time, end_time - mask_size)
|
||||||
end = start + mask_size
|
end = start + mask_size
|
||||||
train_example = mask_axis(train_example, "time", start, end)
|
spectrogram = mask_axis(spectrogram, "time", start, end)
|
||||||
|
|
||||||
return train_example
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyMaskAugmentationConfig(AugmentationConfig):
|
||||||
|
max_perc: float = 0.10
|
||||||
|
max_masks: int = 3
|
||||||
|
|
||||||
|
|
||||||
def mask_frequency(
|
def mask_frequency(
|
||||||
train_example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
max_freq_mask: float = MASK_MAX_FREQ_PERC,
|
max_perc: float = 0.10,
|
||||||
max_num_masks: int = 3,
|
max_masks: int = 3,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Mask a random section of the frequency axis."""
|
"""Mask a random section of the frequency axis."""
|
||||||
|
num_masks = np.random.randint(1, max_masks + 1)
|
||||||
|
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore
|
||||||
|
|
||||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
spectrogram = example["spectrogram"]
|
||||||
|
|
||||||
freq_coord = train_example.coords["frequency"]
|
|
||||||
min_freq = float(freq_coord.min())
|
|
||||||
max_freq = float(freq_coord.max())
|
|
||||||
|
|
||||||
for _ in range(num_masks):
|
for _ in range(num_masks):
|
||||||
mask_size = np.random.uniform(0, max_freq_mask)
|
mask_size = np.random.uniform(0, max_perc) * (max_freq - min_freq)
|
||||||
start = np.random.uniform(min_freq, max_freq - mask_size)
|
start = np.random.uniform(min_freq, max_freq - mask_size)
|
||||||
end = start + mask_size
|
end = start + mask_size
|
||||||
train_example = mask_axis(train_example, "frequency", start, end)
|
spectrogram = mask_axis(spectrogram, "frequency", start, end)
|
||||||
|
|
||||||
return train_example
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
AUGMENTATIONS: List[Augmentation] = [
|
class AugmentationsConfig(BaseConfig):
|
||||||
select_random_subclip,
|
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||||
add_echo,
|
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
||||||
scale_volume,
|
echo: EchoAugmentationConfig = Field(
|
||||||
mask_time,
|
default_factory=EchoAugmentationConfig
|
||||||
mask_frequency,
|
)
|
||||||
]
|
volume: VolumeAugmentationConfig = Field(
|
||||||
|
default_factory=VolumeAugmentationConfig
|
||||||
|
)
|
||||||
|
warp: WarpAugmentationConfig = Field(
|
||||||
|
default_factory=WarpAugmentationConfig
|
||||||
|
)
|
||||||
|
time_mask: TimeMaskAugmentationConfig = Field(
|
||||||
|
default_factory=TimeMaskAugmentationConfig
|
||||||
|
)
|
||||||
|
frequency_mask: FrequencyMaskAugmentationConfig = Field(
|
||||||
|
default_factory=FrequencyMaskAugmentationConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def should_apply(config: AugmentationConfig) -> bool:
|
||||||
|
if not config.enable:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return np.random.uniform() < config.probability
|
||||||
|
|
||||||
|
|
||||||
|
def augment_example(
|
||||||
|
example: xr.Dataset,
|
||||||
|
config: AugmentationsConfig,
|
||||||
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
|
others: Optional[Callable[[], xr.Dataset]] = None,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
if config.subclip.enable:
|
||||||
|
example = select_random_subclip(
|
||||||
|
example,
|
||||||
|
duration=config.subclip.duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.mix) and others is not None:
|
||||||
|
other = others()
|
||||||
|
|
||||||
|
if config.subclip.enable:
|
||||||
|
other = select_random_subclip(
|
||||||
|
other,
|
||||||
|
duration=config.subclip.duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
example = mix_examples(
|
||||||
|
example,
|
||||||
|
other,
|
||||||
|
min_weight=config.mix.min_weight,
|
||||||
|
max_weight=config.mix.max_weight,
|
||||||
|
config=preprocessing_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.echo):
|
||||||
|
example = add_echo(
|
||||||
|
example,
|
||||||
|
max_delay=config.echo.max_delay,
|
||||||
|
min_weight=config.echo.min_weight,
|
||||||
|
max_weight=config.echo.max_weight,
|
||||||
|
config=preprocessing_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.volume):
|
||||||
|
example = scale_volume(
|
||||||
|
example,
|
||||||
|
max_scaling=config.volume.max_scaling,
|
||||||
|
min_scaling=config.volume.min_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.warp):
|
||||||
|
example = warp_spectrogram(
|
||||||
|
example,
|
||||||
|
delta=config.warp.delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.time_mask):
|
||||||
|
example = mask_time(
|
||||||
|
example,
|
||||||
|
max_perc=config.time_mask.max_perc,
|
||||||
|
max_mask=config.time_mask.max_masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_apply(config.frequency_mask):
|
||||||
|
example = mask_frequency(
|
||||||
|
example,
|
||||||
|
max_perc=config.frequency_mask.max_perc,
|
||||||
|
max_masks=config.frequency_mask.max_masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
return example
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, NamedTuple, Optional, Sequence, Union
|
from typing import NamedTuple, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from batdetect2.train.augmentations import AugmentationsConfig, augment_example
|
||||||
from batdetect2.train.preprocess import PreprocessingConfig
|
from batdetect2.train.preprocess import PreprocessingConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -34,21 +36,44 @@ class LabeledDataset(Dataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filenames: Sequence[PathLike],
|
filenames: Sequence[PathLike],
|
||||||
transform: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
|
augment: bool = False,
|
||||||
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
|
augmentation_config: Optional[AugmentationsConfig] = None,
|
||||||
):
|
):
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.transform = transform
|
self.augment = augment
|
||||||
|
self.preprocessing_config = (
|
||||||
|
preprocessing_config or PreprocessingConfig()
|
||||||
|
)
|
||||||
|
self.agumentation_config = augmentation_config or AugmentationsConfig()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.filenames)
|
return len(self.filenames)
|
||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
data = self.load(idx)
|
dataset = self.get_dataset(idx)
|
||||||
|
|
||||||
|
if self.augment:
|
||||||
|
dataset = augment_example(
|
||||||
|
dataset,
|
||||||
|
self.agumentation_config,
|
||||||
|
preprocessing_config=self.preprocessing_config,
|
||||||
|
others=self.get_random_example,
|
||||||
|
)
|
||||||
|
|
||||||
return TrainExample(
|
return TrainExample(
|
||||||
spec=data["spectrogram"],
|
spec=torch.tensor(
|
||||||
detection_heatmap=data["detection"],
|
dataset["spectrogram"].values.astype(np.float32)
|
||||||
class_heatmap=data["class"],
|
).unsqueeze(0),
|
||||||
size_heatmap=data["size"],
|
detection_heatmap=torch.tensor(
|
||||||
|
dataset["detection"].values.astype(np.float32)
|
||||||
|
),
|
||||||
|
class_heatmap=torch.tensor(
|
||||||
|
dataset["class"].values.astype(np.float32)
|
||||||
|
),
|
||||||
|
size_heatmap=torch.tensor(
|
||||||
|
dataset["size"].values.astype(np.float32)
|
||||||
|
),
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -56,44 +81,14 @@ class LabeledDataset(Dataset):
|
|||||||
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
||||||
return cls(get_files(directory, extension))
|
return cls(get_files(directory, extension))
|
||||||
|
|
||||||
def load(self, idx) -> Dict[str, torch.Tensor]:
|
def get_random_example(self) -> xr.Dataset:
|
||||||
dataset = self.get_dataset(idx)
|
idx = np.random.randint(0, len(self))
|
||||||
return {
|
return self.get_dataset(idx)
|
||||||
"spectrogram": torch.tensor(
|
|
||||||
dataset["spectrogram"].values
|
|
||||||
).unsqueeze(0),
|
|
||||||
"detection": torch.tensor(dataset["detection"].values),
|
|
||||||
"class": torch.tensor(dataset["class"].values),
|
|
||||||
"size": torch.tensor(dataset["size"].values),
|
|
||||||
}
|
|
||||||
|
|
||||||
def apply_augmentation(self, dataset: xr.Dataset) -> xr.Dataset:
|
def get_dataset(self, idx) -> xr.Dataset:
|
||||||
if self.transform is not None:
|
|
||||||
return self.transform(dataset)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
def get_dataset(self, idx):
|
|
||||||
return xr.open_dataset(self.filenames[idx])
|
return xr.open_dataset(self.filenames[idx])
|
||||||
|
|
||||||
def get_spectrogram(self, idx):
|
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
||||||
return xr.open_dataset(self.filenames[idx])["spectrogram"]
|
return data.ClipAnnotation.model_validate_json(
|
||||||
|
self.get_dataset(idx).attrs["clip_annotation"]
|
||||||
def get_detection_mask(self, idx):
|
)
|
||||||
return xr.open_dataset(self.filenames[idx])["detection"]
|
|
||||||
|
|
||||||
def get_class_mask(self, idx):
|
|
||||||
return xr.open_dataset(self.filenames[idx])["class"]
|
|
||||||
|
|
||||||
def get_size_mask(self, idx):
|
|
||||||
return xr.open_dataset(self.filenames[idx])["size"]
|
|
||||||
|
|
||||||
def get_clip_annotation(self, idx):
|
|
||||||
filename = self.filenames[idx]
|
|
||||||
dataset = xr.open_dataset(filename)
|
|
||||||
clip_annotation = dataset.attrs["clip_annotation"]
|
|
||||||
return data.ClipAnnotation.model_validate_json(clip_annotation)
|
|
||||||
|
|
||||||
def get_preprocessing_configuration(self, idx):
|
|
||||||
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"]
|
|
||||||
return PreprocessingConfig.model_validate_json(config)
|
|
||||||
|
@ -1,18 +1,14 @@
|
|||||||
from typing import Sequence, Tuple
|
from typing import Callable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from scipy.ndimage import gaussian_filter
|
from scipy.ndimage import gaussian_filter
|
||||||
from soundevent import arrays, data, geometry
|
from soundevent import arrays, data, geometry
|
||||||
from soundevent.geometry.operations import Positions
|
from soundevent.geometry.operations import Positions
|
||||||
from soundevent.types import ClassMapper
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["generate_heatmaps"]
|
||||||
"ClassMapper",
|
|
||||||
"generate_heatmaps",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class HeatmapsConfig(BaseConfig):
|
class HeatmapsConfig(BaseConfig):
|
||||||
@ -25,7 +21,8 @@ class HeatmapsConfig(BaseConfig):
|
|||||||
def generate_heatmaps(
|
def generate_heatmaps(
|
||||||
sound_events: Sequence[data.SoundEventAnnotation],
|
sound_events: Sequence[data.SoundEventAnnotation],
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
class_mapper: ClassMapper,
|
class_names: List[str],
|
||||||
|
encoder: Callable[[data.SoundEventAnnotation], Optional[str]],
|
||||||
target_sigma: float = 3.0,
|
target_sigma: float = 3.0,
|
||||||
position: Positions = "bottom-left",
|
position: Positions = "bottom-left",
|
||||||
time_scale: float = 1000.0,
|
time_scale: float = 1000.0,
|
||||||
@ -42,10 +39,10 @@ def generate_heatmaps(
|
|||||||
# Initialize heatmaps
|
# Initialize heatmaps
|
||||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||||
class_heatmap = xr.DataArray(
|
class_heatmap = xr.DataArray(
|
||||||
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
|
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
|
||||||
dims=["category", *spec.dims],
|
dims=["category", *spec.dims],
|
||||||
coords={
|
coords={
|
||||||
"category": [*class_mapper.class_labels],
|
"category": [*class_names],
|
||||||
**spec.coords,
|
**spec.coords,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -92,7 +89,7 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get the class name of the sound event
|
# Get the class name of the sound event
|
||||||
class_name = class_mapper.encode(sound_event_annotation)
|
class_name = encoder(sound_event_annotation)
|
||||||
|
|
||||||
if class_name is None:
|
if class_name is None:
|
||||||
# If the label is None skip the sound event
|
# If the label is None skip the sound event
|
||||||
|
@ -1,56 +0,0 @@
|
|||||||
import pytorch_lightning as L
|
|
||||||
from torch import Tensor, optim
|
|
||||||
|
|
||||||
from batdetect2.models.typing import DetectionModel, ModelOutput
|
|
||||||
from batdetect2.train import losses
|
|
||||||
|
|
||||||
from batdetect2.train.dataset import TrainExample
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"LitDetectorModel",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class LitDetectorModel(L.LightningModule):
|
|
||||||
model: DetectionModel
|
|
||||||
|
|
||||||
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.learning_rate = learning_rate
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
outputs: ModelOutput,
|
|
||||||
batch: TrainExample,
|
|
||||||
) -> Tensor:
|
|
||||||
detection_loss = losses.focal_loss(
|
|
||||||
outputs.detection_probs,
|
|
||||||
batch.detection_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
size_loss = losses.bbox_size_loss(
|
|
||||||
outputs.size_preds,
|
|
||||||
batch.size_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
|
||||||
classification_loss = losses.focal_loss(
|
|
||||||
outputs.class_probs,
|
|
||||||
batch.class_heatmap,
|
|
||||||
valid_mask=valid_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
return detection_loss + size_loss + classification_loss
|
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
|
|
||||||
outputs: ModelOutput = self.model(batch.spec)
|
|
||||||
loss = self.compute_loss(outputs, batch)
|
|
||||||
self.log("train_loss", loss)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
|
||||||
return [optimizer], [scheduler]
|
|
@ -1,7 +1,16 @@
|
|||||||
from typing import Optional
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
|
||||||
|
|
||||||
|
class SizeLossConfig(BaseConfig):
|
||||||
|
weight: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
def bbox_size_loss(
|
def bbox_size_loss(
|
||||||
@ -17,6 +26,11 @@ def bbox_size_loss(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FocalLossConfig(BaseConfig):
|
||||||
|
beta: float = 4
|
||||||
|
alpha: float = 2
|
||||||
|
|
||||||
|
|
||||||
def focal_loss(
|
def focal_loss(
|
||||||
pred: torch.Tensor,
|
pred: torch.Tensor,
|
||||||
gt: torch.Tensor,
|
gt: torch.Tensor,
|
||||||
@ -44,7 +58,7 @@ def focal_loss(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
pos_loss = pos_loss * weights
|
pos_loss = pos_loss * torch.tensor(weights)
|
||||||
# neg_loss = neg_loss*weights
|
# neg_loss = neg_loss*weights
|
||||||
|
|
||||||
if valid_mask is not None:
|
if valid_mask is not None:
|
||||||
@ -75,3 +89,71 @@ def mse_loss(
|
|||||||
else:
|
else:
|
||||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionLossConfig(BaseConfig):
|
||||||
|
weight: float = 1.0
|
||||||
|
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationLossConfig(BaseConfig):
|
||||||
|
weight: float = 2.0
|
||||||
|
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||||
|
class_weights: Optional[list[float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LossConfig(BaseConfig):
|
||||||
|
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
|
||||||
|
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
|
||||||
|
classification: ClassificationLossConfig = Field(
|
||||||
|
default_factory=ClassificationLossConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Losses(NamedTuple):
|
||||||
|
detection: torch.Tensor
|
||||||
|
size: torch.Tensor
|
||||||
|
classification: torch.Tensor
|
||||||
|
total: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
batch: TrainExample,
|
||||||
|
outputs: ModelOutput,
|
||||||
|
conf: LossConfig,
|
||||||
|
class_weights: Optional[torch.Tensor] = None,
|
||||||
|
) -> Losses:
|
||||||
|
detection_loss = focal_loss(
|
||||||
|
outputs.detection_probs,
|
||||||
|
batch.detection_heatmap,
|
||||||
|
beta=conf.detection.focal.beta,
|
||||||
|
alpha=conf.detection.focal.alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
size_loss = bbox_size_loss(
|
||||||
|
outputs.size_preds,
|
||||||
|
batch.size_heatmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||||
|
classification_loss = focal_loss(
|
||||||
|
outputs.class_probs,
|
||||||
|
batch.class_heatmap,
|
||||||
|
weights=class_weights,
|
||||||
|
valid_mask=valid_mask,
|
||||||
|
beta=conf.classification.focal.beta,
|
||||||
|
alpha=conf.classification.focal.alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
total = (
|
||||||
|
detection_loss * conf.detection.weight
|
||||||
|
+ size_loss * conf.size.weight
|
||||||
|
+ classification_loss * conf.classification.weight
|
||||||
|
)
|
||||||
|
|
||||||
|
return Losses(
|
||||||
|
detection=detection_loss,
|
||||||
|
size=size_loss,
|
||||||
|
classification=classification_loss,
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
98
batdetect2/train/modules.py
Normal file
98
batdetect2/train/modules.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytorch_lightning as L
|
||||||
|
import torch
|
||||||
|
from pydantic import Field
|
||||||
|
from torch import optim
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.models import (
|
||||||
|
BBoxHead,
|
||||||
|
ClassifierHead,
|
||||||
|
ModelConfig,
|
||||||
|
get_backbone,
|
||||||
|
)
|
||||||
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig
|
||||||
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
from batdetect2.train.losses import LossConfig, compute_loss
|
||||||
|
from batdetect2.train.targets import TargetConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerConfig(BaseConfig):
|
||||||
|
learning_rate: float = 1e-3
|
||||||
|
t_max: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfig(BaseConfig):
|
||||||
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleConfig(BaseConfig):
|
||||||
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
|
backbone: ModelConfig = Field(default_factory=ModelConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectorModel(L.LightningModule):
|
||||||
|
config: ModuleConfig
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[ModuleConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config or ModuleConfig()
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
self.backbone = get_backbone(
|
||||||
|
input_height=self.config.preprocessing.spectrogram.size.height,
|
||||||
|
config=self.config.backbone,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.classifier = ClassifierHead(
|
||||||
|
num_classes=len(self.config.targets.classes),
|
||||||
|
in_channels=self.backbone.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
|
||||||
|
|
||||||
|
conf = self.training_config.loss.classification
|
||||||
|
self.class_weights = (
|
||||||
|
torch.tensor(conf.class_weights) if conf.class_weights else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||||
|
features = self.backbone(spec)
|
||||||
|
detection_probs, classification_probs = self.classifier(features)
|
||||||
|
size_preds = self.bbox(features)
|
||||||
|
return ModelOutput(
|
||||||
|
detection_probs=detection_probs,
|
||||||
|
size_preds=size_preds,
|
||||||
|
class_probs=classification_probs,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
def training_step(self, batch: TrainExample):
|
||||||
|
outputs = self.forward(batch.spec)
|
||||||
|
losses = compute_loss(
|
||||||
|
batch,
|
||||||
|
outputs,
|
||||||
|
conf=self.config.train.loss,
|
||||||
|
class_weights=self.class_weights,
|
||||||
|
)
|
||||||
|
return losses.total
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
conf = self.config.train.optimizer
|
||||||
|
optimizer = optim.Adam(self.parameters(), lr=conf.learning_rate)
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
T_max=conf.t_max,
|
||||||
|
)
|
||||||
|
return [optimizer], [scheduler]
|
@ -20,8 +20,9 @@ from batdetect2.preprocess import (
|
|||||||
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
|
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
|
||||||
from batdetect2.train.targets import (
|
from batdetect2.train.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_class_mapper,
|
build_encoder,
|
||||||
build_sound_event_filter,
|
build_sound_event_filter,
|
||||||
|
get_class_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
@ -64,11 +65,15 @@ def generate_train_example(
|
|||||||
selected_events = [
|
selected_events = [
|
||||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||||
]
|
]
|
||||||
class_mapper = build_class_mapper(config.target.classes)
|
|
||||||
|
class_names = get_class_names(config.target.classes)
|
||||||
|
encoder = build_encoder(config.target.classes)
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
selected_events,
|
selected_events,
|
||||||
spectrogram,
|
spectrogram,
|
||||||
class_mapper,
|
class_names,
|
||||||
|
encoder,
|
||||||
target_sigma=config.heatmaps.sigma,
|
target_sigma=config.heatmaps.sigma,
|
||||||
position=config.heatmaps.position,
|
position=config.heatmaps.position,
|
||||||
time_scale=config.heatmaps.time_scale,
|
time_scale=config.heatmaps.time_scale,
|
||||||
|
@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Set
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.types import ClassMapper
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||||
@ -12,11 +11,25 @@ from batdetect2.terms import TagInfo, get_tag_from_info
|
|||||||
class TargetConfig(BaseConfig):
|
class TargetConfig(BaseConfig):
|
||||||
"""Configuration for target generation."""
|
"""Configuration for target generation."""
|
||||||
|
|
||||||
classes: List[TagInfo] = Field(default_factory=list)
|
classes: List[TagInfo] = Field(
|
||||||
generic_class: Optional[TagInfo] = None
|
default_factory=lambda: [
|
||||||
|
TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST
|
||||||
|
]
|
||||||
|
)
|
||||||
|
generic_class: Optional[TagInfo] = Field(
|
||||||
|
default_factory=lambda: TagInfo(key="class", value="Bat")
|
||||||
|
)
|
||||||
|
|
||||||
include: Optional[List[TagInfo]] = None
|
include: Optional[List[TagInfo]] = Field(
|
||||||
exclude: Optional[List[TagInfo]] = None
|
default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
|
||||||
|
)
|
||||||
|
exclude: Optional[List[TagInfo]] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
TagInfo(key="class", value=""),
|
||||||
|
TagInfo(key="class", value=" "),
|
||||||
|
TagInfo(key="class", value="Unknown"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_filter(
|
def build_sound_event_filter(
|
||||||
@ -36,13 +49,54 @@ def build_sound_event_filter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_class_mapper(classes: List[TagInfo]) -> ClassMapper:
|
def get_tag_label(tag_info: TagInfo) -> str:
|
||||||
target_tags = [get_tag_from_info(tag) for tag in classes]
|
return tag_info.label if tag_info.label else tag_info.value
|
||||||
labels = [tag.label if tag.label else tag.value for tag in classes]
|
|
||||||
return GenericMapper(
|
|
||||||
classes=target_tags,
|
def get_class_names(classes: List[TagInfo]) -> List[str]:
|
||||||
labels=labels,
|
return sorted({get_tag_label(tag) for tag in classes})
|
||||||
)
|
|
||||||
|
|
||||||
|
def build_encoder(
|
||||||
|
classes: List[TagInfo],
|
||||||
|
) -> Callable[[data.SoundEventAnnotation], Optional[str]]:
|
||||||
|
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||||
|
|
||||||
|
tag_mapping = {
|
||||||
|
tag: get_tag_label(tag_info)
|
||||||
|
for tag, tag_info in zip(target_tags, classes)
|
||||||
|
}
|
||||||
|
|
||||||
|
def encoder(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> Optional[str]:
|
||||||
|
tags = set(sound_event_annotation.tags)
|
||||||
|
|
||||||
|
intersection = tags & target_tags
|
||||||
|
|
||||||
|
if not intersection:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first = intersection.pop()
|
||||||
|
return tag_mapping[first]
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def build_decoder(
|
||||||
|
classes: List[TagInfo],
|
||||||
|
) -> Callable[[str], List[data.Tag]]:
|
||||||
|
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||||
|
tag_mapping = {
|
||||||
|
get_tag_label(tag_info): tag
|
||||||
|
for tag, tag_info in zip(target_tags, classes)
|
||||||
|
}
|
||||||
|
|
||||||
|
def decoder(label: str) -> List[data.Tag]:
|
||||||
|
tag = tag_mapping.get(label)
|
||||||
|
return [tag] if tag else []
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def filter_sound_event(
|
def filter_sound_event(
|
||||||
@ -61,39 +115,22 @@ def filter_sound_event(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class GenericMapper(ClassMapper):
|
DEFAULT_SPECIES_LIST = [
|
||||||
"""Generic class mapper configuration."""
|
"Barbastellus barbastellus",
|
||||||
|
"Eptesicus serotinus",
|
||||||
def __init__(
|
"Myotis alcathoe",
|
||||||
self,
|
"Myotis bechsteinii",
|
||||||
classes: List[data.Tag],
|
"Myotis brandtii",
|
||||||
labels: List[str],
|
"Myotis daubentonii",
|
||||||
):
|
"Myotis mystacinus",
|
||||||
if not len(classes) == len(labels):
|
"Myotis nattereri",
|
||||||
raise ValueError("Number of targets and class labels must match.")
|
"Nyctalus leisleri",
|
||||||
|
"Nyctalus noctula",
|
||||||
self.targets = set(classes)
|
"Pipistrellus nathusii",
|
||||||
self.class_labels = list(dict.fromkeys(labels))
|
"Pipistrellus pipistrellus",
|
||||||
|
"Pipistrellus pygmaeus",
|
||||||
self._mapping = {tag: label for tag, label in zip(classes, labels)}
|
"Plecotus auritus",
|
||||||
self._inverse_mapping = {
|
"Plecotus austriacus",
|
||||||
label: tag for tag, label in zip(classes, labels)
|
"Rhinolophus ferrumequinum",
|
||||||
}
|
"Rhinolophus hipposideros",
|
||||||
|
]
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> Optional[str]:
|
|
||||||
tags = set(sound_event_annotation.tags)
|
|
||||||
|
|
||||||
intersection = tags & self.targets
|
|
||||||
if not intersection:
|
|
||||||
return None
|
|
||||||
|
|
||||||
tag = intersection.pop()
|
|
||||||
return self._mapping[tag]
|
|
||||||
|
|
||||||
def decode(self, label: str) -> List[data.Tag]:
|
|
||||||
if label not in self._inverse_mapping:
|
|
||||||
return []
|
|
||||||
return [self._inverse_mapping[label]]
|
|
||||||
|
@ -91,7 +91,7 @@ def recording_factory(wav_factory: Callable[..., Path]):
|
|||||||
recording_id: Optional[uuid.UUID] = None,
|
recording_id: Optional[uuid.UUID] = None,
|
||||||
duration: float = 1,
|
duration: float = 1,
|
||||||
channels: int = 1,
|
channels: int = 1,
|
||||||
samplerate: int = 44100,
|
samplerate: int = 256_000,
|
||||||
time_expansion: float = 1,
|
time_expansion: float = 1,
|
||||||
) -> data.Recording:
|
) -> data.Recording:
|
||||||
path = path or wav_factory(
|
path = path or wav_factory(
|
||||||
|
@ -80,11 +80,12 @@ def test_spectrogram_generation_hasnt_changed(
|
|||||||
max_freq = 120_000
|
max_freq = 120_000
|
||||||
fft_overlap = 0.75
|
fft_overlap = 0.75
|
||||||
|
|
||||||
scale = None
|
|
||||||
if spec_scale == "log":
|
if spec_scale == "log":
|
||||||
scale = "log"
|
scale = preprocess.LogScaleConfig()
|
||||||
elif spec_scale == "pcen":
|
elif spec_scale == "pcen":
|
||||||
scale = preprocess.PcenConfig()
|
scale = preprocess.PcenScaleConfig()
|
||||||
|
else:
|
||||||
|
scale = preprocess.AmplitudeScaleConfig()
|
||||||
|
|
||||||
config = preprocess.SpectrogramConfig(
|
config = preprocess.SpectrogramConfig(
|
||||||
fft=preprocess.FFTConfig(
|
fft=preprocess.FFTConfig(
|
||||||
|
0
tests/test_train/__init__.py
Normal file
0
tests/test_train/__init__.py
Normal file
70
tests/test_train/test_augmentations.py
Normal file
70
tests/test_train/test_augmentations.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.train.augmentations import (
|
||||||
|
add_echo,
|
||||||
|
mix_examples,
|
||||||
|
select_random_subclip,
|
||||||
|
)
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
|
TrainPreprocessingConfig,
|
||||||
|
generate_train_example,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mix_examples(
|
||||||
|
recording_factory: Callable[..., data.Recording],
|
||||||
|
):
|
||||||
|
recording1 = recording_factory()
|
||||||
|
recording2 = recording_factory()
|
||||||
|
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
|
|
||||||
|
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
|
||||||
|
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
|
||||||
|
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||||
|
|
||||||
|
config = TrainPreprocessingConfig()
|
||||||
|
|
||||||
|
example1 = generate_train_example(clip_annotation_1, config)
|
||||||
|
example2 = generate_train_example(clip_annotation_2, config)
|
||||||
|
|
||||||
|
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
||||||
|
|
||||||
|
assert mixed["spectrogram"].shape == example1["spectrogram"].shape
|
||||||
|
assert mixed["detection"].shape == example1["detection"].shape
|
||||||
|
assert mixed["size"].shape == example1["size"].shape
|
||||||
|
assert mixed["class"].shape == example1["class"].shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_echo(
|
||||||
|
recording_factory: Callable[..., data.Recording],
|
||||||
|
):
|
||||||
|
recording1 = recording_factory()
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
config = TrainPreprocessingConfig()
|
||||||
|
original = generate_train_example(clip_annotation_1, config)
|
||||||
|
with_echo = add_echo(original, config=config.preprocessing)
|
||||||
|
|
||||||
|
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
|
||||||
|
xr.testing.assert_identical(with_echo["size"], original["size"])
|
||||||
|
xr.testing.assert_identical(with_echo["class"], original["class"])
|
||||||
|
xr.testing.assert_identical(with_echo["detection"], original["detection"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_selected_random_subclip_has_the_correct_width(
|
||||||
|
recording_factory: Callable[..., data.Recording],
|
||||||
|
):
|
||||||
|
recording1 = recording_factory()
|
||||||
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
|
config = TrainPreprocessingConfig()
|
||||||
|
original = generate_train_example(clip_annotation_1, config)
|
||||||
|
subclip = select_random_subclip(original, width=100)
|
||||||
|
|
||||||
|
assert subclip["spectrogram"].shape[1] == 100
|
@ -5,7 +5,7 @@ import xarray as xr
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.types import ClassMapper
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
from batdetect2.data.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
|
||||||
recording = data.Recording(
|
recording = data.Recording(
|
||||||
samplerate=256_000,
|
samplerate=256_000,
|
||||||
@ -60,7 +60,7 @@ def test_generated_heatmaps_have_correct_dimensions():
|
|||||||
class_mapper = Mapper()
|
class_mapper = Mapper()
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation,
|
clip_annotation.sound_events,
|
||||||
spec,
|
spec,
|
||||||
class_mapper,
|
class_mapper,
|
||||||
)
|
)
|
||||||
@ -109,7 +109,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
|
|||||||
|
|
||||||
class_mapper = Mapper()
|
class_mapper = Mapper()
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation,
|
clip_annotation.sound_events,
|
||||||
spec,
|
spec,
|
||||||
class_mapper,
|
class_mapper,
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user