Added types to most functions

This commit is contained in:
Santiago Martinez 2024-01-14 17:15:22 +00:00
parent 458e11cf73
commit 0aa61af445
15 changed files with 1216 additions and 902 deletions

View File

@ -226,11 +226,10 @@ def generate_spectrogram(
if config is None:
config = DEFAULT_SPECTROGRAM_PARAMETERS
_, spec, _ = du.compute_spectrogram(
_, spec = du.compute_spectrogram(
audio,
samp_rate,
config,
return_np=False,
device=device,
)

View File

@ -1,5 +1,5 @@
"""Functions to compute features from predictions."""
from typing import Dict, Optional
from typing import Dict, List, Optional
import numpy as np
@ -7,15 +7,26 @@ from batdetect2 import types
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
def convert_int_to_freq(
spec_ind: int,
spec_height: int,
min_freq: float,
max_freq: float,
) -> int:
"""Convert spectrogram index to frequency in Hz.""" ""
spec_ind = spec_height - spec_ind
return round(
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
return int(
round(
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq,
2,
)
)
def extract_spec_slices(spec, pred_nms):
def extract_spec_slices(
spec: np.ndarray,
pred_nms: types.PredictionResults,
) -> List[np.ndarray]:
"""Extract spectrogram slices from spectrogram.
The slices are extracted based on detected call locations.
@ -109,7 +120,7 @@ def compute_max_power_bb(
return int(
convert_int_to_freq(
y_high + max_power_ind,
int(y_high + max_power_ind),
spec.shape[0],
min_freq,
max_freq,
@ -135,14 +146,12 @@ def compute_max_power(
spec_call = spec[:, x_start:x_end]
power_per_freq_band = np.sum(spec_call, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
return convert_int_to_freq(
int(max_power_ind),
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_max_power_first(
@ -164,14 +173,12 @@ def compute_max_power_first(
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
power_per_freq_band = np.sum(first_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
return convert_int_to_freq(
int(max_power_ind),
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_max_power_second(
@ -193,14 +200,12 @@ def compute_max_power_second(
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
power_per_freq_band = np.sum(second_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
return convert_int_to_freq(
int(max_power_ind),
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_call_interval(
@ -214,6 +219,7 @@ def compute_call_interval(
return round(prediction["start_time"] - previous["end_time"], 5)
# NOTE: The order of the features in this dictionary is important. The
# features are extracted in this order and the order of the columns in the
# output csv file is determined by this order. In order to avoid breaking
@ -236,7 +242,7 @@ def get_feats(
spec: np.ndarray,
pred_nms: types.PredictionResults,
params: types.FeatureExtractionParameters,
):
) -> np.ndarray:
"""Extract features from spectrogram based on detected call locations.
The features extracted are:

View File

@ -79,7 +79,13 @@ class ConvBlockDownCoordF(nn.Module):
class ConvBlockDownStandard(nn.Module):
def __init__(
self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1
self,
in_chn,
out_chn,
ip_height=None,
k_size=3,
pad_size=1,
stride=1,
):
super(ConvBlockDownStandard, self).__init__()
self.conv = nn.Conv2d(

View File

@ -103,15 +103,15 @@ class Net2DFast(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0
)
def forward(self, ip, return_feats=False) -> ModelOutput:
def forward(self, spec: torch.Tensor) -> ModelOutput:
# encoder
x1 = self.conv_dn_0(ip)
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
# bottleneck
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x)
x = x.repeat([1, 1, self.bneck_height * 4, 1])
@ -121,13 +121,13 @@ class Net2DFast(nn.Module):
x = self.conv_up_4(x + x1)
# output
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_size=F.relu(self.conv_size_op(x)),
pred_class=comb,
pred_class_un_norm=cls,
features=x,
@ -215,26 +215,26 @@ class Net2DFastNoAttn(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0
)
def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip)
def forward(self, spec: torch.Tensor) -> ModelOutput:
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = x.repeat([1, 1, self.bneck_height * 4, 1])
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_size=F.relu_(self.conv_size_op(x)),
pred_class=comb,
pred_class_un_norm=cls,
features=x,
@ -324,13 +324,13 @@ class Net2DFastNoCoordConv(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0
)
def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip)
def forward(self, spec: torch.Tensor) -> ModelOutput:
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x)
x = x.repeat([1, 1, self.bneck_height * 4, 1])
@ -338,15 +338,13 @@ class Net2DFastNoCoordConv(nn.Module):
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,)
return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_size=F.relu_(self.conv_size_op(x)),
pred_class=comb,
pred_class_un_norm=cls,
features=x,

View File

