Added types to most functions

This commit is contained in:
Santiago Martinez 2024-01-14 17:15:22 +00:00
parent 458e11cf73
commit 0aa61af445
15 changed files with 1216 additions and 902 deletions

View File

@ -226,11 +226,10 @@ def generate_spectrogram(
if config is None: if config is None:
config = DEFAULT_SPECTROGRAM_PARAMETERS config = DEFAULT_SPECTROGRAM_PARAMETERS
_, spec, _ = du.compute_spectrogram( _, spec = du.compute_spectrogram(
audio, audio,
samp_rate, samp_rate,
config, config,
return_np=False,
device=device, device=device,
) )

View File

@ -1,5 +1,5 @@
"""Functions to compute features from predictions.""" """Functions to compute features from predictions."""
from typing import Dict, Optional from typing import Dict, List, Optional
import numpy as np import numpy as np
@ -7,15 +7,26 @@ from batdetect2 import types
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq): def convert_int_to_freq(
spec_ind: int,
spec_height: int,
min_freq: float,
max_freq: float,
) -> int:
"""Convert spectrogram index to frequency in Hz.""" "" """Convert spectrogram index to frequency in Hz.""" ""
spec_ind = spec_height - spec_ind spec_ind = spec_height - spec_ind
return round( return int(
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 round(
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq,
2,
)
) )
def extract_spec_slices(spec, pred_nms): def extract_spec_slices(
spec: np.ndarray,
pred_nms: types.PredictionResults,
) -> List[np.ndarray]:
"""Extract spectrogram slices from spectrogram. """Extract spectrogram slices from spectrogram.
The slices are extracted based on detected call locations. The slices are extracted based on detected call locations.
@ -109,7 +120,7 @@ def compute_max_power_bb(
return int( return int(
convert_int_to_freq( convert_int_to_freq(
y_high + max_power_ind, int(y_high + max_power_ind),
spec.shape[0], spec.shape[0],
min_freq, min_freq,
max_freq, max_freq,
@ -135,13 +146,11 @@ def compute_max_power(
spec_call = spec[:, x_start:x_end] spec_call = spec[:, x_start:x_end]
power_per_freq_band = np.sum(spec_call, axis=1) power_per_freq_band = np.sum(spec_call, axis=1)
max_power_ind = np.argmax(power_per_freq_band) max_power_ind = np.argmax(power_per_freq_band)
return int( return convert_int_to_freq(
convert_int_to_freq( int(max_power_ind),
max_power_ind, spec.shape[0],
spec.shape[0], min_freq,
min_freq, max_freq,
max_freq,
)
) )
@ -164,13 +173,11 @@ def compute_max_power_first(
first_half = spec_call[:, : int(spec_call.shape[1] / 2)] first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
power_per_freq_band = np.sum(first_half, axis=1) power_per_freq_band = np.sum(first_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band) max_power_ind = np.argmax(power_per_freq_band)
return int( return convert_int_to_freq(
convert_int_to_freq( int(max_power_ind),
max_power_ind, spec.shape[0],
spec.shape[0], min_freq,
min_freq, max_freq,
max_freq,
)
) )
@ -193,13 +200,11 @@ def compute_max_power_second(
second_half = spec_call[:, int(spec_call.shape[1] / 2) :] second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
power_per_freq_band = np.sum(second_half, axis=1) power_per_freq_band = np.sum(second_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band) max_power_ind = np.argmax(power_per_freq_band)
return int( return convert_int_to_freq(
convert_int_to_freq( int(max_power_ind),
max_power_ind, spec.shape[0],
spec.shape[0], min_freq,
min_freq, max_freq,
max_freq,
)
) )
@ -214,6 +219,7 @@ def compute_call_interval(
return round(prediction["start_time"] - previous["end_time"], 5) return round(prediction["start_time"] - previous["end_time"], 5)
# NOTE: The order of the features in this dictionary is important. The # NOTE: The order of the features in this dictionary is important. The
# features are extracted in this order and the order of the columns in the # features are extracted in this order and the order of the columns in the
# output csv file is determined by this order. In order to avoid breaking # output csv file is determined by this order. In order to avoid breaking
@ -236,7 +242,7 @@ def get_feats(
spec: np.ndarray, spec: np.ndarray,
pred_nms: types.PredictionResults, pred_nms: types.PredictionResults,
params: types.FeatureExtractionParameters, params: types.FeatureExtractionParameters,
): ) -> np.ndarray:
"""Extract features from spectrogram based on detected call locations. """Extract features from spectrogram based on detected call locations.
The features extracted are: The features extracted are:

View File

@ -79,7 +79,13 @@ class ConvBlockDownCoordF(nn.Module):
class ConvBlockDownStandard(nn.Module): class ConvBlockDownStandard(nn.Module):
def __init__( def __init__(
self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1 self,
in_chn,
out_chn,
ip_height=None,
k_size=3,
pad_size=1,
stride=1,
): ):
super(ConvBlockDownStandard, self).__init__() super(ConvBlockDownStandard, self).__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(

View File

@ -103,15 +103,15 @@ class Net2DFast(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, ip, return_feats=False) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
# encoder # encoder
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
# bottleneck # bottleneck
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x) x = self.att(x)
x = x.repeat([1, 1, self.bneck_height * 4, 1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
@ -121,13 +121,13 @@ class Net2DFast(nn.Module):
x = self.conv_up_4(x + x1) x = self.conv_up_4(x + x1)
# output # output
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) x = F.relu_(self.conv_op_bn(self.conv_op(x)))
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
return ModelOutput( return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_size=F.relu(self.conv_size_op(x)),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
features=x, features=x,
@ -215,26 +215,26 @@ class Net2DFastNoAttn(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, ip, return_feats=False) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = x.repeat([1, 1, self.bneck_height * 4, 1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
x = self.conv_up_2(x + x3) x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2) x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1) x = self.conv_up_4(x + x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) x = F.relu_(self.conv_op_bn(self.conv_op(x)))
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
return ModelOutput( return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_size=F.relu_(self.conv_size_op(x)),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
features=x, features=x,
@ -324,13 +324,13 @@ class Net2DFastNoCoordConv(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, ip, return_feats=False) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x) x = self.att(x)
x = x.repeat([1, 1, self.bneck_height * 4, 1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
@ -338,15 +338,13 @@ class Net2DFastNoCoordConv(nn.Module):
x = self.conv_up_3(x + x2) x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1) x = self.conv_up_4(x + x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) x = F.relu_(self.conv_op_bn(self.conv_op(x)))
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,)
return ModelOutput( return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True), pred_size=F.relu_(self.conv_size_op(x)),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
features=x, features=x,

View File

@ -1,6 +1,11 @@
import datetime import datetime
import os import os
from pathlib import Path
from typing import List, Optional, Union
from pydantic import BaseModel, Field, computed_field
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
TARGET_SAMPLERATE_HZ = 256000 TARGET_SAMPLERATE_HZ = 256000
@ -75,158 +80,154 @@ def mk_dir(path):
os.makedirs(path) os.makedirs(path)
def get_params(make_dirs=False, exps_dir="../../experiments/"): AUG_SAMPLING_RATES = [
params = {} 220500,
256000,
300000,
312500,
384000,
441000,
500000,
]
CLASSES_TO_IGNORE = ["", " ", "Unknown", "Not Bat"]
GENERIC_CLASSES = ["Bat"]
EVENTS_OF_INTEREST = ["Echolocation"]
params[
"model_name" class TrainingParameters(BaseModel):
] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
params["num_filters"] = 128 model_name: str = "Net2DFast"
num_filters: int = 128
experiment: Path
model_file_name: Path
op_im_dir: Path
op_im_dir_test: Path
notes: str = ""
target_samp_rate: int = TARGET_SAMPLERATE_HZ
fft_win_length: float = FFT_WIN_LENGTH_S
fft_overlap: float = FFT_OVERLAP
max_freq: int = MAX_FREQ_HZ
min_freq: int = MIN_FREQ_HZ
resize_factor: float = RESIZE_FACTOR
spec_height: int = SPEC_HEIGHT
spec_train_width: int = 512
spec_divide_factor: int = SPEC_DIVIDE_FACTOR
denoise_spec_avg: bool = DENOISE_SPEC_AVG
scale_raw_audio: bool = SCALE_RAW_AUDIO
max_scale_spec: bool = MAX_SCALE_SPEC
spec_scale: str = SPEC_SCALE
detection_overlap: float = 0.01
ignore_start_end: float = 0.01
detection_threshold: float = DETECTION_THRESHOLD
nms_kernel_size: int = NMS_KERNEL_SIZE
nms_top_k_per_sec: int = NMS_TOP_K_PER_SEC
aug_prob: float = 0.20
augment_at_train: bool = True
augment_at_train_combine: bool = True
echo_max_delay: float = 0.005
stretch_squeeze_delta: float = 0.04
mask_max_time_perc: float = 0.05
mask_max_freq_perc: float = 0.10
spec_amp_scaling: float = 2.0
aug_sampling_rates: List[int] = AUG_SAMPLING_RATES
train_loss: str = "focal"
det_loss_weight: float = 1.0
size_loss_weight: float = 0.1
class_loss_weight: float = 2.0
individual_loss_weight: float = 0.0
lr: float = 0.001
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 200
num_eval_epochs: int = 5
device: str = "cuda"
save_test_image_during_train: bool = False
save_test_image_after_train: bool = True
convert_to_genus: bool = False
class_names: List[str] = Field(default_factory=list)
classes_to_ignore: List[str] = Field(
default_factory=lambda: CLASSES_TO_IGNORE
)
generic_class: List[str] = Field(default_factory=lambda: GENERIC_CLASSES)
events_of_interest: List[str] = Field(
default_factory=lambda: EVENTS_OF_INTEREST
)
standardize_classs_names: List[str] = Field(default_factory=list)
@computed_field
@property
def emb_dim(self) -> int:
if self.individual_loss_weight == 0.0:
return 0
return 3
@computed_field
@property
def genus_mapping(self) -> List[int]:
_, mapping = get_genus_mapping(self.class_names)
return mapping
@computed_field
@property
def genus_classes(self) -> List[str]:
names, _ = get_genus_mapping(self.class_names)
return names
@computed_field
@property
def class_names_short(self) -> List[str]:
return get_short_class_names(self.class_names)
def get_params(
make_dirs: bool = False,
exps_dir: str = "../../experiments/",
model_name: Optional[str] = None,
experiment: Union[Path, str, None] = None,
**kwargs,
) -> TrainingParameters:
experiments_dir = Path(exps_dir)
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S") now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
model_name = now_str + ".pth.tar"
params["experiment"] = os.path.join(exps_dir, now_str, "") if model_name is None:
params["model_file_name"] = os.path.join(params["experiment"], model_name) model_name = f"{now_str}.pth.tar"
params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
params["op_im_dir_test"] = os.path.join( if experiment is None:
params["experiment"], "op_ims_test", "" experiment = experiments_dir / now_str
experiment = Path(experiment)
model_file_name = experiment / model_name
op_ims_dir = experiment / "op_ims"
op_ims_test_dir = experiment / "op_ims_test"
params = TrainingParameters(
model_name=model_name,
experiment=experiment,
model_file_name=model_file_name,
op_im_dir=op_ims_dir,
op_im_dir_test=op_ims_test_dir,
**kwargs,
) )
# params['notes'] = '' # can save notes about an experiment here
# spec parameters
params[
"target_samp_rate"
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
params[
"fft_win_length"
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
params[
"max_freq"
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
params[
"min_freq"
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
params[
"resize_factor"
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
params[
"spec_height"
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
params[
"spec_train_width"
] = 512 # units are number of time steps (before resizing is performed)
params[
"spec_divide_factor"
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
# spec processing params
params[
"denoise_spec_avg"
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
params[
"scale_raw_audio"
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
params[
"max_scale_spec"
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
# detection params
params[
"detection_overlap"
] = 0.01 # has to be within this number of ms to count as detection
params[
"ignore_start_end"
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
params[
"detection_threshold"
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
params[
"nms_kernel_size"
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
params[
"nms_top_k_per_sec"
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
params["target_sigma"] = 2.0
# augmentation params
params[
"aug_prob"
] = 0.20 # augmentations will be performed with this probability
params["augment_at_train"] = True
params["augment_at_train_combine"] = True
params[
"echo_max_delay"
] = 0.005 # simulate echo by adding copy of raw audio
params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
params[
"mask_max_time_perc"
] = 0.05 # max mask size - here percentage, not ideal
params[
"mask_max_freq_perc"
] = 0.10 # max mask size - here percentage, not ideal
params[
"spec_amp_scaling"
] = 2.0 # multiply the "volume" by 0:X times current amount
params["aug_sampling_rates"] = [
220500,
256000,
300000,
312500,
384000,
441000,
500000,
]
# loss params
params["train_loss"] = "focal" # mse or focal
params["det_loss_weight"] = 1.0 # weight for the detection part of the loss
params["size_loss_weight"] = 0.1 # weight for the bbox size loss
params["class_loss_weight"] = 2.0 # weight for the classification loss
params["individual_loss_weight"] = 0.0 # not used
if params["individual_loss_weight"] == 0.0:
params[
"emb_dim"
] = 0 # number of dimensions used for individual id embedding
else:
params["emb_dim"] = 3
# train params
params["lr"] = 0.001
params["batch_size"] = 8
params["num_workers"] = 4
params["num_epochs"] = 200
params["num_eval_epochs"] = 5 # run evaluation every X epochs
params["device"] = "cuda"
params["save_test_image_during_train"] = False
params["save_test_image_after_train"] = True
params["convert_to_genus"] = False
params["genus_mapping"] = []
params["class_names"] = []
params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
params["generic_class"] = ["Bat"]
params["events_of_interest"] = [
"Echolocation"
] # will ignore all other types of events e.g. social calls
# the classes in this list are standardized during training so that the same low and high freq are used
params["standardize_classs_names"] = []
# create directories
if make_dirs: if make_dirs:
print("Model name : " + params["model_name"]) mk_dir(experiment)
print("Model file : " + params["model_file_name"]) mk_dir(params.model_file_name.parent)
print("Experiment : " + params["experiment"]) if params.save_test_image_during_train:
mk_dir(params.op_im_dir)
mk_dir(params["experiment"]) if params.save_test_image_after_train:
if params["save_test_image_during_train"]: mk_dir(params.op_im_dir_test)
mk_dir(params["op_im_dir"])
if params["save_test_image_after_train"]:
mk_dir(params["op_im_dir_test"])
mk_dir(os.path.dirname(params["model_file_name"]))
return params return params

View File

@ -1,33 +1,31 @@
import argparse import argparse
import glob
import json
import os import os
import sys import warnings
from typing import List, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.utils.data
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.models as models
import batdetect2.detector.parameters as parameters import batdetect2.detector.parameters as parameters
import batdetect2.detector.post_process as pp
import batdetect2.train.audio_dataloader as adl import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl
import batdetect2.train.losses as losses import batdetect2.train.losses as losses
import batdetect2.train.train_model as tm import batdetect2.train.train_model as tm
import batdetect2.train.train_utils as tu import batdetect2.train.train_utils as tu
import batdetect2.utils.detector_utils as du import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
from batdetect2 import types
from batdetect2.detector.models import Net2DFast
if __name__ == "__main__": BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
info_str = "\nBatDetect - Finetune Model\n"
print(info_str)
def parse_arugments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"audio_path", type=str, help="Input directory for audio" "audio_path",
type=str,
help="Input directory for audio",
) )
parser.add_argument( parser.add_argument(
"train_ann_path", "train_ann_path",
@ -39,7 +37,15 @@ if __name__ == "__main__":
type=str, type=str,
help="Path to where test annotation file is stored", help="Path to where test annotation file is stored",
) )
parser.add_argument("model_path", type=str, help="Path to pretrained model") parser.add_argument(
"model_path", type=str, help="Path to pretrained model"
)
parser.add_argument(
"--experiment_dir",
type=str,
default=os.path.join(BASE_DIR, "experiments"),
help="Path to where experiment files are stored",
)
parser.add_argument( parser.add_argument(
"--op_model_name", "--op_model_name",
type=str, type=str,
@ -71,107 +77,63 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--notes", type=str, default="", help="Notes to save in text file" "--notes", type=str, default="", help="Notes to save in text file"
) )
args = vars(parser.parse_args()) args = parser.parse_args()
return args
params = parameters.get_params(True, "../../experiments/")
def select_device(warn=True) -> str:
if torch.cuda.is_available(): if torch.cuda.is_available():
params["device"] = "cuda" return "cuda"
else:
params["device"] = "cpu" if warn:
print( warnings.warn(
"\nNote, this will be a lot faster if you use computer with a GPU.\n" "No GPU available, using the CPU instead. Please consider using a GPU "
"to speed up training."
) )
print("\nAudio directory: " + args["audio_path"]) return "cpu"
print("Train file: " + args["train_ann_path"])
print("Test file: " + args["test_ann_path"])
print("Loading model: " + args["model_path"])
dataset_name = (
os.path.basename(args["train_ann_path"])
.replace(".json", "")
.replace("_TRAIN", "")
)
if args["train_from_scratch"]: def load_annotations(
print("\nTraining model from scratch i.e. not using pretrained weights") dataset_name: str,
model, params_train = du.load_model(args["model_path"], False) ann_path: str,
else: audio_path: str,
model, params_train = du.load_model(args["model_path"], True) classes_to_ignore: Optional[List[str]] = None,
model.to(params["device"]) events_of_interest: Optional[List[str]] = None,
) -> List[types.FileAnnotation]:
params["num_epochs"] = args["num_epochs"] train_sets: List[types.DatasetDict] = []
if args["op_model_name"] != "":
params["model_file_name"] = args["op_model_name"]
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
# save notes file
params["notes"] = args["notes"]
if args["notes"] != "":
tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
# load train annotations
train_sets = []
train_sets.append( train_sets.append(
tu.get_blank_dataset_dict(
dataset_name, False, args["train_ann_path"], args["audio_path"]
)
)
params["train_sets"] = [
tu.get_blank_dataset_dict( tu.get_blank_dataset_dict(
dataset_name, dataset_name,
False, is_test=False,
os.path.basename(args["train_ann_path"]), ann_path=ann_path,
args["audio_path"], wav_path=audio_path,
)
]
print("\nTrain set:")
(
data_train,
params["class_names"],
params["class_inv_freq"],
) = tu.load_set_of_anns(
train_sets, classes_to_ignore, params["events_of_interest"]
)
print("Number of files", len(data_train))
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
params["class_names"]
)
params["class_names_short"] = tu.get_short_class_names(
params["class_names"]
)
# load test annotations
test_sets = []
test_sets.append(
tu.get_blank_dataset_dict(
dataset_name, True, args["test_ann_path"], args["audio_path"]
) )
) )
params["test_sets"] = [
tu.get_blank_dataset_dict(
dataset_name,
True,
os.path.basename(args["test_ann_path"]),
args["audio_path"],
)
]
print("\nTest set:") return tu.load_set_of_anns(
data_test, _, _ = tu.load_set_of_anns( train_sets,
test_sets, classes_to_ignore, params["events_of_interest"] events_of_interest=events_of_interest,
classes_to_ignore=classes_to_ignore,
) )
print("Number of files", len(data_test))
def finetune_model(
model: types.DetectionModel,
data_train: List[types.FileAnnotation],
data_test: List[types.FileAnnotation],
params: parameters.TrainingParameters,
model_params: types.ModelParameters,
finetune_only_last_layer: bool = False,
save_images: bool = True,
):
# train loader # train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True) train_dataset = adl.AudioLoader(data_train, params, is_train=True)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
batch_size=params["batch_size"], batch_size=params.batch_size,
shuffle=True, shuffle=True,
num_workers=params["num_workers"], num_workers=params.num_workers,
pin_memory=True, pin_memory=True,
) )
@ -181,32 +143,36 @@ if __name__ == "__main__":
test_dataset, test_dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
num_workers=params["num_workers"], num_workers=params.num_workers,
pin_memory=True, pin_memory=True,
) )
inputs_train = next(iter(train_loader)) inputs_train = next(iter(train_loader))
params["ip_height"] = inputs_train["spec"].shape[2] params.ip_height = inputs_train["spec"].shape[2]
print("\ntrain batch size :", inputs_train["spec"].shape) print("\ntrain batch size :", inputs_train["spec"].shape)
assert params_train["model_name"] == "Net2DFast" # Check that the model is the same as the one used to train the pretrained
# weights
assert model_params["model_name"] == "Net2DFast"
assert isinstance(model, Net2DFast)
print( print(
"\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n" "\n\nSOME hyperparams need to be the same as the loaded model "
"(e.g. FFT) - currently they are getting overwritten.\n\n"
) )
# set the number of output classes # set the number of output classes
num_filts = model.conv_classes_op.in_channels num_filts = model.conv_classes_op.in_channels
k_size = model.conv_classes_op.kernel_size (k_size,) = model.conv_classes_op.kernel_size
pad = model.conv_classes_op.padding (pad,) = model.conv_classes_op.padding
model.conv_classes_op = torch.nn.Conv2d( model.conv_classes_op = torch.nn.Conv2d(
num_filts, num_filts,
len(params["class_names"]) + 1, len(params.class_names) + 1,
kernel_size=k_size, kernel_size=k_size,
padding=pad, padding=pad,
) )
model.conv_classes_op.to(params["device"]) model.conv_classes_op.to(params.device)
if args["finetune_only_last_layer"]: if finetune_only_last_layer:
print("\nOnly finetuning the final layers.\n") print("\nOnly finetuning the final layers.\n")
train_layers_i = [ train_layers_i = [
"conv_classes", "conv_classes",
@ -223,19 +189,26 @@ if __name__ == "__main__":
else: else:
param.requires_grad = False param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"]) optimizer = torch.optim.Adam(
scheduler = CosineAnnealingLR( model.parameters(),
optimizer, params["num_epochs"] * len(train_loader) lr=params.lr,
) )
if params["train_loss"] == "mse": scheduler = CosineAnnealingLR(
optimizer,
params.num_epochs * len(train_loader),
)
if params.train_loss == "mse":
det_criterion = losses.mse_loss det_criterion = losses.mse_loss
elif params["train_loss"] == "focal": elif params.train_loss == "focal":
det_criterion = losses.focal_loss det_criterion = losses.focal_loss
else:
raise ValueError("Unknown loss function")
# plotting # plotting
train_plt_ls = pu.LossPlotter( train_plt_ls = pu.LossPlotter(
params["experiment"] + "train_loss.png", params.experiment / "train_loss.png",
params["num_epochs"] + 1, params.num_epochs + 1,
["train_loss"], ["train_loss"],
None, None,
None, None,
@ -243,8 +216,8 @@ if __name__ == "__main__":
logy=True, logy=True,
) )
test_plt_ls = pu.LossPlotter( test_plt_ls = pu.LossPlotter(
params["experiment"] + "test_loss.png", params.experiment / "test_loss.png",
params["num_epochs"] + 1, params.num_epochs + 1,
["test_loss"], ["test_loss"],
None, None,
None, None,
@ -252,24 +225,24 @@ if __name__ == "__main__":
logy=True, logy=True,
) )
test_plt = pu.LossPlotter( test_plt = pu.LossPlotter(
params["experiment"] + "test.png", params.experiment / "test.png",
params["num_epochs"] + 1, params.num_epochs + 1,
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"], ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
[0, 1], [0, 1],
None, None,
["epoch", ""], ["epoch", ""],
) )
test_plt_class = pu.LossPlotter( test_plt_class = pu.LossPlotter(
params["experiment"] + "test_avg_prec.png", params.experiment / "test_avg_prec.png",
params["num_epochs"] + 1, params.num_epochs + 1,
params["class_names_short"], params.class_names_short,
[0, 1], [0, 1],
params["class_names_short"], params.class_names_short,
["epoch", "avg_prec"], ["epoch", "avg_prec"],
) )
# main train loop # main train loop
for epoch in range(0, params["num_epochs"] + 1): for epoch in range(0, params.num_epochs + 1):
train_loss = tm.train( train_loss = tm.train(
model, model,
epoch, epoch,
@ -281,10 +254,14 @@ if __name__ == "__main__":
) )
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]]) train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
if epoch % params["num_eval_epochs"] == 0: if epoch % params.num_eval_epochs == 0:
# detection accuracy on test set # detection accuracy on test set
test_res, test_loss = tm.test( test_res, test_loss = tm.test(
model, epoch, test_loader, det_criterion, params model,
epoch,
test_loader,
det_criterion,
params,
) )
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]]) test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
test_plt.update_and_save( test_plt.update_and_save(
@ -301,18 +278,106 @@ if __name__ == "__main__":
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]] epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
) )
pu.plot_pr_curve_class( pu.plot_pr_curve_class(
params["experiment"], "test_pr", "test_pr", test_res params.experiment, "test_pr", "test_pr", test_res
) )
# save finetuned model # save finetuned model
print("saving model to: " + params["model_file_name"]) print(f"saving model to: {params.model_file_name}")
op_state = { op_state = {
"epoch": epoch + 1, "epoch": epoch + 1,
"state_dict": model.state_dict(), "state_dict": model.state_dict(),
"params": params, "params": params,
} }
torch.save(op_state, params["model_file_name"]) torch.save(op_state, params.model_file_name)
# save an image with associated prediction for each batch in the test set # save an image with associated prediction for each batch in the test set
if not args["do_not_save_images"]: if save_images:
tm.save_images_batch(model, test_loader, params) tm.save_images_batch(model, test_loader, params)
def main():
info_str = "\nBatDetect - Finetune Model\n"
print(info_str)
args = parse_arugments()
# Load experiment parameters
params = parameters.get_params(
make_dirs=True,
exps_dir=args.experiment_dir,
device=select_device(),
num_epochs=args.num_epochs,
notes=args.notes,
)
print("\nAudio directory: " + args.audio_path)
print("Train file: " + args.train_ann_path)
print("Test file: " + args.test_ann_path)
print("Loading model: " + args.model_path)
if args.train_from_scratch:
print(
"\nTraining model from scratch i.e. not using pretrained weights"
)
model, model_params = du.load_model(
args.model_path,
load_weights=not args.train_from_scratch,
device=params.device,
)
if args.op_model_name != "":
params.model_file_name = args.op_model_name
classes_to_ignore = params.classes_to_ignore + params.generic_class
# save notes file
if params.notes:
tu.write_notes_file(
params.experiment / "notes.txt",
args.notes,
)
# NOTE:??
dataset_name = (
os.path.basename(args.train_ann_path)
.replace(".json", "")
.replace("_TRAIN", "")
)
# ==== LOAD DATA ====
# load train annotations
data_train = load_annotations(
dataset_name,
args.train_ann_path,
args.audio_path,
params.events_of_interest,
)
print("\nTrain set:")
print("Number of files", len(data_train))
# load test annotations
data_test = load_annotations(
dataset_name,
args.test_ann_path,
args.audio_path,
classes_to_ignore,
params.events_of_interest,
)
print("\nTrain set:")
print("Number of files", len(data_train))
finetune_model(
model,
data_train,
data_test,
params,
model_params,
finetune_only_last_layer=args.finetune_only_last_layer,
save_images=args.do_not_save_images,
)
if __name__ == "__main__":
main()

View File

@ -1,62 +1,54 @@
import argparse import argparse
import json import json
import os import os
from collections import Counter
from typing import List, Optional, Tuple
import numpy as np import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
import batdetect2.train.train_utils as tu import batdetect2.train.train_utils as tu
from batdetect2 import types
def print_dataset_stats(data, split_name, classes_to_ignore): def print_dataset_stats(
print("\nSplit:", split_name) data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None,
) -> Counter[str]:
print("Num files:", len(data)) print("Num files:", len(data))
counts, _ = tu.get_class_names(data, classes_to_ignore)
class_cnts = {} if len(counts) > 0:
for dd in data: tu.report_class_counts(counts)
for aa in dd["annotation"]: return counts
if aa["class"] not in classes_to_ignore:
if aa["class"] in class_cnts:
class_cnts[aa["class"]] += 1
else:
class_cnts[aa["class"]] = 1
if len(class_cnts) == 0:
class_names = []
else:
class_names = np.sort([*class_cnts]).tolist()
print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for ii, cc in enumerate(class_names):
print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc]))
return class_names
def load_file_names(file_name): def load_file_names(file_name: str) -> List[str]:
if os.path.isfile(file_name): if not os.path.isfile(file_name):
with open(file_name) as da: raise FileNotFoundError(f"Input file not found - {file_name}")
files = [line.rstrip() for line in da.readlines()]
for ff in files: with open(file_name) as da:
if ff.lower()[-3:] != "wav": files = [line.rstrip() for line in da.readlines()]
print("Error: Filenames need to end in .wav - ", ff)
assert False for path in files:
else: if path.lower()[-3:] != "wav":
print("Error: Input file not found - ", file_name) raise ValueError(
assert False f"Invalid file name - {path}. Must be a .wav file"
)
return files return files
if __name__ == "__main__": def parse_args():
info_str = "\nBatDetect - Prepare Data for Finetuning\n" info_str = "\nBatDetect - Prepare Data for Finetuning\n"
print(info_str) print(info_str)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"dataset_name", type=str, help="Name to call your dataset" "dataset_name", type=str, help="Name to call your dataset"
) )
parser.add_argument("audio_dir", type=str, help="Input directory for audio") parser.add_argument(
"audio_dir", type=str, help="Input directory for audio"
)
parser.add_argument( parser.add_argument(
"ann_dir", "ann_dir",
type=str, type=str,
@ -102,88 +94,126 @@ if __name__ == "__main__":
type=str, type=str,
default="", default="",
help='New class names to use instead. One to one mapping with "--input_class_names". \ help='New class names to use instead. One to one mapping with "--input_class_names". \
Separate with ";"', Separate with ";"',
) )
args = vars(parser.parse_args()) return parser.parse_args()
np.random.seed(args["rand_seed"])
def split_data(
data: List[types.FileAnnotation],
train_file: str,
test_file: str,
n_splits: int = 5,
random_state: int = 0,
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
if train_file != "" and test_file != "":
# user has specifed the train / test split
mapping = {
file_annotation["id"]: file_annotation for file_annotation in data
}
train_files = load_file_names(train_file)
test_files = load_file_names(test_file)
data_train = [
mapping[file_id] for file_id in train_files if file_id in mapping
]
data_test = [
mapping[file_id] for file_id in test_files if file_id in mapping
]
return data_train, data_test
# NOTE: Using StratifiedGroupKFold to ensure that the same file does not
# appear in both the training and test sets and trying to keep the
# distribution of classes the same in both sets.
splitter = StratifiedGroupKFold(
n_splits=n_splits,
shuffle=True,
random_state=random_state,
)
anns = np.array(
[
[dd["id"], ann["class"], ann["event"]]
for dd in data
for ann in dd["annotation"]
]
)
y = anns[:, 1]
group = anns[:, 0]
train_idx, test_idx = next(splitter.split(X=anns, y=y, groups=group))
train_ids = set(anns[train_idx, 0])
test_ids = set(anns[test_idx, 0])
assert not (train_ids & test_ids)
data_train = [dd for dd in data if dd["id"] in train_ids]
data_test = [dd for dd in data if dd["id"] in test_ids]
return data_train, data_test
def main():
args = parse_args()
np.random.seed(args.rand_seed)
classes_to_ignore = ["", " ", "Unknown", "Not Bat"] classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
generic_class = ["Bat"]
events_of_interest = ["Echolocation"] events_of_interest = ["Echolocation"]
if args["input_class_names"] != "" and args["output_class_names"] != "": name_dict = None
if args.input_class_names != "" and args.output_class_names != "":
# change the names of the classes # change the names of the classes
ip_names = args["input_class_names"].split(";") ip_names = args.input_class_names.split(";")
op_names = args["output_class_names"].split(";") op_names = args.output_class_names.split(";")
name_dict = dict(zip(ip_names, op_names)) name_dict = dict(zip(ip_names, op_names))
else:
name_dict = False
# load annotations # load annotations
data_all, _, _ = tu.load_set_of_anns( data_all = tu.load_set_of_anns(
{"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]}, [
classes_to_ignore, {
events_of_interest, "dataset_name": args.dataset_name,
False, "ann_path": args.ann_dir,
False, "wav_path": args.audio_dir,
list_of_anns=True, "is_test": False,
"is_binary": False,
}
],
classes_to_ignore=classes_to_ignore,
events_of_interest=events_of_interest,
convert_to_genus=False,
filter_issues=True, filter_issues=True,
name_replace=name_dict, name_replace=name_dict,
) )
print("Dataset name: " + args["dataset_name"]) print("Dataset name: " + args.dataset_name)
print("Audio directory: " + args["audio_dir"]) print("Audio directory: " + args.audio_dir)
print("Annotation directory: " + args["ann_dir"]) print("Annotation directory: " + args.ann_dir)
print("Ouput directory: " + args["op_dir"]) print("Ouput directory: " + args.op_dir)
print("Num annotated files: " + str(len(data_all))) print("Num annotated files: " + str(len(data_all)))
if args["train_file"] != "" and args["test_file"] != "": data_train, data_test = split_data(
# user has specifed the train / test split data=data_all,
train_files = load_file_names(args["train_file"]) train_file=args.train_file,
test_files = load_file_names(args["test_file"]) test_file=args.test_file,
file_names_all = [dd["id"] for dd in data_all] n_splits=5,
train_inds = [ random_state=args.rand_seed,
file_names_all.index(ff) )
for ff in train_files
if ff in file_names_all
]
test_inds = [
file_names_all.index(ff)
for ff in test_files
if ff in file_names_all
]
else: if not os.path.isdir(args.op_dir):
# split the data into train and test at the file level os.makedirs(args.op_dir)
num_exs = len(data_all) op_name = os.path.join(args.op_dir, args.dataset_name)
test_inds = np.random.choice(
np.arange(num_exs),
int(num_exs * args["percent_val"]),
replace=False,
)
test_inds = np.sort(test_inds)
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
data_train = [data_all[ii] for ii in train_inds]
data_test = [data_all[ii] for ii in test_inds]
if not os.path.isdir(args["op_dir"]):
os.makedirs(args["op_dir"])
op_name = os.path.join(args["op_dir"], args["dataset_name"])
op_name_train = op_name + "_TRAIN.json" op_name_train = op_name + "_TRAIN.json"
op_name_test = op_name + "_TEST.json" op_name_test = op_name + "_TEST.json"
class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore) print("\nSplit: Train")
class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore) class_un_train = print_dataset_stats(data_train, classes_to_ignore)
print("\nSplit: Test")
class_un_test = print_dataset_stats(data_test, classes_to_ignore)
if len(data_train) > 0 and len(data_test) > 0: if len(data_train) > 0 and len(data_test) > 0:
if class_un_train != class_un_test: if set(class_un_train.keys()) != set(class_un_test.keys()):
print( raise RuntimeError(
'\nError: some classes are not in both the training and test sets.\ "Error: some classes are not in both the training and test sets."
\nTry a different random seed "--rand_seed".' 'Try a different random seed "--rand_seed".'
) )
assert False
print("\n") print("\n")
if len(data_train) == 0: if len(data_train) == 0:
@ -199,3 +229,7 @@ if __name__ == "__main__":
print("Saving: ", op_name_test) print("Saving: ", op_name_test)
with open(op_name_test, "w") as da: with open(op_name_test, "w") as da:
json.dump(data_test, da, indent=2) json.dump(data_test, da, indent=2)
if __name__ == "__main__":
main()

View File

@ -12,19 +12,24 @@ import torchaudio
import batdetect2.utils.audio_utils as au import batdetect2.utils.audio_utils as au
from batdetect2.types import ( from batdetect2.types import (
Annotation, Annotation,
AnnotationGroup,
AudioLoaderAnnotationGroup, AudioLoaderAnnotationGroup,
FileAnnotations, AudioLoaderParameters,
HeatmapParameters, FileAnnotation,
) )
def generate_gt_heatmaps( def generate_gt_heatmaps(
spec_op_shape: Tuple[int, int], spec_op_shape: Tuple[int, int],
sampling_rate: int, sampling_rate: float,
ann: AnnotationGroup, ann: AudioLoaderAnnotationGroup,
params: HeatmapParameters, class_names: List[str],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]: fft_win_length: float,
fft_overlap: float,
max_freq: float,
min_freq: float,
resize_factor: float,
target_sigma: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
"""Generate ground truth heatmaps from annotations. """Generate ground truth heatmaps from annotations.
Parameters Parameters
@ -53,31 +58,31 @@ def generate_gt_heatmaps(
the x and y indices of their pixel location in the input spectrogram. the x and y indices of their pixel location in the input spectrogram.
""" """
# spec may be resized on input into the network # spec may be resized on input into the network
num_classes = len(params["class_names"]) num_classes = len(class_names)
op_height = spec_op_shape[0] op_height = spec_op_shape[0]
op_width = spec_op_shape[1] op_width = spec_op_shape[1]
freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height freq_per_bin = (max_freq - min_freq) / op_height
# start and end times # start and end times
x_pos_start = au.time_to_x_coords( x_pos_start = au.time_to_x_coords(
ann["start_times"], ann["start_times"],
sampling_rate, sampling_rate,
params["fft_win_length"], fft_win_length,
params["fft_overlap"], fft_overlap,
) )
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32) x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
x_pos_end = au.time_to_x_coords( x_pos_end = au.time_to_x_coords(
ann["end_times"], ann["end_times"],
sampling_rate, sampling_rate,
params["fft_win_length"], fft_win_length,
params["fft_overlap"], fft_overlap,
) )
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32) x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
# location on y axis i.e. frequency # location on y axis i.e. frequency
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
y_pos_low = (op_height - y_pos_low).astype(np.int32) y_pos_low = (op_height - y_pos_low).astype(np.int32)
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
y_pos_high = (op_height - y_pos_high).astype(np.int32) y_pos_high = (op_height - y_pos_high).astype(np.int32)
bb_widths = x_pos_end - x_pos_start bb_widths = x_pos_end - x_pos_start
bb_heights = y_pos_low - y_pos_high bb_heights = y_pos_low - y_pos_high
@ -90,26 +95,17 @@ def generate_gt_heatmaps(
& (y_pos_low < (op_height - 1)) & (y_pos_low < (op_height - 1))
)[0] )[0]
ann_aug: AnnotationGroup = { ann_aug: AudioLoaderAnnotationGroup = {
**ann,
"start_times": ann["start_times"][valid_inds], "start_times": ann["start_times"][valid_inds],
"end_times": ann["end_times"][valid_inds], "end_times": ann["end_times"][valid_inds],
"high_freqs": ann["high_freqs"][valid_inds], "high_freqs": ann["high_freqs"][valid_inds],
"low_freqs": ann["low_freqs"][valid_inds], "low_freqs": ann["low_freqs"][valid_inds],
"class_ids": ann["class_ids"][valid_inds], "class_ids": ann["class_ids"][valid_inds],
"individual_ids": ann["individual_ids"][valid_inds], "individual_ids": ann["individual_ids"][valid_inds],
"x_inds": x_pos_start[valid_inds],
"y_inds": y_pos_low[valid_inds],
} }
ann_aug["x_inds"] = x_pos_start[valid_inds]
ann_aug["y_inds"] = y_pos_low[valid_inds]
# keys = [
# "start_times",
# "end_times",
# "high_freqs",
# "low_freqs",
# "class_ids",
# "individual_ids",
# ]
# for kk in keys:
# ann_aug[kk] = ann[kk][valid_inds]
# if the number of calls is only 1, then it is unique # if the number of calls is only 1, then it is unique
# TODO would be better if we found these unique calls at the merging stage # TODO would be better if we found these unique calls at the merging stage
@ -118,6 +114,7 @@ def generate_gt_heatmaps(
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32) y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32) y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
# num classes and "background" class # num classes and "background" class
y_2d_classes: np.ndarray = np.zeros( y_2d_classes: np.ndarray = np.zeros(
(num_classes + 1, op_height, op_width), dtype=np.float32 (num_classes + 1, op_height, op_width), dtype=np.float32
@ -128,14 +125,8 @@ def generate_gt_heatmaps(
draw_gaussian( draw_gaussian(
y_2d_det[0, :], y_2d_det[0, :],
(x_pos_start[ii], y_pos_low[ii]), (x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"], target_sigma,
) )
# draw_gaussian(
# y_2d_det[0, :],
# (x_pos_start[ii], y_pos_low[ii]),
# params["target_sigma"],
# params["target_sigma"] * 2,
# )
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii] y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii] y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
@ -144,14 +135,8 @@ def generate_gt_heatmaps(
draw_gaussian( draw_gaussian(
y_2d_classes[cls_id, :], y_2d_classes[cls_id, :],
(x_pos_start[ii], y_pos_low[ii]), (x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"], target_sigma,
) )
# draw_gaussian(
# y_2d_classes[cls_id, :],
# (x_pos_start[ii], y_pos_low[ii]),
# params["target_sigma"],
# params["target_sigma"] * 2,
# )
# be careful as this will have a 1.0 places where we have event but # be careful as this will have a 1.0 places where we have event but
# dont know gt class this will be masked in training anyway # dont know gt class this will be masked in training anyway
@ -235,8 +220,8 @@ def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
def warp_spec_aug( def warp_spec_aug(
spec: torch.Tensor, spec: torch.Tensor,
ann: AnnotationGroup, ann: AudioLoaderAnnotationGroup,
params: dict, stretch_squeeze_delta: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Warp spectrogram by randomly stretching and squeezing. """Warp spectrogram by randomly stretching and squeezing.
@ -247,8 +232,8 @@ def warp_spec_aug(
ann: AnnotationGroup ann: AnnotationGroup
Annotation group for the spectrogram. Must be provided to sync Annotation group for the spectrogram. Must be provided to sync
the start and stop times with the spectrogram after warping. the start and stop times with the spectrogram after warping.
params: dict stretch_squeeze_delta: float
Parameters for the augmentation. Maximum amount to stretch or squeeze the spectrogram.
Returns Returns
------- -------
@ -259,11 +244,10 @@ def warp_spec_aug(
----- -----
This function modifies the annotation group in place. This function modifies the annotation group in place.
""" """
# This is messy
# Augment spectrogram by randomly stretch and squeezing # Augment spectrogram by randomly stretch and squeezing
# NOTE this also changes the start and stop time in place # NOTE this also changes the start and stop time in place
delta = params["stretch_squeeze_delta"] delta = stretch_squeeze_delta
op_size = (spec.shape[1], spec.shape[2]) op_size = (spec.shape[1], spec.shape[2])
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0 resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
resize_amt = int(spec.shape[2] * resize_fract_r) resize_amt = int(spec.shape[2] * resize_fract_r)
@ -277,7 +261,7 @@ def warp_spec_aug(
dtype=spec.dtype, dtype=spec.dtype,
), ),
), ),
2, dim=2,
) )
else: else:
spec_r = spec[:, :, :resize_amt] spec_r = spec[:, :, :resize_amt]
@ -297,7 +281,10 @@ def warp_spec_aug(
return spec return spec
def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: def mask_time_aug(
spec: torch.Tensor,
mask_max_time_perc: float,
) -> torch.Tensor:
"""Mask out random blocks of time. """Mask out random blocks of time.
Will randomly mask out a block of time in the spectrogram. The block Will randomly mask out a block of time in the spectrogram. The block
@ -308,8 +295,8 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
---------- ----------
spec: torch.Tensor spec: torch.Tensor
Spectrogram to mask. Spectrogram to mask.
params: dict mask_max_time_perc: float
Parameters for the augmentation. Maximum percentage of time to mask out.
Returns Returns
------- -------
@ -324,14 +311,17 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
Recognition Recognition
""" """
fm = torchaudio.transforms.TimeMasking( fm = torchaudio.transforms.TimeMasking(
int(spec.shape[1] * params["mask_max_time_perc"]) int(spec.shape[1] * mask_max_time_perc)
) )
for _ in range(np.random.randint(1, 4)): for _ in range(np.random.randint(1, 4)):
spec = fm(spec) spec = fm(spec)
return spec return spec
def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: def mask_freq_aug(
spec: torch.Tensor,
mask_max_freq_perc: float,
) -> torch.Tensor:
"""Mask out random blocks of frequency. """Mask out random blocks of frequency.
Will randomly mask out a block of frequency in the spectrogram. The block Will randomly mask out a block of frequency in the spectrogram. The block
@ -342,8 +332,8 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
---------- ----------
spec: torch.Tensor spec: torch.Tensor
Spectrogram to mask. Spectrogram to mask.
params: dict mask_max_freq_perc: float
Parameters for the augmentation. Maximum percentage of frequency to mask out.
Returns Returns
------- -------
@ -358,41 +348,48 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
Recognition Recognition
""" """
fm = torchaudio.transforms.FrequencyMasking( fm = torchaudio.transforms.FrequencyMasking(
int(spec.shape[1] * params["mask_max_freq_perc"]) int(spec.shape[1] * mask_max_freq_perc)
) )
for _ in range(np.random.randint(1, 4)): for _ in range(np.random.randint(1, 4)):
spec = fm(spec) spec = fm(spec)
return spec return spec
def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: def scale_vol_aug(
spec: torch.Tensor,
spec_amp_scaling: float,
) -> torch.Tensor:
"""Scale the volume of the spectrogram. """Scale the volume of the spectrogram.
Parameters Parameters
---------- ----------
spec: torch.Tensor spec: torch.Tensor
Spectrogram to scale. Spectrogram to scale.
params: dict spec_amp_scaling: float
Parameters for the augmentation. Maximum scaling factor.
Returns Returns
------- -------
torch.Tensor torch.Tensor
""" """
return spec * np.random.random() * params["spec_amp_scaling"] return spec * np.random.random() * spec_amp_scaling
def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray: def echo_aug(
audio: np.ndarray,
sampling_rate: float,
echo_max_delay: float,
) -> np.ndarray:
"""Add echo to audio. """Add echo to audio.
Parameters Parameters
---------- ----------
audio: np.ndarray audio: np.ndarray
Audio to add echo to. Audio to add echo to.
sampling_rate: int sampling_rate: float
Sampling rate of the audio. Sampling rate of the audio.
params: dict echo_max_delay: float
Parameters for the augmentation. Maximum delay of the echo in seconds.
Returns Returns
------- -------
@ -400,7 +397,7 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
Audio with echo added. Audio with echo added.
""" """
sample_offset = ( sample_offset = (
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1 int(echo_max_delay * np.random.random() * sampling_rate) + 1
) )
audio[:-sample_offset] += np.random.random() * audio[sample_offset:] audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
return audio return audio
@ -408,9 +405,14 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
def resample_aug( def resample_aug(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: float,
params: dict, fft_win_length: float,
) -> Tuple[np.ndarray, int, float]: fft_overlap: float,
resize_factor: float,
spec_divide_factor: float,
spec_train_width: int,
aug_sampling_rates: List[int],
) -> Tuple[np.ndarray, float, float]:
"""Resample audio augmentation. """Resample audio augmentation.
Will resample the audio to a random sampling rate from the list of Will resample the audio to a random sampling rate from the list of
@ -420,23 +422,32 @@ def resample_aug(
---------- ----------
audio: np.ndarray audio: np.ndarray
Audio to resample. Audio to resample.
sampling_rate: int sampling_rate: float
Original sampling rate of the audio. Original sampling rate of the audio.
params: dict fft_win_length: float
Parameters for the augmentation. Includes the list of sampling rates Length of the FFT window in seconds.
to choose from for resampling in `aug_sampling_rates`. fft_overlap: float
Amount of overlap between FFT windows.
resize_factor: float
Factor to resize the spectrogram by.
spec_divide_factor: float
Factor to divide the spectrogram by.
spec_train_width: int
Width of the spectrogram.
aug_sampling_rates: List[int]
List of sampling rates to resample to.
Returns Returns
------- -------
audio : np.ndarray audio : np.ndarray
Resampled audio. Resampled audio.
sampling_rate : int sampling_rate : float
New sampling rate. New sampling rate.
duration : float duration : float
Duration of the audio in seconds. Duration of the audio in seconds.
""" """
sampling_rate_old = sampling_rate sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(params["aug_sampling_rates"]) sampling_rate = np.random.choice(aug_sampling_rates)
audio = librosa.resample( audio = librosa.resample(
audio, audio,
orig_sr=sampling_rate_old, orig_sr=sampling_rate_old,
@ -447,11 +458,11 @@ def resample_aug(
audio = au.pad_audio( audio = au.pad_audio(
audio, audio,
sampling_rate, sampling_rate,
params["fft_win_length"], fft_win_length,
params["fft_overlap"], fft_overlap,
params["resize_factor"], resize_factor,
params["spec_divide_factor"], spec_divide_factor,
params["spec_train_width"], spec_train_width,
) )
duration = audio.shape[0] / float(sampling_rate) duration = audio.shape[0] / float(sampling_rate)
return audio, sampling_rate, duration return audio, sampling_rate, duration
@ -459,28 +470,28 @@ def resample_aug(
def resample_audio( def resample_audio(
num_samples: int, num_samples: int,
sampling_rate: int, sampling_rate: float,
audio2: np.ndarray, audio2: np.ndarray,
sampling_rate2: int, sampling_rate2: float,
) -> Tuple[np.ndarray, int]: ) -> Tuple[np.ndarray, float]:
"""Resample audio. """Resample audio.
Parameters Parameters
---------- ----------
num_samples: int num_samples: int
Expected number of samples for the output audio. Expected number of samples for the output audio.
sampling_rate: int sampling_rate: float
Original sampling rate of the audio. Original sampling rate of the audio.
audio2: np.ndarray audio2: np.ndarray
Audio to resample. Audio to resample.
sampling_rate2: int sampling_rate2: float
Target sampling rate of the audio. Target sampling rate of the audio.
Returns Returns
------- -------
audio2 : np.ndarray audio2 : np.ndarray
Resampled audio. Resampled audio.
sampling_rate2 : int sampling_rate2 : float
New sampling rate. New sampling rate.
""" """
# resample to target sampling rate # resample to target sampling rate
@ -509,12 +520,12 @@ def resample_audio(
def combine_audio_aug( def combine_audio_aug(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: float,
ann: AnnotationGroup, ann: AudioLoaderAnnotationGroup,
audio2: np.ndarray, audio2: np.ndarray,
sampling_rate2: int, sampling_rate2: float,
ann2: AnnotationGroup, ann2: AudioLoaderAnnotationGroup,
) -> Tuple[np.ndarray, AnnotationGroup]: ) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
"""Combine two audio files. """Combine two audio files.
Will combine two audio files by resampling them to the same sampling rate Will combine two audio files by resampling them to the same sampling rate
@ -570,7 +581,9 @@ def combine_audio_aug(
# from different individuals # from different individuals
if kk == "individual_ids": if kk == "individual_ids":
if (ann[kk] > -1).sum() > 0: if (ann[kk] > -1).sum() > 0:
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1 ann2[kk][ann2[kk] > -1] += (
np.max(ann[kk][ann[kk] > -1]) + 1
)
if (kk != "class_id_file") and (kk != "annotated"): if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds] ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
@ -579,7 +592,8 @@ def combine_audio_aug(
def _prepare_annotation( def _prepare_annotation(
annotation: Annotation, class_names: List[str] annotation: Annotation,
class_names: List[str],
) -> Annotation: ) -> Annotation:
try: try:
class_id = class_names.index(annotation["class"]) class_id = class_names.index(annotation["class"])
@ -598,7 +612,7 @@ def _prepare_annotation(
def _prepare_file_annotation( def _prepare_file_annotation(
annotation: FileAnnotations, annotation: FileAnnotation,
class_names: List[str], class_names: List[str],
classes_to_ignore: List[str], classes_to_ignore: List[str],
) -> AudioLoaderAnnotationGroup: ) -> AudioLoaderAnnotationGroup:
@ -626,7 +640,9 @@ def _prepare_file_annotation(
"end_times": np.array([ann["end_time"] for ann in annotations]), "end_times": np.array([ann["end_time"] for ann in annotations]),
"high_freqs": np.array([ann["high_freq"] for ann in annotations]), "high_freqs": np.array([ann["high_freq"] for ann in annotations]),
"low_freqs": np.array([ann["low_freq"] for ann in annotations]), "low_freqs": np.array([ann["low_freq"] for ann in annotations]),
"class_ids": np.array([ann.get("class_id", -1) for ann in annotations]), "class_ids": np.array(
[ann.get("class_id", -1) for ann in annotations]
),
"individual_ids": np.array([ann["individual"] for ann in annotations]), "individual_ids": np.array([ann["individual"] for ann in annotations]),
"class_id_file": class_id_file, "class_id_file": class_id_file,
} }
@ -639,15 +655,15 @@ class AudioLoader(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
data_anns_ip: List[FileAnnotations], data_anns_ip: List[FileAnnotation],
params, params: AudioLoaderParameters,
dataset_name: Optional[str] = None, dataset_name: Optional[str] = None,
is_train: bool = False, is_train: bool = False,
return_spec_for_viz: bool = False,
): ):
self.is_train: bool = is_train self.is_train = is_train
self.params: dict = params self.params = params
self.return_spec_for_viz: bool = False self.return_spec_for_viz = return_spec_for_viz
self.data_anns: List[AudioLoaderAnnotationGroup] = [ self.data_anns: List[AudioLoaderAnnotationGroup] = [
_prepare_file_annotation( _prepare_file_annotation(
ann, ann,
@ -657,61 +673,6 @@ class AudioLoader(torch.utils.data.Dataset):
for ann in data_anns_ip for ann in data_anns_ip
] ]
# for ii in range(len(data_anns_ip)):
# dd = copy.deepcopy(data_anns_ip[ii])
#
# # filter out unused annotation here
# filtered_annotations = []
# for ii, aa in enumerate(dd["annotation"]):
# if "individual" in aa.keys():
# aa["individual"] = int(aa["individual"])
#
# # if only one call labeled it has to be from the same
# # individual
# if len(dd["annotation"]) == 1:
# aa["individual"] = 0
#
# # convert class name into class label
# if aa["class"] in self.params["class_names"]:
# aa["class_id"] = self.params["class_names"].index(
# aa["class"]
# )
# else:
# aa["class_id"] = -1
#
# if aa["class"] not in self.params["classes_to_ignore"]:
# filtered_annotations.append(aa)
#
# dd["annotation"] = filtered_annotations
# dd["start_times"] = np.array(
# [aa["start_time"] for aa in dd["annotation"]]
# )
# dd["end_times"] = np.array(
# [aa["end_time"] for aa in dd["annotation"]]
# )
# dd["high_freqs"] = np.array(
# [float(aa["high_freq"]) for aa in dd["annotation"]]
# )
# dd["low_freqs"] = np.array(
# [float(aa["low_freq"]) for aa in dd["annotation"]]
# )
# dd["class_ids"] = np.array(
# [aa["class_id"] for aa in dd["annotation"]]
# ).astype(np.int32)
# dd["individual_ids"] = np.array(
# [aa["individual"] for aa in dd["annotation"]]
# ).astype(np.int32)
#
# # file level class name
# dd["class_id_file"] = -1
# if "class_name" in dd.keys():
# if dd["class_name"] in self.params["class_names"]:
# dd["class_id_file"] = self.params["class_names"].index(
# dd["class_name"]
# )
#
# self.data_anns.append(dd)
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns] ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
self.max_num_anns = 2 * np.max( self.max_num_anns = 2 * np.max(
ann_cnt ann_cnt
@ -730,7 +691,7 @@ class AudioLoader(torch.utils.data.Dataset):
def get_file_and_anns( def get_file_and_anns(
self, self,
index: Optional[int] = None, index: Optional[int] = None,
) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]: ) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
"""Get an audio file and its annotations. """Get an audio file and its annotations.
Parameters Parameters
@ -742,7 +703,7 @@ class AudioLoader(torch.utils.data.Dataset):
------- -------
audio_raw : np.ndarray audio_raw : np.ndarray
Loaded audio file. Loaded audio file.
sampling_rate : int sampling_rate : float
Sampling rate of the audio file. Sampling rate of the audio file.
duration : float duration : float
Duration of the audio file in seconds. Duration of the audio file in seconds.
@ -837,7 +798,7 @@ class AudioLoader(torch.utils.data.Dataset):
( (
audio2, audio2,
sampling_rate2, sampling_rate2,
duration2, _,
ann2, ann2,
) = self.get_file_and_anns() ) = self.get_file_and_anns()
audio, ann = combine_audio_aug( audio, ann = combine_audio_aug(
@ -846,7 +807,11 @@ class AudioLoader(torch.utils.data.Dataset):
# simulate echo by adding delayed copy of the file # simulate echo by adding delayed copy of the file
if np.random.random() < self.params["aug_prob"]: if np.random.random() < self.params["aug_prob"]:
audio = echo_aug(audio, sampling_rate, self.params) audio = echo_aug(
audio,
sampling_rate,
echo_max_delay=self.params["echo_max_delay"],
)
# resample the audio # resample the audio
# if np.random.random() < self.params["aug_prob"]: # if np.random.random() < self.params["aug_prob"]:
@ -855,11 +820,16 @@ class AudioLoader(torch.utils.data.Dataset):
# ) # )
# create spectrogram # create spectrogram
spec, spec_for_viz = au.generate_spectrogram( spec = au.generate_spectrogram(
audio, audio,
sampling_rate, sampling_rate,
self.params, fft_win_length=self.params["fft_win_length"],
self.return_spec_for_viz, fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
min_freq=self.params["min_freq"],
spec_scale=self.params["spec_scale"],
denoise_spec_avg=self.params["denoise_spec_avg"],
max_scale_spec=self.params["max_scale_spec"],
) )
rsf = self.params["resize_factor"] rsf = self.params["resize_factor"]
spec_op_shape = ( spec_op_shape = (
@ -879,20 +849,29 @@ class AudioLoader(torch.utils.data.Dataset):
# augment spectrogram # augment spectrogram
if self.is_train and self.params["augment_at_train"]: if self.is_train and self.params["augment_at_train"]:
if np.random.random() < self.params["aug_prob"]: if np.random.random() < self.params["aug_prob"]:
spec = scale_vol_aug(spec, self.params) spec = scale_vol_aug(
spec,
spec_amp_scaling=self.params["spec_amp_scaling"],
)
if np.random.random() < self.params["aug_prob"]: if np.random.random() < self.params["aug_prob"]:
spec = warp_spec_aug( spec = warp_spec_aug(
spec, spec,
ann, ann,
self.params, stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
) )
if np.random.random() < self.params["aug_prob"]: if np.random.random() < self.params["aug_prob"]:
spec = mask_time_aug(spec, self.params) spec = mask_time_aug(
spec,
mask_max_time_perc=self.params["mask_max_time_perc"],
)
if np.random.random() < self.params["aug_prob"]: if np.random.random() < self.params["aug_prob"]:
spec = mask_freq_aug(spec, self.params) spec = mask_freq_aug(
spec,
mask_max_freq_perc=self.params["mask_max_freq_perc"],
)
outputs = {} outputs = {}
outputs["spec"] = spec outputs["spec"] = spec
@ -911,7 +890,13 @@ class AudioLoader(torch.utils.data.Dataset):
spec_op_shape, spec_op_shape,
sampling_rate, sampling_rate,
ann, ann,
self.params, class_names=self.params["class_names"],
fft_win_length=self.params["fft_win_length"],
fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
min_freq=self.params["min_freq"],
resize_factor=self.params["resize_factor"],
target_sigma=self.params["target_sigma"],
) )
# hack to get around requirement that all vectors are the same length # hack to get around requirement that all vectors are the same length

View File

@ -1,8 +1,13 @@
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def bbox_size_loss(pred_size, gt_size): def bbox_size_loss(
pred_size: torch.Tensor,
gt_size: torch.Tensor,
) -> torch.Tensor:
""" """
Bounding box size loss. Only compute loss where there is a bounding box. Bounding box size loss. Only compute loss where there is a bounding box.
""" """
@ -12,7 +17,12 @@ def bbox_size_loss(pred_size, gt_size):
) )
def focal_loss(pred, gt, weights=None, valid_mask=None): def focal_loss(
pred: torch.Tensor,
gt: torch.Tensor,
weights: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
""" """
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
pred (batch x c x h x w) pred (batch x c x h x w)
@ -52,7 +62,11 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
return loss return loss
def mse_loss(pred, gt, weights=None, valid_mask=None): def mse_loss(
pred: torch.Tensor,
gt: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
""" """
Mean squared error loss. Mean squared error loss.
""" """

View File

@ -5,6 +5,7 @@ import warnings
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.utils.data
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.post_process as pp import batdetect2.detector.post_process as pp
@ -29,7 +30,7 @@ def save_images_batch(model, data_loader, params):
ind = 0 # first image in each batch ind = 0 # first image in each batch
with torch.no_grad(): with torch.no_grad():
for batch_idx, inputs in enumerate(data_loader): for inputs in data_loader:
data = inputs["spec"].to(params["device"]) data = inputs["spec"].to(params["device"])
outputs = model(data) outputs = model(data)
@ -81,7 +82,12 @@ def save_image(
def loss_fun( def loss_fun(
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
): ):
# detection loss # detection loss
loss = params["det_loss_weight"] * det_criterion( loss = params["det_loss_weight"] * det_criterion(
@ -104,7 +110,13 @@ def loss_fun(
def train( def train(
model, epoch, data_loader, det_criterion, optimizer, scheduler, params model,
epoch,
data_loader,
det_criterion,
optimizer,
scheduler,
params,
): ):
model.train() model.train()
@ -309,7 +321,7 @@ def select_model(params):
resize_factor=params["resize_factor"], resize_factor=params["resize_factor"],
) )
else: else:
print("No valid network specified") raise ValueError("No valid network specified")
return model return model
@ -319,9 +331,9 @@ def main():
params = parameters.get_params(True) params = parameters.get_params(True)
if torch.cuda.is_available(): if torch.cuda.is_available():
params["device"] = "cuda" params.device = "cuda"
else: else:
params["device"] = "cpu" params.device = "cpu"
# setup arg parser and populate it with exiting parameters - will not work with lists # setup arg parser and populate it with exiting parameters - will not work with lists
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -349,13 +361,16 @@ def main():
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros", default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
help='Will set low and high frequency the same for these classes. Separate names with ";"', help='Will set low and high frequency the same for these classes. Separate names with ";"',
) )
for key, val in params.items(): for key, val in params.items():
parser.add_argument("--" + key, type=type(val), default=val) parser.add_argument("--" + key, type=type(val), default=val)
params = vars(parser.parse_args()) params = vars(parser.parse_args())
# save notes file # save notes file
if params["notes"] != "": if params["notes"] != "":
tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"]) tu.write_notes_file(
params["experiment"] + "notes.txt", params["notes"]
)
# load the training and test meta data - there are different splits defined # load the training and test meta data - there are different splits defined
train_sets, test_sets = ts.get_train_test_data( train_sets, test_sets = ts.get_train_test_data(
@ -374,15 +389,11 @@ def main():
for tt in train_sets: for tt in train_sets:
print(tt["ann_path"]) print(tt["ann_path"])
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
( data_train = tu.load_set_of_anns(
data_train,
params["class_names"],
params["class_inv_freq"],
) = tu.load_set_of_anns(
train_sets, train_sets,
classes_to_ignore, classes_to_ignore=classes_to_ignore,
params["events_of_interest"], events_of_interest=params["events_of_interest"],
params["convert_to_genus"], convert_to_genus=params["convert_to_genus"],
) )
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping( params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
params["class_names"] params["class_names"]
@ -415,11 +426,12 @@ def main():
print("\nTesting on:") print("\nTesting on:")
for tt in test_sets: for tt in test_sets:
print(tt["ann_path"]) print(tt["ann_path"])
data_test, _, _ = tu.load_set_of_anns(
data_test = tu.load_set_of_anns(
test_sets, test_sets,
classes_to_ignore, classes_to_ignore=classes_to_ignore,
params["events_of_interest"], events_of_interest=params["events_of_interest"],
params["convert_to_genus"], convert_to_genus=params["convert_to_genus"],
) )
data_train = tu.remove_dupes(data_train, data_test) data_train = tu.remove_dupes(data_train, data_test)
test_dataset = adl.AudioLoader(data_test, params, is_train=False) test_dataset = adl.AudioLoader(data_test, params, is_train=False)
@ -447,10 +459,13 @@ def main():
scheduler = CosineAnnealingLR( scheduler = CosineAnnealingLR(
optimizer, params["num_epochs"] * len(train_loader) optimizer, params["num_epochs"] * len(train_loader)
) )
if params["train_loss"] == "mse": if params["train_loss"] == "mse":
det_criterion = losses.mse_loss det_criterion = losses.mse_loss
elif params["train_loss"] == "focal": elif params["train_loss"] == "focal":
det_criterion = losses.focal_loss det_criterion = losses.focal_loss
else:
raise ValueError("No valid loss specified")
# save parameters to file # save parameters to file
with open(params["experiment"] + "params.json", "w") as da: with open(params["experiment"] + "params.json", "w") as da:

View File

@ -1,28 +1,37 @@
import glob
import json import json
import os from collections import Counter
import random from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple
import numpy as np import numpy as np
from batdetect2 import types
def write_notes_file(file_name, text):
def write_notes_file(file_name: str, text: str):
with open(file_name, "a") as da: with open(file_name, "a") as da:
da.write(text + "\n") da.write(text + "\n")
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path): def get_blank_dataset_dict(
ddict = { dataset_name: str,
is_test: bool,
ann_path: str,
wav_path: str,
) -> types.DatasetDict:
return {
"dataset_name": dataset_name, "dataset_name": dataset_name,
"is_test": is_test, "is_test": is_test,
"is_binary": False, "is_binary": False,
"ann_path": ann_path, "ann_path": ann_path,
"wav_path": wav_path, "wav_path": wav_path,
} }
return ddict
def get_short_class_names(class_names, str_len=3): def get_short_class_names(
class_names: List[str],
str_len: int = 3,
) -> List[str]:
class_names_short = [] class_names_short = []
for cc in class_names: for cc in class_names:
class_names_short.append( class_names_short.append(
@ -31,7 +40,10 @@ def get_short_class_names(class_names, str_len=3):
return class_names_short return class_names_short
def remove_dupes(data_train, data_test): def remove_dupes(
data_train: List[types.FileAnnotation],
data_test: List[types.FileAnnotation],
) -> List[types.FileAnnotation]:
test_ids = [dd["id"] for dd in data_test] test_ids = [dd["id"] for dd in data_test]
data_train_prune = [] data_train_prune = []
for aa in data_train: for aa in data_train:
@ -43,14 +55,16 @@ def remove_dupes(data_train, data_test):
return data_train_prune return data_train_prune
def get_genus_mapping(class_names): def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
genus_names, genus_mapping = np.unique( genus_names, genus_mapping = np.unique(
[cc.split(" ")[0] for cc in class_names], return_inverse=True [cc.split(" ")[0] for cc in class_names], return_inverse=True
) )
return genus_names.tolist(), genus_mapping.tolist() return genus_names.tolist(), genus_mapping.tolist()
def standardize_low_freq(data, class_of_interest): def standardize_low_freq(
data: List[types.FileAnnotation], class_of_interest: str,
) -> List[types.FileAnnotation]:
# address the issue of highly variable low frequency annotations # address the issue of highly variable low frequency annotations
# this often happens for contstant frequency calls # this often happens for contstant frequency calls
# for the class of interest sets the low and high freq to be the dataset mean # for the class of interest sets the low and high freq to be the dataset mean
@ -62,8 +76,8 @@ def standardize_low_freq(data, class_of_interest):
low_freqs.append(aa["low_freq"]) low_freqs.append(aa["low_freq"])
high_freqs.append(aa["high_freq"]) high_freqs.append(aa["high_freq"])
low_mean = np.mean(low_freqs) low_mean = float(np.mean(low_freqs))
high_mean = np.mean(high_freqs) high_mean = float(np.mean(high_freqs))
assert low_mean < high_mean assert low_mean < high_mean
print("\nStandardizing low and high frequency for:") print("\nStandardizing low and high frequency for:")
@ -83,115 +97,148 @@ def standardize_low_freq(data, class_of_interest):
return data return data
def load_set_of_anns( def format_annotation(
data, annotation: types.FileAnnotation,
classes_to_ignore=[], events_of_interest: Optional[List[str]] = None,
events_of_interest=None, name_replace: Optional[Dict[str, str]] = None,
convert_to_genus=False, convert_to_genus: bool = False,
verbose=True, classes_to_ignore: Optional[List[str]] = None,
list_of_anns=False, ) -> types.FileAnnotation:
filter_issues=False, formated = []
name_replace=False, for aa in annotation["annotation"]:
): if (
events_of_interest is not None
and aa["event"] not in events_of_interest
):
# Omit files with annotation issues
continue
# remove leading and trailing spaces
class_name = aa["class"].strip()
if name_replace is not None:
# replace_names will be a dictionary mapping input name to output
class_name = name_replace.get(class_name, class_name)
if convert_to_genus:
# convert everything to genus name
class_name = class_name.split(" ")[0]
# NOTE: It is important to acknowledge that the class names filtering
# is done after the name replacement and the conversion to
# genus name. This allows filtering converted genus names and names
# that were replaced with a name that should be ignored.
if classes_to_ignore is not None and class_name in classes_to_ignore:
# Omit annotations with ignored classes
continue
formated.append(
{
**aa,
"class": class_name,
}
)
return {
**annotation,
"annotation": formated,
}
def get_class_names(
data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None,
) -> Tuple[Counter[str], List[float]]:
"""Extracts class names and their inverse frequencies.
Parameters
----------
data
A list of file annotations, where each annotation contains a list of
sound events with associated class names.
classes_to_ignore
A list of class names to ignore.
Returns:
--------
class_names
A list of unique class names extracted from the annotations.
class_inv_freq
List of inverse frequencies of each class name in the provided data.
"""
if classes_to_ignore is None:
classes_to_ignore = []
class_names_list: List[str] = []
for annotation in data:
for sound_event in annotation["annotation"]:
if sound_event["class"] in classes_to_ignore:
continue
class_names_list.append(sound_event["class"])
counts = Counter(class_names_list)
mean_counts = float(np.mean(list(counts.values())))
return counts, [mean_counts / counts[cc] for cc in class_names_list]
def report_class_counts(class_names: Counter[str]):
print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for index, (class_name, count) in enumerate(class_names.most_common()):
print(f"{index:<5}{class_name:<{str_len}}{count}")
def load_set_of_anns(
data: List[types.DatasetDict],
*,
convert_to_genus: bool = False,
filter_issues: bool = False,
events_of_interest: Optional[List[str]] = None,
classes_to_ignore: Optional[List[str]] = None,
name_replace: Optional[Dict[str, str]] = None,
) -> List[types.FileAnnotation]:
# load the annotations # load the annotations
anns = [] anns = []
if list_of_anns:
# path to list of individual json files
anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
else:
# dictionary of datasets
for dd in data:
anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
# discarding unannoated files # dictionary of datasets
anns = [aa for aa in anns if aa["annotated"] is True] for dataset in data:
for ann in load_anns(dataset["ann_path"], dataset["wav_path"]):
if not ann["annotated"]:
# Omit unannotated files
continue
# filter files that have annotation issues - is the input is a dictionary of if filter_issues and ann["issues"]:
# datasets, this will lilely have already been done # Omit files with annotation issues
if filter_issues: continue
anns = [aa for aa in anns if aa["issues"] is False]
# check for some basic formatting errors with class names anns.append(
for ann in anns: format_annotation(
for aa in ann["annotation"]: ann,
aa["class"] = aa["class"].strip() events_of_interest=events_of_interest,
name_replace=name_replace,
# only load specified events - i.e. types of calls convert_to_genus=convert_to_genus,
if events_of_interest is not None: classes_to_ignore=classes_to_ignore,
for ann in anns: )
filtered_events = []
for aa in ann["annotation"]:
if aa["event"] in events_of_interest:
filtered_events.append(aa)
ann["annotation"] = filtered_events
# change class names
# replace_names will be a dictionary mapping input name to output
if type(name_replace) is dict:
for ann in anns:
for aa in ann["annotation"]:
if aa["class"] in name_replace:
aa["class"] = name_replace[aa["class"]]
# convert everything to genus name
if convert_to_genus:
for ann in anns:
for aa in ann["annotation"]:
aa["class"] = aa["class"].split(" ")[0]
# get unique class names
class_names_all = []
for ann in anns:
for aa in ann["annotation"]:
if aa["class"] not in classes_to_ignore:
class_names_all.append(aa["class"])
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
class_inv_freq = class_cnts.sum() / (
len(class_names) * class_cnts.astype(np.float32)
)
if verbose:
print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for cc in range(len(class_names)):
print(
str(cc).ljust(5)
+ class_names[cc].ljust(str_len)
+ str(class_cnts[cc])
) )
if len(classes_to_ignore) == 0:
return anns
else:
return anns, class_names.tolist(), class_inv_freq.tolist()
def load_anns(ann_file_name, raw_audio_dir):
with open(ann_file_name) as da:
anns = json.load(da)
for aa in anns:
aa["file_path"] = raw_audio_dir + aa["id"]
return anns return anns
def load_anns_from_path(ann_file_dir, raw_audio_dir): def load_anns(
files = glob.glob(ann_file_dir + "*.json") ann_dir: str,
anns = [] raw_audio_dir: str,
for ff in files: ) -> Generator[types.FileAnnotation, None, None]:
with open(ff) as da: for path in Path(ann_dir).rglob("*.json"):
ann = json.load(da) with open(path) as fp:
ann["file_path"] = raw_audio_dir + ann["id"] file_annotation = json.load(fp)
anns.append(ann)
return anns file_annotation["file_path"] = raw_audio_dir + file_annotation["id"]
yield file_annotation
class AverageMeter(object): class AverageMeter:
"""Computes and stores the average and current value""" """Computes and stores the average and current value."""
def __init__(self): def __init__(self):
self.reset() self.reset()

View File

@ -1,5 +1,5 @@
"""Types used in the code base.""" """Types used in the code base."""
from typing import List, NamedTuple, Optional, Union from typing import Any, List, NamedTuple, Optional
import numpy as np import numpy as np
import torch import torch
@ -26,8 +26,7 @@ __all__ = [
"Annotation", "Annotation",
"DetectionModel", "DetectionModel",
"FeatureExtractionParameters", "FeatureExtractionParameters",
"FeatureExtractor", "FileAnnotation",
"FileAnnotations",
"ModelOutput", "ModelOutput",
"ModelParameters", "ModelParameters",
"NonMaximumSuppressionConfig", "NonMaximumSuppressionConfig",
@ -94,7 +93,10 @@ class ModelParameters(TypedDict):
"""Resize factor.""" """Resize factor."""
class_names: List[str] class_names: List[str]
"""Class names. The model is trained to detect these classes.""" """Class names.
The model is trained to detect these classes.
"""
DictWithClass = TypedDict("DictWithClass", {"class": str}) DictWithClass = TypedDict("DictWithClass", {"class": str})
@ -103,8 +105,8 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass): class Annotation(DictWithClass):
"""Format of annotations. """Format of annotations.
This is the format of a single annotation as expected by the annotation This is the format of a single annotation as expected by the
tool. annotation tool.
""" """
start_time: float start_time: float
@ -113,10 +115,10 @@ class Annotation(DictWithClass):
end_time: float end_time: float
"""End time in seconds.""" """End time in seconds."""
low_freq: int low_freq: float
"""Low frequency in Hz.""" """Low frequency in Hz."""
high_freq: int high_freq: float
"""High frequency in Hz.""" """High frequency in Hz."""
class_prob: float class_prob: float
@ -135,7 +137,7 @@ class Annotation(DictWithClass):
"""Numeric ID for the class of the annotation.""" """Numeric ID for the class of the annotation."""
class FileAnnotations(TypedDict): class FileAnnotation(TypedDict):
"""Format of results. """Format of results.
This is the format of the results expected by the annotation tool. This is the format of the results expected by the annotation tool.
@ -157,7 +159,7 @@ class FileAnnotations(TypedDict):
"""Time expansion factor.""" """Time expansion factor."""
class_name: str class_name: str
"""Class predicted at file level""" """Class predicted at file level."""
notes: str notes: str
"""Notes of file.""" """Notes of file."""
@ -169,7 +171,7 @@ class FileAnnotations(TypedDict):
class RunResults(TypedDict): class RunResults(TypedDict):
"""Run results.""" """Run results."""
pred_dict: FileAnnotations pred_dict: FileAnnotation
"""Predictions in the format expected by the annotation tool.""" """Predictions in the format expected by the annotation tool."""
spec_feats: NotRequired[List[np.ndarray]] spec_feats: NotRequired[List[np.ndarray]]
@ -394,9 +396,9 @@ class PredictionResults(TypedDict):
class DetectionModel(Protocol): class DetectionModel(Protocol):
"""Protocol for detection models. """Protocol for detection models.
This protocol is used to define the interface for the detection models. This protocol is used to define the interface for the detection
This allows us to use the same code for training and inference, even models. This allows us to use the same code for training and
though the models are different. inference, even though the models are different.
""" """
num_classes: int num_classes: int
@ -416,16 +418,14 @@ class DetectionModel(Protocol):
def forward( def forward(
self, self,
ip: torch.Tensor, spec: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput: ) -> ModelOutput:
"""Forward pass of the model.""" """Forward pass of the model."""
... ...
def __call__( def __call__(
self, self,
ip: torch.Tensor, spec: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput: ) -> ModelOutput:
"""Forward pass of the model.""" """Forward pass of the model."""
... ...
@ -490,8 +490,10 @@ class HeatmapParameters(TypedDict):
"""Maximum frequency to consider in Hz.""" """Maximum frequency to consider in Hz."""
target_sigma: float target_sigma: float
"""Sigma for the Gaussian kernel. Controls the width of the points in """Sigma for the Gaussian kernel.
the heatmap."""
Controls the width of the points in the heatmap.
"""
class AnnotationGroup(TypedDict): class AnnotationGroup(TypedDict):
@ -522,10 +524,10 @@ class AnnotationGroup(TypedDict):
annotated: NotRequired[bool] annotated: NotRequired[bool]
"""Wether the annotation group is complete or not. """Wether the annotation group is complete or not.
Usually annotation groups are associated to a single Usually annotation groups are associated to a single audio clip. If
audio clip. If the annotation group is complete, it means that all the annotation group is complete, it means that all relevant sound
relevant sound events have been annotated. If it is not complete, it events have been annotated. If it is not complete, it means that
means that some sound events might not have been annotated. some sound events might not have been annotated.
""" """
x_inds: NotRequired[np.ndarray] x_inds: NotRequired[np.ndarray]
@ -535,12 +537,88 @@ class AnnotationGroup(TypedDict):
"""Y coordinate of the annotations in the spectrogram.""" """Y coordinate of the annotations in the spectrogram."""
class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations): class AudioLoaderAnnotationGroup(TypedDict):
"""Group of annotation items for the training audio loader. """Group of annotation items for the training audio loader.
This class is used to store the annotations for the training audio This class is used to store the annotations for the training audio
loader. It inherits from `AnnotationGroup` and `FileAnnotations`. loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
""" """
id: str
duration: float
issues: bool
file_path: str
time_exp: float
class_name: str
notes: str
start_times: np.ndarray
end_times: np.ndarray
low_freqs: np.ndarray
high_freqs: np.ndarray
class_ids: np.ndarray
individual_ids: np.ndarray
x_inds: np.ndarray
y_inds: np.ndarray
annotation: List[Annotation]
annotated: bool
class_id_file: int class_id_file: int
"""ID of the class of the file.""" """ID of the class of the file."""
class AudioLoaderParameters(TypedDict):
class_names: List[str]
classes_to_ignore: List[str]
target_samp_rate: int
scale_raw_audio: bool
fft_win_length: float
fft_overlap: float
spec_train_width: int
resize_factor: float
spec_divide_factor: int
augment_at_train: bool
augment_at_train_combine: bool
aug_prob: float
spec_height: int
echo_max_delay: float
spec_amp_scaling: float
stretch_squeeze_delta: float
mask_max_time_perc: float
mask_max_freq_perc: float
max_freq: float
min_freq: float
spec_scale: str
denoise_spec_avg: bool
max_scale_spec: bool
target_sigma: float
class FeatureExtractor(Protocol):
def __call__(
self,
prediction: Prediction,
**kwargs: Any,
) -> float:
...
class DatasetDict(TypedDict):
"""Dataset dictionary.
This is the format of the dictionary that contains the dataset
information.
"""
dataset_name: str
"""Name of the dataset."""
is_test: bool
"""Whether the dataset is a test set."""
is_binary: bool
"""Whether the dataset is binary."""
ann_path: str
"""Path to the annotations."""
wav_path: str
"""Path to the audio files."""

View File

@ -1,13 +1,11 @@
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple, Union, overload
import librosa import librosa
import librosa.core.spectrum import librosa.core.spectrum
import numpy as np import numpy as np
import torch import torch
from . import wavfile
__all__ = [ __all__ = [
"load_audio", "load_audio",
"generate_spectrogram", "generate_spectrogram",
@ -15,113 +13,171 @@ __all__ = [
] ]
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): @overload
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor def time_to_x_coords(
time_in_file: np.ndarray,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> np.ndarray:
...
@overload
def time_to_x_coords(
time_in_file: float,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> float:
...
def time_to_x_coords(
time_in_file: Union[float, np.ndarray],
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> Union[float, np.ndarray]:
nfft = np.floor(fft_win_length * sampling_rate)
noverlap = np.floor(fft_overlap * nfft) noverlap = np.floor(fft_overlap * nfft)
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap) return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
# NOTE this is also defined in post_process # NOTE this is also defined in post_process
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): def x_coords_to_time(
x_pos: float,
sampling_rate: int,
fft_win_length: float,
fft_overlap: float,
) -> float:
nfft = np.floor(fft_win_length * sampling_rate) nfft = np.floor(fft_win_length * sampling_rate)
noverlap = np.floor(fft_overlap * nfft) noverlap = np.floor(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for
# center of temporal window
def generate_spectrogram( def generate_spectrogram(
audio, audio: np.ndarray,
sampling_rate, sampling_rate: float,
params, fft_win_length: float,
return_spec_for_viz=False, fft_overlap: float,
check_spec_size=True, max_freq: float,
): min_freq: float,
spec_scale: str,
denoise_spec_avg: bool = False,
max_scale_spec: bool = False,
) -> np.ndarray:
# generate spectrogram # generate spectrogram
spec = gen_mag_spectrogram( spec = gen_mag_spectrogram(
audio, audio,
sampling_rate, sampling_rate,
params["fft_win_length"], window_len=fft_win_length,
params["fft_overlap"], overlap_perc=fft_overlap,
)
spec = crop_spectrogram(
spec,
fft_win_length=fft_win_length,
max_freq=max_freq,
min_freq=min_freq,
)
spec = scale_spectrogram(
spec,
sampling_rate,
spec_scale=spec_scale,
fft_win_length=fft_win_length,
) )
if denoise_spec_avg:
spec = denoise_spectrogram(spec)
if max_scale_spec:
spec = max_scale_spectrogram(spec)
return spec
def crop_spectrogram(
spec: np.ndarray,
fft_win_length: float,
max_freq: float,
min_freq: float,
) -> np.ndarray:
# crop to min/max freq # crop to min/max freq
max_freq = round(params["max_freq"] * params["fft_win_length"]) max_freq = round(max_freq * fft_win_length)
min_freq = round(params["min_freq"] * params["fft_win_length"]) min_freq = round(min_freq * fft_win_length)
if spec.shape[0] < max_freq: if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0] freq_pad = max_freq - spec.shape[0]
spec = np.vstack( spec = np.vstack(
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec) (np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
) )
spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :] return spec[-max_freq : spec.shape[0] - min_freq, :]
if params["spec_scale"] == "log":
log_scaling = ( def denoise_spectrogram(spec: np.ndarray) -> np.ndarray:
2.0 spec = spec - np.mean(spec, 1)[:, np.newaxis]
* (1.0 / sampling_rate) return spec.clip(min=0)
* (
1.0
/ ( def max_scale_spectrogram(spec: np.ndarray) -> np.ndarray:
np.abs( return spec / (spec.max() + 10e-6)
np.hanning(
int(params["fft_win_length"] * sampling_rate)
) def log_scale(
) spec: np.ndarray,
** 2 sampling_rate: float,
).sum() fft_win_length: float,
) ) -> np.ndarray:
log_scaling = (
2.0
* (1.0 / sampling_rate)
* (
1.0
/ (
np.abs(np.hanning(int(fft_win_length * sampling_rate))) ** 2
).sum()
) )
# log_scaling = (1.0 / sampling_rate)*0.1 )
# log_scaling = (1.0 / sampling_rate)*10e4 return np.log1p(log_scaling * spec)
spec = np.log1p(log_scaling * spec_cropped)
elif params["spec_scale"] == "pcen":
spec = pcen(spec_cropped, sampling_rate)
elif params["spec_scale"] == "none":
pass
if params["denoise_spec_avg"]: def scale_spectrogram(
spec = spec - np.mean(spec, 1)[:, np.newaxis] spec: np.ndarray,
spec.clip(min=0, out=spec) sampling_rate: float,
spec_scale: str,
fft_win_length: float,
) -> np.ndarray:
if spec_scale == "log":
return log_scale(spec, sampling_rate, fft_win_length)
if params["max_scale_spec"]: if spec_scale == "pcen":
spec = spec / (spec.max() + 10e-6) return pcen(spec, sampling_rate)
# needs to be divisible by specific factor - if not it should have been padded return spec
# if check_spec_size:
# assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
# assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
def prepare_spec_for_viz(
spec: np.ndarray,
sampling_rate: int,
fft_win_length: float,
) -> np.ndarray:
# for visualization purposes - use log scaled spectrogram # for visualization purposes - use log scaled spectrogram
if return_spec_for_viz: return log_scale(
log_scaling = ( spec,
2.0 sampling_rate,
* (1.0 / sampling_rate) fft_win_length=fft_win_length,
* ( ).astype(np.float32)
1.0
/ (
np.abs(
np.hanning(
int(params["fft_win_length"] * sampling_rate)
)
)
** 2
).sum()
)
)
spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
else:
spec_for_viz = None
return spec, spec_for_viz
def load_audio( def load_audio(
audio_file: str, audio_file: str,
time_exp_fact: float, time_exp_fact: float,
target_samp_rate: int, target_sampling_rate: int,
scale: bool = False, scale: bool = False,
max_duration: Optional[float] = None, max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray]: ) -> Tuple[float, np.ndarray]:
"""Load an audio file and resample it to the target sampling rate. """Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration. The audio is also scaled to [-1, 1] and clipped to the maximum duration.
@ -152,63 +208,82 @@ def load_audio(
""" """
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning) audio, sampling_rate = librosa.load(
# sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(
audio_file, audio_file,
sr=None, sr=None,
dtype=np.float32, dtype=np.float32,
) )
if len(audio_raw.shape) > 1: if len(audio.shape) > 1:
raise ValueError("Currently does not handle stereo files") raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact sampling_rate = sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion # resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate audio = resample_audio(audio, sampling_rate, target_sampling_rate)
sampling_rate = target_samp_rate
if sampling_rate_old != sampling_rate:
audio_raw = librosa.resample(
audio_raw,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
# clipping maximum duration
if max_duration is not None: if max_duration is not None:
max_duration = int( audio = clip_audio(audio, target_sampling_rate, max_duration)
np.minimum(
int(sampling_rate * max_duration),
audio_raw.shape[0],
)
)
audio_raw = audio_raw[:max_duration]
# scale to [-1, 1] # scale to [-1, 1]
if scale: if scale:
audio_raw = audio_raw - audio_raw.mean() audio = scale_audio(audio)
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
return sampling_rate, audio_raw return target_sampling_rate, audio
def resample_audio(
audio: np.ndarray,
sr_orig: float,
sr_target: float,
) -> np.ndarray:
if sr_orig != sr_target:
return librosa.resample(
audio,
orig_sr=sr_orig,
target_sr=sr_target,
res_type="polyphase",
)
return audio
def clip_audio(
audio: np.ndarray,
sampling_rate: float,
max_duration: float,
) -> np.ndarray:
max_duration = int(
np.minimum(
int(sampling_rate * max_duration),
audio.shape[0],
)
)
return audio[:max_duration]
def scale_audio(
audio: np.ndarray,
eps: float = 10e-6,
) -> np.ndarray:
return (audio - audio.mean()) / (np.abs(audio).max() + eps)
def pad_audio( def pad_audio(
audio_raw, audio_raw: np.ndarray,
fs, sampling_rate: float,
ms, window_len: float,
overlap_perc, overlap_perc: float,
resize_factor, resize_factor: float,
divide_factor, divide_factor: float,
fixed_width=None, fixed_width: Optional[int] = None,
): ) -> np.ndarray:
# Adds zeros to the end of the raw data so that the generated sepctrogram # Adds zeros to the end of the raw data so that the generated sepctrogram
# will be evenly divisible by `divide_factor` # will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training # Also deals with very short audio clips and fixed_width during training
# This code could be clearer, clean up # This code could be clearer, clean up
nfft = int(ms * fs) nfft = int(window_len * sampling_rate)
noverlap = int(overlap_perc * nfft) noverlap = int(overlap_perc * nfft)
step = nfft - noverlap step = nfft - noverlap
min_size = int(divide_factor * (1.0 / resize_factor)) min_size = int(divide_factor * (1.0 / resize_factor))
@ -245,19 +320,24 @@ def pad_audio(
return audio_raw return audio_raw
def gen_mag_spectrogram(x, fs, ms, overlap_perc): def gen_mag_spectrogram(
audio: np.ndarray,
sampling_rate: float,
window_len: float,
overlap_perc: float,
) -> np.ndarray:
# Computes magnitude spectrogram by specifying time. # Computes magnitude spectrogram by specifying time.
audio = audio.astype(np.float32)
x = x.astype(np.float32) nfft = int(window_len * sampling_rate)
nfft = int(ms * fs)
noverlap = int(overlap_perc * nfft) noverlap = int(overlap_perc * nfft)
# window data
step = nfft - noverlap
# compute spec # compute spec
spec, _ = librosa.core.spectrum._spectrogram( spec, _ = librosa.core.spectrum._spectrogram(
y=x, power=1, n_fft=nfft, hop_length=step, center=False y=audio,
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
center=False,
) )
# remove DC component and flip vertical orientation # remove DC component and flip vertical orientation
@ -266,24 +346,25 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
return spec.astype(np.float32) return spec.astype(np.float32)
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc): def gen_mag_spectrogram_pt(
nfft = int(ms * fs) audio: torch.Tensor,
sampling_rate: float,
window_len: float,
overlap_perc: float,
) -> torch.Tensor:
nfft = int(window_len * sampling_rate)
nstep = round((1.0 - overlap_perc) * nfft) nstep = round((1.0 - overlap_perc) * nfft)
han_win = torch.hann_window(nfft, periodic=False).to(audio.device)
han_win = torch.hann_window(nfft, periodic=False).to(x.device) complex_spec = torch.stft(audio, nfft, nstep, window=han_win, center=False)
complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False)
spec = complex_spec.pow(2.0).sum(-1) spec = complex_spec.pow(2.0).sum(-1)
# remove DC component and flip vertically # remove DC component and flip vertically
spec = torch.flipud(spec[0, 1:, :]) return torch.flipud(spec[0, 1:, :])
return spec
def pcen(spec_cropped, sampling_rate): def pcen(spec: np.ndarray, sampling_rate: float) -> np.ndarray:
# TODO should be passing hop_length too i.e. step # TODO should be passing hop_length too i.e. step
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype( return librosa.pcen(spec * (2**31), sr=sampling_rate / 10).astype(
np.float32 np.float32
) )
return spec

View File

@ -16,7 +16,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ( from batdetect2.types import (
Annotation, Annotation,
DetectionModel, DetectionModel,
FileAnnotations, FileAnnotation,
ModelOutput, ModelOutput,
ModelParameters, ModelParameters,
PredictionResults, PredictionResults,
@ -79,7 +79,7 @@ def list_audio_files(ip_dir: str) -> List[str]:
def load_model( def load_model(
model_path: str = DEFAULT_MODEL_PATH, model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True, load_weights: bool = True,
device: Optional[torch.device] = None, device: Union[torch.device, str, None] = None,
) -> Tuple[DetectionModel, ModelParameters]: ) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file. """Load model from file.
@ -222,7 +222,7 @@ def format_single_result(
duration: float, duration: float,
predictions: PredictionResults, predictions: PredictionResults,
class_names: List[str], class_names: List[str],
) -> FileAnnotations: ) -> FileAnnotation:
"""Format results into the format expected by the annotation tool. """Format results into the format expected by the annotation tool.
Args: Args:
@ -399,11 +399,10 @@ def save_results_to_file(results, op_path: str) -> None:
def compute_spectrogram( def compute_spectrogram(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: float,
params: SpectrogramParameters, params: SpectrogramParameters,
device: torch.device, device: torch.device,
return_np: bool = False, ) -> Tuple[float, torch.Tensor]:
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
"""Compute a spectrogram from an audio array. """Compute a spectrogram from an audio array.
Will pad the audio array so that it is evenly divisible by the Will pad the audio array so that it is evenly divisible by the
@ -412,24 +411,16 @@ def compute_spectrogram(
Parameters Parameters
---------- ----------
audio : np.ndarray audio : np.ndarray
sampling_rate : int sampling_rate : int
params : SpectrogramParameters params : SpectrogramParameters
The parameters to use for generating the spectrogram. The parameters to use for generating the spectrogram.
return_np : bool, optional
Whether to return the spectrogram as a numpy array as well as a
torch tensor. The default is False.
Returns Returns
------- -------
duration : float duration : float
The duration of the spectrgram in seconds. The duration of the spectrgram in seconds.
spec : torch.Tensor spec : torch.Tensor
The spectrogram as a torch tensor. The spectrogram as a torch tensor.
spec_np : np.ndarray, optional spec_np : np.ndarray, optional
The spectrogram as a numpy array. Only returned if `return_np` is The spectrogram as a numpy array. Only returned if `return_np` is
True, otherwise None. True, otherwise None.
@ -446,7 +437,7 @@ def compute_spectrogram(
) )
# generate spectrogram # generate spectrogram
spec, _ = au.generate_spectrogram(audio, sampling_rate, params) spec = au.generate_spectrogram(audio, sampling_rate, params)
# convert to pytorch # convert to pytorch
spec = torch.from_numpy(spec).to(device) spec = torch.from_numpy(spec).to(device)
@ -466,18 +457,12 @@ def compute_spectrogram(
mode="bilinear", mode="bilinear",
align_corners=False, align_corners=False,
) )
return duration, spec
if return_np:
spec_np = spec[0, 0, :].cpu().data.numpy()
else:
spec_np = None
return duration, spec, spec_np
def iterate_over_chunks( def iterate_over_chunks(
audio: np.ndarray, audio: np.ndarray,
samplerate: int, samplerate: float,
chunk_size: float, chunk_size: float,
) -> Iterator[Tuple[float, np.ndarray]]: ) -> Iterator[Tuple[float, np.ndarray]]:
"""Iterate over audio in chunks of size chunk_size. """Iterate over audio in chunks of size chunk_size.
@ -510,7 +495,7 @@ def iterate_over_chunks(
def _process_spectrogram( def _process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samplerate: int, samplerate: float,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Tuple[PredictionResults, np.ndarray]: ) -> Tuple[PredictionResults, np.ndarray]:
@ -632,13 +617,13 @@ def process_spectrogram(
def _process_audio_array( def _process_audio_array(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: float,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]: ) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
# load audio file and compute spectrogram # load audio file and compute spectrogram
_, spec, _ = compute_spectrogram( _, spec = compute_spectrogram(
audio, audio,
sampling_rate, sampling_rate,
{ {
@ -654,7 +639,6 @@ def _process_audio_array(
"max_scale_spec": config["max_scale_spec"], "max_scale_spec": config["max_scale_spec"],
}, },
device, device,
return_np=False,
) )
# process spectrogram with model # process spectrogram with model
@ -754,13 +738,15 @@ def process_file(
# Get original sampling rate # Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file) file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1 orig_samp_rate = file_samp_rate * float(
config.get("time_expansion", 1.0) or 1.0
)
# load audio file # load audio file
sampling_rate, audio_full = au.load_audio( sampling_rate, audio_full = au.load_audio(
audio_file, audio_file,
time_exp_fact=config.get("time_expansion", 1) or 1, time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"], target_sampling_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"], scale=config["scale_raw_audio"],
max_duration=config.get("max_duration"), max_duration=config.get("max_duration"),
) )
@ -802,7 +788,6 @@ def process_file(
cnn_feats.append(features[0]) cnn_feats.append(features[0])
if config["spec_slices"]: if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms)) spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
# Merge results from chunks # Merge results from chunks

View File

@ -152,7 +152,7 @@ def test_compute_max_power_bb(max_power: int):
target_samp_rate=samplerate, target_samp_rate=samplerate,
) )
spec, _ = au.generate_spectrogram( spec = au.generate_spectrogram(
audio, audio,
samplerate, samplerate,
params, params,
@ -240,7 +240,7 @@ def test_compute_max_power():
target_samp_rate=samplerate, target_samp_rate=samplerate,
) )
spec, _ = au.generate_spectrogram( spec = au.generate_spectrogram(
audio, audio,
samplerate, samplerate,
params, params,