mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added types to most functions
This commit is contained in:
parent
458e11cf73
commit
0aa61af445
@ -226,11 +226,10 @@ def generate_spectrogram(
|
||||
if config is None:
|
||||
config = DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
_, spec, _ = du.compute_spectrogram(
|
||||
_, spec = du.compute_spectrogram(
|
||||
audio,
|
||||
samp_rate,
|
||||
config,
|
||||
return_np=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Functions to compute features from predictions."""
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -7,15 +7,26 @@ from batdetect2 import types
|
||||
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
|
||||
|
||||
|
||||
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
||||
def convert_int_to_freq(
|
||||
spec_ind: int,
|
||||
spec_height: int,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> int:
|
||||
"""Convert spectrogram index to frequency in Hz.""" ""
|
||||
spec_ind = spec_height - spec_ind
|
||||
return round(
|
||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||
return int(
|
||||
round(
|
||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq,
|
||||
2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def extract_spec_slices(spec, pred_nms):
|
||||
def extract_spec_slices(
|
||||
spec: np.ndarray,
|
||||
pred_nms: types.PredictionResults,
|
||||
) -> List[np.ndarray]:
|
||||
"""Extract spectrogram slices from spectrogram.
|
||||
|
||||
The slices are extracted based on detected call locations.
|
||||
@ -109,7 +120,7 @@ def compute_max_power_bb(
|
||||
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
y_high + max_power_ind,
|
||||
int(y_high + max_power_ind),
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
@ -135,14 +146,12 @@ def compute_max_power(
|
||||
spec_call = spec[:, x_start:x_end]
|
||||
power_per_freq_band = np.sum(spec_call, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
return convert_int_to_freq(
|
||||
int(max_power_ind),
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_max_power_first(
|
||||
@ -164,14 +173,12 @@ def compute_max_power_first(
|
||||
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
|
||||
power_per_freq_band = np.sum(first_half, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
return convert_int_to_freq(
|
||||
int(max_power_ind),
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_max_power_second(
|
||||
@ -193,14 +200,12 @@ def compute_max_power_second(
|
||||
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
|
||||
power_per_freq_band = np.sum(second_half, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
return convert_int_to_freq(
|
||||
int(max_power_ind),
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_call_interval(
|
||||
@ -214,6 +219,7 @@ def compute_call_interval(
|
||||
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||
|
||||
|
||||
|
||||
# NOTE: The order of the features in this dictionary is important. The
|
||||
# features are extracted in this order and the order of the columns in the
|
||||
# output csv file is determined by this order. In order to avoid breaking
|
||||
@ -236,7 +242,7 @@ def get_feats(
|
||||
spec: np.ndarray,
|
||||
pred_nms: types.PredictionResults,
|
||||
params: types.FeatureExtractionParameters,
|
||||
):
|
||||
) -> np.ndarray:
|
||||
"""Extract features from spectrogram based on detected call locations.
|
||||
|
||||
The features extracted are:
|
||||
|
@ -79,7 +79,13 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
|
||||
class ConvBlockDownStandard(nn.Module):
|
||||
def __init__(
|
||||
self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1
|
||||
self,
|
||||
in_chn,
|
||||
out_chn,
|
||||
ip_height=None,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
):
|
||||
super(ConvBlockDownStandard, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
|
@ -103,15 +103,15 @@ class Net2DFast(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
# encoder
|
||||
x1 = self.conv_dn_0(ip)
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
# bottleneck
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
@ -121,13 +121,13 @@ class Net2DFast(nn.Module):
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
# output
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_size=F.relu(self.conv_size_op(x)),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
features=x,
|
||||
@ -215,26 +215,26 @@ class Net2DFastNoAttn(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(ip)
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_size=F.relu_(self.conv_size_op(x)),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
features=x,
|
||||
@ -324,13 +324,13 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(ip)
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
@ -338,15 +338,13 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,)
|
||||
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_size=F.relu_(self.conv_size_op(x)),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
features=x,
|
||||
|
@ -1,6 +1,11 @@
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
|
||||
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
@ -75,103 +80,7 @@ def mk_dir(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
||||
params = {}
|
||||
|
||||
params[
|
||||
"model_name"
|
||||
] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
|
||||
params["num_filters"] = 128
|
||||
|
||||
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
||||
model_name = now_str + ".pth.tar"
|
||||
params["experiment"] = os.path.join(exps_dir, now_str, "")
|
||||
params["model_file_name"] = os.path.join(params["experiment"], model_name)
|
||||
params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
|
||||
params["op_im_dir_test"] = os.path.join(
|
||||
params["experiment"], "op_ims_test", ""
|
||||
)
|
||||
# params['notes'] = '' # can save notes about an experiment here
|
||||
|
||||
# spec parameters
|
||||
params[
|
||||
"target_samp_rate"
|
||||
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
|
||||
params[
|
||||
"fft_win_length"
|
||||
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
|
||||
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
|
||||
|
||||
params[
|
||||
"max_freq"
|
||||
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
|
||||
params[
|
||||
"min_freq"
|
||||
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
|
||||
|
||||
params[
|
||||
"resize_factor"
|
||||
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
|
||||
params[
|
||||
"spec_height"
|
||||
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
|
||||
params[
|
||||
"spec_train_width"
|
||||
] = 512 # units are number of time steps (before resizing is performed)
|
||||
params[
|
||||
"spec_divide_factor"
|
||||
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
|
||||
|
||||
# spec processing params
|
||||
params[
|
||||
"denoise_spec_avg"
|
||||
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
|
||||
params[
|
||||
"scale_raw_audio"
|
||||
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
|
||||
params[
|
||||
"max_scale_spec"
|
||||
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
|
||||
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
|
||||
|
||||
# detection params
|
||||
params[
|
||||
"detection_overlap"
|
||||
] = 0.01 # has to be within this number of ms to count as detection
|
||||
params[
|
||||
"ignore_start_end"
|
||||
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
|
||||
params[
|
||||
"detection_threshold"
|
||||
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
|
||||
params[
|
||||
"nms_kernel_size"
|
||||
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
|
||||
params[
|
||||
"nms_top_k_per_sec"
|
||||
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
|
||||
params["target_sigma"] = 2.0
|
||||
|
||||
# augmentation params
|
||||
params[
|
||||
"aug_prob"
|
||||
] = 0.20 # augmentations will be performed with this probability
|
||||
params["augment_at_train"] = True
|
||||
params["augment_at_train_combine"] = True
|
||||
params[
|
||||
"echo_max_delay"
|
||||
] = 0.005 # simulate echo by adding copy of raw audio
|
||||
params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
|
||||
params[
|
||||
"mask_max_time_perc"
|
||||
] = 0.05 # max mask size - here percentage, not ideal
|
||||
params[
|
||||
"mask_max_freq_perc"
|
||||
] = 0.10 # max mask size - here percentage, not ideal
|
||||
params[
|
||||
"spec_amp_scaling"
|
||||
] = 2.0 # multiply the "volume" by 0:X times current amount
|
||||
params["aug_sampling_rates"] = [
|
||||
AUG_SAMPLING_RATES = [
|
||||
220500,
|
||||
256000,
|
||||
300000,
|
||||
@ -179,54 +88,146 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
||||
384000,
|
||||
441000,
|
||||
500000,
|
||||
]
|
||||
]
|
||||
CLASSES_TO_IGNORE = ["", " ", "Unknown", "Not Bat"]
|
||||
GENERIC_CLASSES = ["Bat"]
|
||||
EVENTS_OF_INTEREST = ["Echolocation"]
|
||||
|
||||
# loss params
|
||||
params["train_loss"] = "focal" # mse or focal
|
||||
params["det_loss_weight"] = 1.0 # weight for the detection part of the loss
|
||||
params["size_loss_weight"] = 0.1 # weight for the bbox size loss
|
||||
params["class_loss_weight"] = 2.0 # weight for the classification loss
|
||||
params["individual_loss_weight"] = 0.0 # not used
|
||||
if params["individual_loss_weight"] == 0.0:
|
||||
params[
|
||||
"emb_dim"
|
||||
] = 0 # number of dimensions used for individual id embedding
|
||||
else:
|
||||
params["emb_dim"] = 3
|
||||
|
||||
# train params
|
||||
params["lr"] = 0.001
|
||||
params["batch_size"] = 8
|
||||
params["num_workers"] = 4
|
||||
params["num_epochs"] = 200
|
||||
params["num_eval_epochs"] = 5 # run evaluation every X epochs
|
||||
params["device"] = "cuda"
|
||||
params["save_test_image_during_train"] = False
|
||||
params["save_test_image_after_train"] = True
|
||||
class TrainingParameters(BaseModel):
|
||||
# Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
|
||||
model_name: str = "Net2DFast"
|
||||
num_filters: int = 128
|
||||
|
||||
params["convert_to_genus"] = False
|
||||
params["genus_mapping"] = []
|
||||
params["class_names"] = []
|
||||
params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
|
||||
params["generic_class"] = ["Bat"]
|
||||
params["events_of_interest"] = [
|
||||
"Echolocation"
|
||||
] # will ignore all other types of events e.g. social calls
|
||||
experiment: Path
|
||||
model_file_name: Path
|
||||
|
||||
# the classes in this list are standardized during training so that the same low and high freq are used
|
||||
params["standardize_classs_names"] = []
|
||||
op_im_dir: Path
|
||||
op_im_dir_test: Path
|
||||
|
||||
notes: str = ""
|
||||
|
||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ
|
||||
fft_win_length: float = FFT_WIN_LENGTH_S
|
||||
fft_overlap: float = FFT_OVERLAP
|
||||
|
||||
max_freq: int = MAX_FREQ_HZ
|
||||
min_freq: int = MIN_FREQ_HZ
|
||||
|
||||
resize_factor: float = RESIZE_FACTOR
|
||||
spec_height: int = SPEC_HEIGHT
|
||||
spec_train_width: int = 512
|
||||
spec_divide_factor: int = SPEC_DIVIDE_FACTOR
|
||||
|
||||
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
||||
scale_raw_audio: bool = SCALE_RAW_AUDIO
|
||||
max_scale_spec: bool = MAX_SCALE_SPEC
|
||||
spec_scale: str = SPEC_SCALE
|
||||
|
||||
detection_overlap: float = 0.01
|
||||
ignore_start_end: float = 0.01
|
||||
detection_threshold: float = DETECTION_THRESHOLD
|
||||
nms_kernel_size: int = NMS_KERNEL_SIZE
|
||||
nms_top_k_per_sec: int = NMS_TOP_K_PER_SEC
|
||||
|
||||
aug_prob: float = 0.20
|
||||
augment_at_train: bool = True
|
||||
augment_at_train_combine: bool = True
|
||||
echo_max_delay: float = 0.005
|
||||
stretch_squeeze_delta: float = 0.04
|
||||
mask_max_time_perc: float = 0.05
|
||||
mask_max_freq_perc: float = 0.10
|
||||
spec_amp_scaling: float = 2.0
|
||||
aug_sampling_rates: List[int] = AUG_SAMPLING_RATES
|
||||
|
||||
train_loss: str = "focal"
|
||||
det_loss_weight: float = 1.0
|
||||
size_loss_weight: float = 0.1
|
||||
class_loss_weight: float = 2.0
|
||||
individual_loss_weight: float = 0.0
|
||||
|
||||
lr: float = 0.001
|
||||
batch_size: int = 8
|
||||
num_workers: int = 4
|
||||
num_epochs: int = 200
|
||||
num_eval_epochs: int = 5
|
||||
device: str = "cuda"
|
||||
save_test_image_during_train: bool = False
|
||||
save_test_image_after_train: bool = True
|
||||
|
||||
convert_to_genus: bool = False
|
||||
class_names: List[str] = Field(default_factory=list)
|
||||
classes_to_ignore: List[str] = Field(
|
||||
default_factory=lambda: CLASSES_TO_IGNORE
|
||||
)
|
||||
generic_class: List[str] = Field(default_factory=lambda: GENERIC_CLASSES)
|
||||
events_of_interest: List[str] = Field(
|
||||
default_factory=lambda: EVENTS_OF_INTEREST
|
||||
)
|
||||
standardize_classs_names: List[str] = Field(default_factory=list)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def emb_dim(self) -> int:
|
||||
if self.individual_loss_weight == 0.0:
|
||||
return 0
|
||||
return 3
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def genus_mapping(self) -> List[int]:
|
||||
_, mapping = get_genus_mapping(self.class_names)
|
||||
return mapping
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def genus_classes(self) -> List[str]:
|
||||
names, _ = get_genus_mapping(self.class_names)
|
||||
return names
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def class_names_short(self) -> List[str]:
|
||||
return get_short_class_names(self.class_names)
|
||||
|
||||
|
||||
def get_params(
|
||||
make_dirs: bool = False,
|
||||
exps_dir: str = "../../experiments/",
|
||||
model_name: Optional[str] = None,
|
||||
experiment: Union[Path, str, None] = None,
|
||||
**kwargs,
|
||||
) -> TrainingParameters:
|
||||
experiments_dir = Path(exps_dir)
|
||||
|
||||
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
||||
|
||||
if model_name is None:
|
||||
model_name = f"{now_str}.pth.tar"
|
||||
|
||||
if experiment is None:
|
||||
experiment = experiments_dir / now_str
|
||||
experiment = Path(experiment)
|
||||
|
||||
model_file_name = experiment / model_name
|
||||
op_ims_dir = experiment / "op_ims"
|
||||
op_ims_test_dir = experiment / "op_ims_test"
|
||||
|
||||
params = TrainingParameters(
|
||||
model_name=model_name,
|
||||
experiment=experiment,
|
||||
model_file_name=model_file_name,
|
||||
op_im_dir=op_ims_dir,
|
||||
op_im_dir_test=op_ims_test_dir,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# create directories
|
||||
if make_dirs:
|
||||
print("Model name : " + params["model_name"])
|
||||
print("Model file : " + params["model_file_name"])
|
||||
print("Experiment : " + params["experiment"])
|
||||
|
||||
mk_dir(params["experiment"])
|
||||
if params["save_test_image_during_train"]:
|
||||
mk_dir(params["op_im_dir"])
|
||||
if params["save_test_image_after_train"]:
|
||||
mk_dir(params["op_im_dir_test"])
|
||||
mk_dir(os.path.dirname(params["model_file_name"]))
|
||||
mk_dir(experiment)
|
||||
mk_dir(params.model_file_name.parent)
|
||||
if params.save_test_image_during_train:
|
||||
mk_dir(params.op_im_dir)
|
||||
if params.save_test_image_after_train:
|
||||
mk_dir(params.op_im_dir_test)
|
||||
|
||||
return params
|
||||
|
@ -1,33 +1,31 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.models as models
|
||||
import batdetect2.detector.parameters as parameters
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.losses as losses
|
||||
import batdetect2.train.train_model as tm
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2 import types
|
||||
from batdetect2.detector.models import Net2DFast
|
||||
|
||||
if __name__ == "__main__":
|
||||
info_str = "\nBatDetect - Finetune Model\n"
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
print(info_str)
|
||||
|
||||
def parse_arugments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"audio_path", type=str, help="Input directory for audio"
|
||||
"audio_path",
|
||||
type=str,
|
||||
help="Input directory for audio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"train_ann_path",
|
||||
@ -39,7 +37,15 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="Path to where test annotation file is stored",
|
||||
)
|
||||
parser.add_argument("model_path", type=str, help="Path to pretrained model")
|
||||
parser.add_argument(
|
||||
"model_path", type=str, help="Path to pretrained model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experiment_dir",
|
||||
type=str,
|
||||
default=os.path.join(BASE_DIR, "experiments"),
|
||||
help="Path to where experiment files are stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--op_model_name",
|
||||
type=str,
|
||||
@ -71,107 +77,63 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--notes", type=str, default="", help="Notes to save in text file"
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
params = parameters.get_params(True, "../../experiments/")
|
||||
|
||||
def select_device(warn=True) -> str:
|
||||
if torch.cuda.is_available():
|
||||
params["device"] = "cuda"
|
||||
else:
|
||||
params["device"] = "cpu"
|
||||
print(
|
||||
"\nNote, this will be a lot faster if you use computer with a GPU.\n"
|
||||
return "cuda"
|
||||
|
||||
if warn:
|
||||
warnings.warn(
|
||||
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||
"to speed up training."
|
||||
)
|
||||
|
||||
print("\nAudio directory: " + args["audio_path"])
|
||||
print("Train file: " + args["train_ann_path"])
|
||||
print("Test file: " + args["test_ann_path"])
|
||||
print("Loading model: " + args["model_path"])
|
||||
return "cpu"
|
||||
|
||||
dataset_name = (
|
||||
os.path.basename(args["train_ann_path"])
|
||||
.replace(".json", "")
|
||||
.replace("_TRAIN", "")
|
||||
)
|
||||
|
||||
if args["train_from_scratch"]:
|
||||
print("\nTraining model from scratch i.e. not using pretrained weights")
|
||||
model, params_train = du.load_model(args["model_path"], False)
|
||||
else:
|
||||
model, params_train = du.load_model(args["model_path"], True)
|
||||
model.to(params["device"])
|
||||
|
||||
params["num_epochs"] = args["num_epochs"]
|
||||
if args["op_model_name"] != "":
|
||||
params["model_file_name"] = args["op_model_name"]
|
||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
||||
|
||||
# save notes file
|
||||
params["notes"] = args["notes"]
|
||||
if args["notes"] != "":
|
||||
tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
|
||||
|
||||
# load train annotations
|
||||
train_sets = []
|
||||
def load_annotations(
|
||||
dataset_name: str,
|
||||
ann_path: str,
|
||||
audio_path: str,
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
events_of_interest: Optional[List[str]] = None,
|
||||
) -> List[types.FileAnnotation]:
|
||||
train_sets: List[types.DatasetDict] = []
|
||||
train_sets.append(
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name, False, args["train_ann_path"], args["audio_path"]
|
||||
)
|
||||
)
|
||||
params["train_sets"] = [
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name,
|
||||
False,
|
||||
os.path.basename(args["train_ann_path"]),
|
||||
args["audio_path"],
|
||||
is_test=False,
|
||||
ann_path=ann_path,
|
||||
wav_path=audio_path,
|
||||
)
|
||||
]
|
||||
|
||||
print("\nTrain set:")
|
||||
(
|
||||
data_train,
|
||||
params["class_names"],
|
||||
params["class_inv_freq"],
|
||||
) = tu.load_set_of_anns(
|
||||
train_sets, classes_to_ignore, params["events_of_interest"]
|
||||
)
|
||||
print("Number of files", len(data_train))
|
||||
|
||||
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
||||
params["class_names"]
|
||||
)
|
||||
params["class_names_short"] = tu.get_short_class_names(
|
||||
params["class_names"]
|
||||
)
|
||||
|
||||
# load test annotations
|
||||
test_sets = []
|
||||
test_sets.append(
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name, True, args["test_ann_path"], args["audio_path"]
|
||||
return tu.load_set_of_anns(
|
||||
train_sets,
|
||||
events_of_interest=events_of_interest,
|
||||
classes_to_ignore=classes_to_ignore,
|
||||
)
|
||||
)
|
||||
params["test_sets"] = [
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name,
|
||||
True,
|
||||
os.path.basename(args["test_ann_path"]),
|
||||
args["audio_path"],
|
||||
)
|
||||
]
|
||||
|
||||
print("\nTest set:")
|
||||
data_test, _, _ = tu.load_set_of_anns(
|
||||
test_sets, classes_to_ignore, params["events_of_interest"]
|
||||
)
|
||||
print("Number of files", len(data_test))
|
||||
|
||||
def finetune_model(
|
||||
model: types.DetectionModel,
|
||||
data_train: List[types.FileAnnotation],
|
||||
data_test: List[types.FileAnnotation],
|
||||
params: parameters.TrainingParameters,
|
||||
model_params: types.ModelParameters,
|
||||
finetune_only_last_layer: bool = False,
|
||||
save_images: bool = True,
|
||||
):
|
||||
# train loader
|
||||
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=params["batch_size"],
|
||||
batch_size=params.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=params["num_workers"],
|
||||
num_workers=params.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
@ -181,32 +143,36 @@ if __name__ == "__main__":
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=params["num_workers"],
|
||||
num_workers=params.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
inputs_train = next(iter(train_loader))
|
||||
params["ip_height"] = inputs_train["spec"].shape[2]
|
||||
params.ip_height = inputs_train["spec"].shape[2]
|
||||
print("\ntrain batch size :", inputs_train["spec"].shape)
|
||||
|
||||
assert params_train["model_name"] == "Net2DFast"
|
||||
# Check that the model is the same as the one used to train the pretrained
|
||||
# weights
|
||||
assert model_params["model_name"] == "Net2DFast"
|
||||
assert isinstance(model, Net2DFast)
|
||||
print(
|
||||
"\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
|
||||
"\n\nSOME hyperparams need to be the same as the loaded model "
|
||||
"(e.g. FFT) - currently they are getting overwritten.\n\n"
|
||||
)
|
||||
|
||||
# set the number of output classes
|
||||
num_filts = model.conv_classes_op.in_channels
|
||||
k_size = model.conv_classes_op.kernel_size
|
||||
pad = model.conv_classes_op.padding
|
||||
(k_size,) = model.conv_classes_op.kernel_size
|
||||
(pad,) = model.conv_classes_op.padding
|
||||
model.conv_classes_op = torch.nn.Conv2d(
|
||||
num_filts,
|
||||
len(params["class_names"]) + 1,
|
||||
len(params.class_names) + 1,
|
||||
kernel_size=k_size,
|
||||
padding=pad,
|
||||
)
|
||||
model.conv_classes_op.to(params["device"])
|
||||
model.conv_classes_op.to(params.device)
|
||||
|
||||
if args["finetune_only_last_layer"]:
|
||||
if finetune_only_last_layer:
|
||||
print("\nOnly finetuning the final layers.\n")
|
||||
train_layers_i = [
|
||||
"conv_classes",
|
||||
@ -223,19 +189,26 @@ if __name__ == "__main__":
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer, params["num_epochs"] * len(train_loader)
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=params.lr,
|
||||
)
|
||||
if params["train_loss"] == "mse":
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
params.num_epochs * len(train_loader),
|
||||
)
|
||||
|
||||
if params.train_loss == "mse":
|
||||
det_criterion = losses.mse_loss
|
||||
elif params["train_loss"] == "focal":
|
||||
elif params.train_loss == "focal":
|
||||
det_criterion = losses.focal_loss
|
||||
else:
|
||||
raise ValueError("Unknown loss function")
|
||||
|
||||
# plotting
|
||||
train_plt_ls = pu.LossPlotter(
|
||||
params["experiment"] + "train_loss.png",
|
||||
params["num_epochs"] + 1,
|
||||
params.experiment / "train_loss.png",
|
||||
params.num_epochs + 1,
|
||||
["train_loss"],
|
||||
None,
|
||||
None,
|
||||
@ -243,8 +216,8 @@ if __name__ == "__main__":
|
||||
logy=True,
|
||||
)
|
||||
test_plt_ls = pu.LossPlotter(
|
||||
params["experiment"] + "test_loss.png",
|
||||
params["num_epochs"] + 1,
|
||||
params.experiment / "test_loss.png",
|
||||
params.num_epochs + 1,
|
||||
["test_loss"],
|
||||
None,
|
||||
None,
|
||||
@ -252,24 +225,24 @@ if __name__ == "__main__":
|
||||
logy=True,
|
||||
)
|
||||
test_plt = pu.LossPlotter(
|
||||
params["experiment"] + "test.png",
|
||||
params["num_epochs"] + 1,
|
||||
params.experiment / "test.png",
|
||||
params.num_epochs + 1,
|
||||
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
|
||||
[0, 1],
|
||||
None,
|
||||
["epoch", ""],
|
||||
)
|
||||
test_plt_class = pu.LossPlotter(
|
||||
params["experiment"] + "test_avg_prec.png",
|
||||
params["num_epochs"] + 1,
|
||||
params["class_names_short"],
|
||||
params.experiment / "test_avg_prec.png",
|
||||
params.num_epochs + 1,
|
||||
params.class_names_short,
|
||||
[0, 1],
|
||||
params["class_names_short"],
|
||||
params.class_names_short,
|
||||
["epoch", "avg_prec"],
|
||||
)
|
||||
|
||||
# main train loop
|
||||
for epoch in range(0, params["num_epochs"] + 1):
|
||||
for epoch in range(0, params.num_epochs + 1):
|
||||
train_loss = tm.train(
|
||||
model,
|
||||
epoch,
|
||||
@ -281,10 +254,14 @@ if __name__ == "__main__":
|
||||
)
|
||||
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
|
||||
|
||||
if epoch % params["num_eval_epochs"] == 0:
|
||||
if epoch % params.num_eval_epochs == 0:
|
||||
# detection accuracy on test set
|
||||
test_res, test_loss = tm.test(
|
||||
model, epoch, test_loader, det_criterion, params
|
||||
model,
|
||||
epoch,
|
||||
test_loader,
|
||||
det_criterion,
|
||||
params,
|
||||
)
|
||||
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
|
||||
test_plt.update_and_save(
|
||||
@ -301,18 +278,106 @@ if __name__ == "__main__":
|
||||
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
|
||||
)
|
||||
pu.plot_pr_curve_class(
|
||||
params["experiment"], "test_pr", "test_pr", test_res
|
||||
params.experiment, "test_pr", "test_pr", test_res
|
||||
)
|
||||
|
||||
# save finetuned model
|
||||
print("saving model to: " + params["model_file_name"])
|
||||
print(f"saving model to: {params.model_file_name}")
|
||||
op_state = {
|
||||
"epoch": epoch + 1,
|
||||
"state_dict": model.state_dict(),
|
||||
"params": params,
|
||||
}
|
||||
torch.save(op_state, params["model_file_name"])
|
||||
torch.save(op_state, params.model_file_name)
|
||||
|
||||
# save an image with associated prediction for each batch in the test set
|
||||
if not args["do_not_save_images"]:
|
||||
if save_images:
|
||||
tm.save_images_batch(model, test_loader, params)
|
||||
|
||||
|
||||
def main():
|
||||
info_str = "\nBatDetect - Finetune Model\n"
|
||||
print(info_str)
|
||||
|
||||
args = parse_arugments()
|
||||
|
||||
# Load experiment parameters
|
||||
params = parameters.get_params(
|
||||
make_dirs=True,
|
||||
exps_dir=args.experiment_dir,
|
||||
device=select_device(),
|
||||
num_epochs=args.num_epochs,
|
||||
notes=args.notes,
|
||||
)
|
||||
|
||||
print("\nAudio directory: " + args.audio_path)
|
||||
print("Train file: " + args.train_ann_path)
|
||||
print("Test file: " + args.test_ann_path)
|
||||
print("Loading model: " + args.model_path)
|
||||
|
||||
if args.train_from_scratch:
|
||||
print(
|
||||
"\nTraining model from scratch i.e. not using pretrained weights"
|
||||
)
|
||||
|
||||
model, model_params = du.load_model(
|
||||
args.model_path,
|
||||
load_weights=not args.train_from_scratch,
|
||||
device=params.device,
|
||||
)
|
||||
|
||||
if args.op_model_name != "":
|
||||
params.model_file_name = args.op_model_name
|
||||
|
||||
classes_to_ignore = params.classes_to_ignore + params.generic_class
|
||||
|
||||
# save notes file
|
||||
if params.notes:
|
||||
tu.write_notes_file(
|
||||
params.experiment / "notes.txt",
|
||||
args.notes,
|
||||
)
|
||||
|
||||
# NOTE:??
|
||||
dataset_name = (
|
||||
os.path.basename(args.train_ann_path)
|
||||
.replace(".json", "")
|
||||
.replace("_TRAIN", "")
|
||||
)
|
||||
|
||||
# ==== LOAD DATA ====
|
||||
|
||||
# load train annotations
|
||||
data_train = load_annotations(
|
||||
dataset_name,
|
||||
args.train_ann_path,
|
||||
args.audio_path,
|
||||
params.events_of_interest,
|
||||
)
|
||||
print("\nTrain set:")
|
||||
print("Number of files", len(data_train))
|
||||
|
||||
# load test annotations
|
||||
data_test = load_annotations(
|
||||
dataset_name,
|
||||
args.test_ann_path,
|
||||
args.audio_path,
|
||||
classes_to_ignore,
|
||||
params.events_of_interest,
|
||||
)
|
||||
print("\nTrain set:")
|
||||
print("Number of files", len(data_train))
|
||||
|
||||
finetune_model(
|
||||
model,
|
||||
data_train,
|
||||
data_test,
|
||||
params,
|
||||
model_params,
|
||||
finetune_only_last_layer=args.finetune_only_last_layer,
|
||||
save_images=args.do_not_save_images,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,62 +1,54 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import StratifiedGroupKFold
|
||||
|
||||
import batdetect2.train.train_utils as tu
|
||||
from batdetect2 import types
|
||||
|
||||
|
||||
def print_dataset_stats(data, split_name, classes_to_ignore):
|
||||
print("\nSplit:", split_name)
|
||||
def print_dataset_stats(
|
||||
data: List[types.FileAnnotation],
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
) -> Counter[str]:
|
||||
print("Num files:", len(data))
|
||||
|
||||
class_cnts = {}
|
||||
for dd in data:
|
||||
for aa in dd["annotation"]:
|
||||
if aa["class"] not in classes_to_ignore:
|
||||
if aa["class"] in class_cnts:
|
||||
class_cnts[aa["class"]] += 1
|
||||
else:
|
||||
class_cnts[aa["class"]] = 1
|
||||
|
||||
if len(class_cnts) == 0:
|
||||
class_names = []
|
||||
else:
|
||||
class_names = np.sort([*class_cnts]).tolist()
|
||||
print("Class count:")
|
||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||
|
||||
for ii, cc in enumerate(class_names):
|
||||
print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc]))
|
||||
|
||||
return class_names
|
||||
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
||||
if len(counts) > 0:
|
||||
tu.report_class_counts(counts)
|
||||
return counts
|
||||
|
||||
|
||||
def load_file_names(file_name):
|
||||
if os.path.isfile(file_name):
|
||||
def load_file_names(file_name: str) -> List[str]:
|
||||
if not os.path.isfile(file_name):
|
||||
raise FileNotFoundError(f"Input file not found - {file_name}")
|
||||
|
||||
with open(file_name) as da:
|
||||
files = [line.rstrip() for line in da.readlines()]
|
||||
for ff in files:
|
||||
if ff.lower()[-3:] != "wav":
|
||||
print("Error: Filenames need to end in .wav - ", ff)
|
||||
assert False
|
||||
else:
|
||||
print("Error: Input file not found - ", file_name)
|
||||
assert False
|
||||
|
||||
for path in files:
|
||||
if path.lower()[-3:] != "wav":
|
||||
raise ValueError(
|
||||
f"Invalid file name - {path}. Must be a .wav file"
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
|
||||
|
||||
print(info_str)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"dataset_name", type=str, help="Name to call your dataset"
|
||||
)
|
||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
||||
parser.add_argument(
|
||||
"audio_dir", type=str, help="Input directory for audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"ann_dir",
|
||||
type=str,
|
||||
@ -104,86 +96,124 @@ if __name__ == "__main__":
|
||||
help='New class names to use instead. One to one mapping with "--input_class_names". \
|
||||
Separate with ";"',
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
return parser.parse_args()
|
||||
|
||||
np.random.seed(args["rand_seed"])
|
||||
|
||||
def split_data(
|
||||
data: List[types.FileAnnotation],
|
||||
train_file: str,
|
||||
test_file: str,
|
||||
n_splits: int = 5,
|
||||
random_state: int = 0,
|
||||
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
|
||||
if train_file != "" and test_file != "":
|
||||
# user has specifed the train / test split
|
||||
mapping = {
|
||||
file_annotation["id"]: file_annotation for file_annotation in data
|
||||
}
|
||||
train_files = load_file_names(train_file)
|
||||
test_files = load_file_names(test_file)
|
||||
data_train = [
|
||||
mapping[file_id] for file_id in train_files if file_id in mapping
|
||||
]
|
||||
data_test = [
|
||||
mapping[file_id] for file_id in test_files if file_id in mapping
|
||||
]
|
||||
return data_train, data_test
|
||||
|
||||
# NOTE: Using StratifiedGroupKFold to ensure that the same file does not
|
||||
# appear in both the training and test sets and trying to keep the
|
||||
# distribution of classes the same in both sets.
|
||||
splitter = StratifiedGroupKFold(
|
||||
n_splits=n_splits,
|
||||
shuffle=True,
|
||||
random_state=random_state,
|
||||
)
|
||||
anns = np.array(
|
||||
[
|
||||
[dd["id"], ann["class"], ann["event"]]
|
||||
for dd in data
|
||||
for ann in dd["annotation"]
|
||||
]
|
||||
)
|
||||
y = anns[:, 1]
|
||||
group = anns[:, 0]
|
||||
|
||||
train_idx, test_idx = next(splitter.split(X=anns, y=y, groups=group))
|
||||
train_ids = set(anns[train_idx, 0])
|
||||
test_ids = set(anns[test_idx, 0])
|
||||
|
||||
assert not (train_ids & test_ids)
|
||||
data_train = [dd for dd in data if dd["id"] in train_ids]
|
||||
data_test = [dd for dd in data if dd["id"] in test_ids]
|
||||
return data_train, data_test
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
np.random.seed(args.rand_seed)
|
||||
|
||||
classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
|
||||
generic_class = ["Bat"]
|
||||
events_of_interest = ["Echolocation"]
|
||||
|
||||
if args["input_class_names"] != "" and args["output_class_names"] != "":
|
||||
name_dict = None
|
||||
if args.input_class_names != "" and args.output_class_names != "":
|
||||
# change the names of the classes
|
||||
ip_names = args["input_class_names"].split(";")
|
||||
op_names = args["output_class_names"].split(";")
|
||||
ip_names = args.input_class_names.split(";")
|
||||
op_names = args.output_class_names.split(";")
|
||||
name_dict = dict(zip(ip_names, op_names))
|
||||
else:
|
||||
name_dict = False
|
||||
|
||||
# load annotations
|
||||
data_all, _, _ = tu.load_set_of_anns(
|
||||
{"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]},
|
||||
classes_to_ignore,
|
||||
events_of_interest,
|
||||
False,
|
||||
False,
|
||||
list_of_anns=True,
|
||||
data_all = tu.load_set_of_anns(
|
||||
[
|
||||
{
|
||||
"dataset_name": args.dataset_name,
|
||||
"ann_path": args.ann_dir,
|
||||
"wav_path": args.audio_dir,
|
||||
"is_test": False,
|
||||
"is_binary": False,
|
||||
}
|
||||
],
|
||||
classes_to_ignore=classes_to_ignore,
|
||||
events_of_interest=events_of_interest,
|
||||
convert_to_genus=False,
|
||||
filter_issues=True,
|
||||
name_replace=name_dict,
|
||||
)
|
||||
|
||||
print("Dataset name: " + args["dataset_name"])
|
||||
print("Audio directory: " + args["audio_dir"])
|
||||
print("Annotation directory: " + args["ann_dir"])
|
||||
print("Ouput directory: " + args["op_dir"])
|
||||
print("Dataset name: " + args.dataset_name)
|
||||
print("Audio directory: " + args.audio_dir)
|
||||
print("Annotation directory: " + args.ann_dir)
|
||||
print("Ouput directory: " + args.op_dir)
|
||||
print("Num annotated files: " + str(len(data_all)))
|
||||
|
||||
if args["train_file"] != "" and args["test_file"] != "":
|
||||
# user has specifed the train / test split
|
||||
train_files = load_file_names(args["train_file"])
|
||||
test_files = load_file_names(args["test_file"])
|
||||
file_names_all = [dd["id"] for dd in data_all]
|
||||
train_inds = [
|
||||
file_names_all.index(ff)
|
||||
for ff in train_files
|
||||
if ff in file_names_all
|
||||
]
|
||||
test_inds = [
|
||||
file_names_all.index(ff)
|
||||
for ff in test_files
|
||||
if ff in file_names_all
|
||||
]
|
||||
|
||||
else:
|
||||
# split the data into train and test at the file level
|
||||
num_exs = len(data_all)
|
||||
test_inds = np.random.choice(
|
||||
np.arange(num_exs),
|
||||
int(num_exs * args["percent_val"]),
|
||||
replace=False,
|
||||
data_train, data_test = split_data(
|
||||
data=data_all,
|
||||
train_file=args.train_file,
|
||||
test_file=args.test_file,
|
||||
n_splits=5,
|
||||
random_state=args.rand_seed,
|
||||
)
|
||||
test_inds = np.sort(test_inds)
|
||||
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
|
||||
|
||||
data_train = [data_all[ii] for ii in train_inds]
|
||||
data_test = [data_all[ii] for ii in test_inds]
|
||||
|
||||
if not os.path.isdir(args["op_dir"]):
|
||||
os.makedirs(args["op_dir"])
|
||||
op_name = os.path.join(args["op_dir"], args["dataset_name"])
|
||||
if not os.path.isdir(args.op_dir):
|
||||
os.makedirs(args.op_dir)
|
||||
op_name = os.path.join(args.op_dir, args.dataset_name)
|
||||
op_name_train = op_name + "_TRAIN.json"
|
||||
op_name_test = op_name + "_TEST.json"
|
||||
|
||||
class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore)
|
||||
class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore)
|
||||
print("\nSplit: Train")
|
||||
class_un_train = print_dataset_stats(data_train, classes_to_ignore)
|
||||
|
||||
print("\nSplit: Test")
|
||||
class_un_test = print_dataset_stats(data_test, classes_to_ignore)
|
||||
|
||||
if len(data_train) > 0 and len(data_test) > 0:
|
||||
if class_un_train != class_un_test:
|
||||
print(
|
||||
'\nError: some classes are not in both the training and test sets.\
|
||||
\nTry a different random seed "--rand_seed".'
|
||||
if set(class_un_train.keys()) != set(class_un_test.keys()):
|
||||
raise RuntimeError(
|
||||
"Error: some classes are not in both the training and test sets."
|
||||
'Try a different random seed "--rand_seed".'
|
||||
)
|
||||
assert False
|
||||
|
||||
print("\n")
|
||||
if len(data_train) == 0:
|
||||
@ -199,3 +229,7 @@ if __name__ == "__main__":
|
||||
print("Saving: ", op_name_test)
|
||||
with open(op_name_test, "w") as da:
|
||||
json.dump(data_test, da, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -12,19 +12,24 @@ import torchaudio
|
||||
import batdetect2.utils.audio_utils as au
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AnnotationGroup,
|
||||
AudioLoaderAnnotationGroup,
|
||||
FileAnnotations,
|
||||
HeatmapParameters,
|
||||
AudioLoaderParameters,
|
||||
FileAnnotation,
|
||||
)
|
||||
|
||||
|
||||
def generate_gt_heatmaps(
|
||||
spec_op_shape: Tuple[int, int],
|
||||
sampling_rate: int,
|
||||
ann: AnnotationGroup,
|
||||
params: HeatmapParameters,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
|
||||
sampling_rate: float,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
class_names: List[str],
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
resize_factor: float,
|
||||
target_sigma: float,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
|
||||
"""Generate ground truth heatmaps from annotations.
|
||||
|
||||
Parameters
|
||||
@ -53,31 +58,31 @@ def generate_gt_heatmaps(
|
||||
the x and y indices of their pixel location in the input spectrogram.
|
||||
"""
|
||||
# spec may be resized on input into the network
|
||||
num_classes = len(params["class_names"])
|
||||
num_classes = len(class_names)
|
||||
op_height = spec_op_shape[0]
|
||||
op_width = spec_op_shape[1]
|
||||
freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height
|
||||
freq_per_bin = (max_freq - min_freq) / op_height
|
||||
|
||||
# start and end times
|
||||
x_pos_start = au.time_to_x_coords(
|
||||
ann["start_times"],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
)
|
||||
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32)
|
||||
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
|
||||
x_pos_end = au.time_to_x_coords(
|
||||
ann["end_times"],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
)
|
||||
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32)
|
||||
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
|
||||
|
||||
# location on y axis i.e. frequency
|
||||
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int32)
|
||||
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int32)
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = y_pos_low - y_pos_high
|
||||
@ -90,26 +95,17 @@ def generate_gt_heatmaps(
|
||||
& (y_pos_low < (op_height - 1))
|
||||
)[0]
|
||||
|
||||
ann_aug: AnnotationGroup = {
|
||||
ann_aug: AudioLoaderAnnotationGroup = {
|
||||
**ann,
|
||||
"start_times": ann["start_times"][valid_inds],
|
||||
"end_times": ann["end_times"][valid_inds],
|
||||
"high_freqs": ann["high_freqs"][valid_inds],
|
||||
"low_freqs": ann["low_freqs"][valid_inds],
|
||||
"class_ids": ann["class_ids"][valid_inds],
|
||||
"individual_ids": ann["individual_ids"][valid_inds],
|
||||
"x_inds": x_pos_start[valid_inds],
|
||||
"y_inds": y_pos_low[valid_inds],
|
||||
}
|
||||
ann_aug["x_inds"] = x_pos_start[valid_inds]
|
||||
ann_aug["y_inds"] = y_pos_low[valid_inds]
|
||||
# keys = [
|
||||
# "start_times",
|
||||
# "end_times",
|
||||
# "high_freqs",
|
||||
# "low_freqs",
|
||||
# "class_ids",
|
||||
# "individual_ids",
|
||||
# ]
|
||||
# for kk in keys:
|
||||
# ann_aug[kk] = ann[kk][valid_inds]
|
||||
|
||||
# if the number of calls is only 1, then it is unique
|
||||
# TODO would be better if we found these unique calls at the merging stage
|
||||
@ -118,6 +114,7 @@ def generate_gt_heatmaps(
|
||||
|
||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
||||
|
||||
# num classes and "background" class
|
||||
y_2d_classes: np.ndarray = np.zeros(
|
||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
||||
@ -128,14 +125,8 @@ def generate_gt_heatmaps(
|
||||
draw_gaussian(
|
||||
y_2d_det[0, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
target_sigma,
|
||||
)
|
||||
# draw_gaussian(
|
||||
# y_2d_det[0, :],
|
||||
# (x_pos_start[ii], y_pos_low[ii]),
|
||||
# params["target_sigma"],
|
||||
# params["target_sigma"] * 2,
|
||||
# )
|
||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
||||
|
||||
@ -144,14 +135,8 @@ def generate_gt_heatmaps(
|
||||
draw_gaussian(
|
||||
y_2d_classes[cls_id, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
target_sigma,
|
||||
)
|
||||
# draw_gaussian(
|
||||
# y_2d_classes[cls_id, :],
|
||||
# (x_pos_start[ii], y_pos_low[ii]),
|
||||
# params["target_sigma"],
|
||||
# params["target_sigma"] * 2,
|
||||
# )
|
||||
|
||||
# be careful as this will have a 1.0 places where we have event but
|
||||
# dont know gt class this will be masked in training anyway
|
||||
@ -235,8 +220,8 @@ def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
|
||||
|
||||
def warp_spec_aug(
|
||||
spec: torch.Tensor,
|
||||
ann: AnnotationGroup,
|
||||
params: dict,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
stretch_squeeze_delta: float,
|
||||
) -> torch.Tensor:
|
||||
"""Warp spectrogram by randomly stretching and squeezing.
|
||||
|
||||
@ -247,8 +232,8 @@ def warp_spec_aug(
|
||||
ann: AnnotationGroup
|
||||
Annotation group for the spectrogram. Must be provided to sync
|
||||
the start and stop times with the spectrogram after warping.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
stretch_squeeze_delta: float
|
||||
Maximum amount to stretch or squeeze the spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -259,11 +244,10 @@ def warp_spec_aug(
|
||||
-----
|
||||
This function modifies the annotation group in place.
|
||||
"""
|
||||
# This is messy
|
||||
# Augment spectrogram by randomly stretch and squeezing
|
||||
# NOTE this also changes the start and stop time in place
|
||||
|
||||
delta = params["stretch_squeeze_delta"]
|
||||
delta = stretch_squeeze_delta
|
||||
op_size = (spec.shape[1], spec.shape[2])
|
||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
||||
@ -277,7 +261,7 @@ def warp_spec_aug(
|
||||
dtype=spec.dtype,
|
||||
),
|
||||
),
|
||||
2,
|
||||
dim=2,
|
||||
)
|
||||
else:
|
||||
spec_r = spec[:, :, :resize_amt]
|
||||
@ -297,7 +281,10 @@ def warp_spec_aug(
|
||||
return spec
|
||||
|
||||
|
||||
def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
def mask_time_aug(
|
||||
spec: torch.Tensor,
|
||||
mask_max_time_perc: float,
|
||||
) -> torch.Tensor:
|
||||
"""Mask out random blocks of time.
|
||||
|
||||
Will randomly mask out a block of time in the spectrogram. The block
|
||||
@ -308,8 +295,8 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
mask_max_time_perc: float
|
||||
Maximum percentage of time to mask out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -324,14 +311,17 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.TimeMasking(
|
||||
int(spec.shape[1] * params["mask_max_time_perc"])
|
||||
int(spec.shape[1] * mask_max_time_perc)
|
||||
)
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
def mask_freq_aug(
|
||||
spec: torch.Tensor,
|
||||
mask_max_freq_perc: float,
|
||||
) -> torch.Tensor:
|
||||
"""Mask out random blocks of frequency.
|
||||
|
||||
Will randomly mask out a block of frequency in the spectrogram. The block
|
||||
@ -342,8 +332,8 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to mask.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
mask_max_freq_perc: float
|
||||
Maximum percentage of frequency to mask out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -358,41 +348,48 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
Recognition
|
||||
"""
|
||||
fm = torchaudio.transforms.FrequencyMasking(
|
||||
int(spec.shape[1] * params["mask_max_freq_perc"])
|
||||
int(spec.shape[1] * mask_max_freq_perc)
|
||||
)
|
||||
for _ in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
|
||||
def scale_vol_aug(
|
||||
spec: torch.Tensor,
|
||||
spec_amp_scaling: float,
|
||||
) -> torch.Tensor:
|
||||
"""Scale the volume of the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: torch.Tensor
|
||||
Spectrogram to scale.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
spec_amp_scaling: float
|
||||
Maximum scaling factor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
"""
|
||||
return spec * np.random.random() * params["spec_amp_scaling"]
|
||||
return spec * np.random.random() * spec_amp_scaling
|
||||
|
||||
|
||||
def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
|
||||
def echo_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
echo_max_delay: float,
|
||||
) -> np.ndarray:
|
||||
"""Add echo to audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to add echo to.
|
||||
sampling_rate: int
|
||||
sampling_rate: float
|
||||
Sampling rate of the audio.
|
||||
params: dict
|
||||
Parameters for the augmentation.
|
||||
echo_max_delay: float
|
||||
Maximum delay of the echo in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -400,7 +397,7 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
|
||||
Audio with echo added.
|
||||
"""
|
||||
sample_offset = (
|
||||
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
|
||||
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
||||
)
|
||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
||||
return audio
|
||||
@ -408,9 +405,14 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
|
||||
|
||||
def resample_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
params: dict,
|
||||
) -> Tuple[np.ndarray, int, float]:
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
resize_factor: float,
|
||||
spec_divide_factor: float,
|
||||
spec_train_width: int,
|
||||
aug_sampling_rates: List[int],
|
||||
) -> Tuple[np.ndarray, float, float]:
|
||||
"""Resample audio augmentation.
|
||||
|
||||
Will resample the audio to a random sampling rate from the list of
|
||||
@ -420,23 +422,32 @@ def resample_aug(
|
||||
----------
|
||||
audio: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate: int
|
||||
sampling_rate: float
|
||||
Original sampling rate of the audio.
|
||||
params: dict
|
||||
Parameters for the augmentation. Includes the list of sampling rates
|
||||
to choose from for resampling in `aug_sampling_rates`.
|
||||
fft_win_length: float
|
||||
Length of the FFT window in seconds.
|
||||
fft_overlap: float
|
||||
Amount of overlap between FFT windows.
|
||||
resize_factor: float
|
||||
Factor to resize the spectrogram by.
|
||||
spec_divide_factor: float
|
||||
Factor to divide the spectrogram by.
|
||||
spec_train_width: int
|
||||
Width of the spectrogram.
|
||||
aug_sampling_rates: List[int]
|
||||
List of sampling rates to resample to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate : int
|
||||
sampling_rate : float
|
||||
New sampling rate.
|
||||
duration : float
|
||||
Duration of the audio in seconds.
|
||||
"""
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = np.random.choice(params["aug_sampling_rates"])
|
||||
sampling_rate = np.random.choice(aug_sampling_rates)
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=sampling_rate_old,
|
||||
@ -447,11 +458,11 @@ def resample_aug(
|
||||
audio = au.pad_audio(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
params["resize_factor"],
|
||||
params["spec_divide_factor"],
|
||||
params["spec_train_width"],
|
||||
fft_win_length,
|
||||
fft_overlap,
|
||||
resize_factor,
|
||||
spec_divide_factor,
|
||||
spec_train_width,
|
||||
)
|
||||
duration = audio.shape[0] / float(sampling_rate)
|
||||
return audio, sampling_rate, duration
|
||||
@ -459,28 +470,28 @@ def resample_aug(
|
||||
|
||||
def resample_audio(
|
||||
num_samples: int,
|
||||
sampling_rate: int,
|
||||
sampling_rate: float,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: int,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
sampling_rate2: float,
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Resample audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_samples: int
|
||||
Expected number of samples for the output audio.
|
||||
sampling_rate: int
|
||||
sampling_rate: float
|
||||
Original sampling rate of the audio.
|
||||
audio2: np.ndarray
|
||||
Audio to resample.
|
||||
sampling_rate2: int
|
||||
sampling_rate2: float
|
||||
Target sampling rate of the audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
audio2 : np.ndarray
|
||||
Resampled audio.
|
||||
sampling_rate2 : int
|
||||
sampling_rate2 : float
|
||||
New sampling rate.
|
||||
"""
|
||||
# resample to target sampling rate
|
||||
@ -509,12 +520,12 @@ def resample_audio(
|
||||
|
||||
def combine_audio_aug(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
ann: AnnotationGroup,
|
||||
sampling_rate: float,
|
||||
ann: AudioLoaderAnnotationGroup,
|
||||
audio2: np.ndarray,
|
||||
sampling_rate2: int,
|
||||
ann2: AnnotationGroup,
|
||||
) -> Tuple[np.ndarray, AnnotationGroup]:
|
||||
sampling_rate2: float,
|
||||
ann2: AudioLoaderAnnotationGroup,
|
||||
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
|
||||
"""Combine two audio files.
|
||||
|
||||
Will combine two audio files by resampling them to the same sampling rate
|
||||
@ -570,7 +581,9 @@ def combine_audio_aug(
|
||||
# from different individuals
|
||||
if kk == "individual_ids":
|
||||
if (ann[kk] > -1).sum() > 0:
|
||||
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
|
||||
ann2[kk][ann2[kk] > -1] += (
|
||||
np.max(ann[kk][ann[kk] > -1]) + 1
|
||||
)
|
||||
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
||||
@ -579,7 +592,8 @@ def combine_audio_aug(
|
||||
|
||||
|
||||
def _prepare_annotation(
|
||||
annotation: Annotation, class_names: List[str]
|
||||
annotation: Annotation,
|
||||
class_names: List[str],
|
||||
) -> Annotation:
|
||||
try:
|
||||
class_id = class_names.index(annotation["class"])
|
||||
@ -598,7 +612,7 @@ def _prepare_annotation(
|
||||
|
||||
|
||||
def _prepare_file_annotation(
|
||||
annotation: FileAnnotations,
|
||||
annotation: FileAnnotation,
|
||||
class_names: List[str],
|
||||
classes_to_ignore: List[str],
|
||||
) -> AudioLoaderAnnotationGroup:
|
||||
@ -626,7 +640,9 @@ def _prepare_file_annotation(
|
||||
"end_times": np.array([ann["end_time"] for ann in annotations]),
|
||||
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
|
||||
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
|
||||
"class_ids": np.array([ann.get("class_id", -1) for ann in annotations]),
|
||||
"class_ids": np.array(
|
||||
[ann.get("class_id", -1) for ann in annotations]
|
||||
),
|
||||
"individual_ids": np.array([ann["individual"] for ann in annotations]),
|
||||
"class_id_file": class_id_file,
|
||||
}
|
||||
@ -639,15 +655,15 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_anns_ip: List[FileAnnotations],
|
||||
params,
|
||||
data_anns_ip: List[FileAnnotation],
|
||||
params: AudioLoaderParameters,
|
||||
dataset_name: Optional[str] = None,
|
||||
is_train: bool = False,
|
||||
return_spec_for_viz: bool = False,
|
||||
):
|
||||
self.is_train: bool = is_train
|
||||
self.params: dict = params
|
||||
self.return_spec_for_viz: bool = False
|
||||
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.return_spec_for_viz = return_spec_for_viz
|
||||
self.data_anns: List[AudioLoaderAnnotationGroup] = [
|
||||
_prepare_file_annotation(
|
||||
ann,
|
||||
@ -657,61 +673,6 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
for ann in data_anns_ip
|
||||
]
|
||||
|
||||
# for ii in range(len(data_anns_ip)):
|
||||
# dd = copy.deepcopy(data_anns_ip[ii])
|
||||
#
|
||||
# # filter out unused annotation here
|
||||
# filtered_annotations = []
|
||||
# for ii, aa in enumerate(dd["annotation"]):
|
||||
# if "individual" in aa.keys():
|
||||
# aa["individual"] = int(aa["individual"])
|
||||
#
|
||||
# # if only one call labeled it has to be from the same
|
||||
# # individual
|
||||
# if len(dd["annotation"]) == 1:
|
||||
# aa["individual"] = 0
|
||||
#
|
||||
# # convert class name into class label
|
||||
# if aa["class"] in self.params["class_names"]:
|
||||
# aa["class_id"] = self.params["class_names"].index(
|
||||
# aa["class"]
|
||||
# )
|
||||
# else:
|
||||
# aa["class_id"] = -1
|
||||
#
|
||||
# if aa["class"] not in self.params["classes_to_ignore"]:
|
||||
# filtered_annotations.append(aa)
|
||||
#
|
||||
# dd["annotation"] = filtered_annotations
|
||||
# dd["start_times"] = np.array(
|
||||
# [aa["start_time"] for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["end_times"] = np.array(
|
||||
# [aa["end_time"] for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["high_freqs"] = np.array(
|
||||
# [float(aa["high_freq"]) for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["low_freqs"] = np.array(
|
||||
# [float(aa["low_freq"]) for aa in dd["annotation"]]
|
||||
# )
|
||||
# dd["class_ids"] = np.array(
|
||||
# [aa["class_id"] for aa in dd["annotation"]]
|
||||
# ).astype(np.int32)
|
||||
# dd["individual_ids"] = np.array(
|
||||
# [aa["individual"] for aa in dd["annotation"]]
|
||||
# ).astype(np.int32)
|
||||
#
|
||||
# # file level class name
|
||||
# dd["class_id_file"] = -1
|
||||
# if "class_name" in dd.keys():
|
||||
# if dd["class_name"] in self.params["class_names"]:
|
||||
# dd["class_id_file"] = self.params["class_names"].index(
|
||||
# dd["class_name"]
|
||||
# )
|
||||
#
|
||||
# self.data_anns.append(dd)
|
||||
|
||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
||||
self.max_num_anns = 2 * np.max(
|
||||
ann_cnt
|
||||
@ -730,7 +691,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
def get_file_and_anns(
|
||||
self,
|
||||
index: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]:
|
||||
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
|
||||
"""Get an audio file and its annotations.
|
||||
|
||||
Parameters
|
||||
@ -742,7 +703,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
-------
|
||||
audio_raw : np.ndarray
|
||||
Loaded audio file.
|
||||
sampling_rate : int
|
||||
sampling_rate : float
|
||||
Sampling rate of the audio file.
|
||||
duration : float
|
||||
Duration of the audio file in seconds.
|
||||
@ -837,7 +798,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
(
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
duration2,
|
||||
_,
|
||||
ann2,
|
||||
) = self.get_file_and_anns()
|
||||
audio, ann = combine_audio_aug(
|
||||
@ -846,7 +807,11 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
|
||||
# simulate echo by adding delayed copy of the file
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
audio = echo_aug(audio, sampling_rate, self.params)
|
||||
audio = echo_aug(
|
||||
audio,
|
||||
sampling_rate,
|
||||
echo_max_delay=self.params["echo_max_delay"],
|
||||
)
|
||||
|
||||
# resample the audio
|
||||
# if np.random.random() < self.params["aug_prob"]:
|
||||
@ -855,11 +820,16 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
# )
|
||||
|
||||
# create spectrogram
|
||||
spec, spec_for_viz = au.generate_spectrogram(
|
||||
spec = au.generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
self.params,
|
||||
self.return_spec_for_viz,
|
||||
fft_win_length=self.params["fft_win_length"],
|
||||
fft_overlap=self.params["fft_overlap"],
|
||||
max_freq=self.params["max_freq"],
|
||||
min_freq=self.params["min_freq"],
|
||||
spec_scale=self.params["spec_scale"],
|
||||
denoise_spec_avg=self.params["denoise_spec_avg"],
|
||||
max_scale_spec=self.params["max_scale_spec"],
|
||||
)
|
||||
rsf = self.params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
@ -879,20 +849,29 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
# augment spectrogram
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = scale_vol_aug(spec, self.params)
|
||||
spec = scale_vol_aug(
|
||||
spec,
|
||||
spec_amp_scaling=self.params["spec_amp_scaling"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = warp_spec_aug(
|
||||
spec,
|
||||
ann,
|
||||
self.params,
|
||||
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_time_aug(spec, self.params)
|
||||
spec = mask_time_aug(
|
||||
spec,
|
||||
mask_max_time_perc=self.params["mask_max_time_perc"],
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_freq_aug(spec, self.params)
|
||||
spec = mask_freq_aug(
|
||||
spec,
|
||||
mask_max_freq_perc=self.params["mask_max_freq_perc"],
|
||||
)
|
||||
|
||||
outputs = {}
|
||||
outputs["spec"] = spec
|
||||
@ -911,7 +890,13 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
spec_op_shape,
|
||||
sampling_rate,
|
||||
ann,
|
||||
self.params,
|
||||
class_names=self.params["class_names"],
|
||||
fft_win_length=self.params["fft_win_length"],
|
||||
fft_overlap=self.params["fft_overlap"],
|
||||
max_freq=self.params["max_freq"],
|
||||
min_freq=self.params["min_freq"],
|
||||
resize_factor=self.params["resize_factor"],
|
||||
target_sigma=self.params["target_sigma"],
|
||||
)
|
||||
|
||||
# hack to get around requirement that all vectors are the same length
|
||||
|
@ -1,8 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def bbox_size_loss(pred_size, gt_size):
|
||||
def bbox_size_loss(
|
||||
pred_size: torch.Tensor,
|
||||
gt_size: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Bounding box size loss. Only compute loss where there is a bounding box.
|
||||
"""
|
||||
@ -12,7 +17,12 @@ def bbox_size_loss(pred_size, gt_size):
|
||||
)
|
||||
|
||||
|
||||
def focal_loss(pred, gt, weights=None, valid_mask=None):
|
||||
def focal_loss(
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
weights: Optional[torch.Tensor] = None,
|
||||
valid_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
||||
pred (batch x c x h x w)
|
||||
@ -52,7 +62,11 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
|
||||
return loss
|
||||
|
||||
|
||||
def mse_loss(pred, gt, weights=None, valid_mask=None):
|
||||
def mse_loss(
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
valid_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Mean squared error loss.
|
||||
"""
|
||||
|
@ -5,6 +5,7 @@ import warnings
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.post_process as pp
|
||||
@ -29,7 +30,7 @@ def save_images_batch(model, data_loader, params):
|
||||
|
||||
ind = 0 # first image in each batch
|
||||
with torch.no_grad():
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
for inputs in data_loader:
|
||||
data = inputs["spec"].to(params["device"])
|
||||
outputs = model(data)
|
||||
|
||||
@ -81,7 +82,12 @@ def save_image(
|
||||
|
||||
|
||||
def loss_fun(
|
||||
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
):
|
||||
# detection loss
|
||||
loss = params["det_loss_weight"] * det_criterion(
|
||||
@ -104,7 +110,13 @@ def loss_fun(
|
||||
|
||||
|
||||
def train(
|
||||
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
||||
model,
|
||||
epoch,
|
||||
data_loader,
|
||||
det_criterion,
|
||||
optimizer,
|
||||
scheduler,
|
||||
params,
|
||||
):
|
||||
model.train()
|
||||
|
||||
@ -309,7 +321,7 @@ def select_model(params):
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
else:
|
||||
print("No valid network specified")
|
||||
raise ValueError("No valid network specified")
|
||||
return model
|
||||
|
||||
|
||||
@ -319,9 +331,9 @@ def main():
|
||||
params = parameters.get_params(True)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
params["device"] = "cuda"
|
||||
params.device = "cuda"
|
||||
else:
|
||||
params["device"] = "cpu"
|
||||
params.device = "cpu"
|
||||
|
||||
# setup arg parser and populate it with exiting parameters - will not work with lists
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -349,13 +361,16 @@ def main():
|
||||
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
|
||||
help='Will set low and high frequency the same for these classes. Separate names with ";"',
|
||||
)
|
||||
|
||||
for key, val in params.items():
|
||||
parser.add_argument("--" + key, type=type(val), default=val)
|
||||
params = vars(parser.parse_args())
|
||||
|
||||
# save notes file
|
||||
if params["notes"] != "":
|
||||
tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"])
|
||||
tu.write_notes_file(
|
||||
params["experiment"] + "notes.txt", params["notes"]
|
||||
)
|
||||
|
||||
# load the training and test meta data - there are different splits defined
|
||||
train_sets, test_sets = ts.get_train_test_data(
|
||||
@ -374,15 +389,11 @@ def main():
|
||||
for tt in train_sets:
|
||||
print(tt["ann_path"])
|
||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
||||
(
|
||||
data_train,
|
||||
params["class_names"],
|
||||
params["class_inv_freq"],
|
||||
) = tu.load_set_of_anns(
|
||||
data_train = tu.load_set_of_anns(
|
||||
train_sets,
|
||||
classes_to_ignore,
|
||||
params["events_of_interest"],
|
||||
params["convert_to_genus"],
|
||||
classes_to_ignore=classes_to_ignore,
|
||||
events_of_interest=params["events_of_interest"],
|
||||
convert_to_genus=params["convert_to_genus"],
|
||||
)
|
||||
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
||||
params["class_names"]
|
||||
@ -415,11 +426,12 @@ def main():
|
||||
print("\nTesting on:")
|
||||
for tt in test_sets:
|
||||
print(tt["ann_path"])
|
||||
data_test, _, _ = tu.load_set_of_anns(
|
||||
|
||||
data_test = tu.load_set_of_anns(
|
||||
test_sets,
|
||||
classes_to_ignore,
|
||||
params["events_of_interest"],
|
||||
params["convert_to_genus"],
|
||||
classes_to_ignore=classes_to_ignore,
|
||||
events_of_interest=params["events_of_interest"],
|
||||
convert_to_genus=params["convert_to_genus"],
|
||||
)
|
||||
data_train = tu.remove_dupes(data_train, data_test)
|
||||
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
||||
@ -447,10 +459,13 @@ def main():
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer, params["num_epochs"] * len(train_loader)
|
||||
)
|
||||
|
||||
if params["train_loss"] == "mse":
|
||||
det_criterion = losses.mse_loss
|
||||
elif params["train_loss"] == "focal":
|
||||
det_criterion = losses.focal_loss
|
||||
else:
|
||||
raise ValueError("No valid loss specified")
|
||||
|
||||
# save parameters to file
|
||||
with open(params["experiment"] + "params.json", "w") as da:
|
||||
|
@ -1,28 +1,37 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from batdetect2 import types
|
||||
|
||||
def write_notes_file(file_name, text):
|
||||
|
||||
def write_notes_file(file_name: str, text: str):
|
||||
with open(file_name, "a") as da:
|
||||
da.write(text + "\n")
|
||||
|
||||
|
||||
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
|
||||
ddict = {
|
||||
def get_blank_dataset_dict(
|
||||
dataset_name: str,
|
||||
is_test: bool,
|
||||
ann_path: str,
|
||||
wav_path: str,
|
||||
) -> types.DatasetDict:
|
||||
return {
|
||||
"dataset_name": dataset_name,
|
||||
"is_test": is_test,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_path,
|
||||
"wav_path": wav_path,
|
||||
}
|
||||
return ddict
|
||||
|
||||
|
||||
def get_short_class_names(class_names, str_len=3):
|
||||
def get_short_class_names(
|
||||
class_names: List[str],
|
||||
str_len: int = 3,
|
||||
) -> List[str]:
|
||||
class_names_short = []
|
||||
for cc in class_names:
|
||||
class_names_short.append(
|
||||
@ -31,7 +40,10 @@ def get_short_class_names(class_names, str_len=3):
|
||||
return class_names_short
|
||||
|
||||
|
||||
def remove_dupes(data_train, data_test):
|
||||
def remove_dupes(
|
||||
data_train: List[types.FileAnnotation],
|
||||
data_test: List[types.FileAnnotation],
|
||||
) -> List[types.FileAnnotation]:
|
||||
test_ids = [dd["id"] for dd in data_test]
|
||||
data_train_prune = []
|
||||
for aa in data_train:
|
||||
@ -43,14 +55,16 @@ def remove_dupes(data_train, data_test):
|
||||
return data_train_prune
|
||||
|
||||
|
||||
def get_genus_mapping(class_names):
|
||||
def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
|
||||
genus_names, genus_mapping = np.unique(
|
||||
[cc.split(" ")[0] for cc in class_names], return_inverse=True
|
||||
)
|
||||
return genus_names.tolist(), genus_mapping.tolist()
|
||||
|
||||
|
||||
def standardize_low_freq(data, class_of_interest):
|
||||
def standardize_low_freq(
|
||||
data: List[types.FileAnnotation], class_of_interest: str,
|
||||
) -> List[types.FileAnnotation]:
|
||||
# address the issue of highly variable low frequency annotations
|
||||
# this often happens for contstant frequency calls
|
||||
# for the class of interest sets the low and high freq to be the dataset mean
|
||||
@ -62,8 +76,8 @@ def standardize_low_freq(data, class_of_interest):
|
||||
low_freqs.append(aa["low_freq"])
|
||||
high_freqs.append(aa["high_freq"])
|
||||
|
||||
low_mean = np.mean(low_freqs)
|
||||
high_mean = np.mean(high_freqs)
|
||||
low_mean = float(np.mean(low_freqs))
|
||||
high_mean = float(np.mean(high_freqs))
|
||||
assert low_mean < high_mean
|
||||
|
||||
print("\nStandardizing low and high frequency for:")
|
||||
@ -83,115 +97,148 @@ def standardize_low_freq(data, class_of_interest):
|
||||
return data
|
||||
|
||||
|
||||
def load_set_of_anns(
|
||||
data,
|
||||
classes_to_ignore=[],
|
||||
events_of_interest=None,
|
||||
convert_to_genus=False,
|
||||
verbose=True,
|
||||
list_of_anns=False,
|
||||
filter_issues=False,
|
||||
name_replace=False,
|
||||
):
|
||||
def format_annotation(
|
||||
annotation: types.FileAnnotation,
|
||||
events_of_interest: Optional[List[str]] = None,
|
||||
name_replace: Optional[Dict[str, str]] = None,
|
||||
convert_to_genus: bool = False,
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
) -> types.FileAnnotation:
|
||||
formated = []
|
||||
for aa in annotation["annotation"]:
|
||||
if (
|
||||
events_of_interest is not None
|
||||
and aa["event"] not in events_of_interest
|
||||
):
|
||||
# Omit files with annotation issues
|
||||
continue
|
||||
|
||||
# load the annotations
|
||||
anns = []
|
||||
if list_of_anns:
|
||||
# path to list of individual json files
|
||||
anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
|
||||
else:
|
||||
# dictionary of datasets
|
||||
for dd in data:
|
||||
anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
|
||||
# remove leading and trailing spaces
|
||||
class_name = aa["class"].strip()
|
||||
|
||||
# discarding unannoated files
|
||||
anns = [aa for aa in anns if aa["annotated"] is True]
|
||||
|
||||
# filter files that have annotation issues - is the input is a dictionary of
|
||||
# datasets, this will lilely have already been done
|
||||
if filter_issues:
|
||||
anns = [aa for aa in anns if aa["issues"] is False]
|
||||
|
||||
# check for some basic formatting errors with class names
|
||||
for ann in anns:
|
||||
for aa in ann["annotation"]:
|
||||
aa["class"] = aa["class"].strip()
|
||||
|
||||
# only load specified events - i.e. types of calls
|
||||
if events_of_interest is not None:
|
||||
for ann in anns:
|
||||
filtered_events = []
|
||||
for aa in ann["annotation"]:
|
||||
if aa["event"] in events_of_interest:
|
||||
filtered_events.append(aa)
|
||||
ann["annotation"] = filtered_events
|
||||
|
||||
# change class names
|
||||
if name_replace is not None:
|
||||
# replace_names will be a dictionary mapping input name to output
|
||||
if type(name_replace) is dict:
|
||||
for ann in anns:
|
||||
for aa in ann["annotation"]:
|
||||
if aa["class"] in name_replace:
|
||||
aa["class"] = name_replace[aa["class"]]
|
||||
class_name = name_replace.get(class_name, class_name)
|
||||
|
||||
# convert everything to genus name
|
||||
if convert_to_genus:
|
||||
for ann in anns:
|
||||
for aa in ann["annotation"]:
|
||||
aa["class"] = aa["class"].split(" ")[0]
|
||||
# convert everything to genus name
|
||||
class_name = class_name.split(" ")[0]
|
||||
|
||||
# get unique class names
|
||||
class_names_all = []
|
||||
for ann in anns:
|
||||
for aa in ann["annotation"]:
|
||||
if aa["class"] not in classes_to_ignore:
|
||||
class_names_all.append(aa["class"])
|
||||
# NOTE: It is important to acknowledge that the class names filtering
|
||||
# is done after the name replacement and the conversion to
|
||||
# genus name. This allows filtering converted genus names and names
|
||||
# that were replaced with a name that should be ignored.
|
||||
if classes_to_ignore is not None and class_name in classes_to_ignore:
|
||||
# Omit annotations with ignored classes
|
||||
continue
|
||||
|
||||
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
|
||||
class_inv_freq = class_cnts.sum() / (
|
||||
len(class_names) * class_cnts.astype(np.float32)
|
||||
formated.append(
|
||||
{
|
||||
**aa,
|
||||
"class": class_name,
|
||||
}
|
||||
)
|
||||
|
||||
if verbose:
|
||||
return {
|
||||
**annotation,
|
||||
"annotation": formated,
|
||||
}
|
||||
|
||||
|
||||
def get_class_names(
|
||||
data: List[types.FileAnnotation],
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
) -> Tuple[Counter[str], List[float]]:
|
||||
"""Extracts class names and their inverse frequencies.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data
|
||||
A list of file annotations, where each annotation contains a list of
|
||||
sound events with associated class names.
|
||||
classes_to_ignore
|
||||
A list of class names to ignore.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
class_names
|
||||
A list of unique class names extracted from the annotations.
|
||||
class_inv_freq
|
||||
List of inverse frequencies of each class name in the provided data.
|
||||
"""
|
||||
if classes_to_ignore is None:
|
||||
classes_to_ignore = []
|
||||
|
||||
class_names_list: List[str] = []
|
||||
for annotation in data:
|
||||
for sound_event in annotation["annotation"]:
|
||||
if sound_event["class"] in classes_to_ignore:
|
||||
continue
|
||||
|
||||
class_names_list.append(sound_event["class"])
|
||||
|
||||
counts = Counter(class_names_list)
|
||||
mean_counts = float(np.mean(list(counts.values())))
|
||||
return counts, [mean_counts / counts[cc] for cc in class_names_list]
|
||||
|
||||
|
||||
def report_class_counts(class_names: Counter[str]):
|
||||
print("Class count:")
|
||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||
for cc in range(len(class_names)):
|
||||
print(
|
||||
str(cc).ljust(5)
|
||||
+ class_names[cc].ljust(str_len)
|
||||
+ str(class_cnts[cc])
|
||||
for index, (class_name, count) in enumerate(class_names.most_common()):
|
||||
print(f"{index:<5}{class_name:<{str_len}}{count}")
|
||||
|
||||
|
||||
def load_set_of_anns(
|
||||
data: List[types.DatasetDict],
|
||||
*,
|
||||
convert_to_genus: bool = False,
|
||||
filter_issues: bool = False,
|
||||
events_of_interest: Optional[List[str]] = None,
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
name_replace: Optional[Dict[str, str]] = None,
|
||||
) -> List[types.FileAnnotation]:
|
||||
# load the annotations
|
||||
anns = []
|
||||
|
||||
# dictionary of datasets
|
||||
for dataset in data:
|
||||
for ann in load_anns(dataset["ann_path"], dataset["wav_path"]):
|
||||
if not ann["annotated"]:
|
||||
# Omit unannotated files
|
||||
continue
|
||||
|
||||
if filter_issues and ann["issues"]:
|
||||
# Omit files with annotation issues
|
||||
continue
|
||||
|
||||
anns.append(
|
||||
format_annotation(
|
||||
ann,
|
||||
events_of_interest=events_of_interest,
|
||||
name_replace=name_replace,
|
||||
convert_to_genus=convert_to_genus,
|
||||
classes_to_ignore=classes_to_ignore,
|
||||
)
|
||||
)
|
||||
|
||||
if len(classes_to_ignore) == 0:
|
||||
return anns
|
||||
else:
|
||||
return anns, class_names.tolist(), class_inv_freq.tolist()
|
||||
|
||||
|
||||
def load_anns(ann_file_name, raw_audio_dir):
|
||||
with open(ann_file_name) as da:
|
||||
anns = json.load(da)
|
||||
|
||||
for aa in anns:
|
||||
aa["file_path"] = raw_audio_dir + aa["id"]
|
||||
|
||||
return anns
|
||||
|
||||
|
||||
def load_anns_from_path(ann_file_dir, raw_audio_dir):
|
||||
files = glob.glob(ann_file_dir + "*.json")
|
||||
anns = []
|
||||
for ff in files:
|
||||
with open(ff) as da:
|
||||
ann = json.load(da)
|
||||
ann["file_path"] = raw_audio_dir + ann["id"]
|
||||
anns.append(ann)
|
||||
def load_anns(
|
||||
ann_dir: str,
|
||||
raw_audio_dir: str,
|
||||
) -> Generator[types.FileAnnotation, None, None]:
|
||||
for path in Path(ann_dir).rglob("*.json"):
|
||||
with open(path) as fp:
|
||||
file_annotation = json.load(fp)
|
||||
|
||||
return anns
|
||||
file_annotation["file_path"] = raw_audio_dir + file_annotation["id"]
|
||||
yield file_annotation
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Types used in the code base."""
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -26,8 +26,7 @@ __all__ = [
|
||||
"Annotation",
|
||||
"DetectionModel",
|
||||
"FeatureExtractionParameters",
|
||||
"FeatureExtractor",
|
||||
"FileAnnotations",
|
||||
"FileAnnotation",
|
||||
"ModelOutput",
|
||||
"ModelParameters",
|
||||
"NonMaximumSuppressionConfig",
|
||||
@ -94,7 +93,10 @@ class ModelParameters(TypedDict):
|
||||
"""Resize factor."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names. The model is trained to detect these classes."""
|
||||
"""Class names.
|
||||
|
||||
The model is trained to detect these classes.
|
||||
"""
|
||||
|
||||
|
||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
@ -103,8 +105,8 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
tool.
|
||||
This is the format of a single annotation as expected by the
|
||||
annotation tool.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
@ -113,10 +115,10 @@ class Annotation(DictWithClass):
|
||||
end_time: float
|
||||
"""End time in seconds."""
|
||||
|
||||
low_freq: int
|
||||
low_freq: float
|
||||
"""Low frequency in Hz."""
|
||||
|
||||
high_freq: int
|
||||
high_freq: float
|
||||
"""High frequency in Hz."""
|
||||
|
||||
class_prob: float
|
||||
@ -135,7 +137,7 @@ class Annotation(DictWithClass):
|
||||
"""Numeric ID for the class of the annotation."""
|
||||
|
||||
|
||||
class FileAnnotations(TypedDict):
|
||||
class FileAnnotation(TypedDict):
|
||||
"""Format of results.
|
||||
|
||||
This is the format of the results expected by the annotation tool.
|
||||
@ -157,7 +159,7 @@ class FileAnnotations(TypedDict):
|
||||
"""Time expansion factor."""
|
||||
|
||||
class_name: str
|
||||
"""Class predicted at file level"""
|
||||
"""Class predicted at file level."""
|
||||
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
@ -169,7 +171,7 @@ class FileAnnotations(TypedDict):
|
||||
class RunResults(TypedDict):
|
||||
"""Run results."""
|
||||
|
||||
pred_dict: FileAnnotations
|
||||
pred_dict: FileAnnotation
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: NotRequired[List[np.ndarray]]
|
||||
@ -394,9 +396,9 @@ class PredictionResults(TypedDict):
|
||||
class DetectionModel(Protocol):
|
||||
"""Protocol for detection models.
|
||||
|
||||
This protocol is used to define the interface for the detection models.
|
||||
This allows us to use the same code for training and inference, even
|
||||
though the models are different.
|
||||
This protocol is used to define the interface for the detection
|
||||
models. This allows us to use the same code for training and
|
||||
inference, even though the models are different.
|
||||
"""
|
||||
|
||||
num_classes: int
|
||||
@ -416,16 +418,14 @@ class DetectionModel(Protocol):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
spec: torch.Tensor,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
spec: torch.Tensor,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
...
|
||||
@ -490,8 +490,10 @@ class HeatmapParameters(TypedDict):
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
target_sigma: float
|
||||
"""Sigma for the Gaussian kernel. Controls the width of the points in
|
||||
the heatmap."""
|
||||
"""Sigma for the Gaussian kernel.
|
||||
|
||||
Controls the width of the points in the heatmap.
|
||||
"""
|
||||
|
||||
|
||||
class AnnotationGroup(TypedDict):
|
||||
@ -522,10 +524,10 @@ class AnnotationGroup(TypedDict):
|
||||
annotated: NotRequired[bool]
|
||||
"""Wether the annotation group is complete or not.
|
||||
|
||||
Usually annotation groups are associated to a single
|
||||
audio clip. If the annotation group is complete, it means that all
|
||||
relevant sound events have been annotated. If it is not complete, it
|
||||
means that some sound events might not have been annotated.
|
||||
Usually annotation groups are associated to a single audio clip. If
|
||||
the annotation group is complete, it means that all relevant sound
|
||||
events have been annotated. If it is not complete, it means that
|
||||
some sound events might not have been annotated.
|
||||
"""
|
||||
|
||||
x_inds: NotRequired[np.ndarray]
|
||||
@ -535,12 +537,88 @@ class AnnotationGroup(TypedDict):
|
||||
"""Y coordinate of the annotations in the spectrogram."""
|
||||
|
||||
|
||||
class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations):
|
||||
class AudioLoaderAnnotationGroup(TypedDict):
|
||||
"""Group of annotation items for the training audio loader.
|
||||
|
||||
This class is used to store the annotations for the training audio
|
||||
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
|
||||
"""
|
||||
|
||||
id: str
|
||||
duration: float
|
||||
issues: bool
|
||||
file_path: str
|
||||
time_exp: float
|
||||
class_name: str
|
||||
notes: str
|
||||
start_times: np.ndarray
|
||||
end_times: np.ndarray
|
||||
low_freqs: np.ndarray
|
||||
high_freqs: np.ndarray
|
||||
class_ids: np.ndarray
|
||||
individual_ids: np.ndarray
|
||||
x_inds: np.ndarray
|
||||
y_inds: np.ndarray
|
||||
annotation: List[Annotation]
|
||||
annotated: bool
|
||||
class_id_file: int
|
||||
"""ID of the class of the file."""
|
||||
|
||||
|
||||
class AudioLoaderParameters(TypedDict):
|
||||
class_names: List[str]
|
||||
classes_to_ignore: List[str]
|
||||
target_samp_rate: int
|
||||
scale_raw_audio: bool
|
||||
fft_win_length: float
|
||||
fft_overlap: float
|
||||
spec_train_width: int
|
||||
resize_factor: float
|
||||
spec_divide_factor: int
|
||||
augment_at_train: bool
|
||||
augment_at_train_combine: bool
|
||||
aug_prob: float
|
||||
spec_height: int
|
||||
echo_max_delay: float
|
||||
spec_amp_scaling: float
|
||||
stretch_squeeze_delta: float
|
||||
mask_max_time_perc: float
|
||||
mask_max_freq_perc: float
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
spec_scale: str
|
||||
denoise_spec_avg: bool
|
||||
max_scale_spec: bool
|
||||
target_sigma: float
|
||||
|
||||
|
||||
class FeatureExtractor(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
prediction: Prediction,
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
...
|
||||
|
||||
|
||||
class DatasetDict(TypedDict):
|
||||
"""Dataset dictionary.
|
||||
|
||||
This is the format of the dictionary that contains the dataset
|
||||
information.
|
||||
"""
|
||||
|
||||
dataset_name: str
|
||||
"""Name of the dataset."""
|
||||
|
||||
is_test: bool
|
||||
"""Whether the dataset is a test set."""
|
||||
|
||||
is_binary: bool
|
||||
"""Whether the dataset is binary."""
|
||||
|
||||
ann_path: str
|
||||
"""Path to the annotations."""
|
||||
|
||||
wav_path: str
|
||||
"""Path to the audio files."""
|
||||
|
@ -1,13 +1,11 @@
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union, overload
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import wavfile
|
||||
|
||||
__all__ = [
|
||||
"load_audio",
|
||||
"generate_spectrogram",
|
||||
@ -15,113 +13,171 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
@overload
|
||||
def time_to_x_coords(
|
||||
time_in_file: np.ndarray,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> np.ndarray:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def time_to_x_coords(
|
||||
time_in_file: float,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> float:
|
||||
...
|
||||
|
||||
|
||||
def time_to_x_coords(
|
||||
time_in_file: Union[float, np.ndarray],
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> Union[float, np.ndarray]:
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||
|
||||
|
||||
# NOTE this is also defined in post_process
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
def x_coords_to_time(
|
||||
x_pos: float,
|
||||
sampling_rate: int,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> float:
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for
|
||||
# center of temporal window
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params,
|
||||
return_spec_for_viz=False,
|
||||
check_spec_size=True,
|
||||
):
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
spec_scale: str,
|
||||
denoise_spec_avg: bool = False,
|
||||
max_scale_spec: bool = False,
|
||||
) -> np.ndarray:
|
||||
# generate spectrogram
|
||||
spec = gen_mag_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
window_len=fft_win_length,
|
||||
overlap_perc=fft_overlap,
|
||||
)
|
||||
spec = crop_spectrogram(
|
||||
spec,
|
||||
fft_win_length=fft_win_length,
|
||||
max_freq=max_freq,
|
||||
min_freq=min_freq,
|
||||
)
|
||||
spec = scale_spectrogram(
|
||||
spec,
|
||||
sampling_rate,
|
||||
spec_scale=spec_scale,
|
||||
fft_win_length=fft_win_length,
|
||||
)
|
||||
|
||||
if denoise_spec_avg:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if max_scale_spec:
|
||||
spec = max_scale_spectrogram(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def crop_spectrogram(
|
||||
spec: np.ndarray,
|
||||
fft_win_length: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> np.ndarray:
|
||||
# crop to min/max freq
|
||||
max_freq = round(params["max_freq"] * params["fft_win_length"])
|
||||
min_freq = round(params["min_freq"] * params["fft_win_length"])
|
||||
max_freq = round(max_freq * fft_win_length)
|
||||
min_freq = round(min_freq * fft_win_length)
|
||||
if spec.shape[0] < max_freq:
|
||||
freq_pad = max_freq - spec.shape[0]
|
||||
spec = np.vstack(
|
||||
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
|
||||
)
|
||||
spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||
return spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||
|
||||
if params["spec_scale"] == "log":
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / sampling_rate)
|
||||
* (
|
||||
1.0
|
||||
/ (
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
)
|
||||
)
|
||||
# log_scaling = (1.0 / sampling_rate)*0.1
|
||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||
spec = np.log1p(log_scaling * spec_cropped)
|
||||
elif params["spec_scale"] == "pcen":
|
||||
spec = pcen(spec_cropped, sampling_rate)
|
||||
|
||||
elif params["spec_scale"] == "none":
|
||||
pass
|
||||
|
||||
if params["denoise_spec_avg"]:
|
||||
def denoise_spectrogram(spec: np.ndarray) -> np.ndarray:
|
||||
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
||||
spec.clip(min=0, out=spec)
|
||||
return spec.clip(min=0)
|
||||
|
||||
if params["max_scale_spec"]:
|
||||
spec = spec / (spec.max() + 10e-6)
|
||||
|
||||
# needs to be divisible by specific factor - if not it should have been padded
|
||||
# if check_spec_size:
|
||||
# assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||
# assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||
def max_scale_spectrogram(spec: np.ndarray) -> np.ndarray:
|
||||
return spec / (spec.max() + 10e-6)
|
||||
|
||||
# for visualization purposes - use log scaled spectrogram
|
||||
if return_spec_for_viz:
|
||||
|
||||
def log_scale(
|
||||
spec: np.ndarray,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
) -> np.ndarray:
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / sampling_rate)
|
||||
* (
|
||||
1.0
|
||||
/ (
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
np.abs(np.hanning(int(fft_win_length * sampling_rate))) ** 2
|
||||
).sum()
|
||||
)
|
||||
)
|
||||
spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
|
||||
else:
|
||||
spec_for_viz = None
|
||||
return np.log1p(log_scaling * spec)
|
||||
|
||||
return spec, spec_for_viz
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: np.ndarray,
|
||||
sampling_rate: float,
|
||||
spec_scale: str,
|
||||
fft_win_length: float,
|
||||
) -> np.ndarray:
|
||||
if spec_scale == "log":
|
||||
return log_scale(spec, sampling_rate, fft_win_length)
|
||||
|
||||
if spec_scale == "pcen":
|
||||
return pcen(spec, sampling_rate)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def prepare_spec_for_viz(
|
||||
spec: np.ndarray,
|
||||
sampling_rate: int,
|
||||
fft_win_length: float,
|
||||
) -> np.ndarray:
|
||||
# for visualization purposes - use log scaled spectrogram
|
||||
return log_scale(
|
||||
spec,
|
||||
sampling_rate,
|
||||
fft_win_length=fft_win_length,
|
||||
).astype(np.float32)
|
||||
|
||||
|
||||
def load_audio(
|
||||
audio_file: str,
|
||||
time_exp_fact: float,
|
||||
target_samp_rate: int,
|
||||
target_sampling_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
) -> Tuple[float, np.ndarray]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
@ -152,63 +208,82 @@ def load_audio(
|
||||
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||
audio_raw, sampling_rate = librosa.load(
|
||||
audio, sampling_rate = librosa.load(
|
||||
audio_file,
|
||||
sr=None,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
if len(audio_raw.shape) > 1:
|
||||
if len(audio.shape) > 1:
|
||||
raise ValueError("Currently does not handle stereo files")
|
||||
|
||||
sampling_rate = sampling_rate * time_exp_fact
|
||||
|
||||
# resample - need to do this after correcting for time expansion
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = target_samp_rate
|
||||
if sampling_rate_old != sampling_rate:
|
||||
audio_raw = librosa.resample(
|
||||
audio_raw,
|
||||
orig_sr=sampling_rate_old,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
audio = resample_audio(audio, sampling_rate, target_sampling_rate)
|
||||
|
||||
# clipping maximum duration
|
||||
if max_duration is not None:
|
||||
max_duration = int(
|
||||
np.minimum(
|
||||
int(sampling_rate * max_duration),
|
||||
audio_raw.shape[0],
|
||||
)
|
||||
)
|
||||
audio_raw = audio_raw[:max_duration]
|
||||
audio = clip_audio(audio, target_sampling_rate, max_duration)
|
||||
|
||||
# scale to [-1, 1]
|
||||
if scale:
|
||||
audio_raw = audio_raw - audio_raw.mean()
|
||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||
audio = scale_audio(audio)
|
||||
|
||||
return sampling_rate, audio_raw
|
||||
return target_sampling_rate, audio
|
||||
|
||||
|
||||
def resample_audio(
|
||||
audio: np.ndarray,
|
||||
sr_orig: float,
|
||||
sr_target: float,
|
||||
) -> np.ndarray:
|
||||
if sr_orig != sr_target:
|
||||
return librosa.resample(
|
||||
audio,
|
||||
orig_sr=sr_orig,
|
||||
target_sr=sr_target,
|
||||
res_type="polyphase",
|
||||
)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def clip_audio(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
max_duration: float,
|
||||
) -> np.ndarray:
|
||||
max_duration = int(
|
||||
np.minimum(
|
||||
int(sampling_rate * max_duration),
|
||||
audio.shape[0],
|
||||
)
|
||||
)
|
||||
return audio[:max_duration]
|
||||
|
||||
|
||||
def scale_audio(
|
||||
audio: np.ndarray,
|
||||
eps: float = 10e-6,
|
||||
) -> np.ndarray:
|
||||
return (audio - audio.mean()) / (np.abs(audio).max() + eps)
|
||||
|
||||
|
||||
def pad_audio(
|
||||
audio_raw,
|
||||
fs,
|
||||
ms,
|
||||
overlap_perc,
|
||||
resize_factor,
|
||||
divide_factor,
|
||||
fixed_width=None,
|
||||
):
|
||||
audio_raw: np.ndarray,
|
||||
sampling_rate: float,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
resize_factor: float,
|
||||
divide_factor: float,
|
||||
fixed_width: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||
# will be evenly divisible by `divide_factor`
|
||||
# Also deals with very short audio clips and fixed_width during training
|
||||
|
||||
# This code could be clearer, clean up
|
||||
nfft = int(ms * fs)
|
||||
nfft = int(window_len * sampling_rate)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
step = nfft - noverlap
|
||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||
@ -245,19 +320,24 @@ def pad_audio(
|
||||
return audio_raw
|
||||
|
||||
|
||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
def gen_mag_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
) -> np.ndarray:
|
||||
# Computes magnitude spectrogram by specifying time.
|
||||
|
||||
x = x.astype(np.float32)
|
||||
nfft = int(ms * fs)
|
||||
audio = audio.astype(np.float32)
|
||||
nfft = int(window_len * sampling_rate)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
|
||||
# window data
|
||||
step = nfft - noverlap
|
||||
|
||||
# compute spec
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=x, power=1, n_fft=nfft, hop_length=step, center=False
|
||||
y=audio,
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
)
|
||||
|
||||
# remove DC component and flip vertical orientation
|
||||
@ -266,24 +346,25 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
return spec.astype(np.float32)
|
||||
|
||||
|
||||
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
|
||||
nfft = int(ms * fs)
|
||||
def gen_mag_spectrogram_pt(
|
||||
audio: torch.Tensor,
|
||||
sampling_rate: float,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
) -> torch.Tensor:
|
||||
nfft = int(window_len * sampling_rate)
|
||||
nstep = round((1.0 - overlap_perc) * nfft)
|
||||
han_win = torch.hann_window(nfft, periodic=False).to(audio.device)
|
||||
|
||||
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
|
||||
|
||||
complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False)
|
||||
complex_spec = torch.stft(audio, nfft, nstep, window=han_win, center=False)
|
||||
spec = complex_spec.pow(2.0).sum(-1)
|
||||
|
||||
# remove DC component and flip vertically
|
||||
spec = torch.flipud(spec[0, 1:, :])
|
||||
|
||||
return spec
|
||||
return torch.flipud(spec[0, 1:, :])
|
||||
|
||||
|
||||
def pcen(spec_cropped, sampling_rate):
|
||||
def pcen(spec: np.ndarray, sampling_rate: float) -> np.ndarray:
|
||||
# TODO should be passing hop_length too i.e. step
|
||||
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype(
|
||||
return librosa.pcen(spec * (2**31), sr=sampling_rate / 10).astype(
|
||||
np.float32
|
||||
)
|
||||
return spec
|
||||
|
@ -16,7 +16,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
DetectionModel,
|
||||
FileAnnotations,
|
||||
FileAnnotation,
|
||||
ModelOutput,
|
||||
ModelParameters,
|
||||
PredictionResults,
|
||||
@ -79,7 +79,7 @@ def list_audio_files(ip_dir: str) -> List[str]:
|
||||
def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
@ -222,7 +222,7 @@ def format_single_result(
|
||||
duration: float,
|
||||
predictions: PredictionResults,
|
||||
class_names: List[str],
|
||||
) -> FileAnnotations:
|
||||
) -> FileAnnotation:
|
||||
"""Format results into the format expected by the annotation tool.
|
||||
|
||||
Args:
|
||||
@ -399,11 +399,10 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
|
||||
def compute_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
sampling_rate: float,
|
||||
params: SpectrogramParameters,
|
||||
device: torch.device,
|
||||
return_np: bool = False,
|
||||
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
"""Compute a spectrogram from an audio array.
|
||||
|
||||
Will pad the audio array so that it is evenly divisible by the
|
||||
@ -412,24 +411,16 @@ def compute_spectrogram(
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
|
||||
sampling_rate : int
|
||||
|
||||
params : SpectrogramParameters
|
||||
The parameters to use for generating the spectrogram.
|
||||
|
||||
return_np : bool, optional
|
||||
Whether to return the spectrogram as a numpy array as well as a
|
||||
torch tensor. The default is False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
duration : float
|
||||
The duration of the spectrgram in seconds.
|
||||
|
||||
spec : torch.Tensor
|
||||
The spectrogram as a torch tensor.
|
||||
|
||||
spec_np : np.ndarray, optional
|
||||
The spectrogram as a numpy array. Only returned if `return_np` is
|
||||
True, otherwise None.
|
||||
@ -446,7 +437,7 @@ def compute_spectrogram(
|
||||
)
|
||||
|
||||
# generate spectrogram
|
||||
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
||||
spec = au.generate_spectrogram(audio, sampling_rate, params)
|
||||
|
||||
# convert to pytorch
|
||||
spec = torch.from_numpy(spec).to(device)
|
||||
@ -466,18 +457,12 @@ def compute_spectrogram(
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
if return_np:
|
||||
spec_np = spec[0, 0, :].cpu().data.numpy()
|
||||
else:
|
||||
spec_np = None
|
||||
|
||||
return duration, spec, spec_np
|
||||
return duration, spec
|
||||
|
||||
|
||||
def iterate_over_chunks(
|
||||
audio: np.ndarray,
|
||||
samplerate: int,
|
||||
samplerate: float,
|
||||
chunk_size: float,
|
||||
) -> Iterator[Tuple[float, np.ndarray]]:
|
||||
"""Iterate over audio in chunks of size chunk_size.
|
||||
@ -510,7 +495,7 @@ def iterate_over_chunks(
|
||||
|
||||
def _process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: int,
|
||||
samplerate: float,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[PredictionResults, np.ndarray]:
|
||||
@ -632,13 +617,13 @@ def process_spectrogram(
|
||||
|
||||
def _process_audio_array(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
sampling_rate: float,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||
# load audio file and compute spectrogram
|
||||
_, spec, _ = compute_spectrogram(
|
||||
_, spec = compute_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
{
|
||||
@ -654,7 +639,6 @@ def _process_audio_array(
|
||||
"max_scale_spec": config["max_scale_spec"],
|
||||
},
|
||||
device,
|
||||
return_np=False,
|
||||
)
|
||||
|
||||
# process spectrogram with model
|
||||
@ -754,13 +738,15 @@ def process_file(
|
||||
|
||||
# Get original sampling rate
|
||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
|
||||
orig_samp_rate = file_samp_rate * float(
|
||||
config.get("time_expansion", 1.0) or 1.0
|
||||
)
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
audio_file,
|
||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
target_sampling_rate=config["target_samp_rate"],
|
||||
scale=config["scale_raw_audio"],
|
||||
max_duration=config.get("max_duration"),
|
||||
)
|
||||
@ -802,7 +788,6 @@ def process_file(
|
||||
cnn_feats.append(features[0])
|
||||
|
||||
if config["spec_slices"]:
|
||||
# FIX: This is not currently working. Returns empty slices
|
||||
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
|
||||
|
||||
# Merge results from chunks
|
||||
|
@ -152,7 +152,7 @@ def test_compute_max_power_bb(max_power: int):
|
||||
target_samp_rate=samplerate,
|
||||
)
|
||||
|
||||
spec, _ = au.generate_spectrogram(
|
||||
spec = au.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
@ -240,7 +240,7 @@ def test_compute_max_power():
|
||||
target_samp_rate=samplerate,
|
||||
)
|
||||
|
||||
spec, _ = au.generate_spectrogram(
|
||||
spec = au.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
|
Loading…
Reference in New Issue
Block a user