@ -1,6 +1,11 @@
import datetime
import os
from pathlib import Path
from typing import List, Optional, Union
from pydantic import BaseModel, Field, computed_field
from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
TARGET_SAMPLERATE_HZ = 256000
@ -75,103 +80,7 @@ def mk_dir(path):
os.makedirs(path)
def get_params(make_dirs=False, exps_dir="../../experiments/"):
params = {}
params[
"model_name"
] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
params["num_filters"] = 128
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
model_name = now_str + ".pth.tar"
params["experiment"] = os.path.join(exps_dir, now_str, "")
params["model_file_name"] = os.path.join(params["experiment"], model_name)
params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
params["op_im_dir_test"] = os.path.join(
params["experiment"], "op_ims_test", ""
)
# params['notes'] = '' # can save notes about an experiment here
# spec parameters
params[
"target_samp_rate"
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
params[
"fft_win_length"
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
params[
"max_freq"
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
params[
"min_freq"
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
params[
"resize_factor"
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
params[
"spec_height"
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
params[
"spec_train_width"
] = 512 # units are number of time steps (before resizing is performed)
params[
"spec_divide_factor"
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
# spec processing params
params[
"denoise_spec_avg"
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
params[
"scale_raw_audio"
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
params[
"max_scale_spec"
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
# detection params
params[
"detection_overlap"
] = 0.01 # has to be within this number of ms to count as detection
params[
"ignore_start_end"
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
params[
"detection_threshold"
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
params[
"nms_kernel_size"
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
params[
"nms_top_k_per_sec"
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
params["target_sigma"] = 2.0
# augmentation params
params[
"aug_prob"
] = 0.20 # augmentations will be performed with this probability
params["augment_at_train"] = True
params["augment_at_train_combine"] = True
params[
"echo_max_delay"
] = 0.005 # simulate echo by adding copy of raw audio
params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
params[
"mask_max_time_perc"
] = 0.05 # max mask size - here percentage, not ideal
params[
"mask_max_freq_perc"
] = 0.10 # max mask size - here percentage, not ideal
params[
"spec_amp_scaling"
] = 2.0 # multiply the "volume" by 0:X times current amount
params["aug_sampling_rates"] = [
AUG_SAMPLING_RATES = [
220500,
256000,
300000,
@ -180,53 +89,145 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
441000,
500000,
]
CLASSES_TO_IGNORE = ["", " ", "Unknown", "Not Bat"]
GENERIC_CLASSES = ["Bat"]
EVENTS_OF_INTEREST = ["Echolocation"]
# loss params
params["train_loss"] = "focal" # mse or focal
params["det_loss_weight"] = 1.0 # weight for the detection part of the loss
params["size_loss_weight"] = 0.1 # weight for the bbox size loss
params["class_loss_weight"] = 2.0 # weight for the classification loss
params["individual_loss_weight"] = 0.0 # not used
if params["individual_loss_weight"] == 0.0:
params[
"emb_dim"
] = 0 # number of dimensions used for individual id embedding
else:
params["emb_dim"] = 3
# train params
params["lr"] = 0.001
params["batch_size"] = 8
params["num_workers"] = 4
params["num_epochs"] = 200
params["num_eval_epochs"] = 5 # run evaluation every X epochs
params["device"] = "cuda"
params["save_test_image_during_train"] = False
params["save_test_image_after_train"] = True
class TrainingParameters(BaseModel):
# Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
model_name: str = "Net2DFast"
num_filters: int = 128
params["convert_to_genus"] = False
params["genus_mapping"] = []
params["class_names"] = []
params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
params["generic_class"] = ["Bat"]
params["events_of_interest"] = [
"Echolocation"
] # will ignore all other types of events e.g. social calls
experiment: Path
model_file_name: Path
# the classes in this list are standardized during training so that the same low and high freq are used
params["standardize_classs_names"] = []
op_im_dir: Path
op_im_dir_test: Path
notes: str = ""
target_samp_rate: int = TARGET_SAMPLERATE_HZ
fft_win_length: float = FFT_WIN_LENGTH_S
fft_overlap: float = FFT_OVERLAP
max_freq: int = MAX_FREQ_HZ
min_freq: int = MIN_FREQ_HZ
resize_factor: float = RESIZE_FACTOR
spec_height: int = SPEC_HEIGHT
spec_train_width: int = 512
spec_divide_factor: int = SPEC_DIVIDE_FACTOR
denoise_spec_avg: bool = DENOISE_SPEC_AVG
scale_raw_audio: bool = SCALE_RAW_AUDIO
max_scale_spec: bool = MAX_SCALE_SPEC
spec_scale: str = SPEC_SCALE
detection_overlap: float = 0.01
ignore_start_end: float = 0.01
detection_threshold: float = DETECTION_THRESHOLD
nms_kernel_size: int = NMS_KERNEL_SIZE
nms_top_k_per_sec: int = NMS_TOP_K_PER_SEC
aug_prob: float = 0.20
augment_at_train: bool = True
augment_at_train_combine: bool = True
echo_max_delay: float = 0.005
stretch_squeeze_delta: float = 0.04
mask_max_time_perc: float = 0.05
mask_max_freq_perc: float = 0.10
spec_amp_scaling: float = 2.0
aug_sampling_rates: List[int] = AUG_SAMPLING_RATES
train_loss: str = "focal"
det_loss_weight: float = 1.0
size_loss_weight: float = 0.1
class_loss_weight: float = 2.0
individual_loss_weight: float = 0.0
lr: float = 0.001
batch_size: int = 8
num_workers: int = 4
num_epochs: int = 200
num_eval_epochs: int = 5
device: str = "cuda"
save_test_image_during_train: bool = False
save_test_image_after_train: bool = True
convert_to_genus: bool = False
class_names: List[str] = Field(default_factory=list)
classes_to_ignore: List[str] = Field(
default_factory=lambda: CLASSES_TO_IGNORE
)
generic_class: List[str] = Field(default_factory=lambda: GENERIC_CLASSES)
events_of_interest: List[str] = Field(
default_factory=lambda: EVENTS_OF_INTEREST
)
standardize_classs_names: List[str] = Field(default_factory=list)
@computed_field
@property
def emb_dim(self) -> int:
if self.individual_loss_weight == 0.0:
return 0
return 3
@computed_field
@property
def genus_mapping(self) -> List[int]:
_, mapping = get_genus_mapping(self.class_names)
return mapping
@computed_field
@property
def genus_classes(self) -> List[str]:
names, _ = get_genus_mapping(self.class_names)
return names
@computed_field
@property
def class_names_short(self) -> List[str]:
return get_short_class_names(self.class_names)
def get_params(
make_dirs: bool = False,
exps_dir: str = "../../experiments/",
model_name: Optional[str] = None,
experiment: Union[Path, str, None] = None,
**kwargs,
) -> TrainingParameters:
experiments_dir = Path(exps_dir)
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
if model_name is None:
model_name = f"{now_str}.pth.tar"
if experiment is None:
experiment = experiments_dir / now_str
experiment = Path(experiment)
model_file_name = experiment / model_name
op_ims_dir = experiment / "op_ims"
op_ims_test_dir = experiment / "op_ims_test"
params = TrainingParameters(
model_name=model_name,
experiment=experiment,
model_file_name=model_file_name,
op_im_dir=op_ims_dir,
op_im_dir_test=op_ims_test_dir,
**kwargs,
)
# create directories
if make_dirs:
print("Model name : " + params["model_name"])
print("Model file : " + params["model_file_name"])
print("Experiment : " + params["experiment"])
mk_dir(params["experiment"])
if params["save_test_image_during_train"]:
mk_dir(params["op_im_dir"])
if params["save_test_image_after_train"]:
mk_dir(params["op_im_dir_test"])
mk_dir(os.path.dirname(params["model_file_name"]))
mk_dir(experiment)
mk_dir(params.model_file_name.parent)
if params.save_test_image_during_train:
mk_dir(params.op_im_dir)
if params.save_test_image_after_train:
mk_dir(params.op_im_dir_test)
return params

View File

@ -1,33 +1,31 @@
import argparse
import glob
import json
import os
import sys
import warnings
from typing import List, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.models as models
import batdetect2.detector.parameters as parameters
import batdetect2.detector.post_process as pp
import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl
import batdetect2.train.losses as losses
import batdetect2.train.train_model as tm
import batdetect2.train.train_utils as tu
import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu
from batdetect2 import types
from batdetect2.detector.models import Net2DFast
if __name__ == "__main__":
info_str = "\nBatDetect - Finetune Model\n"
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
print(info_str)
def parse_arugments():
parser = argparse.ArgumentParser()
parser.add_argument(
"audio_path", type=str, help="Input directory for audio"
"audio_path",
type=str,
help="Input directory for audio",
)
parser.add_argument(
"train_ann_path",
@ -39,7 +37,15 @@ if __name__ == "__main__":
type=str,
help="Path to where test annotation file is stored",
)
parser.add_argument("model_path", type=str, help="Path to pretrained model")
parser.add_argument(
"model_path", type=str, help="Path to pretrained model"
)
parser.add_argument(
"--experiment_dir",
type=str,
default=os.path.join(BASE_DIR, "experiments"),
help="Path to where experiment files are stored",
)
parser.add_argument(
"--op_model_name",
type=str,
@ -71,107 +77,63 @@ if __name__ == "__main__":
parser.add_argument(
"--notes", type=str, default="", help="Notes to save in text file"
)
args = vars(parser.parse_args())
args = parser.parse_args()
return args
params = parameters.get_params(True, "../../experiments/")
def select_device(warn=True) -> str:
if torch.cuda.is_available():
params["device"] = "cuda"
else:
params["device"] = "cpu"
print(
"\nNote, this will be a lot faster if you use computer with a GPU.\n"
return "cuda"
if warn:
warnings.warn(
"No GPU available, using the CPU instead. Please consider using a GPU "
"to speed up training."
)
print("\nAudio directory: " + args["audio_path"])
print("Train file: " + args["train_ann_path"])
print("Test file: " + args["test_ann_path"])
print("Loading model: " + args["model_path"])
return "cpu"
dataset_name = (
os.path.basename(args["train_ann_path"])
.replace(".json", "")
.replace("_TRAIN", "")
)
if args["train_from_scratch"]:
print("\nTraining model from scratch i.e. not using pretrained weights")
model, params_train = du.load_model(args["model_path"], False)
else:
model, params_train = du.load_model(args["model_path"], True)
model.to(params["device"])
params["num_epochs"] = args["num_epochs"]
if args["op_model_name"] != "":
params["model_file_name"] = args["op_model_name"]
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
# save notes file
params["notes"] = args["notes"]
if args["notes"] != "":
tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
# load train annotations
train_sets = []
def load_annotations(
dataset_name: str,
ann_path: str,
audio_path: str,
classes_to_ignore: Optional[List[str]] = None,
events_of_interest: Optional[List[str]] = None,
) -> List[types.FileAnnotation]:
train_sets: List[types.DatasetDict] = []
train_sets.append(
tu.get_blank_dataset_dict(
dataset_name, False, args["train_ann_path"], args["audio_path"]
)
)
params["train_sets"] = [
tu.get_blank_dataset_dict(
dataset_name,
False,
os.path.basename(args["train_ann_path"]),
args["audio_path"],
is_test=False,
ann_path=ann_path,
wav_path=audio_path,
)
]
print("\nTrain set:")
(
data_train,
params["class_names"],
params["class_inv_freq"],
) = tu.load_set_of_anns(
train_sets, classes_to_ignore, params["events_of_interest"]
)
print("Number of files", len(data_train))
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
params["class_names"]
)
params["class_names_short"] = tu.get_short_class_names(
params["class_names"]
)
# load test annotations
test_sets = []
test_sets.append(
tu.get_blank_dataset_dict(
dataset_name, True, args["test_ann_path"], args["audio_path"]
return tu.load_set_of_anns(
train_sets,
events_of_interest=events_of_interest,
classes_to_ignore=classes_to_ignore,
)
)
params["test_sets"] = [
tu.get_blank_dataset_dict(
dataset_name,
True,
os.path.basename(args["test_ann_path"]),
args["audio_path"],
)
]
print("\nTest set:")
data_test, _, _ = tu.load_set_of_anns(
test_sets, classes_to_ignore, params["events_of_interest"]
)
print("Number of files", len(data_test))
def finetune_model(
model: types.DetectionModel,
data_train: List[types.FileAnnotation],
data_test: List[types.FileAnnotation],
params: parameters.TrainingParameters,
model_params: types.ModelParameters,
finetune_only_last_layer: bool = False,
save_images: bool = True,
):
# train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=params["batch_size"],
batch_size=params.batch_size,
shuffle=True,
num_workers=params["num_workers"],
num_workers=params.num_workers,
pin_memory=True,
)
@ -181,32 +143,36 @@ if __name__ == "__main__":
test_dataset,
batch_size=1,
shuffle=False,
num_workers=params["num_workers"],
num_workers=params.num_workers,
pin_memory=True,
)
inputs_train = next(iter(train_loader))
params["ip_height"] = inputs_train["spec"].shape[2]
params.ip_height = inputs_train["spec"].shape[2]
print("\ntrain batch size :", inputs_train["spec"].shape)
assert params_train["model_name"] == "Net2DFast"
# Check that the model is the same as the one used to train the pretrained
# weights
assert model_params["model_name"] == "Net2DFast"
assert isinstance(model, Net2DFast)
print(
"\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
"\n\nSOME hyperparams need to be the same as the loaded model "
"(e.g. FFT) - currently they are getting overwritten.\n\n"
)
# set the number of output classes
num_filts = model.conv_classes_op.in_channels
k_size = model.conv_classes_op.kernel_size
pad = model.conv_classes_op.padding
(k_size,) = model.conv_classes_op.kernel_size
(pad,) = model.conv_classes_op.padding
model.conv_classes_op = torch.nn.Conv2d(
num_filts,
len(params["class_names"]) + 1,
len(params.class_names) + 1,
kernel_size=k_size,
padding=pad,
)
model.conv_classes_op.to(params["device"])
model.conv_classes_op.to(params.device)
if args["finetune_only_last_layer"]:
if finetune_only_last_layer:
print("\nOnly finetuning the final layers.\n")
train_layers_i = [
"conv_classes",
@ -223,19 +189,26 @@ if __name__ == "__main__":
else:
param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
scheduler = CosineAnnealingLR(
optimizer, params["num_epochs"] * len(train_loader)
optimizer = torch.optim.Adam(
model.parameters(),
lr=params.lr,
)
if params["train_loss"] == "mse":
scheduler = CosineAnnealingLR(
optimizer,
params.num_epochs * len(train_loader),
)
if params.train_loss == "mse":
det_criterion = losses.mse_loss
elif params["train_loss"] == "focal":
elif params.train_loss == "focal":
det_criterion = losses.focal_loss
else:
raise ValueError("Unknown loss function")
# plotting
train_plt_ls = pu.LossPlotter(
params["experiment"] + "train_loss.png",
params["num_epochs"] + 1,
params.experiment / "train_loss.png",
params.num_epochs + 1,
["train_loss"],
None,
None,
@ -243,8 +216,8 @@ if __name__ == "__main__":
logy=True,
)
test_plt_ls = pu.LossPlotter(
params["experiment"] + "test_loss.png",
params["num_epochs"] + 1,
params.experiment / "test_loss.png",
params.num_epochs + 1,
["test_loss"],
None,
None,
@ -252,24 +225,24 @@ if __name__ == "__main__":
logy=True,
)
test_plt = pu.LossPlotter(
params["experiment"] + "test.png",
params["num_epochs"] + 1,
params.experiment / "test.png",
params.num_epochs + 1,
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
[0, 1],
None,
["epoch", ""],
)
test_plt_class = pu.LossPlotter(
params["experiment"] + "test_avg_prec.png",
params["num_epochs"] + 1,
params["class_names_short"],
params.experiment / "test_avg_prec.png",
params.num_epochs + 1,
params.class_names_short,
[0, 1],
params["class_names_short"],
params.class_names_short,
["epoch", "avg_prec"],
)
# main train loop
for epoch in range(0, params["num_epochs"] + 1):
for epoch in range(0, params.num_epochs + 1):
train_loss = tm.train(
model,
epoch,
@ -281,10 +254,14 @@ if __name__ == "__main__":
)
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
if epoch % params["num_eval_epochs"] == 0:
if epoch % params.num_eval_epochs == 0:
# detection accuracy on test set
test_res, test_loss = tm.test(
model, epoch, test_loader, det_criterion, params
model,
epoch,
test_loader,
det_criterion,
params,
)
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
test_plt.update_and_save(
@ -301,18 +278,106 @@ if __name__ == "__main__":
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
)
pu.plot_pr_curve_class(
params["experiment"], "test_pr", "test_pr", test_res
params.experiment, "test_pr", "test_pr", test_res
)
# save finetuned model
print("saving model to: " + params["model_file_name"])
print(f"saving model to: {params.model_file_name}")
op_state = {
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"params": params,
}
torch.save(op_state, params["model_file_name"])
torch.save(op_state, params.model_file_name)
# save an image with associated prediction for each batch in the test set
if not args["do_not_save_images"]:
if save_images:
tm.save_images_batch(model, test_loader, params)
def main():
info_str = "\nBatDetect - Finetune Model\n"
print(info_str)
args = parse_arugments()
# Load experiment parameters
params = parameters.get_params(
make_dirs=True,
exps_dir=args.experiment_dir,
device=select_device(),
num_epochs=args.num_epochs,
notes=args.notes,
)
print("\nAudio directory: " + args.audio_path)
print("Train file: " + args.train_ann_path)
print("Test file: " + args.test_ann_path)
print("Loading model: " + args.model_path)
if args.train_from_scratch:
print(
"\nTraining model from scratch i.e. not using pretrained weights"
)
model, model_params = du.load_model(
args.model_path,
load_weights=not args.train_from_scratch,
device=params.device,
)
if args.op_model_name != "":
params.model_file_name = args.op_model_name
classes_to_ignore = params.classes_to_ignore + params.generic_class
# save notes file
if params.notes:
tu.write_notes_file(
params.experiment / "notes.txt",
args.notes,
)
# NOTE:??
dataset_name = (
os.path.basename(args.train_ann_path)
.replace(".json", "")
.replace("_TRAIN", "")
)
# ==== LOAD DATA ====
# load train annotations
data_train = load_annotations(
dataset_name,
args.train_ann_path,
args.audio_path,
params.events_of_interest,
)
print("\nTrain set:")
print("Number of files", len(data_train))
# load test annotations
data_test = load_annotations(
dataset_name,
args.test_ann_path,
args.audio_path,
classes_to_ignore,
params.events_of_interest,
)
print("\nTrain set:")
print("Number of files", len(data_train))
finetune_model(
model,
data_train,
data_test,
params,
model_params,
finetune_only_last_layer=args.finetune_only_last_layer,
save_images=args.do_not_save_images,
)
if __name__ == "__main__":
main()

View File

@ -1,62 +1,54 @@
import argparse
import json
import os
from collections import Counter
from typing import List, Optional, Tuple
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
import batdetect2.train.train_utils as tu
from batdetect2 import types
def print_dataset_stats(data, split_name, classes_to_ignore):
print("\nSplit:", split_name)
def print_dataset_stats(
data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None,
) -> Counter[str]:
print("Num files:", len(data))
class_cnts = {}
for dd in data:
for aa in dd["annotation"]:
if aa["class"] not in classes_to_ignore:
if aa["class"] in class_cnts:
class_cnts[aa["class"]] += 1
else:
class_cnts[aa["class"]] = 1
if len(class_cnts) == 0:
class_names = []
else:
class_names = np.sort([*class_cnts]).tolist()
print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for ii, cc in enumerate(class_names):
print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc]))
return class_names
counts, _ = tu.get_class_names(data, classes_to_ignore)
if len(counts) > 0:
tu.report_class_counts(counts)
return counts
def load_file_names(file_name):
if os.path.isfile(file_name):
def load_file_names(file_name: str) -> List[str]:
if not os.path.isfile(file_name):
raise FileNotFoundError(f"Input file not found - {file_name}")
with open(file_name) as da:
files = [line.rstrip() for line in da.readlines()]
for ff in files:
if ff.lower()[-3:] != "wav":
print("Error: Filenames need to end in .wav - ", ff)
assert False
else:
print("Error: Input file not found - ", file_name)
assert False
for path in files:
if path.lower()[-3:] != "wav":
raise ValueError(
f"Invalid file name - {path}. Must be a .wav file"
)
return files
if __name__ == "__main__":
def parse_args():
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
print(info_str)
parser = argparse.ArgumentParser()
parser.add_argument(
"dataset_name", type=str, help="Name to call your dataset"
)
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
parser.add_argument(
"audio_dir", type=str, help="Input directory for audio"
)
parser.add_argument(
"ann_dir",
type=str,
@ -104,86 +96,124 @@ if __name__ == "__main__":
help='New class names to use instead. One to one mapping with "--input_class_names". \
Separate with ";"',
)
args = vars(parser.parse_args())
return parser.parse_args()
np.random.seed(args["rand_seed"])
def split_data(
data: List[types.FileAnnotation],
train_file: str,
test_file: str,
n_splits: int = 5,
random_state: int = 0,
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
if train_file != "" and test_file != "":
# user has specifed the train / test split
mapping = {
file_annotation["id"]: file_annotation for file_annotation in data
}
train_files = load_file_names(train_file)
test_files = load_file_names(test_file)
data_train = [
mapping[file_id] for file_id in train_files if file_id in mapping
]
data_test = [
mapping[file_id] for file_id in test_files if file_id in mapping
]
return data_train, data_test
# NOTE: Using StratifiedGroupKFold to ensure that the same file does not
# appear in both the training and test sets and trying to keep the
# distribution of classes the same in both sets.
splitter = StratifiedGroupKFold(
n_splits=n_splits,
shuffle=True,
random_state=random_state,
)
anns = np.array(
[
[dd["id"], ann["class"], ann["event"]]
for dd in data
for ann in dd["annotation"]
]
)
y = anns[:, 1]
group = anns[:, 0]
train_idx, test_idx = next(splitter.split(X=anns, y=y, groups=group))
train_ids = set(anns[train_idx, 0])
test_ids = set(anns[test_idx, 0])
assert not (train_ids & test_ids)
data_train = [dd for dd in data if dd["id"] in train_ids]
data_test = [dd for dd in data if dd["id"] in test_ids]
return data_train, data_test
def main():
args = parse_args()
np.random.seed(args.rand_seed)
classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
generic_class = ["Bat"]
events_of_interest = ["Echolocation"]
if args["input_class_names"] != "" and args["output_class_names"] != "":
name_dict = None
if args.input_class_names != "" and args.output_class_names != "":
# change the names of the classes
ip_names = args["input_class_names"].split(";")
op_names = args["output_class_names"].split(";")
ip_names = args.input_class_names.split(";")
op_names = args.output_class_names.split(";")
name_dict = dict(zip(ip_names, op_names))
else:
name_dict = False
# load annotations
data_all, _, _ = tu.load_set_of_anns(
{"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]},
classes_to_ignore,
events_of_interest,
False,
False,
list_of_anns=True,
data_all = tu.load_set_of_anns(
[
{
"dataset_name": args.dataset_name,
"ann_path": args.ann_dir,
"wav_path": args.audio_dir,
"is_test": False,
"is_binary": False,
}
],
classes_to_ignore=classes_to_ignore,
events_of_interest=events_of_interest,
convert_to_genus=False,
filter_issues=True,
name_replace=name_dict,
)
print("Dataset name: " + args["dataset_name"])
print("Audio directory: " + args["audio_dir"])
print("Annotation directory: " + args["ann_dir"])
print("Ouput directory: " + args["op_dir"])
print("Dataset name: " + args.dataset_name)
print("Audio directory: " + args.audio_dir)
print("Annotation directory: " + args.ann_dir)
print("Ouput directory: " + args.op_dir)
print("Num annotated files: " + str(len(data_all)))
if args["train_file"] != "" and args["test_file"] != "":
# user has specifed the train / test split
train_files = load_file_names(args["train_file"])
test_files = load_file_names(args["test_file"])
file_names_all = [dd["id"] for dd in data_all]
train_inds = [
file_names_all.index(ff)
for ff in train_files
if ff in file_names_all
]
test_inds = [
file_names_all.index(ff)
for ff in test_files
if ff in file_names_all
]
else:
# split the data into train and test at the file level
num_exs = len(data_all)
test_inds = np.random.choice(
np.arange(num_exs),
int(num_exs * args["percent_val"]),
replace=False,
data_train, data_test = split_data(
data=data_all,
train_file=args.train_file,
test_file=args.test_file,
n_splits=5,
random_state=args.rand_seed,
)
test_inds = np.sort(test_inds)
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
data_train = [data_all[ii] for ii in train_inds]
data_test = [data_all[ii] for ii in test_inds]
if not os.path.isdir(args["op_dir"]):
os.makedirs(args["op_dir"])
op_name = os.path.join(args["op_dir"], args["dataset_name"])
if not os.path.isdir(args.op_dir):
os.makedirs(args.op_dir)
op_name = os.path.join(args.op_dir, args.dataset_name)
op_name_train = op_name + "_TRAIN.json"
op_name_test = op_name + "_TEST.json"
class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore)
class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore)
print("\nSplit: Train")
class_un_train = print_dataset_stats(data_train, classes_to_ignore)
print("\nSplit: Test")
class_un_test = print_dataset_stats(data_test, classes_to_ignore)
if len(data_train) > 0 and len(data_test) > 0:
if class_un_train != class_un_test:
print(
'\nError: some classes are not in both the training and test sets.\
\nTry a different random seed "--rand_seed".'
if set(class_un_train.keys()) != set(class_un_test.keys()):
raise RuntimeError(
"Error: some classes are not in both the training and test sets."
'Try a different random seed "--rand_seed".'
)
assert False
print("\n")
if len(data_train) == 0:
@ -199,3 +229,7 @@ if __name__ == "__main__":
print("Saving: ", op_name_test)
with open(op_name_test, "w") as da:
json.dump(data_test, da, indent=2)
if __name__ == "__main__":
main()

View File

@ -12,19 +12,24 @@ import torchaudio
import batdetect2.utils.audio_utils as au
from batdetect2.types import (
Annotation,
AnnotationGroup,
AudioLoaderAnnotationGroup,
FileAnnotations,
HeatmapParameters,
AudioLoaderParameters,
FileAnnotation,
)
def generate_gt_heatmaps(
spec_op_shape: Tuple[int, int],
sampling_rate: int,
ann: AnnotationGroup,
params: HeatmapParameters,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
sampling_rate: float,
ann: AudioLoaderAnnotationGroup,
class_names: List[str],
fft_win_length: float,
fft_overlap: float,
max_freq: float,
min_freq: float,
resize_factor: float,
target_sigma: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]:
"""Generate ground truth heatmaps from annotations.
Parameters
@ -53,31 +58,31 @@ def generate_gt_heatmaps(
the x and y indices of their pixel location in the input spectrogram.
"""
# spec may be resized on input into the network
num_classes = len(params["class_names"])
num_classes = len(class_names)
op_height = spec_op_shape[0]
op_width = spec_op_shape[1]
freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height
freq_per_bin = (max_freq - min_freq) / op_height
# start and end times
x_pos_start = au.time_to_x_coords(
ann["start_times"],
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
fft_win_length,
fft_overlap,
)
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32)
x_pos_start = (resize_factor * x_pos_start).astype(np.int32)
x_pos_end = au.time_to_x_coords(
ann["end_times"],
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
fft_win_length,
fft_overlap,
)
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32)
x_pos_end = (resize_factor * x_pos_end).astype(np.int32)
# location on y axis i.e. frequency
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin
y_pos_low = (op_height - y_pos_low).astype(np.int32)
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin
y_pos_high = (op_height - y_pos_high).astype(np.int32)
bb_widths = x_pos_end - x_pos_start
bb_heights = y_pos_low - y_pos_high
@ -90,26 +95,17 @@ def generate_gt_heatmaps(
& (y_pos_low < (op_height - 1))
)[0]
ann_aug: AnnotationGroup = {
ann_aug: AudioLoaderAnnotationGroup = {
**ann,
"start_times": ann["start_times"][valid_inds],
"end_times": ann["end_times"][valid_inds],
"high_freqs": ann["high_freqs"][valid_inds],
"low_freqs": ann["low_freqs"][valid_inds],
"class_ids": ann["class_ids"][valid_inds],
"individual_ids": ann["individual_ids"][valid_inds],
"x_inds": x_pos_start[valid_inds],
"y_inds": y_pos_low[valid_inds],
}
ann_aug["x_inds"] = x_pos_start[valid_inds]
ann_aug["y_inds"] = y_pos_low[valid_inds]
# keys = [
# "start_times",
# "end_times",
# "high_freqs",
# "low_freqs",
# "class_ids",
# "individual_ids",
# ]
# for kk in keys:
# ann_aug[kk] = ann[kk][valid_inds]
# if the number of calls is only 1, then it is unique
# TODO would be better if we found these unique calls at the merging stage
@ -118,6 +114,7 @@ def generate_gt_heatmaps(
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
# num classes and "background" class
y_2d_classes: np.ndarray = np.zeros(
(num_classes + 1, op_height, op_width), dtype=np.float32
@ -128,14 +125,8 @@ def generate_gt_heatmaps(
draw_gaussian(
y_2d_det[0, :],
(x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"],
target_sigma,
)
# draw_gaussian(
# y_2d_det[0, :],
# (x_pos_start[ii], y_pos_low[ii]),
# params["target_sigma"],
# params["target_sigma"] * 2,
# )
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
@ -144,14 +135,8 @@ def generate_gt_heatmaps(
draw_gaussian(
y_2d_classes[cls_id, :],
(x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"],
target_sigma,
)
# draw_gaussian(
# y_2d_classes[cls_id, :],
# (x_pos_start[ii], y_pos_low[ii]),
# params["target_sigma"],
# params["target_sigma"] * 2,
# )
# be careful as this will have a 1.0 places where we have event but
# dont know gt class this will be masked in training anyway
@ -235,8 +220,8 @@ def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray:
def warp_spec_aug(
spec: torch.Tensor,
ann: AnnotationGroup,
params: dict,
ann: AudioLoaderAnnotationGroup,
stretch_squeeze_delta: float,
) -> torch.Tensor:
"""Warp spectrogram by randomly stretching and squeezing.
@ -247,8 +232,8 @@ def warp_spec_aug(
ann: AnnotationGroup
Annotation group for the spectrogram. Must be provided to sync
the start and stop times with the spectrogram after warping.
params: dict
Parameters for the augmentation.
stretch_squeeze_delta: float
Maximum amount to stretch or squeeze the spectrogram.
Returns
-------
@ -259,11 +244,10 @@ def warp_spec_aug(
-----
This function modifies the annotation group in place.
"""
# This is messy
# Augment spectrogram by randomly stretch and squeezing
# NOTE this also changes the start and stop time in place
delta = params["stretch_squeeze_delta"]
delta = stretch_squeeze_delta
op_size = (spec.shape[1], spec.shape[2])
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
resize_amt = int(spec.shape[2] * resize_fract_r)
@ -277,7 +261,7 @@ def warp_spec_aug(
dtype=spec.dtype,
),
),
2,
dim=2,
)
else:
spec_r = spec[:, :, :resize_amt]
@ -297,7 +281,10 @@ def warp_spec_aug(
return spec
def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
def mask_time_aug(
spec: torch.Tensor,
mask_max_time_perc: float,
) -> torch.Tensor:
"""Mask out random blocks of time.
Will randomly mask out a block of time in the spectrogram. The block
@ -308,8 +295,8 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
----------
spec: torch.Tensor
Spectrogram to mask.
params: dict
Parameters for the augmentation.
mask_max_time_perc: float
Maximum percentage of time to mask out.
Returns
-------
@ -324,14 +311,17 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
Recognition
"""
fm = torchaudio.transforms.TimeMasking(
int(spec.shape[1] * params["mask_max_time_perc"])
int(spec.shape[1] * mask_max_time_perc)
)
for _ in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
def mask_freq_aug(
spec: torch.Tensor,
mask_max_freq_perc: float,
) -> torch.Tensor:
"""Mask out random blocks of frequency.
Will randomly mask out a block of frequency in the spectrogram. The block
@ -342,8 +332,8 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
----------
spec: torch.Tensor
Spectrogram to mask.
params: dict
Parameters for the augmentation.
mask_max_freq_perc: float
Maximum percentage of frequency to mask out.
Returns
-------
@ -358,41 +348,48 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
Recognition
"""
fm = torchaudio.transforms.FrequencyMasking(
int(spec.shape[1] * params["mask_max_freq_perc"])
int(spec.shape[1] * mask_max_freq_perc)
)
for _ in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor:
def scale_vol_aug(
spec: torch.Tensor,
spec_amp_scaling: float,
) -> torch.Tensor:
"""Scale the volume of the spectrogram.
Parameters
----------
spec: torch.Tensor
Spectrogram to scale.
params: dict
Parameters for the augmentation.
spec_amp_scaling: float
Maximum scaling factor.
Returns
-------
torch.Tensor
"""
return spec * np.random.random() * params["spec_amp_scaling"]
return spec * np.random.random() * spec_amp_scaling
def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
def echo_aug(
audio: np.ndarray,
sampling_rate: float,
echo_max_delay: float,
) -> np.ndarray:
"""Add echo to audio.
Parameters
----------
audio: np.ndarray
Audio to add echo to.
sampling_rate: int
sampling_rate: float
Sampling rate of the audio.
params: dict
Parameters for the augmentation.
echo_max_delay: float
Maximum delay of the echo in seconds.
Returns
-------
@ -400,7 +397,7 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
Audio with echo added.
"""
sample_offset = (
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
int(echo_max_delay * np.random.random() * sampling_rate) + 1
)
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
return audio
@ -408,9 +405,14 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray:
def resample_aug(
audio: np.ndarray,
sampling_rate: int,
params: dict,
) -> Tuple[np.ndarray, int, float]:
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
resize_factor: float,
spec_divide_factor: float,
spec_train_width: int,
aug_sampling_rates: List[int],
) -> Tuple[np.ndarray, float, float]:
"""Resample audio augmentation.
Will resample the audio to a random sampling rate from the list of
@ -420,23 +422,32 @@ def resample_aug(
----------
audio: np.ndarray
Audio to resample.
sampling_rate: int
sampling_rate: float
Original sampling rate of the audio.
params: dict
Parameters for the augmentation. Includes the list of sampling rates
to choose from for resampling in `aug_sampling_rates`.
fft_win_length: float
Length of the FFT window in seconds.
fft_overlap: float
Amount of overlap between FFT windows.
resize_factor: float
Factor to resize the spectrogram by.
spec_divide_factor: float
Factor to divide the spectrogram by.
spec_train_width: int
Width of the spectrogram.
aug_sampling_rates: List[int]
List of sampling rates to resample to.
Returns
-------
audio : np.ndarray
Resampled audio.
sampling_rate : int
sampling_rate : float
New sampling rate.
duration : float
Duration of the audio in seconds.
"""
sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(params["aug_sampling_rates"])
sampling_rate = np.random.choice(aug_sampling_rates)
audio = librosa.resample(
audio,
orig_sr=sampling_rate_old,
@ -447,11 +458,11 @@ def resample_aug(
audio = au.pad_audio(
audio,
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
params["resize_factor"],
params["spec_divide_factor"],
params["spec_train_width"],
fft_win_length,
fft_overlap,
resize_factor,
spec_divide_factor,
spec_train_width,
)
duration = audio.shape[0] / float(sampling_rate)
return audio, sampling_rate, duration
@ -459,28 +470,28 @@ def resample_aug(
def resample_audio(
num_samples: int,
sampling_rate: int,
sampling_rate: float,
audio2: np.ndarray,
sampling_rate2: int,
) -> Tuple[np.ndarray, int]:
sampling_rate2: float,
) -> Tuple[np.ndarray, float]:
"""Resample audio.
Parameters
----------
num_samples: int
Expected number of samples for the output audio.
sampling_rate: int
sampling_rate: float
Original sampling rate of the audio.
audio2: np.ndarray
Audio to resample.
sampling_rate2: int
sampling_rate2: float
Target sampling rate of the audio.
Returns
-------
audio2 : np.ndarray
Resampled audio.
sampling_rate2 : int
sampling_rate2 : float
New sampling rate.
"""
# resample to target sampling rate
@ -509,12 +520,12 @@ def resample_audio(
def combine_audio_aug(
audio: np.ndarray,
sampling_rate: int,
ann: AnnotationGroup,
sampling_rate: float,
ann: AudioLoaderAnnotationGroup,
audio2: np.ndarray,
sampling_rate2: int,
ann2: AnnotationGroup,
) -> Tuple[np.ndarray, AnnotationGroup]:
sampling_rate2: float,
ann2: AudioLoaderAnnotationGroup,
) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]:
"""Combine two audio files.
Will combine two audio files by resampling them to the same sampling rate
@ -570,7 +581,9 @@ def combine_audio_aug(
# from different individuals
if kk == "individual_ids":
if (ann[kk] > -1).sum() > 0:
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
ann2[kk][ann2[kk] > -1] += (
np.max(ann[kk][ann[kk] > -1]) + 1
)
if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
@ -579,7 +592,8 @@ def combine_audio_aug(
def _prepare_annotation(
annotation: Annotation, class_names: List[str]
annotation: Annotation,
class_names: List[str],
) -> Annotation:
try:
class_id = class_names.index(annotation["class"])
@ -598,7 +612,7 @@ def _prepare_annotation(
def _prepare_file_annotation(
annotation: FileAnnotations,
annotation: FileAnnotation,
class_names: List[str],
classes_to_ignore: List[str],
) -> AudioLoaderAnnotationGroup:
@ -626,7 +640,9 @@ def _prepare_file_annotation(
"end_times": np.array([ann["end_time"] for ann in annotations]),
"high_freqs": np.array([ann["high_freq"] for ann in annotations]),
"low_freqs": np.array([ann["low_freq"] for ann in annotations]),
"class_ids": np.array([ann.get("class_id", -1) for ann in annotations]),
"class_ids": np.array(
[ann.get("class_id", -1) for ann in annotations]
),
"individual_ids": np.array([ann["individual"] for ann in annotations]),
"class_id_file": class_id_file,
}
@ -639,15 +655,15 @@ class AudioLoader(torch.utils.data.Dataset):
def __init__(
self,
data_anns_ip: List[FileAnnotations],
params,
data_anns_ip: List[FileAnnotation],
params: AudioLoaderParameters,
dataset_name: Optional[str] = None,
is_train: bool = False,
return_spec_for_viz: bool = False,
):
self.is_train: bool = is_train
self.params: dict = params
self.return_spec_for_viz: bool = False
self.is_train = is_train
self.params = params
self.return_spec_for_viz = return_spec_for_viz
self.data_anns: List[AudioLoaderAnnotationGroup] = [
_prepare_file_annotation(
ann,
@ -657,61 +673,6 @@ class AudioLoader(torch.utils.data.Dataset):
for ann in data_anns_ip
]
# for ii in range(len(data_anns_ip)):
# dd = copy.deepcopy(data_anns_ip[ii])
#
# # filter out unused annotation here
# filtered_annotations = []
# for ii, aa in enumerate(dd["annotation"]):
# if "individual" in aa.keys():
# aa["individual"] = int(aa["individual"])
#
# # if only one call labeled it has to be from the same
# # individual
# if len(dd["annotation"]) == 1:
# aa["individual"] = 0
#
# # convert class name into class label
# if aa["class"] in self.params["class_names"]:
# aa["class_id"] = self.params["class_names"].index(
# aa["class"]
# )
# else:
# aa["class_id"] = -1
#
# if aa["class"] not in self.params["classes_to_ignore"]:
# filtered_annotations.append(aa)
#
# dd["annotation"] = filtered_annotations
# dd["start_times"] = np.array(
# [aa["start_time"] for aa in dd["annotation"]]
# )
# dd["end_times"] = np.array(
# [aa["end_time"] for aa in dd["annotation"]]
# )
# dd["high_freqs"] = np.array(
# [float(aa["high_freq"]) for aa in dd["annotation"]]
# )
# dd["low_freqs"] = np.array(
# [float(aa["low_freq"]) for aa in dd["annotation"]]
# )
# dd["class_ids"] = np.array(
# [aa["class_id"] for aa in dd["annotation"]]
# ).astype(np.int32)
# dd["individual_ids"] = np.array(
# [aa["individual"] for aa in dd["annotation"]]
# ).astype(np.int32)
#
# # file level class name
# dd["class_id_file"] = -1
# if "class_name" in dd.keys():
# if dd["class_name"] in self.params["class_names"]:
# dd["class_id_file"] = self.params["class_names"].index(
# dd["class_name"]
# )
#
# self.data_anns.append(dd)
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
self.max_num_anns = 2 * np.max(
ann_cnt
@ -730,7 +691,7 @@ class AudioLoader(torch.utils.data.Dataset):
def get_file_and_anns(
self,
index: Optional[int] = None,
) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]:
) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]:
"""Get an audio file and its annotations.
Parameters
@ -742,7 +703,7 @@ class AudioLoader(torch.utils.data.Dataset):
-------
audio_raw : np.ndarray
Loaded audio file.
sampling_rate : int
sampling_rate : float
Sampling rate of the audio file.
duration : float
Duration of the audio file in seconds.
@ -837,7 +798,7 @@ class AudioLoader(torch.utils.data.Dataset):
(
audio2,
sampling_rate2,
duration2,
_,
ann2,
) = self.get_file_and_anns()
audio, ann = combine_audio_aug(
@ -846,7 +807,11 @@ class AudioLoader(torch.utils.data.Dataset):
# simulate echo by adding delayed copy of the file
if np.random.random() < self.params["aug_prob"]:
audio = echo_aug(audio, sampling_rate, self.params)
audio = echo_aug(
audio,
sampling_rate,
echo_max_delay=self.params["echo_max_delay"],
)
# resample the audio
# if np.random.random() < self.params["aug_prob"]:
@ -855,11 +820,16 @@ class AudioLoader(torch.utils.data.Dataset):
# )
# create spectrogram
spec, spec_for_viz = au.generate_spectrogram(
spec = au.generate_spectrogram(
audio,
sampling_rate,
self.params,
self.return_spec_for_viz,
fft_win_length=self.params["fft_win_length"],
fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
min_freq=self.params["min_freq"],
spec_scale=self.params["spec_scale"],
denoise_spec_avg=self.params["denoise_spec_avg"],
max_scale_spec=self.params["max_scale_spec"],
)
rsf = self.params["resize_factor"]
spec_op_shape = (
@ -879,20 +849,29 @@ class AudioLoader(torch.utils.data.Dataset):
# augment spectrogram
if self.is_train and self.params["augment_at_train"]:
if np.random.random() < self.params["aug_prob"]:
spec = scale_vol_aug(spec, self.params)
spec = scale_vol_aug(
spec,
spec_amp_scaling=self.params["spec_amp_scaling"],
)
if np.random.random() < self.params["aug_prob"]:
spec = warp_spec_aug(
spec,
ann,
self.params,
stretch_squeeze_delta=self.params["stretch_squeeze_delta"],
)
if np.random.random() < self.params["aug_prob"]:
spec = mask_time_aug(spec, self.params)
spec = mask_time_aug(
spec,
mask_max_time_perc=self.params["mask_max_time_perc"],
)
if np.random.random() < self.params["aug_prob"]:
spec = mask_freq_aug(spec, self.params)
spec = mask_freq_aug(
spec,
mask_max_freq_perc=self.params["mask_max_freq_perc"],
)
outputs = {}
outputs["spec"] = spec
@ -911,7 +890,13 @@ class AudioLoader(torch.utils.data.Dataset):
spec_op_shape,
sampling_rate,
ann,
self.params,
class_names=self.params["class_names"],
fft_win_length=self.params["fft_win_length"],
fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
min_freq=self.params["min_freq"],
resize_factor=self.params["resize_factor"],
target_sigma=self.params["target_sigma"],
)
# hack to get around requirement that all vectors are the same length

View File

@ -1,8 +1,13 @@
from typing import Optional
import torch
import torch.nn.functional as F
def bbox_size_loss(pred_size, gt_size):
def bbox_size_loss(
pred_size: torch.Tensor,
gt_size: torch.Tensor,
) -> torch.Tensor:
"""
Bounding box size loss. Only compute loss where there is a bounding box.
"""
@ -12,7 +17,12 @@ def bbox_size_loss(pred_size, gt_size):
)
def focal_loss(pred, gt, weights=None, valid_mask=None):
def focal_loss(
pred: torch.Tensor,
gt: torch.Tensor,
weights: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
pred (batch x c x h x w)
@ -52,7 +62,11 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
return loss
def mse_loss(pred, gt, weights=None, valid_mask=None):
def mse_loss(
pred: torch.Tensor,
gt: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Mean squared error loss.
"""

View File

@ -5,6 +5,7 @@ import warnings
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data
from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.post_process as pp
@ -29,7 +30,7 @@ def save_images_batch(model, data_loader, params):
ind = 0 # first image in each batch
with torch.no_grad():
for batch_idx, inputs in enumerate(data_loader):
for inputs in data_loader:
data = inputs["spec"].to(params["device"])
outputs = model(data)
@ -81,7 +82,12 @@ def save_image(
def loss_fun(
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
):
# detection loss
loss = params["det_loss_weight"] * det_criterion(
@ -104,7 +110,13 @@ def loss_fun(
def train(
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
model,
epoch,
data_loader,
det_criterion,
optimizer,
scheduler,
params,
):
model.train()
@ -309,7 +321,7 @@ def select_model(params):
resize_factor=params["resize_factor"],
)
else:
print("No valid network specified")
raise ValueError("No valid network specified")
return model
@ -319,9 +331,9 @@ def main():
params = parameters.get_params(True)
if torch.cuda.is_available():
params["device"] = "cuda"
params.device = "cuda"
else:
params["device"] = "cpu"
params.device = "cpu"
# setup arg parser and populate it with exiting parameters - will not work with lists
parser = argparse.ArgumentParser()
@ -349,13 +361,16 @@ def main():
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
help='Will set low and high frequency the same for these classes. Separate names with ";"',
)
for key, val in params.items():
parser.add_argument("--" + key, type=type(val), default=val)
params = vars(parser.parse_args())
# save notes file
if params["notes"] != "":
tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"])
tu.write_notes_file(
params["experiment"] + "notes.txt", params["notes"]
)
# load the training and test meta data - there are different splits defined
train_sets, test_sets = ts.get_train_test_data(
@ -374,15 +389,11 @@ def main():
for tt in train_sets:
print(tt["ann_path"])
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
(
data_train,
params["class_names"],
params["class_inv_freq"],
) = tu.load_set_of_anns(
data_train = tu.load_set_of_anns(
train_sets,
classes_to_ignore,
params["events_of_interest"],
params["convert_to_genus"],
classes_to_ignore=classes_to_ignore,
events_of_interest=params["events_of_interest"],
convert_to_genus=params["convert_to_genus"],
)
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
params["class_names"]
@ -415,11 +426,12 @@ def main():
print("\nTesting on:")
for tt in test_sets:
print(tt["ann_path"])
data_test, _, _ = tu.load_set_of_anns(
data_test = tu.load_set_of_anns(
test_sets,
classes_to_ignore,
params["events_of_interest"],
params["convert_to_genus"],
classes_to_ignore=classes_to_ignore,
events_of_interest=params["events_of_interest"],
convert_to_genus=params["convert_to_genus"],
)
data_train = tu.remove_dupes(data_train, data_test)
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
@ -447,10 +459,13 @@ def main():
scheduler = CosineAnnealingLR(
optimizer, params["num_epochs"] * len(train_loader)
)
if params["train_loss"] == "mse":
det_criterion = losses.mse_loss
elif params["train_loss"] == "focal":
det_criterion = losses.focal_loss
else:
raise ValueError("No valid loss specified")
# save parameters to file
with open(params["experiment"] + "params.json", "w") as da:

View File

@ -1,28 +1,37 @@
import glob
import json
import os
import random
from collections import Counter
from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple
import numpy as np
from batdetect2 import types
def write_notes_file(file_name, text):
def write_notes_file(file_name: str, text: str):
with open(file_name, "a") as da:
da.write(text + "\n")
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
ddict = {
def get_blank_dataset_dict(
dataset_name: str,
is_test: bool,
ann_path: str,
wav_path: str,
) -> types.DatasetDict:
return {
"dataset_name": dataset_name,
"is_test": is_test,
"is_binary": False,
"ann_path": ann_path,
"wav_path": wav_path,
}
return ddict
def get_short_class_names(class_names, str_len=3):
def get_short_class_names(
class_names: List[str],
str_len: int = 3,
) -> List[str]:
class_names_short = []
for cc in class_names:
class_names_short.append(
@ -31,7 +40,10 @@ def get_short_class_names(class_names, str_len=3):
return class_names_short
def remove_dupes(data_train, data_test):
def remove_dupes(
data_train: List[types.FileAnnotation],
data_test: List[types.FileAnnotation],
) -> List[types.FileAnnotation]:
test_ids = [dd["id"] for dd in data_test]
data_train_prune = []
for aa in data_train:
@ -43,14 +55,16 @@ def remove_dupes(data_train, data_test):
return data_train_prune
def get_genus_mapping(class_names):
def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]:
genus_names, genus_mapping = np.unique(
[cc.split(" ")[0] for cc in class_names], return_inverse=True
)
return genus_names.tolist(), genus_mapping.tolist()
def standardize_low_freq(data, class_of_interest):
def standardize_low_freq(
data: List[types.FileAnnotation], class_of_interest: str,
) -> List[types.FileAnnotation]:
# address the issue of highly variable low frequency annotations
# this often happens for contstant frequency calls
# for the class of interest sets the low and high freq to be the dataset mean
@ -62,8 +76,8 @@ def standardize_low_freq(data, class_of_interest):
low_freqs.append(aa["low_freq"])
high_freqs.append(aa["high_freq"])
low_mean = np.mean(low_freqs)
high_mean = np.mean(high_freqs)
low_mean = float(np.mean(low_freqs))
high_mean = float(np.mean(high_freqs))
assert low_mean < high_mean
print("\nStandardizing low and high frequency for:")
@ -83,115 +97,148 @@ def standardize_low_freq(data, class_of_interest):
return data
def load_set_of_anns(
data,
classes_to_ignore=[],
events_of_interest=None,
convert_to_genus=False,
verbose=True,
list_of_anns=False,
filter_issues=False,
name_replace=False,
def format_annotation(
annotation: types.FileAnnotation,
events_of_interest: Optional[List[str]] = None,
name_replace: Optional[Dict[str, str]] = None,
convert_to_genus: bool = False,
classes_to_ignore: Optional[List[str]] = None,
) -> types.FileAnnotation:
formated = []
for aa in annotation["annotation"]:
if (
events_of_interest is not None
and aa["event"] not in events_of_interest
):
# Omit files with annotation issues
continue
# load the annotations
anns = []
if list_of_anns:
# path to list of individual json files
anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
else:
# dictionary of datasets
for dd in data:
anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
# remove leading and trailing spaces
class_name = aa["class"].strip()
# discarding unannoated files
anns = [aa for aa in anns if aa["annotated"] is True]
# filter files that have annotation issues - is the input is a dictionary of
# datasets, this will lilely have already been done
if filter_issues:
anns = [aa for aa in anns if aa["issues"] is False]
# check for some basic formatting errors with class names
for ann in anns:
for aa in ann["annotation"]:
aa["class"] = aa["class"].strip()
# only load specified events - i.e. types of calls
if events_of_interest is not None:
for ann in anns:
filtered_events = []
for aa in ann["annotation"]:
if aa["event"] in events_of_interest:
filtered_events.append(aa)
ann["annotation"] = filtered_events
# change class names
if name_replace is not None:
# replace_names will be a dictionary mapping input name to output
if type(name_replace) is dict:
for ann in anns:
for aa in ann["annotation"]:
if aa["class"] in name_replace:
aa["class"] = name_replace[aa["class"]]
class_name = name_replace.get(class_name, class_name)
# convert everything to genus name
if convert_to_genus:
for ann in anns:
for aa in ann["annotation"]:
aa["class"] = aa["class"].split(" ")[0]
# convert everything to genus name
class_name = class_name.split(" ")[0]
# get unique class names
class_names_all = []
for ann in anns:
for aa in ann["annotation"]:
if aa["class"] not in classes_to_ignore:
class_names_all.append(aa["class"])
# NOTE: It is important to acknowledge that the class names filtering
# is done after the name replacement and the conversion to
# genus name. This allows filtering converted genus names and names
# that were replaced with a name that should be ignored.
if classes_to_ignore is not None and class_name in classes_to_ignore:
# Omit annotations with ignored classes
continue
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
class_inv_freq = class_cnts.sum() / (
len(class_names) * class_cnts.astype(np.float32)
formated.append(
{
**aa,
"class": class_name,
}
)
if verbose:
return {
**annotation,
"annotation": formated,
}
def get_class_names(
data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None,
) -> Tuple[Counter[str], List[float]]:
"""Extracts class names and their inverse frequencies.
Parameters
----------
data
A list of file annotations, where each annotation contains a list of
sound events with associated class names.
classes_to_ignore
A list of class names to ignore.
Returns:
--------
class_names
A list of unique class names extracted from the annotations.
class_inv_freq
List of inverse frequencies of each class name in the provided data.
"""
if classes_to_ignore is None:
classes_to_ignore = []
class_names_list: List[str] = []
for annotation in data:
for sound_event in annotation["annotation"]:
if sound_event["class"] in classes_to_ignore:
continue
class_names_list.append(sound_event["class"])
counts = Counter(class_names_list)
mean_counts = float(np.mean(list(counts.values())))
return counts, [mean_counts / counts[cc] for cc in class_names_list]
def report_class_counts(class_names: Counter[str]):
print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for cc in range(len(class_names)):
print(
str(cc).ljust(5)
+ class_names[cc].ljust(str_len)
+ str(class_cnts[cc])
for index, (class_name, count) in enumerate(class_names.most_common()):
print(f"{index:<5}{class_name:<{str_len}}{count}")
def load_set_of_anns(
data: List[types.DatasetDict],
*,
convert_to_genus: bool = False,
filter_issues: bool = False,
events_of_interest: Optional[List[str]] = None,
classes_to_ignore: Optional[List[str]] = None,
name_replace: Optional[Dict[str, str]] = None,
) -> List[types.FileAnnotation]:
# load the annotations
anns = []
# dictionary of datasets
for dataset in data:
for ann in load_anns(dataset["ann_path"], dataset["wav_path"]):
if not ann["annotated"]:
# Omit unannotated files
continue
if filter_issues and ann["issues"]:
# Omit files with annotation issues
continue
anns.append(
format_annotation(
ann,
events_of_interest=events_of_interest,
name_replace=name_replace,
convert_to_genus=convert_to_genus,
classes_to_ignore=classes_to_ignore,
)
)
if len(classes_to_ignore) == 0:
return anns
else:
return anns, class_names.tolist(), class_inv_freq.tolist()
def load_anns(ann_file_name, raw_audio_dir):
with open(ann_file_name) as da:
anns = json.load(da)
for aa in anns:
aa["file_path"] = raw_audio_dir + aa["id"]
return anns
def load_anns_from_path(ann_file_dir, raw_audio_dir):
files = glob.glob(ann_file_dir + "*.json")
anns = []
for ff in files:
with open(ff) as da:
ann = json.load(da)
ann["file_path"] = raw_audio_dir + ann["id"]
anns.append(ann)
def load_anns(
ann_dir: str,
raw_audio_dir: str,
) -> Generator[types.FileAnnotation, None, None]:
for path in Path(ann_dir).rglob("*.json"):
with open(path) as fp:
file_annotation = json.load(fp)
return anns
file_annotation["file_path"] = raw_audio_dir + file_annotation["id"]
yield file_annotation
class AverageMeter(object):
"""Computes and stores the average and current value"""
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()

View File

@ -1,5 +1,5 @@
"""Types used in the code base."""
from typing import List, NamedTuple, Optional, Union
from typing import Any, List, NamedTuple, Optional
import numpy as np
import torch
@ -26,8 +26,7 @@ __all__ = [
"Annotation",
"DetectionModel",
"FeatureExtractionParameters",
"FeatureExtractor",
"FileAnnotations",
"FileAnnotation",
"ModelOutput",
"ModelParameters",
"NonMaximumSuppressionConfig",
@ -94,7 +93,10 @@ class ModelParameters(TypedDict):
"""Resize factor."""
class_names: List[str]
"""Class names. The model is trained to detect these classes."""
"""Class names.
The model is trained to detect these classes.
"""
DictWithClass = TypedDict("DictWithClass", {"class": str})
@ -103,8 +105,8 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the annotation
tool.
This is the format of a single annotation as expected by the
annotation tool.
"""
start_time: float
@ -113,10 +115,10 @@ class Annotation(DictWithClass):
end_time: float
"""End time in seconds."""
low_freq: int
low_freq: float
"""Low frequency in Hz."""
high_freq: int
high_freq: float
"""High frequency in Hz."""
class_prob: float
@ -135,7 +137,7 @@ class Annotation(DictWithClass):
"""Numeric ID for the class of the annotation."""
class FileAnnotations(TypedDict):
class FileAnnotation(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
@ -157,7 +159,7 @@ class FileAnnotations(TypedDict):
"""Time expansion factor."""
class_name: str
"""Class predicted at file level"""
"""Class predicted at file level."""
notes: str
"""Notes of file."""
@ -169,7 +171,7 @@ class FileAnnotations(TypedDict):
class RunResults(TypedDict):
"""Run results."""
pred_dict: FileAnnotations
pred_dict: FileAnnotation
"""Predictions in the format expected by the annotation tool."""
spec_feats: NotRequired[List[np.ndarray]]
@ -394,9 +396,9 @@ class PredictionResults(TypedDict):
class DetectionModel(Protocol):
"""Protocol for detection models.
This protocol is used to define the interface for the detection models.
This allows us to use the same code for training and inference, even
though the models are different.
This protocol is used to define the interface for the detection
models. This allows us to use the same code for training and
inference, even though the models are different.
"""
num_classes: int
@ -416,16 +418,14 @@ class DetectionModel(Protocol):
def forward(
self,
ip: torch.Tensor,
return_feats: bool = False,
spec: torch.Tensor,
) -> ModelOutput:
"""Forward pass of the model."""
...
def __call__(
self,
ip: torch.Tensor,
return_feats: bool = False,
spec: torch.Tensor,
) -> ModelOutput:
"""Forward pass of the model."""
...
@ -490,8 +490,10 @@ class HeatmapParameters(TypedDict):
"""Maximum frequency to consider in Hz."""
target_sigma: float
"""Sigma for the Gaussian kernel. Controls the width of the points in
the heatmap."""
"""Sigma for the Gaussian kernel.
Controls the width of the points in the heatmap.
"""
class AnnotationGroup(TypedDict):
@ -522,10 +524,10 @@ class AnnotationGroup(TypedDict):
annotated: NotRequired[bool]
"""Wether the annotation group is complete or not.
Usually annotation groups are associated to a single
audio clip. If the annotation group is complete, it means that all
relevant sound events have been annotated. If it is not complete, it
means that some sound events might not have been annotated.
Usually annotation groups are associated to a single audio clip. If
the annotation group is complete, it means that all relevant sound
events have been annotated. If it is not complete, it means that
some sound events might not have been annotated.
"""
x_inds: NotRequired[np.ndarray]
@ -535,12 +537,88 @@ class AnnotationGroup(TypedDict):
"""Y coordinate of the annotations in the spectrogram."""
class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations):
class AudioLoaderAnnotationGroup(TypedDict):
"""Group of annotation items for the training audio loader.
This class is used to store the annotations for the training audio
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
"""
id: str
duration: float
issues: bool
file_path: str
time_exp: float
class_name: str
notes: str
start_times: np.ndarray
end_times: np.ndarray
low_freqs: np.ndarray
high_freqs: np.ndarray
class_ids: np.ndarray
individual_ids: np.ndarray
x_inds: np.ndarray
y_inds: np.ndarray
annotation: List[Annotation]
annotated: bool
class_id_file: int
"""ID of the class of the file."""
class AudioLoaderParameters(TypedDict):
class_names: List[str]
classes_to_ignore: List[str]
target_samp_rate: int
scale_raw_audio: bool
fft_win_length: float
fft_overlap: float
spec_train_width: int
resize_factor: float
spec_divide_factor: int
augment_at_train: bool
augment_at_train_combine: bool
aug_prob: float
spec_height: int
echo_max_delay: float
spec_amp_scaling: float
stretch_squeeze_delta: float
mask_max_time_perc: float
mask_max_freq_perc: float
max_freq: float
min_freq: float
spec_scale: str
denoise_spec_avg: bool
max_scale_spec: bool
target_sigma: float
class FeatureExtractor(Protocol):
def __call__(
self,
prediction: Prediction,
**kwargs: Any,
) -> float:
...
class DatasetDict(TypedDict):
"""Dataset dictionary.
This is the format of the dictionary that contains the dataset
information.
"""
dataset_name: str
"""Name of the dataset."""
is_test: bool
"""Whether the dataset is a test set."""
is_binary: bool
"""Whether the dataset is binary."""
ann_path: str
"""Path to the annotations."""
wav_path: str
"""Path to the audio files."""

View File

@ -1,13 +1,11 @@
import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, overload
import librosa
import librosa.core.spectrum
import numpy as np
import torch
from . import wavfile
__all__ = [
"load_audio",
"generate_spectrogram",
@ -15,113 +13,171 @@ __all__ = [
]
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
@overload
def time_to_x_coords(
time_in_file: np.ndarray,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> np.ndarray:
...
@overload
def time_to_x_coords(
time_in_file: float,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> float:
...
def time_to_x_coords(
time_in_file: Union[float, np.ndarray],
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> Union[float, np.ndarray]:
nfft = np.floor(fft_win_length * sampling_rate)
noverlap = np.floor(fft_overlap * nfft)
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
# NOTE this is also defined in post_process
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
def x_coords_to_time(
x_pos: float,
sampling_rate: int,
fft_win_length: float,
fft_overlap: float,
) -> float:
nfft = np.floor(fft_win_length * sampling_rate)
noverlap = np.floor(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for
# center of temporal window
def generate_spectrogram(
audio,
sampling_rate,
params,
return_spec_for_viz=False,
check_spec_size=True,
):
audio: np.ndarray,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
max_freq: float,
min_freq: float,
spec_scale: str,
denoise_spec_avg: bool = False,
max_scale_spec: bool = False,
) -> np.ndarray:
# generate spectrogram
spec = gen_mag_spectrogram(
audio,
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
window_len=fft_win_length,
overlap_perc=fft_overlap,
)
spec = crop_spectrogram(
spec,
fft_win_length=fft_win_length,
max_freq=max_freq,
min_freq=min_freq,
)
spec = scale_spectrogram(
spec,
sampling_rate,
spec_scale=spec_scale,
fft_win_length=fft_win_length,
)
if denoise_spec_avg:
spec = denoise_spectrogram(spec)
if max_scale_spec:
spec = max_scale_spectrogram(spec)
return spec
def crop_spectrogram(
spec: np.ndarray,
fft_win_length: float,
max_freq: float,
min_freq: float,
) -> np.ndarray:
# crop to min/max freq
max_freq = round(params["max_freq"] * params["fft_win_length"])
min_freq = round(params["min_freq"] * params["fft_win_length"])
max_freq = round(max_freq * fft_win_length)
min_freq = round(min_freq * fft_win_length)
if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0]
spec = np.vstack(
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
)
spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :]
return spec[-max_freq : spec.shape[0] - min_freq, :]
if params["spec_scale"] == "log":
log_scaling = (
2.0
* (1.0 / sampling_rate)
* (
1.0
/ (
np.abs(
np.hanning(
int(params["fft_win_length"] * sampling_rate)
)
)
** 2
).sum()
)
)
# log_scaling = (1.0 / sampling_rate)*0.1
# log_scaling = (1.0 / sampling_rate)*10e4
spec = np.log1p(log_scaling * spec_cropped)
elif params["spec_scale"] == "pcen":
spec = pcen(spec_cropped, sampling_rate)
elif params["spec_scale"] == "none":
pass
if params["denoise_spec_avg"]:
def denoise_spectrogram(spec: np.ndarray) -> np.ndarray:
spec = spec - np.mean(spec, 1)[:, np.newaxis]
spec.clip(min=0, out=spec)
return spec.clip(min=0)
if params["max_scale_spec"]:
spec = spec / (spec.max() + 10e-6)
# needs to be divisible by specific factor - if not it should have been padded
# if check_spec_size:
# assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
# assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
def max_scale_spectrogram(spec: np.ndarray) -> np.ndarray:
return spec / (spec.max() + 10e-6)
# for visualization purposes - use log scaled spectrogram
if return_spec_for_viz:
def log_scale(
spec: np.ndarray,
sampling_rate: float,
fft_win_length: float,
) -> np.ndarray:
log_scaling = (
2.0
* (1.0 / sampling_rate)
* (
1.0
/ (
np.abs(
np.hanning(
int(params["fft_win_length"] * sampling_rate)
)
)
** 2
np.abs(np.hanning(int(fft_win_length * sampling_rate))) ** 2
).sum()
)
)
spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
else:
spec_for_viz = None
return np.log1p(log_scaling * spec)
return spec, spec_for_viz
def scale_spectrogram(
spec: np.ndarray,
sampling_rate: float,
spec_scale: str,
fft_win_length: float,
) -> np.ndarray:
if spec_scale == "log":
return log_scale(spec, sampling_rate, fft_win_length)
if spec_scale == "pcen":
return pcen(spec, sampling_rate)
return spec
def prepare_spec_for_viz(
spec: np.ndarray,
sampling_rate: int,
fft_win_length: float,
) -> np.ndarray:
# for visualization purposes - use log scaled spectrogram
return log_scale(
spec,
sampling_rate,
fft_win_length=fft_win_length,
).astype(np.float32)
def load_audio(
audio_file: str,
time_exp_fact: float,
target_samp_rate: int,
target_sampling_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray]:
) -> Tuple[float, np.ndarray]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
@ -152,63 +208,82 @@ def load_audio(
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
# sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(
audio, sampling_rate = librosa.load(
audio_file,
sr=None,
dtype=np.float32,
)
if len(audio_raw.shape) > 1:
if len(audio.shape) > 1:
raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate
sampling_rate = target_samp_rate
if sampling_rate_old != sampling_rate:
audio_raw = librosa.resample(
audio_raw,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
audio = resample_audio(audio, sampling_rate, target_sampling_rate)
# clipping maximum duration
if max_duration is not None:
max_duration = int(
np.minimum(
int(sampling_rate * max_duration),
audio_raw.shape[0],
)
)
audio_raw = audio_raw[:max_duration]
audio = clip_audio(audio, target_sampling_rate, max_duration)
# scale to [-1, 1]
if scale:
audio_raw = audio_raw - audio_raw.mean()
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
audio = scale_audio(audio)
return sampling_rate, audio_raw
return target_sampling_rate, audio
def resample_audio(
audio: np.ndarray,
sr_orig: float,
sr_target: float,
) -> np.ndarray:
if sr_orig != sr_target:
return librosa.resample(
audio,
orig_sr=sr_orig,
target_sr=sr_target,
res_type="polyphase",
)
return audio
def clip_audio(
audio: np.ndarray,
sampling_rate: float,
max_duration: float,
) -> np.ndarray:
max_duration = int(
np.minimum(
int(sampling_rate * max_duration),
audio.shape[0],
)
)
return audio[:max_duration]
def scale_audio(
audio: np.ndarray,
eps: float = 10e-6,
) -> np.ndarray:
return (audio - audio.mean()) / (np.abs(audio).max() + eps)
def pad_audio(
audio_raw,
fs,
ms,
overlap_perc,
resize_factor,
divide_factor,
fixed_width=None,
):
audio_raw: np.ndarray,
sampling_rate: float,
window_len: float,
overlap_perc: float,
resize_factor: float,
divide_factor: float,
fixed_width: Optional[int] = None,
) -> np.ndarray:
# Adds zeros to the end of the raw data so that the generated sepctrogram
# will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training
# This code could be clearer, clean up
nfft = int(ms * fs)
nfft = int(window_len * sampling_rate)
noverlap = int(overlap_perc * nfft)
step = nfft - noverlap
min_size = int(divide_factor * (1.0 / resize_factor))
@ -245,19 +320,24 @@ def pad_audio(
return audio_raw
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
def gen_mag_spectrogram(
audio: np.ndarray,
sampling_rate: float,
window_len: float,
overlap_perc: float,
) -> np.ndarray:
# Computes magnitude spectrogram by specifying time.
x = x.astype(np.float32)
nfft = int(ms * fs)
audio = audio.astype(np.float32)
nfft = int(window_len * sampling_rate)
noverlap = int(overlap_perc * nfft)
# window data
step = nfft - noverlap
# compute spec
spec, _ = librosa.core.spectrum._spectrogram(
y=x, power=1, n_fft=nfft, hop_length=step, center=False
y=audio,
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
center=False,
)
# remove DC component and flip vertical orientation
@ -266,24 +346,25 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
return spec.astype(np.float32)
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
nfft = int(ms * fs)
def gen_mag_spectrogram_pt(
audio: torch.Tensor,
sampling_rate: float,
window_len: float,
overlap_perc: float,
) -> torch.Tensor:
nfft = int(window_len * sampling_rate)
nstep = round((1.0 - overlap_perc) * nfft)
han_win = torch.hann_window(nfft, periodic=False).to(audio.device)
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False)
complex_spec = torch.stft(audio, nfft, nstep, window=han_win, center=False)
spec = complex_spec.pow(2.0).sum(-1)
# remove DC component and flip vertically
spec = torch.flipud(spec[0, 1:, :])
return spec
return torch.flipud(spec[0, 1:, :])
def pcen(spec_cropped, sampling_rate):
def pcen(spec: np.ndarray, sampling_rate: float) -> np.ndarray:
# TODO should be passing hop_length too i.e. step
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype(
return librosa.pcen(spec * (2**31), sr=sampling_rate / 10).astype(
np.float32
)
return spec

View File

@ -16,7 +16,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import (
Annotation,
DetectionModel,
FileAnnotations,
FileAnnotation,
ModelOutput,
ModelParameters,
PredictionResults,
@ -79,7 +79,7 @@ def list_audio_files(ip_dir: str) -> List[str]:
def load_model(
model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True,
device: Optional[torch.device] = None,
device: Union[torch.device, str, None] = None,
) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file.
@ -222,7 +222,7 @@ def format_single_result(
duration: float,
predictions: PredictionResults,
class_names: List[str],
) -> FileAnnotations:
) -> FileAnnotation:
"""Format results into the format expected by the annotation tool.
Args:
@ -399,11 +399,10 @@ def save_results_to_file(results, op_path: str) -> None:
def compute_spectrogram(
audio: np.ndarray,
sampling_rate: int,
sampling_rate: float,
params: SpectrogramParameters,
device: torch.device,
return_np: bool = False,
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
) -> Tuple[float, torch.Tensor]:
"""Compute a spectrogram from an audio array.
Will pad the audio array so that it is evenly divisible by the
@ -412,24 +411,16 @@ def compute_spectrogram(
Parameters
----------
audio : np.ndarray
sampling_rate : int
params : SpectrogramParameters
The parameters to use for generating the spectrogram.
return_np : bool, optional
Whether to return the spectrogram as a numpy array as well as a
torch tensor. The default is False.
Returns
-------
duration : float
The duration of the spectrgram in seconds.
spec : torch.Tensor
The spectrogram as a torch tensor.
spec_np : np.ndarray, optional
The spectrogram as a numpy array. Only returned if `return_np` is
True, otherwise None.
@ -446,7 +437,7 @@ def compute_spectrogram(
)
# generate spectrogram
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
spec = au.generate_spectrogram(audio, sampling_rate, params)
# convert to pytorch
spec = torch.from_numpy(spec).to(device)
@ -466,18 +457,12 @@ def compute_spectrogram(
mode="bilinear",
align_corners=False,
)
if return_np:
spec_np = spec[0, 0, :].cpu().data.numpy()
else:
spec_np = None
return duration, spec, spec_np
return duration, spec
def iterate_over_chunks(
audio: np.ndarray,
samplerate: int,
samplerate: float,
chunk_size: float,
) -> Iterator[Tuple[float, np.ndarray]]:
"""Iterate over audio in chunks of size chunk_size.
@ -510,7 +495,7 @@ def iterate_over_chunks(
def _process_spectrogram(
spec: torch.Tensor,
samplerate: int,
samplerate: float,
model: DetectionModel,
config: ProcessingConfiguration,
) -> Tuple[PredictionResults, np.ndarray]:
@ -632,13 +617,13 @@ def process_spectrogram(
def _process_audio_array(
audio: np.ndarray,
sampling_rate: int,
sampling_rate: float,
model: DetectionModel,
config: ProcessingConfiguration,
device: torch.device,
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
# load audio file and compute spectrogram
_, spec, _ = compute_spectrogram(
_, spec = compute_spectrogram(
audio,
sampling_rate,
{
@ -654,7 +639,6 @@ def _process_audio_array(
"max_scale_spec": config["max_scale_spec"],
},
device,
return_np=False,
)
# process spectrogram with model
@ -754,13 +738,15 @@ def process_file(
# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
orig_samp_rate = file_samp_rate * float(
config.get("time_expansion", 1.0) or 1.0
)
# load audio file
sampling_rate, audio_full = au.load_audio(
audio_file,
time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"],
target_sampling_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"],
max_duration=config.get("max_duration"),
)
@ -802,7 +788,6 @@ def process_file(
cnn_feats.append(features[0])
if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
# Merge results from chunks

View File

@ -152,7 +152,7 @@ def test_compute_max_power_bb(max_power: int):
target_samp_rate=samplerate,
)
spec, _ = au.generate_spectrogram(
spec = au.generate_spectrogram(
audio,
samplerate,
params,
@ -240,7 +240,7 @@ def test_compute_max_power():
target_samp_rate=samplerate,
)
spec, _ = au.generate_spectrogram(
spec = au.generate_spectrogram(
audio,
samplerate,
params,