mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added types to most functions
This commit is contained in:
parent
458e11cf73
commit
0aa61af445
@ -226,11 +226,10 @@ def generate_spectrogram(
|
|||||||
if config is None:
|
if config is None:
|
||||||
config = DEFAULT_SPECTROGRAM_PARAMETERS
|
config = DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
|
|
||||||
_, spec, _ = du.compute_spectrogram(
|
_, spec = du.compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
samp_rate,
|
samp_rate,
|
||||||
config,
|
config,
|
||||||
return_np=False,
|
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Functions to compute features from predictions."""
|
"""Functions to compute features from predictions."""
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -7,15 +7,26 @@ from batdetect2 import types
|
|||||||
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
|
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
|
||||||
|
|
||||||
|
|
||||||
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
def convert_int_to_freq(
|
||||||
|
spec_ind: int,
|
||||||
|
spec_height: int,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
|
) -> int:
|
||||||
"""Convert spectrogram index to frequency in Hz.""" ""
|
"""Convert spectrogram index to frequency in Hz.""" ""
|
||||||
spec_ind = spec_height - spec_ind
|
spec_ind = spec_height - spec_ind
|
||||||
return round(
|
return int(
|
||||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
round(
|
||||||
|
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq,
|
||||||
|
2,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_spec_slices(spec, pred_nms):
|
def extract_spec_slices(
|
||||||
|
spec: np.ndarray,
|
||||||
|
pred_nms: types.PredictionResults,
|
||||||
|
) -> List[np.ndarray]:
|
||||||
"""Extract spectrogram slices from spectrogram.
|
"""Extract spectrogram slices from spectrogram.
|
||||||
|
|
||||||
The slices are extracted based on detected call locations.
|
The slices are extracted based on detected call locations.
|
||||||
@ -109,7 +120,7 @@ def compute_max_power_bb(
|
|||||||
|
|
||||||
return int(
|
return int(
|
||||||
convert_int_to_freq(
|
convert_int_to_freq(
|
||||||
y_high + max_power_ind,
|
int(y_high + max_power_ind),
|
||||||
spec.shape[0],
|
spec.shape[0],
|
||||||
min_freq,
|
min_freq,
|
||||||
max_freq,
|
max_freq,
|
||||||
@ -135,13 +146,11 @@ def compute_max_power(
|
|||||||
spec_call = spec[:, x_start:x_end]
|
spec_call = spec[:, x_start:x_end]
|
||||||
power_per_freq_band = np.sum(spec_call, axis=1)
|
power_per_freq_band = np.sum(spec_call, axis=1)
|
||||||
max_power_ind = np.argmax(power_per_freq_band)
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
return int(
|
return convert_int_to_freq(
|
||||||
convert_int_to_freq(
|
int(max_power_ind),
|
||||||
max_power_ind,
|
spec.shape[0],
|
||||||
spec.shape[0],
|
min_freq,
|
||||||
min_freq,
|
max_freq,
|
||||||
max_freq,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -164,13 +173,11 @@ def compute_max_power_first(
|
|||||||
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
|
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
|
||||||
power_per_freq_band = np.sum(first_half, axis=1)
|
power_per_freq_band = np.sum(first_half, axis=1)
|
||||||
max_power_ind = np.argmax(power_per_freq_band)
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
return int(
|
return convert_int_to_freq(
|
||||||
convert_int_to_freq(
|
int(max_power_ind),
|
||||||
max_power_ind,
|
spec.shape[0],
|
||||||
spec.shape[0],
|
min_freq,
|
||||||
min_freq,
|
max_freq,
|
||||||
max_freq,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -193,13 +200,11 @@ def compute_max_power_second(
|
|||||||
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
|
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
|
||||||
power_per_freq_band = np.sum(second_half, axis=1)
|
power_per_freq_band = np.sum(second_half, axis=1)
|
||||||
max_power_ind = np.argmax(power_per_freq_band)
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
return int(
|
return convert_int_to_freq(
|
||||||
convert_int_to_freq(
|
int(max_power_ind),
|
||||||
max_power_ind,
|
spec.shape[0],
|
||||||
spec.shape[0],
|
min_freq,
|
||||||
min_freq,
|
max_freq,
|
||||||
max_freq,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -214,6 +219,7 @@ def compute_call_interval(
|
|||||||
return round(prediction["start_time"] - previous["end_time"], 5)
|
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: The order of the features in this dictionary is important. The
|
# NOTE: The order of the features in this dictionary is important. The
|
||||||
# features are extracted in this order and the order of the columns in the
|
# features are extracted in this order and the order of the columns in the
|
||||||
# output csv file is determined by this order. In order to avoid breaking
|
# output csv file is determined by this order. In order to avoid breaking
|
||||||
@ -236,7 +242,7 @@ def get_feats(
|
|||||||
spec: np.ndarray,
|
spec: np.ndarray,
|
||||||
pred_nms: types.PredictionResults,
|
pred_nms: types.PredictionResults,
|
||||||
params: types.FeatureExtractionParameters,
|
params: types.FeatureExtractionParameters,
|
||||||
):
|
) -> np.ndarray:
|
||||||
"""Extract features from spectrogram based on detected call locations.
|
"""Extract features from spectrogram based on detected call locations.
|
||||||
|
|
||||||
The features extracted are:
|
The features extracted are:
|
||||||
|
@ -79,7 +79,13 @@ class ConvBlockDownCoordF(nn.Module):
|
|||||||
|
|
||||||
class ConvBlockDownStandard(nn.Module):
|
class ConvBlockDownStandard(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1
|
self,
|
||||||
|
in_chn,
|
||||||
|
out_chn,
|
||||||
|
ip_height=None,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
):
|
):
|
||||||
super(ConvBlockDownStandard, self).__init__()
|
super(ConvBlockDownStandard, self).__init__()
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
|
@ -103,15 +103,15 @@ class Net2DFast(nn.Module):
|
|||||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
# encoder
|
# encoder
|
||||||
x1 = self.conv_dn_0(ip)
|
x1 = self.conv_dn_0(spec)
|
||||||
x2 = self.conv_dn_1(x1)
|
x2 = self.conv_dn_1(x1)
|
||||||
x3 = self.conv_dn_2(x2)
|
x3 = self.conv_dn_2(x2)
|
||||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||||
|
|
||||||
# bottleneck
|
# bottleneck
|
||||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||||
x = self.att(x)
|
x = self.att(x)
|
||||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||||
|
|
||||||
@ -121,13 +121,13 @@ class Net2DFast(nn.Module):
|
|||||||
x = self.conv_up_4(x + x1)
|
x = self.conv_up_4(x + x1)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
cls = self.conv_classes_op(x)
|
cls = self.conv_classes_op(x)
|
||||||
comb = torch.softmax(cls, 1)
|
comb = torch.softmax(cls, 1)
|
||||||
|
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
pred_size=F.relu(self.conv_size_op(x)),
|
||||||
pred_class=comb,
|
pred_class=comb,
|
||||||
pred_class_un_norm=cls,
|
pred_class_un_norm=cls,
|
||||||
features=x,
|
features=x,
|
||||||
@ -215,26 +215,26 @@ class Net2DFastNoAttn(nn.Module):
|
|||||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
x1 = self.conv_dn_0(ip)
|
x1 = self.conv_dn_0(spec)
|
||||||
x2 = self.conv_dn_1(x1)
|
x2 = self.conv_dn_1(x1)
|
||||||
x3 = self.conv_dn_2(x2)
|
x3 = self.conv_dn_2(x2)
|
||||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||||
|
|
||||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||||
|
|
||||||
x = self.conv_up_2(x + x3)
|
x = self.conv_up_2(x + x3)
|
||||||
x = self.conv_up_3(x + x2)
|
x = self.conv_up_3(x + x2)
|
||||||
x = self.conv_up_4(x + x1)
|
x = self.conv_up_4(x + x1)
|
||||||
|
|
||||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
cls = self.conv_classes_op(x)
|
cls = self.conv_classes_op(x)
|
||||||
comb = torch.softmax(cls, 1)
|
comb = torch.softmax(cls, 1)
|
||||||
|
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
pred_size=F.relu_(self.conv_size_op(x)),
|
||||||
pred_class=comb,
|
pred_class=comb,
|
||||||
pred_class_un_norm=cls,
|
pred_class_un_norm=cls,
|
||||||
features=x,
|
features=x,
|
||||||
@ -324,13 +324,13 @@ class Net2DFastNoCoordConv(nn.Module):
|
|||||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
x1 = self.conv_dn_0(ip)
|
x1 = self.conv_dn_0(spec)
|
||||||
x2 = self.conv_dn_1(x1)
|
x2 = self.conv_dn_1(x1)
|
||||||
x3 = self.conv_dn_2(x2)
|
x3 = self.conv_dn_2(x2)
|
||||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||||
|
|
||||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||||
x = self.att(x)
|
x = self.att(x)
|
||||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||||
|
|
||||||
@ -338,15 +338,13 @@ class Net2DFastNoCoordConv(nn.Module):
|
|||||||
x = self.conv_up_3(x + x2)
|
x = self.conv_up_3(x + x2)
|
||||||
x = self.conv_up_4(x + x1)
|
x = self.conv_up_4(x + x1)
|
||||||
|
|
||||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
cls = self.conv_classes_op(x)
|
cls = self.conv_classes_op(x)
|
||||||
comb = torch.softmax(cls, 1)
|
comb = torch.softmax(cls, 1)
|
||||||
|
|
||||||
pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,)
|
|
||||||
|
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
pred_size=F.relu_(self.conv_size_op(x)),
|
||||||
pred_class=comb,
|
pred_class=comb,
|
||||||
pred_class_un_norm=cls,
|
pred_class_un_norm=cls,
|
||||||
features=x,
|
features=x,
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
|
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
|
||||||
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256000
|
TARGET_SAMPLERATE_HZ = 256000
|
||||||
@ -75,158 +80,154 @@ def mk_dir(path):
|
|||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
|
|
||||||
def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
AUG_SAMPLING_RATES = [
|
||||||
params = {}
|
220500,
|
||||||
|
256000,
|
||||||
|
300000,
|
||||||
|
312500,
|
||||||
|
384000,
|
||||||
|
441000,
|
||||||
|
500000,
|
||||||
|
]
|
||||||
|
CLASSES_TO_IGNORE = ["", " ", "Unknown", "Not Bat"]
|
||||||
|
GENERIC_CLASSES = ["Bat"]
|
||||||
|
EVENTS_OF_INTEREST = ["Echolocation"]
|
||||||
|
|
||||||
params[
|
|
||||||
"model_name"
|
class TrainingParameters(BaseModel):
|
||||||
] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
|
# Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
|
||||||
params["num_filters"] = 128
|
model_name: str = "Net2DFast"
|
||||||
|
num_filters: int = 128
|
||||||
|
|
||||||
|
experiment: Path
|
||||||
|
model_file_name: Path
|
||||||
|
|
||||||
|
op_im_dir: Path
|
||||||
|
op_im_dir_test: Path
|
||||||
|
|
||||||
|
notes: str = ""
|
||||||
|
|
||||||
|
target_samp_rate: int = TARGET_SAMPLERATE_HZ
|
||||||
|
fft_win_length: float = FFT_WIN_LENGTH_S
|
||||||
|
fft_overlap: float = FFT_OVERLAP
|
||||||
|
|
||||||
|
max_freq: int = MAX_FREQ_HZ
|
||||||
|
min_freq: int = MIN_FREQ_HZ
|
||||||
|
|
||||||
|
resize_factor: float = RESIZE_FACTOR
|
||||||
|
spec_height: int = SPEC_HEIGHT
|
||||||
|
spec_train_width: int = 512
|
||||||
|
spec_divide_factor: int = SPEC_DIVIDE_FACTOR
|
||||||
|
|
||||||
|
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
||||||
|
scale_raw_audio: bool = SCALE_RAW_AUDIO
|
||||||
|
max_scale_spec: bool = MAX_SCALE_SPEC
|
||||||
|
spec_scale: str = SPEC_SCALE
|
||||||
|
|
||||||
|
detection_overlap: float = 0.01
|
||||||
|
ignore_start_end: float = 0.01
|
||||||
|
detection_threshold: float = DETECTION_THRESHOLD
|
||||||
|
nms_kernel_size: int = NMS_KERNEL_SIZE
|
||||||
|
nms_top_k_per_sec: int = NMS_TOP_K_PER_SEC
|
||||||
|
|
||||||
|
aug_prob: float = 0.20
|
||||||
|
augment_at_train: bool = True
|
||||||
|
augment_at_train_combine: bool = True
|
||||||
|
echo_max_delay: float = 0.005
|
||||||
|
stretch_squeeze_delta: float = 0.04
|
||||||
|
mask_max_time_perc: float = 0.05
|
||||||
|
mask_max_freq_perc: float = 0.10
|
||||||
|
spec_amp_scaling: float = 2.0
|
||||||
|
aug_sampling_rates: List[int] = AUG_SAMPLING_RATES
|
||||||
|
|
||||||
|
train_loss: str = "focal"
|
||||||
|
det_loss_weight: float = 1.0
|
||||||
|
size_loss_weight: float = 0.1
|
||||||
|
class_loss_weight: float = 2.0
|
||||||
|
individual_loss_weight: float = 0.0
|
||||||
|
|
||||||
|
lr: float = 0.001
|
||||||
|
batch_size: int = 8
|
||||||
|
num_workers: int = 4
|
||||||
|
num_epochs: int = 200
|
||||||
|
num_eval_epochs: int = 5
|
||||||
|
device: str = "cuda"
|
||||||
|
save_test_image_during_train: bool = False
|
||||||
|
save_test_image_after_train: bool = True
|
||||||
|
|
||||||
|
convert_to_genus: bool = False
|
||||||
|
class_names: List[str] = Field(default_factory=list)
|
||||||
|
classes_to_ignore: List[str] = Field(
|
||||||
|
default_factory=lambda: CLASSES_TO_IGNORE
|
||||||
|
)
|
||||||
|
generic_class: List[str] = Field(default_factory=lambda: GENERIC_CLASSES)
|
||||||
|
events_of_interest: List[str] = Field(
|
||||||
|
default_factory=lambda: EVENTS_OF_INTEREST
|
||||||
|
)
|
||||||
|
standardize_classs_names: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def emb_dim(self) -> int:
|
||||||
|
if self.individual_loss_weight == 0.0:
|
||||||
|
return 0
|
||||||
|
return 3
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def genus_mapping(self) -> List[int]:
|
||||||
|
_, mapping = get_genus_mapping(self.class_names)
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def genus_classes(self) -> List[str]:
|
||||||
|
names, _ = get_genus_mapping(self.class_names)
|
||||||
|
return names
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def class_names_short(self) -> List[str]:
|
||||||
|
return get_short_class_names(self.class_names)
|
||||||
|
|
||||||
|
|
||||||
|
def get_params(
|
||||||
|
make_dirs: bool = False,
|
||||||
|
exps_dir: str = "../../experiments/",
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
experiment: Union[Path, str, None] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> TrainingParameters:
|
||||||
|
experiments_dir = Path(exps_dir)
|
||||||
|
|
||||||
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
||||||
model_name = now_str + ".pth.tar"
|
|
||||||
params["experiment"] = os.path.join(exps_dir, now_str, "")
|
if model_name is None:
|
||||||
params["model_file_name"] = os.path.join(params["experiment"], model_name)
|
model_name = f"{now_str}.pth.tar"
|
||||||
params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
|
|
||||||
params["op_im_dir_test"] = os.path.join(
|
if experiment is None:
|
||||||
params["experiment"], "op_ims_test", ""
|
experiment = experiments_dir / now_str
|
||||||
|
experiment = Path(experiment)
|
||||||
|
|
||||||
|
model_file_name = experiment / model_name
|
||||||
|
op_ims_dir = experiment / "op_ims"
|
||||||
|
op_ims_test_dir = experiment / "op_ims_test"
|
||||||
|
|
||||||
|
params = TrainingParameters(
|
||||||
|
model_name=model_name,
|
||||||
|
experiment=experiment,
|
||||||
|
model_file_name=model_file_name,
|
||||||
|
op_im_dir=op_ims_dir,
|
||||||
|
op_im_dir_test=op_ims_test_dir,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# params['notes'] = '' # can save notes about an experiment here
|
|
||||||
|
|
||||||
# spec parameters
|
|
||||||
params[
|
|
||||||
"target_samp_rate"
|
|
||||||
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
|
|
||||||
params[
|
|
||||||
"fft_win_length"
|
|
||||||
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
|
|
||||||
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
|
|
||||||
|
|
||||||
params[
|
|
||||||
"max_freq"
|
|
||||||
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
|
|
||||||
params[
|
|
||||||
"min_freq"
|
|
||||||
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
|
|
||||||
|
|
||||||
params[
|
|
||||||
"resize_factor"
|
|
||||||
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
|
|
||||||
params[
|
|
||||||
"spec_height"
|
|
||||||
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
|
|
||||||
params[
|
|
||||||
"spec_train_width"
|
|
||||||
] = 512 # units are number of time steps (before resizing is performed)
|
|
||||||
params[
|
|
||||||
"spec_divide_factor"
|
|
||||||
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
|
|
||||||
|
|
||||||
# spec processing params
|
|
||||||
params[
|
|
||||||
"denoise_spec_avg"
|
|
||||||
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
|
|
||||||
params[
|
|
||||||
"scale_raw_audio"
|
|
||||||
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
|
|
||||||
params[
|
|
||||||
"max_scale_spec"
|
|
||||||
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
|
|
||||||
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
|
|
||||||
|
|
||||||
# detection params
|
|
||||||
params[
|
|
||||||
"detection_overlap"
|
|
||||||
] = 0.01 # has to be within this number of ms to count as detection
|
|
||||||
params[
|
|
||||||
"ignore_start_end"
|
|
||||||
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
|
|
||||||
params[
|
|
||||||
"detection_threshold"
|
|
||||||
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
|
|
||||||
params[
|
|
||||||
"nms_kernel_size"
|
|
||||||
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
|
|
||||||
params[
|
|
||||||
"nms_top_k_per_sec"
|
|
||||||
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
|
|
||||||
params["target_sigma"] = 2.0
|
|
||||||
|
|
||||||
# augmentation params
|
|
||||||
params[
|
|
||||||
"aug_prob"
|
|
||||||
] = 0.20 # augmentations will be performed with this probability
|
|
||||||
params["augment_at_train"] = True
|
|
||||||
params["augment_at_train_combine"] = True
|
|
||||||
params[
|
|
||||||
"echo_max_delay"
|
|
||||||
] = 0.005 # simulate echo by adding copy of raw audio
|
|
||||||
params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
|
|
||||||
params[
|
|
||||||
"mask_max_time_perc"
|
|
||||||
] = 0.05 # max mask size - here percentage, not ideal
|
|
||||||
params[
|
|
||||||
"mask_max_freq_perc"
|
|
||||||
] = 0.10 # max mask size - here percentage, not ideal
|
|
||||||
params[
|
|
||||||
"spec_amp_scaling"
|
|
||||||
] = 2.0 # multiply the "volume" by 0:X times current amount
|
|
||||||
params["aug_sampling_rates"] = [
|
|
||||||
220500,
|
|
||||||
256000,
|
|
||||||
300000,
|
|
||||||
312500,
|
|
||||||
384000,
|
|
||||||
441000,
|
|
||||||
500000,
|
|
||||||
]
|
|
||||||
|
|
||||||
# loss params
|
|
||||||
params["train_loss"] = "focal" # mse or focal
|
|
||||||
params["det_loss_weight"] = 1.0 # weight for the detection part of the loss
|
|
||||||
params["size_loss_weight"] = 0.1 # weight for the bbox size loss
|
|
||||||
params["class_loss_weight"] = 2.0 # weight for the classification loss
|
|
||||||
params["individual_loss_weight"] = 0.0 # not used
|
|
||||||
if params["individual_loss_weight"] == 0.0:
|
|
||||||
params[
|
|
||||||
"emb_dim"
|
|
||||||
] = 0 # number of dimensions used for individual id embedding
|
|
||||||
else:
|
|
||||||
params["emb_dim"] = 3
|
|
||||||
|
|
||||||
# train params
|
|
||||||
params["lr"] = 0.001
|
|
||||||
params["batch_size"] = 8
|
|
||||||
params["num_workers"] = 4
|
|
||||||
params["num_epochs"] = 200
|
|
||||||
params["num_eval_epochs"] = 5 # run evaluation every X epochs
|
|
||||||
params["device"] = "cuda"
|
|
||||||
params["save_test_image_during_train"] = False
|
|
||||||
params["save_test_image_after_train"] = True
|
|
||||||
|
|
||||||
params["convert_to_genus"] = False
|
|
||||||
params["genus_mapping"] = []
|
|
||||||
params["class_names"] = []
|
|
||||||
params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
|
|
||||||
params["generic_class"] = ["Bat"]
|
|
||||||
params["events_of_interest"] = [
|
|
||||||
"Echolocation"
|
|
||||||
] # will ignore all other types of events e.g. social calls
|
|
||||||
|
|
||||||
# the classes in this list are standardized during training so that the same low and high freq are used
|
|
||||||
params["standardize_classs_names"] = []
|
|
||||||
|
|
||||||
# create directories
|
|
||||||
if make_dirs:
|
if make_dirs:
|
||||||
print("Model name : " + params["model_name"])
|
mk_dir(experiment)
|
||||||
print("Model file : " + params["model_file_name"])
|
mk_dir(params.model_file_name.parent)
|
||||||
print("Experiment : " + params["experiment"])
|
if params.save_test_image_during_train:
|
||||||
|
mk_dir(params.op_im_dir)
|
||||||
mk_dir(params["experiment"])
|
if params.save_test_image_after_train:
|
||||||
if params["save_test_image_during_train"]:
|
mk_dir(params.op_im_dir_test)
|
||||||
mk_dir(params["op_im_dir"])
|
|
||||||
if params["save_test_image_after_train"]:
|
|
||||||
mk_dir(params["op_im_dir_test"])
|
|
||||||
mk_dir(os.path.dirname(params["model_file_name"]))
|
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
@ -1,33 +1,31 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import warnings
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.utils.data
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
import batdetect2.detector.models as models
|
|
||||||
import batdetect2.detector.parameters as parameters
|
import batdetect2.detector.parameters as parameters
|
||||||
import batdetect2.detector.post_process as pp
|
|
||||||
import batdetect2.train.audio_dataloader as adl
|
import batdetect2.train.audio_dataloader as adl
|
||||||
import batdetect2.train.evaluate as evl
|
|
||||||
import batdetect2.train.losses as losses
|
import batdetect2.train.losses as losses
|
||||||
import batdetect2.train.train_model as tm
|
import batdetect2.train.train_model as tm
|
||||||
import batdetect2.train.train_utils as tu
|
import batdetect2.train.train_utils as tu
|
||||||
import batdetect2.utils.detector_utils as du
|
import batdetect2.utils.detector_utils as du
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
|
from batdetect2 import types
|
||||||
|
from batdetect2.detector.models import Net2DFast
|
||||||
|
|
||||||
if __name__ == "__main__":
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
info_str = "\nBatDetect - Finetune Model\n"
|
|
||||||
|
|
||||||
print(info_str)
|
|
||||||
|
def parse_arugments():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"audio_path", type=str, help="Input directory for audio"
|
"audio_path",
|
||||||
|
type=str,
|
||||||
|
help="Input directory for audio",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"train_ann_path",
|
"train_ann_path",
|
||||||
@ -39,7 +37,15 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
help="Path to where test annotation file is stored",
|
help="Path to where test annotation file is stored",
|
||||||
)
|
)
|
||||||
parser.add_argument("model_path", type=str, help="Path to pretrained model")
|
parser.add_argument(
|
||||||
|
"model_path", type=str, help="Path to pretrained model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--experiment_dir",
|
||||||
|
type=str,
|
||||||
|
default=os.path.join(BASE_DIR, "experiments"),
|
||||||
|
help="Path to where experiment files are stored",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--op_model_name",
|
"--op_model_name",
|
||||||
type=str,
|
type=str,
|
||||||
@ -71,107 +77,63 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--notes", type=str, default="", help="Notes to save in text file"
|
"--notes", type=str, default="", help="Notes to save in text file"
|
||||||
)
|
)
|
||||||
args = vars(parser.parse_args())
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
params = parameters.get_params(True, "../../experiments/")
|
|
||||||
|
def select_device(warn=True) -> str:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
params["device"] = "cuda"
|
return "cuda"
|
||||||
else:
|
|
||||||
params["device"] = "cpu"
|
if warn:
|
||||||
print(
|
warnings.warn(
|
||||||
"\nNote, this will be a lot faster if you use computer with a GPU.\n"
|
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||||
|
"to speed up training."
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\nAudio directory: " + args["audio_path"])
|
return "cpu"
|
||||||
print("Train file: " + args["train_ann_path"])
|
|
||||||
print("Test file: " + args["test_ann_path"])
|
|
||||||
print("Loading model: " + args["model_path"])
|
|
||||||
|
|
||||||
dataset_name = (
|
|
||||||
os.path.basename(args["train_ann_path"])
|
|
||||||
.replace(".json", "")
|
|
||||||
.replace("_TRAIN", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
if args["train_from_scratch"]:
|
def load_annotations(
|
||||||
print("\nTraining model from scratch i.e. not using pretrained weights")
|
dataset_name: str,
|
||||||
model, params_train = du.load_model(args["model_path"], False)
|
ann_path: str,
|
||||||
else:
|
audio_path: str,
|
||||||
model, params_train = du.load_model(args["model_path"], True)
|
classes_to_ignore: Optional[List[str]] = None,
|
||||||
model.to(params["device"])
|
events_of_interest: Optional[List[str]] = None,
|
||||||
|
) -> List[types.FileAnnotation]:
|
||||||
params["num_epochs"] = args["num_epochs"]
|
train_sets: List[types.DatasetDict] = []
|
||||||
if args["op_model_name"] != "":
|
|
||||||
params["model_file_name"] = args["op_model_name"]
|
|
||||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
|
||||||
|
|
||||||
# save notes file
|
|
||||||
params["notes"] = args["notes"]
|
|
||||||
if args["notes"] != "":
|
|
||||||
tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
|
|
||||||
|
|
||||||
# load train annotations
|
|
||||||
train_sets = []
|
|
||||||
train_sets.append(
|
train_sets.append(
|
||||||
tu.get_blank_dataset_dict(
|
|
||||||
dataset_name, False, args["train_ann_path"], args["audio_path"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
params["train_sets"] = [
|
|
||||||
tu.get_blank_dataset_dict(
|
tu.get_blank_dataset_dict(
|
||||||
dataset_name,
|
dataset_name,
|
||||||
False,
|
is_test=False,
|
||||||
os.path.basename(args["train_ann_path"]),
|
ann_path=ann_path,
|
||||||
args["audio_path"],
|
wav_path=audio_path,
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
print("\nTrain set:")
|
|
||||||
(
|
|
||||||
data_train,
|
|
||||||
params["class_names"],
|
|
||||||
params["class_inv_freq"],
|
|
||||||
) = tu.load_set_of_anns(
|
|
||||||
train_sets, classes_to_ignore, params["events_of_interest"]
|
|
||||||
)
|
|
||||||
print("Number of files", len(data_train))
|
|
||||||
|
|
||||||
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
|
||||||
params["class_names"]
|
|
||||||
)
|
|
||||||
params["class_names_short"] = tu.get_short_class_names(
|
|
||||||
params["class_names"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# load test annotations
|
|
||||||
test_sets = []
|
|
||||||
test_sets.append(
|
|
||||||
tu.get_blank_dataset_dict(
|
|
||||||
dataset_name, True, args["test_ann_path"], args["audio_path"]
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
params["test_sets"] = [
|
|
||||||
tu.get_blank_dataset_dict(
|
|
||||||
dataset_name,
|
|
||||||
True,
|
|
||||||
os.path.basename(args["test_ann_path"]),
|
|
||||||
args["audio_path"],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
print("\nTest set:")
|
return tu.load_set_of_anns(
|
||||||
data_test, _, _ = tu.load_set_of_anns(
|
train_sets,
|
||||||
test_sets, classes_to_ignore, params["events_of_interest"]
|
events_of_interest=events_of_interest,
|
||||||
|
classes_to_ignore=classes_to_ignore,
|
||||||
)
|
)
|
||||||
print("Number of files", len(data_test))
|
|
||||||
|
|
||||||
|
|
||||||
|
def finetune_model(
|
||||||
|
model: types.DetectionModel,
|
||||||
|
data_train: List[types.FileAnnotation],
|
||||||
|
data_test: List[types.FileAnnotation],
|
||||||
|
params: parameters.TrainingParameters,
|
||||||
|
model_params: types.ModelParameters,
|
||||||
|
finetune_only_last_layer: bool = False,
|
||||||
|
save_images: bool = True,
|
||||||
|
):
|
||||||
# train loader
|
# train loader
|
||||||
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=params["batch_size"],
|
batch_size=params.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=params["num_workers"],
|
num_workers=params.num_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -181,32 +143,36 @@ if __name__ == "__main__":
|
|||||||
test_dataset,
|
test_dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=params["num_workers"],
|
num_workers=params.num_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs_train = next(iter(train_loader))
|
inputs_train = next(iter(train_loader))
|
||||||
params["ip_height"] = inputs_train["spec"].shape[2]
|
params.ip_height = inputs_train["spec"].shape[2]
|
||||||
print("\ntrain batch size :", inputs_train["spec"].shape)
|
print("\ntrain batch size :", inputs_train["spec"].shape)
|
||||||
|
|
||||||
assert params_train["model_name"] == "Net2DFast"
|
# Check that the model is the same as the one used to train the pretrained
|
||||||
|
# weights
|
||||||
|
assert model_params["model_name"] == "Net2DFast"
|
||||||
|
assert isinstance(model, Net2DFast)
|
||||||
print(
|
print(
|
||||||
"\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
|
"\n\nSOME hyperparams need to be the same as the loaded model "
|
||||||
|
"(e.g. FFT) - currently they are getting overwritten.\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# set the number of output classes
|
# set the number of output classes
|
||||||
num_filts = model.conv_classes_op.in_channels
|
num_filts = model.conv_classes_op.in_channels
|
||||||
k_size = model.conv_classes_op.kernel_size
|
(k_size,) = model.conv_classes_op.kernel_size
|
||||||
pad = model.conv_classes_op.padding
|
(pad,) = model.conv_classes_op.padding
|
||||||
model.conv_classes_op = torch.nn.Conv2d(
|
model.conv_classes_op = torch.nn.Conv2d(
|
||||||
num_filts,
|
num_filts,
|
||||||
len(params["class_names"]) + 1,
|
len(params.class_names) + 1,
|
||||||
kernel_size=k_size,
|
kernel_size=k_size,
|
||||||
padding=pad,
|
padding=pad,
|
||||||
)
|
)
|
||||||
model.conv_classes_op.to(params["device"])
|
model.conv_classes_op.to(params.device)
|
||||||
|
|
||||||
if args["finetune_only_last_layer"]:
|
if finetune_only_last_layer:
|
||||||
print("\nOnly finetuning the final layers.\n")
|
print("\nOnly finetuning the final layers.\n")
|
||||||
train_layers_i = [
|
train_layers_i = [
|
||||||
"conv_classes",
|
"conv_classes",
|
||||||
@ -223,19 +189,26 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
|
optimizer = torch.optim.Adam(
|
||||||
scheduler = CosineAnnealingLR(
|
model.parameters(),
|
||||||
optimizer, params["num_epochs"] * len(train_loader)
|
lr=params.lr,
|
||||||
)
|
)
|
||||||
if params["train_loss"] == "mse":
|
scheduler = CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
params.num_epochs * len(train_loader),
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.train_loss == "mse":
|
||||||
det_criterion = losses.mse_loss
|
det_criterion = losses.mse_loss
|
||||||
elif params["train_loss"] == "focal":
|
elif params.train_loss == "focal":
|
||||||
det_criterion = losses.focal_loss
|
det_criterion = losses.focal_loss
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown loss function")
|
||||||
|
|
||||||
# plotting
|
# plotting
|
||||||
train_plt_ls = pu.LossPlotter(
|
train_plt_ls = pu.LossPlotter(
|
||||||
params["experiment"] + "train_loss.png",
|
params.experiment / "train_loss.png",
|
||||||
params["num_epochs"] + 1,
|
params.num_epochs + 1,
|
||||||
["train_loss"],
|
["train_loss"],
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
@ -243,8 +216,8 @@ if __name__ == "__main__":
|
|||||||
logy=True,
|
logy=True,
|
||||||
)
|
)
|
||||||
test_plt_ls = pu.LossPlotter(
|
test_plt_ls = pu.LossPlotter(
|
||||||
params["experiment"] + "test_loss.png",
|
params.experiment / "test_loss.png",
|
||||||
params["num_epochs"] + 1,
|
params.num_epochs + 1,
|
||||||
["test_loss"],
|
["test_loss"],
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
@ -252,24 +225,24 @@ if __name__ == "__main__":
|
|||||||
logy=True,
|
logy=True,
|
||||||
)
|
)
|
||||||
test_plt = pu.LossPlotter(
|
test_plt = pu.LossPlotter(
|
||||||
params["experiment"] + "test.png",
|
params.experiment / "test.png",
|
||||||
params["num_epochs"] + 1,
|
params.num_epochs + 1,
|
||||||
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
|
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
|
||||||
[0, 1],
|
[0, 1],
|
||||||
None,
|
None,
|
||||||
["epoch", ""],
|
["epoch", ""],
|
||||||
)
|
)
|
||||||
test_plt_class = pu.LossPlotter(
|
test_plt_class = pu.LossPlotter(
|
||||||
params["experiment"] + "test_avg_prec.png",
|
params.experiment / "test_avg_prec.png",
|
||||||
params["num_epochs"] + 1,
|
params.num_epochs + 1,
|
||||||
params["class_names_short"],
|
params.class_names_short,
|
||||||
[0, 1],
|
[0, 1],
|
||||||
params["class_names_short"],
|
params.class_names_short,
|
||||||
["epoch", "avg_prec"],
|
["epoch", "avg_prec"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# main train loop
|
# main train loop
|
||||||
for epoch in range(0, params["num_epochs"] + 1):
|
for epoch in range(0, params.num_epochs + 1):
|
||||||
train_loss = tm.train(
|
train_loss = tm.train(
|
||||||
model,
|
model,
|
||||||
epoch,
|
epoch,
|
||||||
@ -281,10 +254,14 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
|
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
|
||||||
|
|
||||||
if epoch % params["num_eval_epochs"] == 0:
|
if epoch % params.num_eval_epochs == 0:
|
||||||
# detection accuracy on test set
|
# detection accuracy on test set
|
||||||
test_res, test_loss = tm.test(
|
test_res, test_loss = tm.test(
|
||||||
model, epoch, test_loader, det_criterion, params
|
model,
|
||||||
|
epoch,
|
||||||
|
test_loader,
|
||||||
|
det_criterion,
|
||||||
|
params,
|
||||||
)
|
)
|
||||||
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
|
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
|
||||||
test_plt.update_and_save(
|
test_plt.update_and_save(
|
||||||
@ -301,18 +278,106 @@ if __name__ == "__main__":
|
|||||||
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
|
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
|
||||||
)
|
)
|
||||||
pu.plot_pr_curve_class(
|
pu.plot_pr_curve_class(
|
||||||
params["experiment"], "test_pr", "test_pr", test_res
|
params.experiment, "test_pr", "test_pr", test_res
|
||||||
)
|
)
|
||||||
|
|
||||||
# save finetuned model
|
# save finetuned model
|
||||||
print("saving model to: " + params["model_file_name"])
|
print(f"saving model to: {params.model_file_name}")
|
||||||
op_state = {
|
op_state = {
|
||||||
"epoch": epoch + 1,
|
"epoch": epoch + 1,
|
||||||
"state_dict": model.state_dict(),
|
"state_dict": model.state_dict(),
|
||||||
"params": params,
|
"params": params,
|
||||||
}
|
}
|
||||||
torch.save(op_state, params["model_file_name"])
|
torch.save(op_state, params.model_file_name)
|
||||||
|
|
||||||
# save an image with associated prediction for each batch in the test set
|
# save an image with associated prediction for each batch in the test set
|
||||||
if not args["do_not_save_images"]:
|
if save_images:
|
||||||
tm.save_images_batch(model, test_loader, params)
|
tm.save_images_batch(model, test_loader, params)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
info_str = "\nBatDetect - Finetune Model\n"
|
||||||
|
print(info_str)
|
||||||
|
|
||||||
|
args = parse_arugments()
|
||||||
|
|
||||||
|
# Load experiment parameters
|
||||||
|
params = parameters.get_params(
|
||||||
|
make_dirs=True,
|
||||||
|
exps_dir=args.experiment_dir,
|
||||||
|
device=select_device(),
|
||||||
|
num_epochs=args.num_epochs,
|
||||||
|
notes=args.notes,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nAudio directory: " + args.audio_path)
|
||||||
|
print("Train file: " + args.train_ann_path)
|
||||||
|
print("Test file: " + args.test_ann_path)
|
||||||
|
print("Loading model: " + args.model_path)
|
||||||
|
|
||||||
|
if args.train_from_scratch:
|
||||||
|
print(
|
||||||
|
"\nTraining model from scratch i.e. not using pretrained weights"
|
||||||
|
)
|
||||||
|
|
||||||
|
model, model_params = du.load_model(
|
||||||
|
args.model_path,
|
||||||
|
load_weights=not args.train_from_scratch,
|
||||||
|
device=params.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.op_model_name != "":
|
||||||
|
params.model_file_name = args.op_model_name
|
||||||
|
|
||||||
|
classes_to_ignore = params.classes_to_ignore + params.generic_class
|
||||||
|
|
||||||
|
# save notes file
|
||||||
|
if params.notes:
|
||||||
|
tu.write_notes_file(
|
||||||
|
params.experiment / "notes.txt",
|
||||||
|
args.notes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE:??
|
||||||
|
dataset_name = (
|
||||||
|
os.path.basename(args.train_ann_path)
|
||||||
|
.replace(".json", "")
|
||||||
|
.replace("_TRAIN", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==== LOAD DATA ====
|
||||||
|
|
||||||
|
# load train annotations
|
||||||
|
data_train = load_annotations(
|
||||||
|
dataset_name,
|
||||||
|
args.train_ann_path,
|
||||||
|
args.audio_path,
|
||||||
|
params.events_of_interest,
|
||||||
|
)
|
||||||
|
print("\nTrain set:")
|
||||||
|
print("Number of files", len(data_train))
|
||||||
|
|
||||||
|
# load test annotations
|
||||||
|
data_test = load_annotations(
|
||||||
|
dataset_name,
|
||||||
|
args.test_ann_path,
|
||||||
|
args.audio_path,
|
||||||
|
classes_to_ignore,
|
||||||
|
params.events_of_interest,
|
||||||
|
)
|
||||||
|
print("\nTrain set:")
|
||||||
|
print("Number of files", len(data_train))
|
||||||
|
|
||||||
|
finetune_model(
|
||||||
|
model,
|
||||||
|
data_train,
|
||||||
|
data_test,
|
||||||
|
params,
|
||||||
|
model_params,
|
||||||
|
finetune_only_last_layer=args.finetune_only_last_layer,
|
||||||
|
save_images=args.do_not_save_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -1,62 +1,54 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from collections import Counter
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from sklearn.model_selection import StratifiedGroupKFold
|
||||||
|
|
||||||
import batdetect2.train.train_utils as tu
|
import batdetect2.train.train_utils as tu
|
||||||
|
from batdetect2 import types
|
||||||
|
|
||||||
|
|
||||||
def print_dataset_stats(data, split_name, classes_to_ignore):
|
def print_dataset_stats(
|
||||||
print("\nSplit:", split_name)
|
data: List[types.FileAnnotation],
|
||||||
|
classes_to_ignore: Optional[List[str]] = None,
|
||||||
|
) -> Counter[str]:
|
||||||
print("Num files:", len(data))
|
print("Num files:", len(data))
|
||||||
|
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
||||||
class_cnts = {}
|
if len(counts) > 0:
|
||||||
for dd in data:
|
tu.report_class_counts(counts)
|
||||||
for aa in dd["annotation"]:
|
return counts
|
||||||
if aa["class"] not in classes_to_ignore:
|
|
||||||
if aa["class"] in class_cnts:
|
|
||||||
class_cnts[aa["class"]] += 1
|
|
||||||
else:
|
|
||||||
class_cnts[aa["class"]] = 1
|
|
||||||
|
|
||||||
if len(class_cnts) == 0:
|
|
||||||
class_names = []
|
|
||||||
else:
|
|
||||||
class_names = np.sort([*class_cnts]).tolist()
|
|
||||||
print("Class count:")
|
|
||||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
|
||||||
|
|
||||||
for ii, cc in enumerate(class_names):
|
|
||||||
print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc]))
|
|
||||||
|
|
||||||
return class_names
|
|
||||||
|
|
||||||
|
|
||||||
def load_file_names(file_name):
|
def load_file_names(file_name: str) -> List[str]:
|
||||||
if os.path.isfile(file_name):
|
if not os.path.isfile(file_name):
|
||||||
with open(file_name) as da:
|
raise FileNotFoundError(f"Input file not found - {file_name}")
|
||||||
files = [line.rstrip() for line in da.readlines()]
|
|
||||||
for ff in files:
|
with open(file_name) as da:
|
||||||
if ff.lower()[-3:] != "wav":
|
files = [line.rstrip() for line in da.readlines()]
|
||||||
print("Error: Filenames need to end in .wav - ", ff)
|
|
||||||
assert False
|
for path in files:
|
||||||
else:
|
if path.lower()[-3:] != "wav":
|
||||||
print("Error: Input file not found - ", file_name)
|
raise ValueError(
|
||||||
assert False
|
f"Invalid file name - {path}. Must be a .wav file"
|
||||||
|
)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def parse_args():
|
||||||
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
|
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
|
||||||
|
|
||||||
print(info_str)
|
print(info_str)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"dataset_name", type=str, help="Name to call your dataset"
|
"dataset_name", type=str, help="Name to call your dataset"
|
||||||
)
|
)
|
||||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
parser.add_argument(
|
||||||
|
"audio_dir", type=str, help="Input directory for audio"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"ann_dir",
|
"ann_dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -102,88 +94,126 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help='New class names to use instead. One to one mapping with "--input_class_names". \
|
help='New class names to use instead. One to one mapping with "--input_class_names". \
|
||||||
Separate with ";"',
|
Separate with ";"',
|
||||||
)
|
)
|
||||||
args = vars(parser.parse_args())
|
return parser.parse_args()
|
||||||
|
|
||||||
np.random.seed(args["rand_seed"])
|
|
||||||
|
def split_data(
|
||||||
|
data: List[types.FileAnnotation],
|
||||||
|
train_file: str,
|
||||||
|
test_file: str,
|
||||||
|
n_splits: int = 5,
|
||||||
|
random_state: int = 0,
|
||||||
|
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
|
||||||
|
if train_file != "" and test_file != "":
|
||||||
|
# user has specifed the train / test split
|
||||||
|
mapping = {
|
||||||
|
file_annotation["id"]: file_annotation for file_annotation in data
|
||||||
|
}
|
||||||
|
train_files = load_file_names(train_file)
|
||||||
|
test_files = load_file_names(test_file)
|
||||||
|
data_train = [
|
||||||
|
mapping[file_id] for file_id in train_files if file_id in mapping
|
||||||
|
]
|
||||||
|
data_test = [
|
||||||
|
mapping[file_id] for file_id in test_files if file_id in mapping
|
||||||
|
]
|
||||||
|
return data_train, data_test
|
||||||
|
|
||||||
|
# NOTE: Using StratifiedGroupKFold to ensure that the same file does not
|
||||||
|
# appear in both the training and test sets and trying to keep the
|
||||||
|
# distribution of classes the same in both sets.
|
||||||
|
splitter = StratifiedGroupKFold(
|
||||||
|
n_splits=n_splits,
|
||||||
|
shuffle=True,
|
||||||
|
random_state=random_state,
|
||||||
|
)
|
||||||
|
anns = np.array(
|
||||||
|
[
|
||||||
|
[dd["id"], ann["class"], ann["event"]]
|
||||||
|
for dd in data
|
||||||
|
for ann in dd["annotation"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
y = anns[:, 1]
|
||||||
|
group = anns[:, 0]
|
||||||
|
|
||||||
|
train_idx, test_idx = next(splitter.split(X=anns, y=y, groups=group))
|
||||||
|
train_ids = set(anns[train_idx, 0])
|
||||||
|
test_ids = set(anns[test_idx, 0])
|
||||||
|
|
||||||
|
assert not (train_ids & test_ids)
|
||||||
|
data_train = [dd for dd in data if dd["id"] in train_ids]
|
||||||
|
data_test = [dd for dd in data if dd["id"] in test_ids]
|
||||||
|
return data_train, data_test
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
np.random.seed(args.rand_seed)
|
||||||
|
|
||||||
classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
|
classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
|
||||||
generic_class = ["Bat"]
|
|
||||||
events_of_interest = ["Echolocation"]
|
events_of_interest = ["Echolocation"]
|
||||||
|
|
||||||
if args["input_class_names"] != "" and args["output_class_names"] != "":
|
name_dict = None
|
||||||
|
if args.input_class_names != "" and args.output_class_names != "":
|
||||||
# change the names of the classes
|
# change the names of the classes
|
||||||
ip_names = args["input_class_names"].split(";")
|
ip_names = args.input_class_names.split(";")
|
||||||
op_names = args["output_class_names"].split(";")
|
op_names = args.output_class_names.split(";")
|
||||||
name_dict = dict(zip(ip_names, op_names))
|
name_dict = dict(zip(ip_names, op_names))
|
||||||
else:
|
|
||||||
name_dict = False
|
|
||||||
|
|
||||||
# load annotations
|
# load annotations
|
||||||
data_all, _, _ = tu.load_set_of_anns(
|
data_all = tu.load_set_of_anns(
|
||||||
{"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]},
|
[
|
||||||
classes_to_ignore,
|
{
|
||||||
events_of_interest,
|
"dataset_name": args.dataset_name,
|
||||||
False,
|
"ann_path": args.ann_dir,
|
||||||
False,
|
"wav_path": args.audio_dir,
|
||||||
list_of_anns=True,
|
"is_test": False,
|
||||||
|
"is_binary": False,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
classes_to_ignore=classes_to_ignore,
|
||||||
|
events_of_interest=events_of_interest,
|
||||||
|
convert_to_genus=False,
|
||||||
filter_issues=True,
|
filter_issues=True,
|
||||||
name_replace=name_dict,
|
name_replace=name_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Dataset name: " + args["dataset_name"])
|
print("Dataset name: " + args.dataset_name)
|
||||||
print("Audio directory: " + args["audio_dir"])
|
print("Audio directory: " + args.audio_dir)
|
||||||
print("Annotation directory: " + args["ann_dir"])
|
print("Annotation directory: " + args.ann_dir)
|
||||||
print("Ouput directory: " + args["op_dir"])
|
print("Ouput directory: " + args.op_dir)
|
||||||
print("Num annotated files: " + str(len(data_all)))
|
print("Num annotated files: " + str(len(data_all)))
|
||||||
|
|
||||||
if args["train_file"] != "" and args["test_file"] != "":
|
data_train, data_test = split_data(
|
||||||
# user has specifed the train / test split
|
data=data_all,
|
||||||
train_files = load_file_names(args["train_file"])
|
train_file=args.train_file,
|
||||||
test_files = load_file_names(args["test_file"])
|
test_file=args.test_file,
|
||||||
file_names_all = [dd["id"] for dd in data_all]
|
n_splits=5,
|
||||||
train_inds = [
|
random_state=args.rand_seed,
|
||||||
file_names_all.index(ff)
|
)
|
||||||
for ff in train_files
|
|
||||||
if ff in file_names_all
|
|
||||||
]
|
|
||||||
test_inds = [
|
|
||||||
file_names_all.index(ff)
|
|
||||||
for ff in test_files
|
|
||||||
if ff in file_names_all
|
|
||||||
]
|
|
||||||
|
|
||||||
else:
|
if not os.path.isdir(args.op_dir):
|
||||||
# split the data into train and test at the file level
|
os.makedirs(args.op_dir)
|
||||||
num_exs = len(data_all)
|
op_name = os.path.join(args.op_dir, args.dataset_name)
|
||||||
test_inds = np.random.choice(
|
|
||||||
np.arange(num_exs),
|
|
||||||
int(num_exs * args["percent_val"]),
|
|
||||||
replace=False,
|
|
||||||
)
|
|
||||||
test_inds = np.sort(test_inds)
|
|
||||||
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
|
|
||||||
|
|
||||||
data_train = [data_all[ii] for ii in train_inds]
|
|
||||||
data_test = [data_all[ii] for ii in test_inds]
|
|
||||||
|
|
||||||
if not os.path.isdir(args["op_dir"]):
|
|
||||||
os.makedirs(args["op_dir"])
|
|
||||||
op_name = os.path.join(args["op_dir"], args["dataset_name"])
|
|
||||||
op_name_train = op_name + "_TRAIN.json"
|
op_name_train = op_name + "_TRAIN.json"
|
||||||
op_name_test = op_name + "_TEST.json"
|
op_name_test = op_name + "_TEST.json"
|
||||||
|
|
||||||
class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore)
|
print("\nSplit: Train")
|
||||||
class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore)
|
class_un_train = print_dataset_stats(data_train, classes_to_ignore)
|
||||||
|
|
||||||
|
print("\nSplit: Test")
|
||||||
|
class_un_test = print_dataset_stats(data_test, classes_to_ignore)
|
||||||
|
|
||||||
if len(data_train) > 0 and len(data_test) > 0:
|
if len(data_train) > 0 and len(data_test) > 0:
|
||||||
if class_un_train != class_un_test:
|
if set(class_un_train.keys()) != set(class_un_test.keys()):
|
||||||
print(
|
raise RuntimeError(
|
||||||
'\nError: some classes are not in both the training and test sets.\
|
"Error: some classes are not in both the training and test sets."
|
||||||
\nTry a different random seed "--rand_seed".'
|
'Try a different random seed "--rand_seed".'
|
||||||
)
|
)
|
||||||
assert False
|
|
||||||
|
|
||||||
print("\n")
|
print("\n")
|
||||||
if len(data_train) == 0:
|
if len(data_train) == 0:
|
||||||
@ -199,3 +229,7 @@ if __name__ == "__main__":
|
|||||||
print("Saving: ", op_name_test)
|
print("Saving: ", op_name_test)
|
||||||
with open(op_name_test, "w") as da:
|
with open(op_name_test, "w") as da:
|
||||||
json.dump(data_test, da, indent=2)
|
json.dump(data_test, da, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -12,19 +12,24 @@ import torchaudio
|
|||||||
import batdetect2.utils.audio_utils as au
|
import batdetect2.utils.audio_utils as au
|
||||||
from batdetect2.types import (
|
from batdetect2.types import (
|
||||||
Annotation,
|
Annotation,
|
||||||
AnnotationGroup,
|
|
||||||
AudioLoaderAnnotationGroup,
|
AudioLoaderAnnotationGroup,
|
||||||
FileAnnotations,
|
AudioLoaderParameters,
|
||||||
HeatmapParameters,
|
FileAnnotation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_gt_heatmaps(
|
def generate_gt_heatmaps(
|
||||||
spec_op_shape: Tuple[int, int],
|
spec_op_shape: Tuple[int, int],
|
||||||
sampling_rate: int,
|
sampling_rate: float,
|
||||||
ann: AnnotationGroup,
|
ann: AudioLoaderAnnotationGroup,
|
||||||
params: HeatmapParameters,
|
class_names: List[str],
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
|
fft_win_length: float,
|
||||||
|
fft_overlap: float,
|
||||||
|
max_freq: float,
|
||||||
|
min_freq: float,
|
||||||
|
resize_factor: float,
|
||||||
|
target_sigma: float,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
|
||||||
"""Generate ground truth heatmaps from annotations.
|
"""Generate ground truth heatmaps from annotations.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -53,31 +58,31 @@ def generate_gt_heatmaps(
|
|||||||
the x and y indices of their pixel location in the input spectrogram.
|
the x and y indices of their pixel location in the input spectrogram.
|
||||||
"""
|
"""
|
||||||
# spec may be resized on input into the network
|
# spec may be resized on input into the network
|
||||||
num_classes = len(params["class_names"])
|
num_classes = len(class_names)
|
||||||
op_height = spec_op_shape[0]
|
op_height = spec_op_shape[0]
|
||||||
op_width = spec_op_shape[1]
|
op_width = spec_op_shape[1]
|
||||||
freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height
|
freq_per_bin = (max_freq - min_freq) / op_height
|
||||||
|
|
||||||
# start and end times
|
# start and end times
|
||||||
x_pos_start = au.time_to_x_coords(
|
x_pos_start = au.time_to_x_coords(
|
||||||
ann["start_times"],
|
ann["start_times"],
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
params["fft_win_length"],
|
fft_win_length,
|
||||||
params["fft_overlap"],
|
fft_overlap,
|
||||||
)
|
)
|
||||||
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32)
|
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
|
||||||
x_pos_end = au.time_to_x_coords(
|
x_pos_end = au.time_to_x_coords(
|
||||||
ann["end_times"],
|
ann["end_times"],
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
params["fft_win_length"],
|
fft_win_length,
|
||||||
params["fft_overlap"],
|
fft_overlap,
|
||||||
)
|
)
|
||||||
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32)
|
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
|
||||||
|
|
||||||
# location on y axis i.e. frequency
|
# location on y axis i.e. frequency
|
||||||
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
|
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
|
||||||
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
||||||
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
|
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
|
||||||
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
||||||
bb_widths = x_pos_end - x_pos_start
|
bb_widths = x_pos_end - x_pos_start
|
||||||
bb_heights = y_pos_low - y_pos_high
|
bb_heights = y_pos_low - y_pos_high
|
||||||
@ -90,26 +95,17 @@ def generate_gt_heatmaps(
|
|||||||
& (y_pos_low < (op_height - 1))
|
& (y_pos_low < (op_height - 1))
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
ann_aug: AnnotationGroup = {
|
ann_aug: AudioLoaderAnnotationGroup = {
|
||||||
|
**ann,
|
||||||
"start_times": ann["start_times"][valid_inds],
|
"start_times": ann["start_times"][valid_inds],
|
||||||
"end_times": ann["end_times"][valid_inds],
|
"end_times": ann["end_times"][valid_inds],
|
||||||
"high_freqs": ann["high_freqs"][valid_inds],
|
"high_freqs": ann["high_freqs"][valid_inds],
|
||||||
"low_freqs": ann["low_freqs"][valid_inds],
|
"low_freqs": ann["low_freqs"][valid_inds],
|
||||||
"class_ids": ann["class_ids"][valid_inds],
|
"class_ids": ann["class_ids"][valid_inds],
|
||||||
"individual_ids": ann["individual_ids"][valid_inds],
|
"individual_ids": ann["individual_ids"][valid_inds],
|
||||||
|
"x_inds": x_pos_start[valid_inds],
|
||||||
|
"y_inds": y_pos_low[valid_inds],
|
||||||
}
|
}
|
||||||
ann_aug["x_inds"] = x_pos_start[valid_inds]
|
|
||||||
ann_aug["y_inds"] = y_pos_low[valid_inds]
|
|
||||||
# keys = [
|
|
||||||
# "start_times",
|
|
||||||
# "end_times",
|
|
||||||
# "high_freqs",
|
|
||||||
# "low_freqs",
|
|
||||||
# "class_ids",
|
|
||||||
# "individual_ids",
|
|
||||||
# ]
|
|
||||||
# for kk in keys:
|
|
||||||
# ann_aug[kk] = ann[kk][valid_inds]
|
|
||||||
|
|
||||||
# if the number of calls is only 1, then it is unique
|
# if the number of calls is only 1, then it is unique
|
||||||
# TODO would be better if we found these unique calls at the merging stage
|
# TODO would be better if we found these unique calls at the merging stage
|
||||||
@ -118,6 +114,7 @@ def generate_gt_heatmaps(
|
|||||||
|
|
||||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
||||||
|
|
||||||
# num classes and "background" class
|
# num classes and "background" class
|
||||||
y_2d_classes: np.ndarray = np.zeros(
|
y_2d_classes: np.ndarray = np.zeros(
|
||||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
(num_classes + 1, op_height, op_width), dtype=np.float32
|
||||||
@ -128,14 +125,8 @@ def generate_gt_heatmaps(
|
|||||||
draw_gaussian(
|
draw_gaussian(
|
||||||
y_2d_det[0, :],
|
y_2d_det[0, :],
|
||||||
(x_pos_start[ii], y_pos_low[ii]),
|
(x_pos_start[ii], y_pos_low[ii]),
|
||||||
params["target_sigma"],
|
target_sigma,
|
||||||
)
|
)
|
||||||
# draw_gaussian(
|
|
||||||
# y_2d_det[0, :],
|
|
||||||
# (x_pos_start[ii], y_pos_low[ii]),
|
|
||||||
# params["target_sigma"],
|
|
||||||
# params["target_sigma"] * 2,
|
|
||||||
# )
|
|
||||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
||||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
||||||
|
|
||||||
@ -144,14 +135,8 @@ def generate_gt_heatmaps(
|
|||||||
draw_gaussian(
|
draw_gaussian(
|
||||||
y_2d_classes[cls_id, :],
|
y_2d_classes[cls_id, :],
|
||||||
(x_pos_start[ii], y_pos_low[ii]),
|
(x_pos_start[ii], y_pos_low[ii]),
|
||||||
params["target_sigma"],
|
target_sigma,
|
||||||
)
|
)
|
||||||
# draw_gaussian(
|
|
||||||
# y_2d_classes[cls_id, :],
|
|
||||||
# (x_pos_start[ii], y_pos_low[ii]),
|
|
||||||
# params["target_sigma"],
|
|
||||||
# params["target_sigma"] * 2,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# be careful as this will have a 1.0 places where we have event but
|
# be careful as this will have a 1.0 places where we have event but
|
||||||
# dont know gt class this will be masked in training anyway
|
# dont know gt class this will be masked in training anyway
|
||||||
@ -235,8 +220,8 @@ def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
|
|||||||
|
|
||||||
def warp_spec_aug(
|
def warp_spec_aug(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
ann: AnnotationGroup,
|
ann: AudioLoaderAnnotationGroup,
|
||||||
params: dict,
|
stretch_squeeze_delta: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Warp spectrogram by randomly stretching and squeezing.
|
"""Warp spectrogram by randomly stretching and squeezing.
|
||||||
|
|
||||||
@ -247,8 +232,8 @@ def warp_spec_aug(
|
|||||||
ann: AnnotationGroup
|
ann: AnnotationGroup
|
||||||
Annotation group for the spectrogram. Must be provided to sync
|
Annotation group for the spectrogram. Must be provided to sync
|
||||||
the start and stop times with the spectrogram after warping.
|
the start and stop times with the spectrogram after warping.
|
||||||
params: dict
|
stretch_squeeze_delta: float
|
||||||
Parameters for the augmentation.
|
Maximum amount to stretch or squeeze the spectrogram.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -259,11 +244,10 @@ def warp_spec_aug(
|
|||||||
-----
|
-----
|
||||||
This function modifies the annotation group in place.
|
This function modifies the annotation group in place.
|
||||||
"""
|
"""
|
||||||
# This is messy
|
|
||||||
# Augment spectrogram by randomly stretch and squeezing
|
# Augment spectrogram by randomly stretch and squeezing
|
||||||
# NOTE this also changes the start and stop time in place
|
# NOTE this also changes the start and stop time in place
|
||||||
|
|
||||||
delta = params["stretch_squeeze_delta"]
|
delta = stretch_squeeze_delta
|
||||||
op_size = (spec.shape[1], spec.shape[2])
|
op_size = (spec.shape[1], spec.shape[2])
|
||||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
||||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
resize_amt = int(spec.shape[2] * resize_fract_r)
|
||||||
@ -277,7 +261,7 @@ def warp_spec_aug(
|
|||||||
dtype=spec.dtype,
|
dtype=spec.dtype,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
2,
|
dim=2,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
spec_r = spec[:, :, :resize_amt]
|
spec_r = spec[:, :, :resize_amt]
|
||||||
@ -297,7 +281,10 @@ def warp_spec_aug(
|
|||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
def mask_time_aug(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
mask_max_time_perc: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Mask out random blocks of time.
|
"""Mask out random blocks of time.
|
||||||
|
|
||||||
Will randomly mask out a block of time in the spectrogram. The block
|
Will randomly mask out a block of time in the spectrogram. The block
|
||||||
@ -308,8 +295,8 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
|||||||
----------
|
----------
|
||||||
spec: torch.Tensor
|
spec: torch.Tensor
|
||||||
Spectrogram to mask.
|
Spectrogram to mask.
|
||||||
params: dict
|
mask_max_time_perc: float
|
||||||
Parameters for the augmentation.
|
Maximum percentage of time to mask out.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -324,14 +311,17 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
|||||||
Recognition
|
Recognition
|
||||||
"""
|
"""
|
||||||
fm = torchaudio.transforms.TimeMasking(
|
fm = torchaudio.transforms.TimeMasking(
|
||||||
int(spec.shape[1] * params["mask_max_time_perc"])
|
int(spec.shape[1] * mask_max_time_perc)
|
||||||
)
|
)
|
||||||
for _ in range(np.random.randint(1, 4)):
|
for _ in range(np.random.randint(1, 4)):
|
||||||
spec = fm(spec)
|
spec = fm(spec)
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
def mask_freq_aug(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
mask_max_freq_perc: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Mask out random blocks of frequency.
|
"""Mask out random blocks of frequency.
|
||||||
|
|
||||||
Will randomly mask out a block of frequency in the spectrogram. The block
|
Will randomly mask out a block of frequency in the spectrogram. The block
|
||||||
@ -342,8 +332,8 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
|||||||
----------
|
----------
|
||||||
spec: torch.Tensor
|
spec: torch.Tensor
|
||||||
Spectrogram to mask.
|
Spectrogram to mask.
|
||||||
params: dict
|
mask_max_freq_perc: float
|
||||||
Parameters for the augmentation.
|
Maximum percentage of frequency to mask out.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -358,41 +348,48 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
|||||||
Recognition
|
Recognition
|
||||||
"""
|
"""
|
||||||
fm = torchaudio.transforms.FrequencyMasking(
|
fm = torchaudio.transforms.FrequencyMasking(
|
||||||
int(spec.shape[1] * params["mask_max_freq_perc"])
|
int(spec.shape[1] * mask_max_freq_perc)
|
||||||
)
|
)
|
||||||
for _ in range(np.random.randint(1, 4)):
|
for _ in range(np.random.randint(1, 4)):
|
||||||
spec = fm(spec)
|
spec = fm(spec)
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
def scale_vol_aug(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
spec_amp_scaling: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Scale the volume of the spectrogram.
|
"""Scale the volume of the spectrogram.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec: torch.Tensor
|
spec: torch.Tensor
|
||||||
Spectrogram to scale.
|
Spectrogram to scale.
|
||||||
params: dict
|
spec_amp_scaling: float
|
||||||
Parameters for the augmentation.
|
Maximum scaling factor.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
"""
|
"""
|
||||||
return spec * np.random.random() * params["spec_amp_scaling"]
|
return spec * np.random.random() * spec_amp_scaling
|
||||||
|
|
||||||
|
|
||||||
def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
|
def echo_aug(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
echo_max_delay: float,
|
||||||
|
) -> np.ndarray:
|
||||||
"""Add echo to audio.
|
"""Add echo to audio.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio: np.ndarray
|
audio: np.ndarray
|
||||||
Audio to add echo to.
|
Audio to add echo to.
|
||||||
sampling_rate: int
|
sampling_rate: float
|
||||||
Sampling rate of the audio.
|
Sampling rate of the audio.
|
||||||
params: dict
|
echo_max_delay: float
|
||||||
Parameters for the augmentation.
|
Maximum delay of the echo in seconds.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -400,7 +397,7 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
|
|||||||
Audio with echo added.
|
Audio with echo added.
|
||||||
"""
|
"""
|
||||||
sample_offset = (
|
sample_offset = (
|
||||||
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
|
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
||||||
)
|
)
|
||||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
||||||
return audio
|
return audio
|
||||||
@ -408,9 +405,14 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
|
|||||||
|
|
||||||
def resample_aug(
|
def resample_aug(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: float,
|
||||||
params: dict,
|
fft_win_length: float,
|
||||||
) -> Tuple[np.ndarray, int, float]:
|
fft_overlap: float,
|
||||||
|
resize_factor: float,
|
||||||
|
spec_divide_factor: float,
|
||||||
|
spec_train_width: int,
|
||||||
|
aug_sampling_rates: List[int],
|
||||||
|
) -> Tuple[np.ndarray, float, float]:
|
||||||
"""Resample audio augmentation.
|
"""Resample audio augmentation.
|
||||||
|
|
||||||
Will resample the audio to a random sampling rate from the list of
|
Will resample the audio to a random sampling rate from the list of
|
||||||
@ -420,23 +422,32 @@ def resample_aug(
|
|||||||
----------
|
----------
|
||||||
audio: np.ndarray
|
audio: np.ndarray
|
||||||
Audio to resample.
|
Audio to resample.
|
||||||
sampling_rate: int
|
sampling_rate: float
|
||||||
Original sampling rate of the audio.
|
Original sampling rate of the audio.
|
||||||
params: dict
|
fft_win_length: float
|
||||||
Parameters for the augmentation. Includes the list of sampling rates
|
Length of the FFT window in seconds.
|
||||||
to choose from for resampling in `aug_sampling_rates`.
|
fft_overlap: float
|
||||||
|
Amount of overlap between FFT windows.
|
||||||
|
resize_factor: float
|
||||||
|
Factor to resize the spectrogram by.
|
||||||
|
spec_divide_factor: float
|
||||||
|
Factor to divide the spectrogram by.
|
||||||
|
spec_train_width: int
|
||||||
|
Width of the spectrogram.
|
||||||
|
aug_sampling_rates: List[int]
|
||||||
|
List of sampling rates to resample to.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
audio : np.ndarray
|
audio : np.ndarray
|
||||||
Resampled audio.
|
Resampled audio.
|
||||||
sampling_rate : int
|
sampling_rate : float
|
||||||
New sampling rate.
|
New sampling rate.
|
||||||
duration : float
|
duration : float
|
||||||
Duration of the audio in seconds.
|
Duration of the audio in seconds.
|
||||||
"""
|
"""
|
||||||
sampling_rate_old = sampling_rate
|
sampling_rate_old = sampling_rate
|
||||||
sampling_rate = np.random.choice(params["aug_sampling_rates"])
|
sampling_rate = np.random.choice(aug_sampling_rates)
|
||||||
audio = librosa.resample(
|
audio = librosa.resample(
|
||||||
audio,
|
audio,
|
||||||
orig_sr=sampling_rate_old,
|
orig_sr=sampling_rate_old,
|
||||||
@ -447,11 +458,11 @@ def resample_aug(
|
|||||||
audio = au.pad_audio(
|
audio = au.pad_audio(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
params["fft_win_length"],
|
fft_win_length,
|
||||||
params["fft_overlap"],
|
fft_overlap,
|
||||||
params["resize_factor"],
|
resize_factor,
|
||||||
params["spec_divide_factor"],
|
spec_divide_factor,
|
||||||
params["spec_train_width"],
|
spec_train_width,
|
||||||
)
|
)
|
||||||
duration = audio.shape[0] / float(sampling_rate)
|
duration = audio.shape[0] / float(sampling_rate)
|
||||||
return audio, sampling_rate, duration
|
return audio, sampling_rate, duration
|
||||||
@ -459,28 +470,28 @@ def resample_aug(
|
|||||||
|
|
||||||
def resample_audio(
|
def resample_audio(
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
sampling_rate: int,
|
sampling_rate: float,
|
||||||
audio2: np.ndarray,
|
audio2: np.ndarray,
|
||||||
sampling_rate2: int,
|
sampling_rate2: float,
|
||||||
) -> Tuple[np.ndarray, int]:
|
) -> Tuple[np.ndarray, float]:
|
||||||
"""Resample audio.
|
"""Resample audio.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
num_samples: int
|
num_samples: int
|
||||||
Expected number of samples for the output audio.
|
Expected number of samples for the output audio.
|
||||||
sampling_rate: int
|
sampling_rate: float
|
||||||
Original sampling rate of the audio.
|
Original sampling rate of the audio.
|
||||||
audio2: np.ndarray
|
audio2: np.ndarray
|
||||||
Audio to resample.
|
Audio to resample.
|
||||||
sampling_rate2: int
|
sampling_rate2: float
|
||||||
Target sampling rate of the audio.
|
Target sampling rate of the audio.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
audio2 : np.ndarray
|
audio2 : np.ndarray
|
||||||
Resampled audio.
|
Resampled audio.
|
||||||
sampling_rate2 : int
|
sampling_rate2 : float
|
||||||
New sampling rate.
|
New sampling rate.
|
||||||
"""
|
"""
|
||||||
# resample to target sampling rate
|
# resample to target sampling rate
|
||||||
@ -509,12 +520,12 @@ def resample_audio(
|
|||||||
|
|
||||||
def combine_audio_aug(
|
def combine_audio_aug(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: float,
|
||||||
ann: AnnotationGroup,
|
ann: AudioLoaderAnnotationGroup,
|
||||||
audio2: np.ndarray,
|
audio2: np.ndarray,
|
||||||
sampling_rate2: int,
|
sampling_rate2: float,
|
||||||
ann2: AnnotationGroup,
|
ann2: AudioLoaderAnnotationGroup,
|
||||||
) -> Tuple[np.ndarray, AnnotationGroup]:
|
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
|
||||||
"""Combine two audio files.
|
"""Combine two audio files.
|
||||||
|
|
||||||
Will combine two audio files by resampling them to the same sampling rate
|
Will combine two audio files by resampling them to the same sampling rate
|
||||||
@ -570,7 +581,9 @@ def combine_audio_aug(
|
|||||||
# from different individuals
|
# from different individuals
|
||||||
if kk == "individual_ids":
|
if kk == "individual_ids":
|
||||||
if (ann[kk] > -1).sum() > 0:
|
if (ann[kk] > -1).sum() > 0:
|
||||||
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
|
ann2[kk][ann2[kk] > -1] += (
|
||||||
|
np.max(ann[kk][ann[kk] > -1]) + 1
|
||||||
|
)
|
||||||
|
|
||||||
if (kk != "class_id_file") and (kk != "annotated"):
|
if (kk != "class_id_file") and (kk != "annotated"):
|
||||||
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
||||||
@ -579,7 +592,8 @@ def combine_audio_aug(
|
|||||||
|
|
||||||
|
|
||||||
def _prepare_annotation(
|
def _prepare_annotation(
|
||||||
annotation: Annotation, class_names: List[str]
|
annotation: Annotation,
|
||||||
|
class_names: List[str],
|
||||||
) -> Annotation:
|
) -> Annotation:
|
||||||
try:
|
try:
|
||||||
class_id = class_names.index(annotation["class"])
|
class_id = class_names.index(annotation["class"])
|
||||||
@ -598,7 +612,7 @@ def _prepare_annotation(
|
|||||||
|
|
||||||
|
|
||||||
def _prepare_file_annotation(
|
def _prepare_file_annotation(
|
||||||
annotation: FileAnnotations,
|
annotation: FileAnnotation,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
classes_to_ignore: List[str],
|
classes_to_ignore: List[str],
|
||||||
) -> AudioLoaderAnnotationGroup:
|
) -> AudioLoaderAnnotationGroup:
|
||||||
@ -626,7 +640,9 @@ def _prepare_file_annotation(
|
|||||||
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
||||||
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
||||||
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
||||||
"class_ids": np.array([ann.get("class_id", -1) for ann in annotations]),
|
"class_ids": np.array(
|
||||||
|
[ann.get("class_id", -1) for ann in annotations]
|
||||||
|
),
|
||||||
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
||||||
"class_id_file": class_id_file,
|
"class_id_file": class_id_file,
|
||||||
}
|
}
|
||||||
@ -639,15 +655,15 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_anns_ip: List[FileAnnotations],
|
data_anns_ip: List[FileAnnotation],
|
||||||
params,
|
params: AudioLoaderParameters,
|
||||||
dataset_name: Optional[str] = None,
|
dataset_name: Optional[str] = None,
|
||||||
is_train: bool = False,
|
is_train: bool = False,
|
||||||
|
return_spec_for_viz: bool = False,
|
||||||
):
|
):
|
||||||
self.is_train: bool = is_train
|
self.is_train = is_train
|
||||||
self.params: dict = params
|
self.params = params
|
||||||
self.return_spec_for_viz: bool = False
|
self.return_spec_for_viz = return_spec_for_viz
|
||||||
|
|
||||||
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
||||||
_prepare_file_annotation(
|
_prepare_file_annotation(
|
||||||
ann,
|
ann,
|
||||||
@ -657,61 +673,6 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
for ann in data_anns_ip
|
for ann in data_anns_ip
|
||||||
]
|
]
|
||||||
|
|
||||||
# for ii in range(len(data_anns_ip)):
|
|
||||||
# dd = copy.deepcopy(data_anns_ip[ii])
|
|
||||||
#
|
|
||||||
# # filter out unused annotation here
|
|
||||||
# filtered_annotations = []
|
|
||||||
# for ii, aa in enumerate(dd["annotation"]):
|
|
||||||
# if "individual" in aa.keys():
|
|
||||||
# aa["individual"] = int(aa["individual"])
|
|
||||||
#
|
|
||||||
# # if only one call labeled it has to be from the same
|
|
||||||
# # individual
|
|
||||||
# if len(dd["annotation"]) == 1:
|
|
||||||
# aa["individual"] = 0
|
|
||||||
#
|
|
||||||
# # convert class name into class label
|
|
||||||
# if aa["class"] in self.params["class_names"]:
|
|
||||||
# aa["class_id"] = self.params["class_names"].index(
|
|
||||||
# aa["class"]
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# aa["class_id"] = -1
|
|
||||||
#
|
|
||||||
# if aa["class"] not in self.params["classes_to_ignore"]:
|
|
||||||
# filtered_annotations.append(aa)
|
|
||||||
#
|
|
||||||
# dd["annotation"] = filtered_annotations
|
|
||||||
# dd["start_times"] = np.array(
|
|
||||||
# [aa["start_time"] for aa in dd["annotation"]]
|
|
||||||
# )
|
|
||||||
# dd["end_times"] = np.array(
|
|
||||||
# [aa["end_time"] for aa in dd["annotation"]]
|
|
||||||
# )
|
|
||||||
# dd["high_freqs"] = np.array(
|
|
||||||
# [float(aa["high_freq"]) for aa in dd["annotation"]]
|
|
||||||
# )
|
|
||||||
# dd["low_freqs"] = np.array(
|
|
||||||
# [float(aa["low_freq"]) for aa in dd["annotation"]]
|
|
||||||
# )
|
|
||||||
# dd["class_ids"] = np.array(
|
|
||||||
# [aa["class_id"] for aa in dd["annotation"]]
|
|
||||||
# ).astype(np.int32)
|
|
||||||
# dd["individual_ids"] = np.array(
|
|
||||||
# [aa["individual"] for aa in dd["annotation"]]
|
|
||||||
# ).astype(np.int32)
|
|
||||||
#
|
|
||||||
# # file level class name
|
|
||||||
# dd["class_id_file"] = -1
|
|
||||||
# if "class_name" in dd.keys():
|
|
||||||
# if dd["class_name"] in self.params["class_names"]:
|
|
||||||
# dd["class_id_file"] = self.params["class_names"].index(
|
|
||||||
# dd["class_name"]
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# self.data_anns.append(dd)
|
|
||||||
|
|
||||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
||||||
self.max_num_anns = 2 * np.max(
|
self.max_num_anns = 2 * np.max(
|
||||||
ann_cnt
|
ann_cnt
|
||||||
@ -730,7 +691,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
def get_file_and_anns(
|
def get_file_and_anns(
|
||||||
self,
|
self,
|
||||||
index: Optional[int] = None,
|
index: Optional[int] = None,
|
||||||
) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]:
|
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
|
||||||
"""Get an audio file and its annotations.
|
"""Get an audio file and its annotations.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -742,7 +703,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
-------
|
-------
|
||||||
audio_raw : np.ndarray
|
audio_raw : np.ndarray
|
||||||
Loaded audio file.
|
Loaded audio file.
|
||||||
sampling_rate : int
|
sampling_rate : float
|
||||||
Sampling rate of the audio file.
|
Sampling rate of the audio file.
|
||||||
duration : float
|
duration : float
|
||||||
Duration of the audio file in seconds.
|
Duration of the audio file in seconds.
|
||||||
@ -837,7 +798,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
(
|
(
|
||||||
audio2,
|
audio2,
|
||||||
sampling_rate2,
|
sampling_rate2,
|
||||||
duration2,
|
_,
|
||||||
ann2,
|
ann2,
|
||||||
) = self.get_file_and_anns()
|
) = self.get_file_and_anns()
|
||||||
audio, ann = combine_audio_aug(
|
audio, ann = combine_audio_aug(
|
||||||
@ -846,7 +807,11 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# simulate echo by adding delayed copy of the file
|
# simulate echo by adding delayed copy of the file
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
if np.random.random() < self.params["aug_prob"]:
|
||||||
audio = echo_aug(audio, sampling_rate, self.params)
|
audio = echo_aug(
|
||||||
|
audio,
|
||||||
|
sampling_rate,
|
||||||
|
echo_max_delay=self.params["echo_max_delay"],
|
||||||
|
)
|
||||||
|
|
||||||
# resample the audio
|
# resample the audio
|
||||||
# if np.random.random() < self.params["aug_prob"]:
|
# if np.random.random() < self.params["aug_prob"]:
|
||||||
@ -855,11 +820,16 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
# create spectrogram
|
# create spectrogram
|
||||||
spec, spec_for_viz = au.generate_spectrogram(
|
spec = au.generate_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
self.params,
|
fft_win_length=self.params["fft_win_length"],
|
||||||
self.return_spec_for_viz,
|
fft_overlap=self.params["fft_overlap"],
|
||||||
|
max_freq=self.params["max_freq"],
|
||||||
|
min_freq=self.params["min_freq"],
|
||||||
|
spec_scale=self.params["spec_scale"],
|
||||||
|
denoise_spec_avg=self.params["denoise_spec_avg"],
|
||||||
|
max_scale_spec=self.params["max_scale_spec"],
|
||||||
)
|
)
|
||||||
rsf = self.params["resize_factor"]
|
rsf = self.params["resize_factor"]
|
||||||
spec_op_shape = (
|
spec_op_shape = (
|
||||||
@ -879,20 +849,29 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
# augment spectrogram
|
# augment spectrogram
|
||||||
if self.is_train and self.params["augment_at_train"]:
|
if self.is_train and self.params["augment_at_train"]:
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
if np.random.random() < self.params["aug_prob"]:
|
||||||
spec = scale_vol_aug(spec, self.params)
|
spec = scale_vol_aug(
|
||||||
|
spec,
|
||||||
|
spec_amp_scaling=self.params["spec_amp_scaling"],
|
||||||
|
)
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
if np.random.random() < self.params["aug_prob"]:
|
||||||
spec = warp_spec_aug(
|
spec = warp_spec_aug(
|
||||||
spec,
|
spec,
|
||||||
ann,
|
ann,
|
||||||
self.params,
|
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
if np.random.random() < self.params["aug_prob"]:
|
||||||
spec = mask_time_aug(spec, self.params)
|
spec = mask_time_aug(
|
||||||
|
spec,
|
||||||
|
mask_max_time_perc=self.params["mask_max_time_perc"],
|
||||||
|
)
|
||||||
|
|
||||||
if np.random.random() < self.params["aug_prob"]:
|
if np.random.random() < self.params["aug_prob"]:
|
||||||
spec = mask_freq_aug(spec, self.params)
|
spec = mask_freq_aug(
|
||||||
|
spec,
|
||||||
|
mask_max_freq_perc=self.params["mask_max_freq_perc"],
|
||||||
|
)
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
outputs["spec"] = spec
|
outputs["spec"] = spec
|
||||||
@ -911,7 +890,13 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
spec_op_shape,
|
spec_op_shape,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
ann,
|
ann,
|
||||||
self.params,
|
class_names=self.params["class_names"],
|
||||||
|
fft_win_length=self.params["fft_win_length"],
|
||||||
|
fft_overlap=self.params["fft_overlap"],
|
||||||
|
max_freq=self.params["max_freq"],
|
||||||
|
min_freq=self.params["min_freq"],
|
||||||
|
resize_factor=self.params["resize_factor"],
|
||||||
|
target_sigma=self.params["target_sigma"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# hack to get around requirement that all vectors are the same length
|
# hack to get around requirement that all vectors are the same length
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def bbox_size_loss(pred_size, gt_size):
|
def bbox_size_loss(
|
||||||
|
pred_size: torch.Tensor,
|
||||||
|
gt_size: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Bounding box size loss. Only compute loss where there is a bounding box.
|
Bounding box size loss. Only compute loss where there is a bounding box.
|
||||||
"""
|
"""
|
||||||
@ -12,7 +17,12 @@ def bbox_size_loss(pred_size, gt_size):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def focal_loss(pred, gt, weights=None, valid_mask=None):
|
def focal_loss(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
gt: torch.Tensor,
|
||||||
|
weights: Optional[torch.Tensor] = None,
|
||||||
|
valid_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
||||||
pred (batch x c x h x w)
|
pred (batch x c x h x w)
|
||||||
@ -52,7 +62,11 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def mse_loss(pred, gt, weights=None, valid_mask=None):
|
def mse_loss(
|
||||||
|
pred: torch.Tensor,
|
||||||
|
gt: torch.Tensor,
|
||||||
|
valid_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Mean squared error loss.
|
Mean squared error loss.
|
||||||
"""
|
"""
|
||||||
|
@ -5,6 +5,7 @@ import warnings
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
import batdetect2.detector.post_process as pp
|
import batdetect2.detector.post_process as pp
|
||||||
@ -29,7 +30,7 @@ def save_images_batch(model, data_loader, params):
|
|||||||
|
|
||||||
ind = 0 # first image in each batch
|
ind = 0 # first image in each batch
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, inputs in enumerate(data_loader):
|
for inputs in data_loader:
|
||||||
data = inputs["spec"].to(params["device"])
|
data = inputs["spec"].to(params["device"])
|
||||||
outputs = model(data)
|
outputs = model(data)
|
||||||
|
|
||||||
@ -81,7 +82,12 @@ def save_image(
|
|||||||
|
|
||||||
|
|
||||||
def loss_fun(
|
def loss_fun(
|
||||||
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
outputs,
|
||||||
|
gt_det,
|
||||||
|
gt_size,
|
||||||
|
gt_class,
|
||||||
|
det_criterion,
|
||||||
|
params,
|
||||||
):
|
):
|
||||||
# detection loss
|
# detection loss
|
||||||
loss = params["det_loss_weight"] * det_criterion(
|
loss = params["det_loss_weight"] * det_criterion(
|
||||||
@ -104,7 +110,13 @@ def loss_fun(
|
|||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
model,
|
||||||
|
epoch,
|
||||||
|
data_loader,
|
||||||
|
det_criterion,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
params,
|
||||||
):
|
):
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
@ -309,7 +321,7 @@ def select_model(params):
|
|||||||
resize_factor=params["resize_factor"],
|
resize_factor=params["resize_factor"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("No valid network specified")
|
raise ValueError("No valid network specified")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -319,9 +331,9 @@ def main():
|
|||||||
params = parameters.get_params(True)
|
params = parameters.get_params(True)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
params["device"] = "cuda"
|
params.device = "cuda"
|
||||||
else:
|
else:
|
||||||
params["device"] = "cpu"
|
params.device = "cpu"
|
||||||
|
|
||||||
# setup arg parser and populate it with exiting parameters - will not work with lists
|
# setup arg parser and populate it with exiting parameters - will not work with lists
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -349,13 +361,16 @@ def main():
|
|||||||
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
|
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
|
||||||
help='Will set low and high frequency the same for these classes. Separate names with ";"',
|
help='Will set low and high frequency the same for these classes. Separate names with ";"',
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, val in params.items():
|
for key, val in params.items():
|
||||||
parser.add_argument("--" + key, type=type(val), default=val)
|
parser.add_argument("--" + key, type=type(val), default=val)
|
||||||
params = vars(parser.parse_args())
|
params = vars(parser.parse_args())
|
||||||
|
|
||||||
# save notes file
|
# save notes file
|
||||||
if params["notes"] != "":
|
if params["notes"] != "":
|
||||||
tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"])
|
tu.write_notes_file(
|
||||||
|
params["experiment"] + "notes.txt", params["notes"]
|
||||||
|
)
|
||||||
|
|
||||||
# load the training and test meta data - there are different splits defined
|
# load the training and test meta data - there are different splits defined
|
||||||
train_sets, test_sets = ts.get_train_test_data(
|
train_sets, test_sets = ts.get_train_test_data(
|
||||||
@ -374,15 +389,11 @@ def main():
|
|||||||
for tt in train_sets:
|
for tt in train_sets:
|
||||||
print(tt["ann_path"])
|
print(tt["ann_path"])
|
||||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
||||||
(
|
data_train = tu.load_set_of_anns(
|
||||||
data_train,
|
|
||||||
params["class_names"],
|
|
||||||
params["class_inv_freq"],
|
|
||||||
) = tu.load_set_of_anns(
|
|
||||||
train_sets,
|
train_sets,
|
||||||
classes_to_ignore,
|
classes_to_ignore=classes_to_ignore,
|
||||||
params["events_of_interest"],
|
events_of_interest=params["events_of_interest"],
|
||||||
params["convert_to_genus"],
|
convert_to_genus=params["convert_to_genus"],
|
||||||
)
|
)
|
||||||
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
||||||
params["class_names"]
|
params["class_names"]
|
||||||
@ -415,11 +426,12 @@ def main():
|
|||||||
print("\nTesting on:")
|
print("\nTesting on:")
|
||||||
for tt in test_sets:
|
for tt in test_sets:
|
||||||
print(tt["ann_path"])
|
print(tt["ann_path"])
|
||||||
data_test, _, _ = tu.load_set_of_anns(
|
|
||||||
|
data_test = tu.load_set_of_anns(
|
||||||
test_sets,
|
test_sets,
|
||||||
classes_to_ignore,
|
classes_to_ignore=classes_to_ignore,
|
||||||
params["events_of_interest"],
|
events_of_interest=params["events_of_interest"],
|
||||||
params["convert_to_genus"],
|
convert_to_genus=params["convert_to_genus"],
|
||||||
)
|
)
|
||||||
data_train = tu.remove_dupes(data_train, data_test)
|
data_train = tu.remove_dupes(data_train, data_test)
|
||||||
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
||||||
@ -447,10 +459,13 @@ def main():
|
|||||||
scheduler = CosineAnnealingLR(
|
scheduler = CosineAnnealingLR(
|
||||||
optimizer, params["num_epochs"] * len(train_loader)
|
optimizer, params["num_epochs"] * len(train_loader)
|
||||||
)
|
)
|
||||||
|
|
||||||
if params["train_loss"] == "mse":
|
if params["train_loss"] == "mse":
|
||||||
det_criterion = losses.mse_loss
|
det_criterion = losses.mse_loss
|
||||||
elif params["train_loss"] == "focal":
|
elif params["train_loss"] == "focal":
|
||||||
det_criterion = losses.focal_loss
|
det_criterion = losses.focal_loss
|
||||||
|
else:
|
||||||
|
raise ValueError("No valid loss specified")
|
||||||
|
|
||||||
# save parameters to file
|
# save parameters to file
|
||||||
with open(params["experiment"] + "params.json", "w") as da:
|
with open(params["experiment"] + "params.json", "w") as da:
|
||||||
|
@ -1,28 +1,37 @@
|
|||||||
import glob
|
|
||||||
import json
|
import json
|
||||||
import os
|
from collections import Counter
|
||||||
import random
|
from pathlib import Path
|
||||||
|
from typing import Dict, Generator, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from batdetect2 import types
|
||||||
|
|
||||||
def write_notes_file(file_name, text):
|
|
||||||
|
def write_notes_file(file_name: str, text: str):
|
||||||
with open(file_name, "a") as da:
|
with open(file_name, "a") as da:
|
||||||
da.write(text + "\n")
|
da.write(text + "\n")
|
||||||
|
|
||||||
|
|
||||||
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
|
def get_blank_dataset_dict(
|
||||||
ddict = {
|
dataset_name: str,
|
||||||
|
is_test: bool,
|
||||||
|
ann_path: str,
|
||||||
|
wav_path: str,
|
||||||
|
) -> types.DatasetDict:
|
||||||
|
return {
|
||||||
"dataset_name": dataset_name,
|
"dataset_name": dataset_name,
|
||||||
"is_test": is_test,
|
"is_test": is_test,
|
||||||
"is_binary": False,
|
"is_binary": False,
|
||||||
"ann_path": ann_path,
|
"ann_path": ann_path,
|
||||||
"wav_path": wav_path,
|
"wav_path": wav_path,
|
||||||
}
|
}
|
||||||
return ddict
|
|
||||||
|
|
||||||
|
|
||||||
def get_short_class_names(class_names, str_len=3):
|
def get_short_class_names(
|
||||||
|
class_names: List[str],
|
||||||
|
str_len: int = 3,
|
||||||
|
) -> List[str]:
|
||||||
class_names_short = []
|
class_names_short = []
|
||||||
for cc in class_names:
|
for cc in class_names:
|
||||||
class_names_short.append(
|
class_names_short.append(
|
||||||
@ -31,7 +40,10 @@ def get_short_class_names(class_names, str_len=3):
|
|||||||
return class_names_short
|
return class_names_short
|
||||||
|
|
||||||
|
|
||||||
def remove_dupes(data_train, data_test):
|
def remove_dupes(
|
||||||
|
data_train: List[types.FileAnnotation],
|
||||||
|
data_test: List[types.FileAnnotation],
|
||||||
|
) -> List[types.FileAnnotation]:
|
||||||
test_ids = [dd["id"] for dd in data_test]
|
test_ids = [dd["id"] for dd in data_test]
|
||||||
data_train_prune = []
|
data_train_prune = []
|
||||||
for aa in data_train:
|
for aa in data_train:
|
||||||
@ -43,14 +55,16 @@ def remove_dupes(data_train, data_test):
|
|||||||
return data_train_prune
|
return data_train_prune
|
||||||
|
|
||||||
|
|
||||||
def get_genus_mapping(class_names):
|
def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
|
||||||
genus_names, genus_mapping = np.unique(
|
genus_names, genus_mapping = np.unique(
|
||||||
[cc.split(" ")[0] for cc in class_names], return_inverse=True
|
[cc.split(" ")[0] for cc in class_names], return_inverse=True
|
||||||
)
|
)
|
||||||
return genus_names.tolist(), genus_mapping.tolist()
|
return genus_names.tolist(), genus_mapping.tolist()
|
||||||
|
|
||||||
|
|
||||||
def standardize_low_freq(data, class_of_interest):
|
def standardize_low_freq(
|
||||||
|
data: List[types.FileAnnotation], class_of_interest: str,
|
||||||
|
) -> List[types.FileAnnotation]:
|
||||||
# address the issue of highly variable low frequency annotations
|
# address the issue of highly variable low frequency annotations
|
||||||
# this often happens for contstant frequency calls
|
# this often happens for contstant frequency calls
|
||||||
# for the class of interest sets the low and high freq to be the dataset mean
|
# for the class of interest sets the low and high freq to be the dataset mean
|
||||||
@ -62,8 +76,8 @@ def standardize_low_freq(data, class_of_interest):
|
|||||||
low_freqs.append(aa["low_freq"])
|
low_freqs.append(aa["low_freq"])
|
||||||
high_freqs.append(aa["high_freq"])
|
high_freqs.append(aa["high_freq"])
|
||||||
|
|
||||||
low_mean = np.mean(low_freqs)
|
low_mean = float(np.mean(low_freqs))
|
||||||
high_mean = np.mean(high_freqs)
|
high_mean = float(np.mean(high_freqs))
|
||||||
assert low_mean < high_mean
|
assert low_mean < high_mean
|
||||||
|
|
||||||
print("\nStandardizing low and high frequency for:")
|
print("\nStandardizing low and high frequency for:")
|
||||||
@ -83,115 +97,148 @@ def standardize_low_freq(data, class_of_interest):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def load_set_of_anns(
|
def format_annotation(
|
||||||
data,
|
annotation: types.FileAnnotation,
|
||||||
classes_to_ignore=[],
|
events_of_interest: Optional[List[str]] = None,
|
||||||
events_of_interest=None,
|
name_replace: Optional[Dict[str, str]] = None,
|
||||||
convert_to_genus=False,
|
convert_to_genus: bool = False,
|
||||||
verbose=True,
|
classes_to_ignore: Optional[List[str]] = None,
|
||||||
list_of_anns=False,
|
) -> types.FileAnnotation:
|
||||||
filter_issues=False,
|
formated = []
|
||||||
name_replace=False,
|
for aa in annotation["annotation"]:
|
||||||
):
|
if (
|
||||||
|
events_of_interest is not None
|
||||||
|
and aa["event"] not in events_of_interest
|
||||||
|
):
|
||||||
|
# Omit files with annotation issues
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove leading and trailing spaces
|
||||||
|
class_name = aa["class"].strip()
|
||||||
|
|
||||||
|
if name_replace is not None:
|
||||||
|
# replace_names will be a dictionary mapping input name to output
|
||||||
|
class_name = name_replace.get(class_name, class_name)
|
||||||
|
|
||||||
|
if convert_to_genus:
|
||||||
|
# convert everything to genus name
|
||||||
|
class_name = class_name.split(" ")[0]
|
||||||
|
|
||||||
|
# NOTE: It is important to acknowledge that the class names filtering
|
||||||
|
# is done after the name replacement and the conversion to
|
||||||
|
# genus name. This allows filtering converted genus names and names
|
||||||
|
# that were replaced with a name that should be ignored.
|
||||||
|
if classes_to_ignore is not None and class_name in classes_to_ignore:
|
||||||
|
# Omit annotations with ignored classes
|
||||||
|
continue
|
||||||
|
|
||||||
|
formated.append(
|
||||||
|
{
|
||||||
|
**aa,
|
||||||
|
"class": class_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**annotation,
|
||||||
|
"annotation": formated,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_names(
|
||||||
|
data: List[types.FileAnnotation],
|
||||||
|
classes_to_ignore: Optional[List[str]] = None,
|
||||||
|
) -> Tuple[Counter[str], List[float]]:
|
||||||
|
"""Extracts class names and their inverse frequencies.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data
|
||||||
|
A list of file annotations, where each annotation contains a list of
|
||||||
|
sound events with associated class names.
|
||||||
|
classes_to_ignore
|
||||||
|
A list of class names to ignore.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
class_names
|
||||||
|
A list of unique class names extracted from the annotations.
|
||||||
|
class_inv_freq
|
||||||
|
List of inverse frequencies of each class name in the provided data.
|
||||||
|
"""
|
||||||
|
if classes_to_ignore is None:
|
||||||
|
classes_to_ignore = []
|
||||||
|
|
||||||
|
class_names_list: List[str] = []
|
||||||
|
for annotation in data:
|
||||||
|
for sound_event in annotation["annotation"]:
|
||||||
|
if sound_event["class"] in classes_to_ignore:
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_names_list.append(sound_event["class"])
|
||||||
|
|
||||||
|
counts = Counter(class_names_list)
|
||||||
|
mean_counts = float(np.mean(list(counts.values())))
|
||||||
|
return counts, [mean_counts / counts[cc] for cc in class_names_list]
|
||||||
|
|
||||||
|
|
||||||
|
def report_class_counts(class_names: Counter[str]):
|
||||||
|
print("Class count:")
|
||||||
|
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||||
|
for index, (class_name, count) in enumerate(class_names.most_common()):
|
||||||
|
print(f"{index:<5}{class_name:<{str_len}}{count}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_set_of_anns(
|
||||||
|
data: List[types.DatasetDict],
|
||||||
|
*,
|
||||||
|
convert_to_genus: bool = False,
|
||||||
|
filter_issues: bool = False,
|
||||||
|
events_of_interest: Optional[List[str]] = None,
|
||||||
|
classes_to_ignore: Optional[List[str]] = None,
|
||||||
|
name_replace: Optional[Dict[str, str]] = None,
|
||||||
|
) -> List[types.FileAnnotation]:
|
||||||
# load the annotations
|
# load the annotations
|
||||||
anns = []
|
anns = []
|
||||||
if list_of_anns:
|
|
||||||
# path to list of individual json files
|
|
||||||
anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
|
|
||||||
else:
|
|
||||||
# dictionary of datasets
|
|
||||||
for dd in data:
|
|
||||||
anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
|
|
||||||
|
|
||||||
# discarding unannoated files
|
# dictionary of datasets
|
||||||
anns = [aa for aa in anns if aa["annotated"] is True]
|
for dataset in data:
|
||||||
|
for ann in load_anns(dataset["ann_path"], dataset["wav_path"]):
|
||||||
|
if not ann["annotated"]:
|
||||||
|
# Omit unannotated files
|
||||||
|
continue
|
||||||
|
|
||||||
# filter files that have annotation issues - is the input is a dictionary of
|
if filter_issues and ann["issues"]:
|
||||||
# datasets, this will lilely have already been done
|
# Omit files with annotation issues
|
||||||
if filter_issues:
|
continue
|
||||||
anns = [aa for aa in anns if aa["issues"] is False]
|
|
||||||
|
|
||||||
# check for some basic formatting errors with class names
|
anns.append(
|
||||||
for ann in anns:
|
format_annotation(
|
||||||
for aa in ann["annotation"]:
|
ann,
|
||||||
aa["class"] = aa["class"].strip()
|
events_of_interest=events_of_interest,
|
||||||
|
name_replace=name_replace,
|
||||||
# only load specified events - i.e. types of calls
|
convert_to_genus=convert_to_genus,
|
||||||
if events_of_interest is not None:
|
classes_to_ignore=classes_to_ignore,
|
||||||
for ann in anns:
|
)
|
||||||
filtered_events = []
|
|
||||||
for aa in ann["annotation"]:
|
|
||||||
if aa["event"] in events_of_interest:
|
|
||||||
filtered_events.append(aa)
|
|
||||||
ann["annotation"] = filtered_events
|
|
||||||
|
|
||||||
# change class names
|
|
||||||
# replace_names will be a dictionary mapping input name to output
|
|
||||||
if type(name_replace) is dict:
|
|
||||||
for ann in anns:
|
|
||||||
for aa in ann["annotation"]:
|
|
||||||
if aa["class"] in name_replace:
|
|
||||||
aa["class"] = name_replace[aa["class"]]
|
|
||||||
|
|
||||||
# convert everything to genus name
|
|
||||||
if convert_to_genus:
|
|
||||||
for ann in anns:
|
|
||||||
for aa in ann["annotation"]:
|
|
||||||
aa["class"] = aa["class"].split(" ")[0]
|
|
||||||
|
|
||||||
# get unique class names
|
|
||||||
class_names_all = []
|
|
||||||
for ann in anns:
|
|
||||||
for aa in ann["annotation"]:
|
|
||||||
if aa["class"] not in classes_to_ignore:
|
|
||||||
class_names_all.append(aa["class"])
|
|
||||||
|
|
||||||
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
|
|
||||||
class_inv_freq = class_cnts.sum() / (
|
|
||||||
len(class_names) * class_cnts.astype(np.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print("Class count:")
|
|
||||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
|
||||||
for cc in range(len(class_names)):
|
|
||||||
print(
|
|
||||||
str(cc).ljust(5)
|
|
||||||
+ class_names[cc].ljust(str_len)
|
|
||||||
+ str(class_cnts[cc])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(classes_to_ignore) == 0:
|
|
||||||
return anns
|
|
||||||
else:
|
|
||||||
return anns, class_names.tolist(), class_inv_freq.tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def load_anns(ann_file_name, raw_audio_dir):
|
|
||||||
with open(ann_file_name) as da:
|
|
||||||
anns = json.load(da)
|
|
||||||
|
|
||||||
for aa in anns:
|
|
||||||
aa["file_path"] = raw_audio_dir + aa["id"]
|
|
||||||
|
|
||||||
return anns
|
return anns
|
||||||
|
|
||||||
|
|
||||||
def load_anns_from_path(ann_file_dir, raw_audio_dir):
|
def load_anns(
|
||||||
files = glob.glob(ann_file_dir + "*.json")
|
ann_dir: str,
|
||||||
anns = []
|
raw_audio_dir: str,
|
||||||
for ff in files:
|
) -> Generator[types.FileAnnotation, None, None]:
|
||||||
with open(ff) as da:
|
for path in Path(ann_dir).rglob("*.json"):
|
||||||
ann = json.load(da)
|
with open(path) as fp:
|
||||||
ann["file_path"] = raw_audio_dir + ann["id"]
|
file_annotation = json.load(fp)
|
||||||
anns.append(ann)
|
|
||||||
|
|
||||||
return anns
|
file_annotation["file_path"] = raw_audio_dir + file_annotation["id"]
|
||||||
|
yield file_annotation
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter(object):
|
class AverageMeter:
|
||||||
"""Computes and stores the average and current value"""
|
"""Computes and stores the average and current value."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
from typing import List, NamedTuple, Optional, Union
|
from typing import Any, List, NamedTuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -26,8 +26,7 @@ __all__ = [
|
|||||||
"Annotation",
|
"Annotation",
|
||||||
"DetectionModel",
|
"DetectionModel",
|
||||||
"FeatureExtractionParameters",
|
"FeatureExtractionParameters",
|
||||||
"FeatureExtractor",
|
"FileAnnotation",
|
||||||
"FileAnnotations",
|
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"ModelParameters",
|
"ModelParameters",
|
||||||
"NonMaximumSuppressionConfig",
|
"NonMaximumSuppressionConfig",
|
||||||
@ -94,7 +93,10 @@ class ModelParameters(TypedDict):
|
|||||||
"""Resize factor."""
|
"""Resize factor."""
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: List[str]
|
||||||
"""Class names. The model is trained to detect these classes."""
|
"""Class names.
|
||||||
|
|
||||||
|
The model is trained to detect these classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||||
@ -103,8 +105,8 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
|
|||||||
class Annotation(DictWithClass):
|
class Annotation(DictWithClass):
|
||||||
"""Format of annotations.
|
"""Format of annotations.
|
||||||
|
|
||||||
This is the format of a single annotation as expected by the annotation
|
This is the format of a single annotation as expected by the
|
||||||
tool.
|
annotation tool.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_time: float
|
start_time: float
|
||||||
@ -113,10 +115,10 @@ class Annotation(DictWithClass):
|
|||||||
end_time: float
|
end_time: float
|
||||||
"""End time in seconds."""
|
"""End time in seconds."""
|
||||||
|
|
||||||
low_freq: int
|
low_freq: float
|
||||||
"""Low frequency in Hz."""
|
"""Low frequency in Hz."""
|
||||||
|
|
||||||
high_freq: int
|
high_freq: float
|
||||||
"""High frequency in Hz."""
|
"""High frequency in Hz."""
|
||||||
|
|
||||||
class_prob: float
|
class_prob: float
|
||||||
@ -135,7 +137,7 @@ class Annotation(DictWithClass):
|
|||||||
"""Numeric ID for the class of the annotation."""
|
"""Numeric ID for the class of the annotation."""
|
||||||
|
|
||||||
|
|
||||||
class FileAnnotations(TypedDict):
|
class FileAnnotation(TypedDict):
|
||||||
"""Format of results.
|
"""Format of results.
|
||||||
|
|
||||||
This is the format of the results expected by the annotation tool.
|
This is the format of the results expected by the annotation tool.
|
||||||
@ -157,7 +159,7 @@ class FileAnnotations(TypedDict):
|
|||||||
"""Time expansion factor."""
|
"""Time expansion factor."""
|
||||||
|
|
||||||
class_name: str
|
class_name: str
|
||||||
"""Class predicted at file level"""
|
"""Class predicted at file level."""
|
||||||
|
|
||||||
notes: str
|
notes: str
|
||||||
"""Notes of file."""
|
"""Notes of file."""
|
||||||
@ -169,7 +171,7 @@ class FileAnnotations(TypedDict):
|
|||||||
class RunResults(TypedDict):
|
class RunResults(TypedDict):
|
||||||
"""Run results."""
|
"""Run results."""
|
||||||
|
|
||||||
pred_dict: FileAnnotations
|
pred_dict: FileAnnotation
|
||||||
"""Predictions in the format expected by the annotation tool."""
|
"""Predictions in the format expected by the annotation tool."""
|
||||||
|
|
||||||
spec_feats: NotRequired[List[np.ndarray]]
|
spec_feats: NotRequired[List[np.ndarray]]
|
||||||
@ -394,9 +396,9 @@ class PredictionResults(TypedDict):
|
|||||||
class DetectionModel(Protocol):
|
class DetectionModel(Protocol):
|
||||||
"""Protocol for detection models.
|
"""Protocol for detection models.
|
||||||
|
|
||||||
This protocol is used to define the interface for the detection models.
|
This protocol is used to define the interface for the detection
|
||||||
This allows us to use the same code for training and inference, even
|
models. This allows us to use the same code for training and
|
||||||
though the models are different.
|
inference, even though the models are different.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_classes: int
|
num_classes: int
|
||||||
@ -416,16 +418,14 @@ class DetectionModel(Protocol):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ip: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
return_feats: bool = False,
|
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
"""Forward pass of the model."""
|
"""Forward pass of the model."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
ip: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
return_feats: bool = False,
|
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
"""Forward pass of the model."""
|
"""Forward pass of the model."""
|
||||||
...
|
...
|
||||||
@ -490,8 +490,10 @@ class HeatmapParameters(TypedDict):
|
|||||||
"""Maximum frequency to consider in Hz."""
|
"""Maximum frequency to consider in Hz."""
|
||||||
|
|
||||||
target_sigma: float
|
target_sigma: float
|
||||||
"""Sigma for the Gaussian kernel. Controls the width of the points in
|
"""Sigma for the Gaussian kernel.
|
||||||
the heatmap."""
|
|
||||||
|
Controls the width of the points in the heatmap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AnnotationGroup(TypedDict):
|
class AnnotationGroup(TypedDict):
|
||||||
@ -522,10 +524,10 @@ class AnnotationGroup(TypedDict):
|
|||||||
annotated: NotRequired[bool]
|
annotated: NotRequired[bool]
|
||||||
"""Wether the annotation group is complete or not.
|
"""Wether the annotation group is complete or not.
|
||||||
|
|
||||||
Usually annotation groups are associated to a single
|
Usually annotation groups are associated to a single audio clip. If
|
||||||
audio clip. If the annotation group is complete, it means that all
|
the annotation group is complete, it means that all relevant sound
|
||||||
relevant sound events have been annotated. If it is not complete, it
|
events have been annotated. If it is not complete, it means that
|
||||||
means that some sound events might not have been annotated.
|
some sound events might not have been annotated.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x_inds: NotRequired[np.ndarray]
|
x_inds: NotRequired[np.ndarray]
|
||||||
@ -535,12 +537,88 @@ class AnnotationGroup(TypedDict):
|
|||||||
"""Y coordinate of the annotations in the spectrogram."""
|
"""Y coordinate of the annotations in the spectrogram."""
|
||||||
|
|
||||||
|
|
||||||
class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations):
|
class AudioLoaderAnnotationGroup(TypedDict):
|
||||||
"""Group of annotation items for the training audio loader.
|
"""Group of annotation items for the training audio loader.
|
||||||
|
|
||||||
This class is used to store the annotations for the training audio
|
This class is used to store the annotations for the training audio
|
||||||
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
|
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
duration: float
|
||||||
|
issues: bool
|
||||||
|
file_path: str
|
||||||
|
time_exp: float
|
||||||
|
class_name: str
|
||||||
|
notes: str
|
||||||
|
start_times: np.ndarray
|
||||||
|
end_times: np.ndarray
|
||||||
|
low_freqs: np.ndarray
|
||||||
|
high_freqs: np.ndarray
|
||||||
|
class_ids: np.ndarray
|
||||||
|
individual_ids: np.ndarray
|
||||||
|
x_inds: np.ndarray
|
||||||
|
y_inds: np.ndarray
|
||||||
|
annotation: List[Annotation]
|
||||||
|
annotated: bool
|
||||||
class_id_file: int
|
class_id_file: int
|
||||||
"""ID of the class of the file."""
|
"""ID of the class of the file."""
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLoaderParameters(TypedDict):
|
||||||
|
class_names: List[str]
|
||||||
|
classes_to_ignore: List[str]
|
||||||
|
target_samp_rate: int
|
||||||
|
scale_raw_audio: bool
|
||||||
|
fft_win_length: float
|
||||||
|
fft_overlap: float
|
||||||
|
spec_train_width: int
|
||||||
|
resize_factor: float
|
||||||
|
spec_divide_factor: int
|
||||||
|
augment_at_train: bool
|
||||||
|
augment_at_train_combine: bool
|
||||||
|
aug_prob: float
|
||||||
|
spec_height: int
|
||||||
|
echo_max_delay: float
|
||||||
|
spec_amp_scaling: float
|
||||||
|
stretch_squeeze_delta: float
|
||||||
|
mask_max_time_perc: float
|
||||||
|
mask_max_freq_perc: float
|
||||||
|
max_freq: float
|
||||||
|
min_freq: float
|
||||||
|
spec_scale: str
|
||||||
|
denoise_spec_avg: bool
|
||||||
|
max_scale_spec: bool
|
||||||
|
target_sigma: float
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractor(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prediction: Prediction,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetDict(TypedDict):
|
||||||
|
"""Dataset dictionary.
|
||||||
|
|
||||||
|
This is the format of the dictionary that contains the dataset
|
||||||
|
information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_name: str
|
||||||
|
"""Name of the dataset."""
|
||||||
|
|
||||||
|
is_test: bool
|
||||||
|
"""Whether the dataset is a test set."""
|
||||||
|
|
||||||
|
is_binary: bool
|
||||||
|
"""Whether the dataset is binary."""
|
||||||
|
|
||||||
|
ann_path: str
|
||||||
|
"""Path to the annotations."""
|
||||||
|
|
||||||
|
wav_path: str
|
||||||
|
"""Path to the audio files."""
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union, overload
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.core.spectrum
|
import librosa.core.spectrum
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import wavfile
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_audio",
|
"load_audio",
|
||||||
"generate_spectrogram",
|
"generate_spectrogram",
|
||||||
@ -15,113 +13,171 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
@overload
|
||||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
def time_to_x_coords(
|
||||||
|
time_in_file: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
fft_win_length: float,
|
||||||
|
fft_overlap: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def time_to_x_coords(
|
||||||
|
time_in_file: float,
|
||||||
|
sampling_rate: float,
|
||||||
|
fft_win_length: float,
|
||||||
|
fft_overlap: float,
|
||||||
|
) -> float:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def time_to_x_coords(
|
||||||
|
time_in_file: Union[float, np.ndarray],
|
||||||
|
sampling_rate: float,
|
||||||
|
fft_win_length: float,
|
||||||
|
fft_overlap: float,
|
||||||
|
) -> Union[float, np.ndarray]:
|
||||||
|
nfft = np.floor(fft_win_length * sampling_rate)
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
noverlap = np.floor(fft_overlap * nfft)
|
||||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||||
|
|
||||||
|
|
||||||
# NOTE this is also defined in post_process
|
# NOTE this is also defined in post_process
|
||||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
def x_coords_to_time(
|
||||||
|
x_pos: float,
|
||||||
|
sampling_rate: int,
|
||||||
|
fft_win_length: float,
|
||||||
|
fft_overlap: float,
|
||||||
|
) -> float:
|
||||||
nfft = np.floor(fft_win_length * sampling_rate)
|
nfft = np.floor(fft_win_length * sampling_rate)
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
noverlap = np.floor(fft_overlap * nfft)
|
||||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
|
||||||
|
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for
|
||||||
|
# center of temporal window
|
||||||
|
|
||||||
|
|
||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
audio,
|
audio: np.ndarray,
|
||||||
sampling_rate,
|
sampling_rate: float,
|
||||||
params,
|
fft_win_length: float,
|
||||||
return_spec_for_viz=False,
|
fft_overlap: float,
|
||||||
check_spec_size=True,
|
max_freq: float,
|
||||||
):
|
min_freq: float,
|
||||||
|
spec_scale: str,
|
||||||
|
denoise_spec_avg: bool = False,
|
||||||
|
max_scale_spec: bool = False,
|
||||||
|
) -> np.ndarray:
|
||||||
# generate spectrogram
|
# generate spectrogram
|
||||||
spec = gen_mag_spectrogram(
|
spec = gen_mag_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
params["fft_win_length"],
|
window_len=fft_win_length,
|
||||||
params["fft_overlap"],
|
overlap_perc=fft_overlap,
|
||||||
|
)
|
||||||
|
spec = crop_spectrogram(
|
||||||
|
spec,
|
||||||
|
fft_win_length=fft_win_length,
|
||||||
|
max_freq=max_freq,
|
||||||
|
min_freq=min_freq,
|
||||||
|
)
|
||||||
|
spec = scale_spectrogram(
|
||||||
|
spec,
|
||||||
|
sampling_rate,
|
||||||
|
spec_scale=spec_scale,
|
||||||
|
fft_win_length=fft_win_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if denoise_spec_avg:
|
||||||
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
|
if max_scale_spec:
|
||||||
|
spec = max_scale_spectrogram(spec)
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def crop_spectrogram(
|
||||||
|
spec: np.ndarray,
|
||||||
|
fft_win_length: float,
|
||||||
|
max_freq: float,
|
||||||
|
min_freq: float,
|
||||||
|
) -> np.ndarray:
|
||||||
# crop to min/max freq
|
# crop to min/max freq
|
||||||
max_freq = round(params["max_freq"] * params["fft_win_length"])
|
max_freq = round(max_freq * fft_win_length)
|
||||||
min_freq = round(params["min_freq"] * params["fft_win_length"])
|
min_freq = round(min_freq * fft_win_length)
|
||||||
if spec.shape[0] < max_freq:
|
if spec.shape[0] < max_freq:
|
||||||
freq_pad = max_freq - spec.shape[0]
|
freq_pad = max_freq - spec.shape[0]
|
||||||
spec = np.vstack(
|
spec = np.vstack(
|
||||||
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
|
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
|
||||||
)
|
)
|
||||||
spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :]
|
return spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||||
|
|
||||||
if params["spec_scale"] == "log":
|
|
||||||
log_scaling = (
|
def denoise_spectrogram(spec: np.ndarray) -> np.ndarray:
|
||||||
2.0
|
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
||||||
* (1.0 / sampling_rate)
|
return spec.clip(min=0)
|
||||||
* (
|
|
||||||
1.0
|
|
||||||
/ (
|
def max_scale_spectrogram(spec: np.ndarray) -> np.ndarray:
|
||||||
np.abs(
|
return spec / (spec.max() + 10e-6)
|
||||||
np.hanning(
|
|
||||||
int(params["fft_win_length"] * sampling_rate)
|
|
||||||
)
|
def log_scale(
|
||||||
)
|
spec: np.ndarray,
|
||||||
** 2
|
sampling_rate: float,
|
||||||
).sum()
|
fft_win_length: float,
|
||||||
)
|
) -> np.ndarray:
|
||||||
|
log_scaling = (
|
||||||
|
2.0
|
||||||
|
* (1.0 / sampling_rate)
|
||||||
|
* (
|
||||||
|
1.0
|
||||||
|
/ (
|
||||||
|
np.abs(np.hanning(int(fft_win_length * sampling_rate))) ** 2
|
||||||
|
).sum()
|
||||||
)
|
)
|
||||||
# log_scaling = (1.0 / sampling_rate)*0.1
|
)
|
||||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
return np.log1p(log_scaling * spec)
|
||||||
spec = np.log1p(log_scaling * spec_cropped)
|
|
||||||
elif params["spec_scale"] == "pcen":
|
|
||||||
spec = pcen(spec_cropped, sampling_rate)
|
|
||||||
|
|
||||||
elif params["spec_scale"] == "none":
|
|
||||||
pass
|
|
||||||
|
|
||||||
if params["denoise_spec_avg"]:
|
def scale_spectrogram(
|
||||||
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
spec: np.ndarray,
|
||||||
spec.clip(min=0, out=spec)
|
sampling_rate: float,
|
||||||
|
spec_scale: str,
|
||||||
|
fft_win_length: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if spec_scale == "log":
|
||||||
|
return log_scale(spec, sampling_rate, fft_win_length)
|
||||||
|
|
||||||
if params["max_scale_spec"]:
|
if spec_scale == "pcen":
|
||||||
spec = spec / (spec.max() + 10e-6)
|
return pcen(spec, sampling_rate)
|
||||||
|
|
||||||
# needs to be divisible by specific factor - if not it should have been padded
|
return spec
|
||||||
# if check_spec_size:
|
|
||||||
# assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
|
||||||
# assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_spec_for_viz(
|
||||||
|
spec: np.ndarray,
|
||||||
|
sampling_rate: int,
|
||||||
|
fft_win_length: float,
|
||||||
|
) -> np.ndarray:
|
||||||
# for visualization purposes - use log scaled spectrogram
|
# for visualization purposes - use log scaled spectrogram
|
||||||
if return_spec_for_viz:
|
return log_scale(
|
||||||
log_scaling = (
|
spec,
|
||||||
2.0
|
sampling_rate,
|
||||||
* (1.0 / sampling_rate)
|
fft_win_length=fft_win_length,
|
||||||
* (
|
).astype(np.float32)
|
||||||
1.0
|
|
||||||
/ (
|
|
||||||
np.abs(
|
|
||||||
np.hanning(
|
|
||||||
int(params["fft_win_length"] * sampling_rate)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
** 2
|
|
||||||
).sum()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
|
|
||||||
else:
|
|
||||||
spec_for_viz = None
|
|
||||||
|
|
||||||
return spec, spec_for_viz
|
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
time_exp_fact: float,
|
time_exp_fact: float,
|
||||||
target_samp_rate: int,
|
target_sampling_rate: int,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: Optional[float] = None,
|
||||||
) -> Tuple[int, np.ndarray]:
|
) -> Tuple[float, np.ndarray]:
|
||||||
"""Load an audio file and resample it to the target sampling rate.
|
"""Load an audio file and resample it to the target sampling rate.
|
||||||
|
|
||||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||||
@ -152,63 +208,82 @@ def load_audio(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
audio, sampling_rate = librosa.load(
|
||||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
|
||||||
audio_raw, sampling_rate = librosa.load(
|
|
||||||
audio_file,
|
audio_file,
|
||||||
sr=None,
|
sr=None,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(audio_raw.shape) > 1:
|
if len(audio.shape) > 1:
|
||||||
raise ValueError("Currently does not handle stereo files")
|
raise ValueError("Currently does not handle stereo files")
|
||||||
|
|
||||||
sampling_rate = sampling_rate * time_exp_fact
|
sampling_rate = sampling_rate * time_exp_fact
|
||||||
|
|
||||||
# resample - need to do this after correcting for time expansion
|
# resample - need to do this after correcting for time expansion
|
||||||
sampling_rate_old = sampling_rate
|
audio = resample_audio(audio, sampling_rate, target_sampling_rate)
|
||||||
sampling_rate = target_samp_rate
|
|
||||||
if sampling_rate_old != sampling_rate:
|
|
||||||
audio_raw = librosa.resample(
|
|
||||||
audio_raw,
|
|
||||||
orig_sr=sampling_rate_old,
|
|
||||||
target_sr=sampling_rate,
|
|
||||||
res_type="polyphase",
|
|
||||||
)
|
|
||||||
|
|
||||||
# clipping maximum duration
|
|
||||||
if max_duration is not None:
|
if max_duration is not None:
|
||||||
max_duration = int(
|
audio = clip_audio(audio, target_sampling_rate, max_duration)
|
||||||
np.minimum(
|
|
||||||
int(sampling_rate * max_duration),
|
|
||||||
audio_raw.shape[0],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
audio_raw = audio_raw[:max_duration]
|
|
||||||
|
|
||||||
# scale to [-1, 1]
|
# scale to [-1, 1]
|
||||||
if scale:
|
if scale:
|
||||||
audio_raw = audio_raw - audio_raw.mean()
|
audio = scale_audio(audio)
|
||||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
|
||||||
|
|
||||||
return sampling_rate, audio_raw
|
return target_sampling_rate, audio
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sr_orig: float,
|
||||||
|
sr_target: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if sr_orig != sr_target:
|
||||||
|
return librosa.resample(
|
||||||
|
audio,
|
||||||
|
orig_sr=sr_orig,
|
||||||
|
target_sr=sr_target,
|
||||||
|
res_type="polyphase",
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def clip_audio(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
max_duration: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
max_duration = int(
|
||||||
|
np.minimum(
|
||||||
|
int(sampling_rate * max_duration),
|
||||||
|
audio.shape[0],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return audio[:max_duration]
|
||||||
|
|
||||||
|
|
||||||
|
def scale_audio(
|
||||||
|
audio: np.ndarray,
|
||||||
|
eps: float = 10e-6,
|
||||||
|
) -> np.ndarray:
|
||||||
|
return (audio - audio.mean()) / (np.abs(audio).max() + eps)
|
||||||
|
|
||||||
|
|
||||||
def pad_audio(
|
def pad_audio(
|
||||||
audio_raw,
|
audio_raw: np.ndarray,
|
||||||
fs,
|
sampling_rate: float,
|
||||||
ms,
|
window_len: float,
|
||||||
overlap_perc,
|
overlap_perc: float,
|
||||||
resize_factor,
|
resize_factor: float,
|
||||||
divide_factor,
|
divide_factor: float,
|
||||||
fixed_width=None,
|
fixed_width: Optional[int] = None,
|
||||||
):
|
) -> np.ndarray:
|
||||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||||
# will be evenly divisible by `divide_factor`
|
# will be evenly divisible by `divide_factor`
|
||||||
# Also deals with very short audio clips and fixed_width during training
|
# Also deals with very short audio clips and fixed_width during training
|
||||||
|
|
||||||
# This code could be clearer, clean up
|
# This code could be clearer, clean up
|
||||||
nfft = int(ms * fs)
|
nfft = int(window_len * sampling_rate)
|
||||||
noverlap = int(overlap_perc * nfft)
|
noverlap = int(overlap_perc * nfft)
|
||||||
step = nfft - noverlap
|
step = nfft - noverlap
|
||||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||||
@ -245,19 +320,24 @@ def pad_audio(
|
|||||||
return audio_raw
|
return audio_raw
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
def gen_mag_spectrogram(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sampling_rate: float,
|
||||||
|
window_len: float,
|
||||||
|
overlap_perc: float,
|
||||||
|
) -> np.ndarray:
|
||||||
# Computes magnitude spectrogram by specifying time.
|
# Computes magnitude spectrogram by specifying time.
|
||||||
|
audio = audio.astype(np.float32)
|
||||||
x = x.astype(np.float32)
|
nfft = int(window_len * sampling_rate)
|
||||||
nfft = int(ms * fs)
|
|
||||||
noverlap = int(overlap_perc * nfft)
|
noverlap = int(overlap_perc * nfft)
|
||||||
|
|
||||||
# window data
|
|
||||||
step = nfft - noverlap
|
|
||||||
|
|
||||||
# compute spec
|
# compute spec
|
||||||
spec, _ = librosa.core.spectrum._spectrogram(
|
spec, _ = librosa.core.spectrum._spectrogram(
|
||||||
y=x, power=1, n_fft=nfft, hop_length=step, center=False
|
y=audio,
|
||||||
|
power=1,
|
||||||
|
n_fft=nfft,
|
||||||
|
hop_length=nfft - noverlap,
|
||||||
|
center=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove DC component and flip vertical orientation
|
# remove DC component and flip vertical orientation
|
||||||
@ -266,24 +346,25 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
|||||||
return spec.astype(np.float32)
|
return spec.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
|
def gen_mag_spectrogram_pt(
|
||||||
nfft = int(ms * fs)
|
audio: torch.Tensor,
|
||||||
|
sampling_rate: float,
|
||||||
|
window_len: float,
|
||||||
|
overlap_perc: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
nfft = int(window_len * sampling_rate)
|
||||||
nstep = round((1.0 - overlap_perc) * nfft)
|
nstep = round((1.0 - overlap_perc) * nfft)
|
||||||
|
han_win = torch.hann_window(nfft, periodic=False).to(audio.device)
|
||||||
|
|
||||||
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
|
complex_spec = torch.stft(audio, nfft, nstep, window=han_win, center=False)
|
||||||
|
|
||||||
complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False)
|
|
||||||
spec = complex_spec.pow(2.0).sum(-1)
|
spec = complex_spec.pow(2.0).sum(-1)
|
||||||
|
|
||||||
# remove DC component and flip vertically
|
# remove DC component and flip vertically
|
||||||
spec = torch.flipud(spec[0, 1:, :])
|
return torch.flipud(spec[0, 1:, :])
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def pcen(spec_cropped, sampling_rate):
|
def pcen(spec: np.ndarray, sampling_rate: float) -> np.ndarray:
|
||||||
# TODO should be passing hop_length too i.e. step
|
# TODO should be passing hop_length too i.e. step
|
||||||
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype(
|
return librosa.pcen(spec * (2**31), sr=sampling_rate / 10).astype(
|
||||||
np.float32
|
np.float32
|
||||||
)
|
)
|
||||||
return spec
|
|
||||||
|
@ -16,7 +16,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
|||||||
from batdetect2.types import (
|
from batdetect2.types import (
|
||||||
Annotation,
|
Annotation,
|
||||||
DetectionModel,
|
DetectionModel,
|
||||||
FileAnnotations,
|
FileAnnotation,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
PredictionResults,
|
PredictionResults,
|
||||||
@ -79,7 +79,7 @@ def list_audio_files(ip_dir: str) -> List[str]:
|
|||||||
def load_model(
|
def load_model(
|
||||||
model_path: str = DEFAULT_MODEL_PATH,
|
model_path: str = DEFAULT_MODEL_PATH,
|
||||||
load_weights: bool = True,
|
load_weights: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Union[torch.device, str, None] = None,
|
||||||
) -> Tuple[DetectionModel, ModelParameters]:
|
) -> Tuple[DetectionModel, ModelParameters]:
|
||||||
"""Load model from file.
|
"""Load model from file.
|
||||||
|
|
||||||
@ -222,7 +222,7 @@ def format_single_result(
|
|||||||
duration: float,
|
duration: float,
|
||||||
predictions: PredictionResults,
|
predictions: PredictionResults,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
) -> FileAnnotations:
|
) -> FileAnnotation:
|
||||||
"""Format results into the format expected by the annotation tool.
|
"""Format results into the format expected by the annotation tool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -399,11 +399,10 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
|
|
||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: float,
|
||||||
params: SpectrogramParameters,
|
params: SpectrogramParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
return_np: bool = False,
|
) -> Tuple[float, torch.Tensor]:
|
||||||
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
|
||||||
"""Compute a spectrogram from an audio array.
|
"""Compute a spectrogram from an audio array.
|
||||||
|
|
||||||
Will pad the audio array so that it is evenly divisible by the
|
Will pad the audio array so that it is evenly divisible by the
|
||||||
@ -412,24 +411,16 @@ def compute_spectrogram(
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio : np.ndarray
|
audio : np.ndarray
|
||||||
|
|
||||||
sampling_rate : int
|
sampling_rate : int
|
||||||
|
|
||||||
params : SpectrogramParameters
|
params : SpectrogramParameters
|
||||||
The parameters to use for generating the spectrogram.
|
The parameters to use for generating the spectrogram.
|
||||||
|
|
||||||
return_np : bool, optional
|
|
||||||
Whether to return the spectrogram as a numpy array as well as a
|
|
||||||
torch tensor. The default is False.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
duration : float
|
duration : float
|
||||||
The duration of the spectrgram in seconds.
|
The duration of the spectrgram in seconds.
|
||||||
|
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
The spectrogram as a torch tensor.
|
The spectrogram as a torch tensor.
|
||||||
|
|
||||||
spec_np : np.ndarray, optional
|
spec_np : np.ndarray, optional
|
||||||
The spectrogram as a numpy array. Only returned if `return_np` is
|
The spectrogram as a numpy array. Only returned if `return_np` is
|
||||||
True, otherwise None.
|
True, otherwise None.
|
||||||
@ -446,7 +437,7 @@ def compute_spectrogram(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# generate spectrogram
|
# generate spectrogram
|
||||||
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
spec = au.generate_spectrogram(audio, sampling_rate, params)
|
||||||
|
|
||||||
# convert to pytorch
|
# convert to pytorch
|
||||||
spec = torch.from_numpy(spec).to(device)
|
spec = torch.from_numpy(spec).to(device)
|
||||||
@ -466,18 +457,12 @@ def compute_spectrogram(
|
|||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
)
|
)
|
||||||
|
return duration, spec
|
||||||
if return_np:
|
|
||||||
spec_np = spec[0, 0, :].cpu().data.numpy()
|
|
||||||
else:
|
|
||||||
spec_np = None
|
|
||||||
|
|
||||||
return duration, spec, spec_np
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_over_chunks(
|
def iterate_over_chunks(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samplerate: int,
|
samplerate: float,
|
||||||
chunk_size: float,
|
chunk_size: float,
|
||||||
) -> Iterator[Tuple[float, np.ndarray]]:
|
) -> Iterator[Tuple[float, np.ndarray]]:
|
||||||
"""Iterate over audio in chunks of size chunk_size.
|
"""Iterate over audio in chunks of size chunk_size.
|
||||||
@ -510,7 +495,7 @@ def iterate_over_chunks(
|
|||||||
|
|
||||||
def _process_spectrogram(
|
def _process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samplerate: int,
|
samplerate: float,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[PredictionResults, np.ndarray]:
|
) -> Tuple[PredictionResults, np.ndarray]:
|
||||||
@ -632,13 +617,13 @@ def process_spectrogram(
|
|||||||
|
|
||||||
def _process_audio_array(
|
def _process_audio_array(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: float,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||||
# load audio file and compute spectrogram
|
# load audio file and compute spectrogram
|
||||||
_, spec, _ = compute_spectrogram(
|
_, spec = compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
{
|
{
|
||||||
@ -654,7 +639,6 @@ def _process_audio_array(
|
|||||||
"max_scale_spec": config["max_scale_spec"],
|
"max_scale_spec": config["max_scale_spec"],
|
||||||
},
|
},
|
||||||
device,
|
device,
|
||||||
return_np=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# process spectrogram with model
|
# process spectrogram with model
|
||||||
@ -754,13 +738,15 @@ def process_file(
|
|||||||
|
|
||||||
# Get original sampling rate
|
# Get original sampling rate
|
||||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||||
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
|
orig_samp_rate = file_samp_rate * float(
|
||||||
|
config.get("time_expansion", 1.0) or 1.0
|
||||||
|
)
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
audio_file,
|
audio_file,
|
||||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||||
target_samp_rate=config["target_samp_rate"],
|
target_sampling_rate=config["target_samp_rate"],
|
||||||
scale=config["scale_raw_audio"],
|
scale=config["scale_raw_audio"],
|
||||||
max_duration=config.get("max_duration"),
|
max_duration=config.get("max_duration"),
|
||||||
)
|
)
|
||||||
@ -802,7 +788,6 @@ def process_file(
|
|||||||
cnn_feats.append(features[0])
|
cnn_feats.append(features[0])
|
||||||
|
|
||||||
if config["spec_slices"]:
|
if config["spec_slices"]:
|
||||||
# FIX: This is not currently working. Returns empty slices
|
|
||||||
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
|
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
|
||||||
|
|
||||||
# Merge results from chunks
|
# Merge results from chunks
|
||||||
|
@ -152,7 +152,7 @@ def test_compute_max_power_bb(max_power: int):
|
|||||||
target_samp_rate=samplerate,
|
target_samp_rate=samplerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec, _ = au.generate_spectrogram(
|
spec = au.generate_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
samplerate,
|
samplerate,
|
||||||
params,
|
params,
|
||||||
@ -240,7 +240,7 @@ def test_compute_max_power():
|
|||||||
target_samp_rate=samplerate,
|
target_samp_rate=samplerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec, _ = au.generate_spectrogram(
|
spec = au.generate_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
samplerate,
|
samplerate,
|
||||||
params,
|
params,
|
||||||
|
Loading…
Reference in New Issue
Block a user