From 0aa61af445bcb405cd2f67f07f759fba791b4f53 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Sun, 14 Jan 2024 17:15:22 +0000 Subject: [PATCH] Added types to most functions --- batdetect2/api.py | 3 +- batdetect2/detector/compute_features.py | 62 ++-- batdetect2/detector/model_helpers.py | 8 +- batdetect2/detector/models.py | 38 ++- batdetect2/detector/parameters.py | 295 +++++++++---------- batdetect2/finetune/finetune_model.py | 319 ++++++++++++--------- batdetect2/finetune/prep_data_finetune.py | 226 ++++++++------- batdetect2/train/audio_dataloader.py | 325 ++++++++++----------- batdetect2/train/losses.py | 20 +- batdetect2/train/train_model.py | 53 ++-- batdetect2/train/train_utils.py | 263 ++++++++++------- batdetect2/types.py | 128 +++++++-- batdetect2/utils/audio_utils.py | 329 ++++++++++++++-------- batdetect2/utils/detector_utils.py | 45 +-- tests/test_features.py | 4 +- 15 files changed, 1216 insertions(+), 902 deletions(-) diff --git a/batdetect2/api.py b/batdetect2/api.py index 4d04f42..7006914 100644 --- a/batdetect2/api.py +++ b/batdetect2/api.py @@ -226,11 +226,10 @@ def generate_spectrogram( if config is None: config = DEFAULT_SPECTROGRAM_PARAMETERS - _, spec, _ = du.compute_spectrogram( + _, spec = du.compute_spectrogram( audio, samp_rate, config, - return_np=False, device=device, ) diff --git a/batdetect2/detector/compute_features.py b/batdetect2/detector/compute_features.py index b53b0cb..31b06bc 100644 --- a/batdetect2/detector/compute_features.py +++ b/batdetect2/detector/compute_features.py @@ -1,5 +1,5 @@ """Functions to compute features from predictions.""" -from typing import Dict, Optional +from typing import Dict, List, Optional import numpy as np @@ -7,15 +7,26 @@ from batdetect2 import types from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ -def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq): +def convert_int_to_freq( + spec_ind: int, + spec_height: int, + min_freq: float, + max_freq: float, +) -> int: """Convert spectrogram index to frequency in Hz.""" "" spec_ind = spec_height - spec_ind - return round( - (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 + return int( + round( + (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, + 2, + ) ) -def extract_spec_slices(spec, pred_nms): +def extract_spec_slices( + spec: np.ndarray, + pred_nms: types.PredictionResults, +) -> List[np.ndarray]: """Extract spectrogram slices from spectrogram. The slices are extracted based on detected call locations. @@ -109,7 +120,7 @@ def compute_max_power_bb( return int( convert_int_to_freq( - y_high + max_power_ind, + int(y_high + max_power_ind), spec.shape[0], min_freq, max_freq, @@ -135,13 +146,11 @@ def compute_max_power( spec_call = spec[:, x_start:x_end] power_per_freq_band = np.sum(spec_call, axis=1) max_power_ind = np.argmax(power_per_freq_band) - return int( - convert_int_to_freq( - max_power_ind, - spec.shape[0], - min_freq, - max_freq, - ) + return convert_int_to_freq( + int(max_power_ind), + spec.shape[0], + min_freq, + max_freq, ) @@ -164,13 +173,11 @@ def compute_max_power_first( first_half = spec_call[:, : int(spec_call.shape[1] / 2)] power_per_freq_band = np.sum(first_half, axis=1) max_power_ind = np.argmax(power_per_freq_band) - return int( - convert_int_to_freq( - max_power_ind, - spec.shape[0], - min_freq, - max_freq, - ) + return convert_int_to_freq( + int(max_power_ind), + spec.shape[0], + min_freq, + max_freq, ) @@ -193,13 +200,11 @@ def compute_max_power_second( second_half = spec_call[:, int(spec_call.shape[1] / 2) :] power_per_freq_band = np.sum(second_half, axis=1) max_power_ind = np.argmax(power_per_freq_band) - return int( - convert_int_to_freq( - max_power_ind, - spec.shape[0], - min_freq, - max_freq, - ) + return convert_int_to_freq( + int(max_power_ind), + spec.shape[0], + min_freq, + max_freq, ) @@ -214,6 +219,7 @@ def compute_call_interval( return round(prediction["start_time"] - previous["end_time"], 5) + # NOTE: The order of the features in this dictionary is important. The # features are extracted in this order and the order of the columns in the # output csv file is determined by this order. In order to avoid breaking @@ -236,7 +242,7 @@ def get_feats( spec: np.ndarray, pred_nms: types.PredictionResults, params: types.FeatureExtractionParameters, -): +) -> np.ndarray: """Extract features from spectrogram based on detected call locations. The features extracted are: diff --git a/batdetect2/detector/model_helpers.py b/batdetect2/detector/model_helpers.py index 789bdb6..f342737 100644 --- a/batdetect2/detector/model_helpers.py +++ b/batdetect2/detector/model_helpers.py @@ -79,7 +79,13 @@ class ConvBlockDownCoordF(nn.Module): class ConvBlockDownStandard(nn.Module): def __init__( - self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1 + self, + in_chn, + out_chn, + ip_height=None, + k_size=3, + pad_size=1, + stride=1, ): super(ConvBlockDownStandard, self).__init__() self.conv = nn.Conv2d( diff --git a/batdetect2/detector/models.py b/batdetect2/detector/models.py index 56e63f3..d0251c2 100644 --- a/batdetect2/detector/models.py +++ b/batdetect2/detector/models.py @@ -103,15 +103,15 @@ class Net2DFast(nn.Module): num_filts, self.emb_dim, kernel_size=1, padding=0 ) - def forward(self, ip, return_feats=False) -> ModelOutput: + def forward(self, spec: torch.Tensor) -> ModelOutput: # encoder - x1 = self.conv_dn_0(ip) + x1 = self.conv_dn_0(spec) x2 = self.conv_dn_1(x1) x3 = self.conv_dn_2(x2) - x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) + x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) # bottleneck - x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) + x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) x = self.att(x) x = x.repeat([1, 1, self.bneck_height * 4, 1]) @@ -121,13 +121,13 @@ class Net2DFast(nn.Module): x = self.conv_up_4(x + x1) # output - x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) + x = F.relu_(self.conv_op_bn(self.conv_op(x))) cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) return ModelOutput( pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), - pred_size=F.relu(self.conv_size_op(x), inplace=True), + pred_size=F.relu(self.conv_size_op(x)), pred_class=comb, pred_class_un_norm=cls, features=x, @@ -215,26 +215,26 @@ class Net2DFastNoAttn(nn.Module): num_filts, self.emb_dim, kernel_size=1, padding=0 ) - def forward(self, ip, return_feats=False) -> ModelOutput: - x1 = self.conv_dn_0(ip) + def forward(self, spec: torch.Tensor) -> ModelOutput: + x1 = self.conv_dn_0(spec) x2 = self.conv_dn_1(x1) x3 = self.conv_dn_2(x2) - x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) + x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) - x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) + x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) x = x.repeat([1, 1, self.bneck_height * 4, 1]) x = self.conv_up_2(x + x3) x = self.conv_up_3(x + x2) x = self.conv_up_4(x + x1) - x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) + x = F.relu_(self.conv_op_bn(self.conv_op(x))) cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) return ModelOutput( pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), - pred_size=F.relu(self.conv_size_op(x), inplace=True), + pred_size=F.relu_(self.conv_size_op(x)), pred_class=comb, pred_class_un_norm=cls, features=x, @@ -324,13 +324,13 @@ class Net2DFastNoCoordConv(nn.Module): num_filts, self.emb_dim, kernel_size=1, padding=0 ) - def forward(self, ip, return_feats=False) -> ModelOutput: - x1 = self.conv_dn_0(ip) + def forward(self, spec: torch.Tensor) -> ModelOutput: + x1 = self.conv_dn_0(spec) x2 = self.conv_dn_1(x1) x3 = self.conv_dn_2(x2) - x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) + x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) - x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) + x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) x = self.att(x) x = x.repeat([1, 1, self.bneck_height * 4, 1]) @@ -338,15 +338,13 @@ class Net2DFastNoCoordConv(nn.Module): x = self.conv_up_3(x + x2) x = self.conv_up_4(x + x1) - x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) + x = F.relu_(self.conv_op_bn(self.conv_op(x))) cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) - pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,) - return ModelOutput( pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), - pred_size=F.relu(self.conv_size_op(x), inplace=True), + pred_size=F.relu_(self.conv_size_op(x)), pred_class=comb, pred_class_un_norm=cls, features=x, diff --git a/batdetect2/detector/parameters.py b/batdetect2/detector/parameters.py index 04544ed..cce641e 100644 --- a/batdetect2/detector/parameters.py +++ b/batdetect2/detector/parameters.py @@ -1,6 +1,11 @@ import datetime import os +from pathlib import Path +from typing import List, Optional, Union +from pydantic import BaseModel, Field, computed_field + +from batdetect2.train.train_utils import get_genus_mapping, get_short_class_names from batdetect2.types import ProcessingConfiguration, SpectrogramParameters TARGET_SAMPLERATE_HZ = 256000 @@ -75,158 +80,154 @@ def mk_dir(path): os.makedirs(path) -def get_params(make_dirs=False, exps_dir="../../experiments/"): - params = {} +AUG_SAMPLING_RATES = [ + 220500, + 256000, + 300000, + 312500, + 384000, + 441000, + 500000, +] +CLASSES_TO_IGNORE = ["", " ", "Unknown", "Not Bat"] +GENERIC_CLASSES = ["Bat"] +EVENTS_OF_INTEREST = ["Echolocation"] - params[ - "model_name" - ] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN - params["num_filters"] = 128 + +class TrainingParameters(BaseModel): + # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN + model_name: str = "Net2DFast" + num_filters: int = 128 + + experiment: Path + model_file_name: Path + + op_im_dir: Path + op_im_dir_test: Path + + notes: str = "" + + target_samp_rate: int = TARGET_SAMPLERATE_HZ + fft_win_length: float = FFT_WIN_LENGTH_S + fft_overlap: float = FFT_OVERLAP + + max_freq: int = MAX_FREQ_HZ + min_freq: int = MIN_FREQ_HZ + + resize_factor: float = RESIZE_FACTOR + spec_height: int = SPEC_HEIGHT + spec_train_width: int = 512 + spec_divide_factor: int = SPEC_DIVIDE_FACTOR + + denoise_spec_avg: bool = DENOISE_SPEC_AVG + scale_raw_audio: bool = SCALE_RAW_AUDIO + max_scale_spec: bool = MAX_SCALE_SPEC + spec_scale: str = SPEC_SCALE + + detection_overlap: float = 0.01 + ignore_start_end: float = 0.01 + detection_threshold: float = DETECTION_THRESHOLD + nms_kernel_size: int = NMS_KERNEL_SIZE + nms_top_k_per_sec: int = NMS_TOP_K_PER_SEC + + aug_prob: float = 0.20 + augment_at_train: bool = True + augment_at_train_combine: bool = True + echo_max_delay: float = 0.005 + stretch_squeeze_delta: float = 0.04 + mask_max_time_perc: float = 0.05 + mask_max_freq_perc: float = 0.10 + spec_amp_scaling: float = 2.0 + aug_sampling_rates: List[int] = AUG_SAMPLING_RATES + + train_loss: str = "focal" + det_loss_weight: float = 1.0 + size_loss_weight: float = 0.1 + class_loss_weight: float = 2.0 + individual_loss_weight: float = 0.0 + + lr: float = 0.001 + batch_size: int = 8 + num_workers: int = 4 + num_epochs: int = 200 + num_eval_epochs: int = 5 + device: str = "cuda" + save_test_image_during_train: bool = False + save_test_image_after_train: bool = True + + convert_to_genus: bool = False + class_names: List[str] = Field(default_factory=list) + classes_to_ignore: List[str] = Field( + default_factory=lambda: CLASSES_TO_IGNORE + ) + generic_class: List[str] = Field(default_factory=lambda: GENERIC_CLASSES) + events_of_interest: List[str] = Field( + default_factory=lambda: EVENTS_OF_INTEREST + ) + standardize_classs_names: List[str] = Field(default_factory=list) + + @computed_field + @property + def emb_dim(self) -> int: + if self.individual_loss_weight == 0.0: + return 0 + return 3 + + @computed_field + @property + def genus_mapping(self) -> List[int]: + _, mapping = get_genus_mapping(self.class_names) + return mapping + + @computed_field + @property + def genus_classes(self) -> List[str]: + names, _ = get_genus_mapping(self.class_names) + return names + + @computed_field + @property + def class_names_short(self) -> List[str]: + return get_short_class_names(self.class_names) + + +def get_params( + make_dirs: bool = False, + exps_dir: str = "../../experiments/", + model_name: Optional[str] = None, + experiment: Union[Path, str, None] = None, + **kwargs, +) -> TrainingParameters: + experiments_dir = Path(exps_dir) now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S") - model_name = now_str + ".pth.tar" - params["experiment"] = os.path.join(exps_dir, now_str, "") - params["model_file_name"] = os.path.join(params["experiment"], model_name) - params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "") - params["op_im_dir_test"] = os.path.join( - params["experiment"], "op_ims_test", "" + + if model_name is None: + model_name = f"{now_str}.pth.tar" + + if experiment is None: + experiment = experiments_dir / now_str + experiment = Path(experiment) + + model_file_name = experiment / model_name + op_ims_dir = experiment / "op_ims" + op_ims_test_dir = experiment / "op_ims_test" + + params = TrainingParameters( + model_name=model_name, + experiment=experiment, + model_file_name=model_file_name, + op_im_dir=op_ims_dir, + op_im_dir_test=op_ims_test_dir, + **kwargs, ) - # params['notes'] = '' # can save notes about an experiment here - # spec parameters - params[ - "target_samp_rate" - ] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate - params[ - "fft_win_length" - ] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step - params["fft_overlap"] = FFT_OVERLAP # stft window overlap - - params[ - "max_freq" - ] = MAX_FREQ_HZ # in Hz, everything above this will be discarded - params[ - "min_freq" - ] = MIN_FREQ_HZ # in Hz, everything below this will be discarded - - params[ - "resize_factor" - ] = RESIZE_FACTOR # resize so the spectrogram at the input of the network - params[ - "spec_height" - ] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed) - params[ - "spec_train_width" - ] = 512 # units are number of time steps (before resizing is performed) - params[ - "spec_divide_factor" - ] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height - - # spec processing params - params[ - "denoise_spec_avg" - ] = DENOISE_SPEC_AVG # removes the mean for each frequency band - params[ - "scale_raw_audio" - ] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1] - params[ - "max_scale_spec" - ] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1 - params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none' - - # detection params - params[ - "detection_overlap" - ] = 0.01 # has to be within this number of ms to count as detection - params[ - "ignore_start_end" - ] = 0.01 # if start of GT calls are within this time from the start/end of file ignore - params[ - "detection_threshold" - ] = DETECTION_THRESHOLD # the smaller this is the better the recall will be - params[ - "nms_kernel_size" - ] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression - params[ - "nms_top_k_per_sec" - ] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio - params["target_sigma"] = 2.0 - - # augmentation params - params[ - "aug_prob" - ] = 0.20 # augmentations will be performed with this probability - params["augment_at_train"] = True - params["augment_at_train_combine"] = True - params[ - "echo_max_delay" - ] = 0.005 # simulate echo by adding copy of raw audio - params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec - params[ - "mask_max_time_perc" - ] = 0.05 # max mask size - here percentage, not ideal - params[ - "mask_max_freq_perc" - ] = 0.10 # max mask size - here percentage, not ideal - params[ - "spec_amp_scaling" - ] = 2.0 # multiply the "volume" by 0:X times current amount - params["aug_sampling_rates"] = [ - 220500, - 256000, - 300000, - 312500, - 384000, - 441000, - 500000, - ] - - # loss params - params["train_loss"] = "focal" # mse or focal - params["det_loss_weight"] = 1.0 # weight for the detection part of the loss - params["size_loss_weight"] = 0.1 # weight for the bbox size loss - params["class_loss_weight"] = 2.0 # weight for the classification loss - params["individual_loss_weight"] = 0.0 # not used - if params["individual_loss_weight"] == 0.0: - params[ - "emb_dim" - ] = 0 # number of dimensions used for individual id embedding - else: - params["emb_dim"] = 3 - - # train params - params["lr"] = 0.001 - params["batch_size"] = 8 - params["num_workers"] = 4 - params["num_epochs"] = 200 - params["num_eval_epochs"] = 5 # run evaluation every X epochs - params["device"] = "cuda" - params["save_test_image_during_train"] = False - params["save_test_image_after_train"] = True - - params["convert_to_genus"] = False - params["genus_mapping"] = [] - params["class_names"] = [] - params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"] - params["generic_class"] = ["Bat"] - params["events_of_interest"] = [ - "Echolocation" - ] # will ignore all other types of events e.g. social calls - - # the classes in this list are standardized during training so that the same low and high freq are used - params["standardize_classs_names"] = [] - - # create directories if make_dirs: - print("Model name : " + params["model_name"]) - print("Model file : " + params["model_file_name"]) - print("Experiment : " + params["experiment"]) - - mk_dir(params["experiment"]) - if params["save_test_image_during_train"]: - mk_dir(params["op_im_dir"]) - if params["save_test_image_after_train"]: - mk_dir(params["op_im_dir_test"]) - mk_dir(os.path.dirname(params["model_file_name"])) + mk_dir(experiment) + mk_dir(params.model_file_name.parent) + if params.save_test_image_during_train: + mk_dir(params.op_im_dir) + if params.save_test_image_after_train: + mk_dir(params.op_im_dir_test) return params diff --git a/batdetect2/finetune/finetune_model.py b/batdetect2/finetune/finetune_model.py index 77a2711..4c2d1c2 100644 --- a/batdetect2/finetune/finetune_model.py +++ b/batdetect2/finetune/finetune_model.py @@ -1,33 +1,31 @@ import argparse -import glob -import json import os -import sys +import warnings +from typing import List, Optional -import matplotlib.pyplot as plt -import numpy as np import torch -import torch.nn.functional as F +import torch.utils.data from torch.optim.lr_scheduler import CosineAnnealingLR -import batdetect2.detector.models as models import batdetect2.detector.parameters as parameters -import batdetect2.detector.post_process as pp import batdetect2.train.audio_dataloader as adl -import batdetect2.train.evaluate as evl import batdetect2.train.losses as losses import batdetect2.train.train_model as tm import batdetect2.train.train_utils as tu import batdetect2.utils.detector_utils as du import batdetect2.utils.plot_utils as pu +from batdetect2 import types +from batdetect2.detector.models import Net2DFast -if __name__ == "__main__": - info_str = "\nBatDetect - Finetune Model\n" +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - print(info_str) + +def parse_arugments(): parser = argparse.ArgumentParser() parser.add_argument( - "audio_path", type=str, help="Input directory for audio" + "audio_path", + type=str, + help="Input directory for audio", ) parser.add_argument( "train_ann_path", @@ -39,7 +37,15 @@ if __name__ == "__main__": type=str, help="Path to where test annotation file is stored", ) - parser.add_argument("model_path", type=str, help="Path to pretrained model") + parser.add_argument( + "model_path", type=str, help="Path to pretrained model" + ) + parser.add_argument( + "--experiment_dir", + type=str, + default=os.path.join(BASE_DIR, "experiments"), + help="Path to where experiment files are stored", + ) parser.add_argument( "--op_model_name", type=str, @@ -71,107 +77,63 @@ if __name__ == "__main__": parser.add_argument( "--notes", type=str, default="", help="Notes to save in text file" ) - args = vars(parser.parse_args()) + args = parser.parse_args() + return args - params = parameters.get_params(True, "../../experiments/") + +def select_device(warn=True) -> str: if torch.cuda.is_available(): - params["device"] = "cuda" - else: - params["device"] = "cpu" - print( - "\nNote, this will be a lot faster if you use computer with a GPU.\n" + return "cuda" + + if warn: + warnings.warn( + "No GPU available, using the CPU instead. Please consider using a GPU " + "to speed up training." ) - print("\nAudio directory: " + args["audio_path"]) - print("Train file: " + args["train_ann_path"]) - print("Test file: " + args["test_ann_path"]) - print("Loading model: " + args["model_path"]) + return "cpu" - dataset_name = ( - os.path.basename(args["train_ann_path"]) - .replace(".json", "") - .replace("_TRAIN", "") - ) - if args["train_from_scratch"]: - print("\nTraining model from scratch i.e. not using pretrained weights") - model, params_train = du.load_model(args["model_path"], False) - else: - model, params_train = du.load_model(args["model_path"], True) - model.to(params["device"]) - - params["num_epochs"] = args["num_epochs"] - if args["op_model_name"] != "": - params["model_file_name"] = args["op_model_name"] - classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] - - # save notes file - params["notes"] = args["notes"] - if args["notes"] != "": - tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"]) - - # load train annotations - train_sets = [] +def load_annotations( + dataset_name: str, + ann_path: str, + audio_path: str, + classes_to_ignore: Optional[List[str]] = None, + events_of_interest: Optional[List[str]] = None, +) -> List[types.FileAnnotation]: + train_sets: List[types.DatasetDict] = [] train_sets.append( - tu.get_blank_dataset_dict( - dataset_name, False, args["train_ann_path"], args["audio_path"] - ) - ) - params["train_sets"] = [ tu.get_blank_dataset_dict( dataset_name, - False, - os.path.basename(args["train_ann_path"]), - args["audio_path"], - ) - ] - - print("\nTrain set:") - ( - data_train, - params["class_names"], - params["class_inv_freq"], - ) = tu.load_set_of_anns( - train_sets, classes_to_ignore, params["events_of_interest"] - ) - print("Number of files", len(data_train)) - - params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping( - params["class_names"] - ) - params["class_names_short"] = tu.get_short_class_names( - params["class_names"] - ) - - # load test annotations - test_sets = [] - test_sets.append( - tu.get_blank_dataset_dict( - dataset_name, True, args["test_ann_path"], args["audio_path"] + is_test=False, + ann_path=ann_path, + wav_path=audio_path, ) ) - params["test_sets"] = [ - tu.get_blank_dataset_dict( - dataset_name, - True, - os.path.basename(args["test_ann_path"]), - args["audio_path"], - ) - ] - print("\nTest set:") - data_test, _, _ = tu.load_set_of_anns( - test_sets, classes_to_ignore, params["events_of_interest"] + return tu.load_set_of_anns( + train_sets, + events_of_interest=events_of_interest, + classes_to_ignore=classes_to_ignore, ) - print("Number of files", len(data_test)) + +def finetune_model( + model: types.DetectionModel, + data_train: List[types.FileAnnotation], + data_test: List[types.FileAnnotation], + params: parameters.TrainingParameters, + model_params: types.ModelParameters, + finetune_only_last_layer: bool = False, + save_images: bool = True, +): # train loader train_dataset = adl.AudioLoader(data_train, params, is_train=True) train_loader = torch.utils.data.DataLoader( train_dataset, - batch_size=params["batch_size"], + batch_size=params.batch_size, shuffle=True, - num_workers=params["num_workers"], + num_workers=params.num_workers, pin_memory=True, ) @@ -181,32 +143,36 @@ if __name__ == "__main__": test_dataset, batch_size=1, shuffle=False, - num_workers=params["num_workers"], + num_workers=params.num_workers, pin_memory=True, ) inputs_train = next(iter(train_loader)) - params["ip_height"] = inputs_train["spec"].shape[2] + params.ip_height = inputs_train["spec"].shape[2] print("\ntrain batch size :", inputs_train["spec"].shape) - assert params_train["model_name"] == "Net2DFast" + # Check that the model is the same as the one used to train the pretrained + # weights + assert model_params["model_name"] == "Net2DFast" + assert isinstance(model, Net2DFast) print( - "\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n" + "\n\nSOME hyperparams need to be the same as the loaded model " + "(e.g. FFT) - currently they are getting overwritten.\n\n" ) # set the number of output classes num_filts = model.conv_classes_op.in_channels - k_size = model.conv_classes_op.kernel_size - pad = model.conv_classes_op.padding + (k_size,) = model.conv_classes_op.kernel_size + (pad,) = model.conv_classes_op.padding model.conv_classes_op = torch.nn.Conv2d( num_filts, - len(params["class_names"]) + 1, + len(params.class_names) + 1, kernel_size=k_size, padding=pad, ) - model.conv_classes_op.to(params["device"]) + model.conv_classes_op.to(params.device) - if args["finetune_only_last_layer"]: + if finetune_only_last_layer: print("\nOnly finetuning the final layers.\n") train_layers_i = [ "conv_classes", @@ -223,19 +189,26 @@ if __name__ == "__main__": else: param.requires_grad = False - optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"]) - scheduler = CosineAnnealingLR( - optimizer, params["num_epochs"] * len(train_loader) + optimizer = torch.optim.Adam( + model.parameters(), + lr=params.lr, ) - if params["train_loss"] == "mse": + scheduler = CosineAnnealingLR( + optimizer, + params.num_epochs * len(train_loader), + ) + + if params.train_loss == "mse": det_criterion = losses.mse_loss - elif params["train_loss"] == "focal": + elif params.train_loss == "focal": det_criterion = losses.focal_loss + else: + raise ValueError("Unknown loss function") # plotting train_plt_ls = pu.LossPlotter( - params["experiment"] + "train_loss.png", - params["num_epochs"] + 1, + params.experiment / "train_loss.png", + params.num_epochs + 1, ["train_loss"], None, None, @@ -243,8 +216,8 @@ if __name__ == "__main__": logy=True, ) test_plt_ls = pu.LossPlotter( - params["experiment"] + "test_loss.png", - params["num_epochs"] + 1, + params.experiment / "test_loss.png", + params.num_epochs + 1, ["test_loss"], None, None, @@ -252,24 +225,24 @@ if __name__ == "__main__": logy=True, ) test_plt = pu.LossPlotter( - params["experiment"] + "test.png", - params["num_epochs"] + 1, + params.experiment / "test.png", + params.num_epochs + 1, ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"], [0, 1], None, ["epoch", ""], ) test_plt_class = pu.LossPlotter( - params["experiment"] + "test_avg_prec.png", - params["num_epochs"] + 1, - params["class_names_short"], + params.experiment / "test_avg_prec.png", + params.num_epochs + 1, + params.class_names_short, [0, 1], - params["class_names_short"], + params.class_names_short, ["epoch", "avg_prec"], ) # main train loop - for epoch in range(0, params["num_epochs"] + 1): + for epoch in range(0, params.num_epochs + 1): train_loss = tm.train( model, epoch, @@ -281,10 +254,14 @@ if __name__ == "__main__": ) train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]]) - if epoch % params["num_eval_epochs"] == 0: + if epoch % params.num_eval_epochs == 0: # detection accuracy on test set test_res, test_loss = tm.test( - model, epoch, test_loader, det_criterion, params + model, + epoch, + test_loader, + det_criterion, + params, ) test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]]) test_plt.update_and_save( @@ -301,18 +278,106 @@ if __name__ == "__main__": epoch, [rs["avg_prec"] for rs in test_res["class_pr"]] ) pu.plot_pr_curve_class( - params["experiment"], "test_pr", "test_pr", test_res + params.experiment, "test_pr", "test_pr", test_res ) # save finetuned model - print("saving model to: " + params["model_file_name"]) + print(f"saving model to: {params.model_file_name}") op_state = { "epoch": epoch + 1, "state_dict": model.state_dict(), "params": params, } - torch.save(op_state, params["model_file_name"]) + torch.save(op_state, params.model_file_name) # save an image with associated prediction for each batch in the test set - if not args["do_not_save_images"]: + if save_images: tm.save_images_batch(model, test_loader, params) + + +def main(): + info_str = "\nBatDetect - Finetune Model\n" + print(info_str) + + args = parse_arugments() + + # Load experiment parameters + params = parameters.get_params( + make_dirs=True, + exps_dir=args.experiment_dir, + device=select_device(), + num_epochs=args.num_epochs, + notes=args.notes, + ) + + print("\nAudio directory: " + args.audio_path) + print("Train file: " + args.train_ann_path) + print("Test file: " + args.test_ann_path) + print("Loading model: " + args.model_path) + + if args.train_from_scratch: + print( + "\nTraining model from scratch i.e. not using pretrained weights" + ) + + model, model_params = du.load_model( + args.model_path, + load_weights=not args.train_from_scratch, + device=params.device, + ) + + if args.op_model_name != "": + params.model_file_name = args.op_model_name + + classes_to_ignore = params.classes_to_ignore + params.generic_class + + # save notes file + if params.notes: + tu.write_notes_file( + params.experiment / "notes.txt", + args.notes, + ) + + # NOTE:?? + dataset_name = ( + os.path.basename(args.train_ann_path) + .replace(".json", "") + .replace("_TRAIN", "") + ) + + # ==== LOAD DATA ==== + + # load train annotations + data_train = load_annotations( + dataset_name, + args.train_ann_path, + args.audio_path, + params.events_of_interest, + ) + print("\nTrain set:") + print("Number of files", len(data_train)) + + # load test annotations + data_test = load_annotations( + dataset_name, + args.test_ann_path, + args.audio_path, + classes_to_ignore, + params.events_of_interest, + ) + print("\nTrain set:") + print("Number of files", len(data_train)) + + finetune_model( + model, + data_train, + data_test, + params, + model_params, + finetune_only_last_layer=args.finetune_only_last_layer, + save_images=args.do_not_save_images, + ) + + +if __name__ == "__main__": + main() diff --git a/batdetect2/finetune/prep_data_finetune.py b/batdetect2/finetune/prep_data_finetune.py index 11702a9..1aea005 100644 --- a/batdetect2/finetune/prep_data_finetune.py +++ b/batdetect2/finetune/prep_data_finetune.py @@ -1,62 +1,54 @@ import argparse import json import os +from collections import Counter +from typing import List, Optional, Tuple import numpy as np +from sklearn.model_selection import StratifiedGroupKFold import batdetect2.train.train_utils as tu +from batdetect2 import types -def print_dataset_stats(data, split_name, classes_to_ignore): - print("\nSplit:", split_name) +def print_dataset_stats( + data: List[types.FileAnnotation], + classes_to_ignore: Optional[List[str]] = None, +) -> Counter[str]: print("Num files:", len(data)) - - class_cnts = {} - for dd in data: - for aa in dd["annotation"]: - if aa["class"] not in classes_to_ignore: - if aa["class"] in class_cnts: - class_cnts[aa["class"]] += 1 - else: - class_cnts[aa["class"]] = 1 - - if len(class_cnts) == 0: - class_names = [] - else: - class_names = np.sort([*class_cnts]).tolist() - print("Class count:") - str_len = np.max([len(cc) for cc in class_names]) + 5 - - for ii, cc in enumerate(class_names): - print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc])) - - return class_names + counts, _ = tu.get_class_names(data, classes_to_ignore) + if len(counts) > 0: + tu.report_class_counts(counts) + return counts -def load_file_names(file_name): - if os.path.isfile(file_name): - with open(file_name) as da: - files = [line.rstrip() for line in da.readlines()] - for ff in files: - if ff.lower()[-3:] != "wav": - print("Error: Filenames need to end in .wav - ", ff) - assert False - else: - print("Error: Input file not found - ", file_name) - assert False +def load_file_names(file_name: str) -> List[str]: + if not os.path.isfile(file_name): + raise FileNotFoundError(f"Input file not found - {file_name}") + + with open(file_name) as da: + files = [line.rstrip() for line in da.readlines()] + + for path in files: + if path.lower()[-3:] != "wav": + raise ValueError( + f"Invalid file name - {path}. Must be a .wav file" + ) return files -if __name__ == "__main__": +def parse_args(): info_str = "\nBatDetect - Prepare Data for Finetuning\n" - print(info_str) + parser = argparse.ArgumentParser() parser.add_argument( "dataset_name", type=str, help="Name to call your dataset" ) - parser.add_argument("audio_dir", type=str, help="Input directory for audio") + parser.add_argument( + "audio_dir", type=str, help="Input directory for audio" + ) parser.add_argument( "ann_dir", type=str, @@ -102,88 +94,126 @@ if __name__ == "__main__": type=str, default="", help='New class names to use instead. One to one mapping with "--input_class_names". \ - Separate with ";"', + Separate with ";"', ) - args = vars(parser.parse_args()) + return parser.parse_args() - np.random.seed(args["rand_seed"]) + +def split_data( + data: List[types.FileAnnotation], + train_file: str, + test_file: str, + n_splits: int = 5, + random_state: int = 0, +) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]: + if train_file != "" and test_file != "": + # user has specifed the train / test split + mapping = { + file_annotation["id"]: file_annotation for file_annotation in data + } + train_files = load_file_names(train_file) + test_files = load_file_names(test_file) + data_train = [ + mapping[file_id] for file_id in train_files if file_id in mapping + ] + data_test = [ + mapping[file_id] for file_id in test_files if file_id in mapping + ] + return data_train, data_test + + # NOTE: Using StratifiedGroupKFold to ensure that the same file does not + # appear in both the training and test sets and trying to keep the + # distribution of classes the same in both sets. + splitter = StratifiedGroupKFold( + n_splits=n_splits, + shuffle=True, + random_state=random_state, + ) + anns = np.array( + [ + [dd["id"], ann["class"], ann["event"]] + for dd in data + for ann in dd["annotation"] + ] + ) + y = anns[:, 1] + group = anns[:, 0] + + train_idx, test_idx = next(splitter.split(X=anns, y=y, groups=group)) + train_ids = set(anns[train_idx, 0]) + test_ids = set(anns[test_idx, 0]) + + assert not (train_ids & test_ids) + data_train = [dd for dd in data if dd["id"] in train_ids] + data_test = [dd for dd in data if dd["id"] in test_ids] + return data_train, data_test + + +def main(): + args = parse_args() + + np.random.seed(args.rand_seed) classes_to_ignore = ["", " ", "Unknown", "Not Bat"] - generic_class = ["Bat"] events_of_interest = ["Echolocation"] - if args["input_class_names"] != "" and args["output_class_names"] != "": + name_dict = None + if args.input_class_names != "" and args.output_class_names != "": # change the names of the classes - ip_names = args["input_class_names"].split(";") - op_names = args["output_class_names"].split(";") + ip_names = args.input_class_names.split(";") + op_names = args.output_class_names.split(";") name_dict = dict(zip(ip_names, op_names)) - else: - name_dict = False # load annotations - data_all, _, _ = tu.load_set_of_anns( - {"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]}, - classes_to_ignore, - events_of_interest, - False, - False, - list_of_anns=True, + data_all = tu.load_set_of_anns( + [ + { + "dataset_name": args.dataset_name, + "ann_path": args.ann_dir, + "wav_path": args.audio_dir, + "is_test": False, + "is_binary": False, + } + ], + classes_to_ignore=classes_to_ignore, + events_of_interest=events_of_interest, + convert_to_genus=False, filter_issues=True, name_replace=name_dict, ) - print("Dataset name: " + args["dataset_name"]) - print("Audio directory: " + args["audio_dir"]) - print("Annotation directory: " + args["ann_dir"]) - print("Ouput directory: " + args["op_dir"]) + print("Dataset name: " + args.dataset_name) + print("Audio directory: " + args.audio_dir) + print("Annotation directory: " + args.ann_dir) + print("Ouput directory: " + args.op_dir) print("Num annotated files: " + str(len(data_all))) - if args["train_file"] != "" and args["test_file"] != "": - # user has specifed the train / test split - train_files = load_file_names(args["train_file"]) - test_files = load_file_names(args["test_file"]) - file_names_all = [dd["id"] for dd in data_all] - train_inds = [ - file_names_all.index(ff) - for ff in train_files - if ff in file_names_all - ] - test_inds = [ - file_names_all.index(ff) - for ff in test_files - if ff in file_names_all - ] + data_train, data_test = split_data( + data=data_all, + train_file=args.train_file, + test_file=args.test_file, + n_splits=5, + random_state=args.rand_seed, + ) - else: - # split the data into train and test at the file level - num_exs = len(data_all) - test_inds = np.random.choice( - np.arange(num_exs), - int(num_exs * args["percent_val"]), - replace=False, - ) - test_inds = np.sort(test_inds) - train_inds = np.setdiff1d(np.arange(num_exs), test_inds) - - data_train = [data_all[ii] for ii in train_inds] - data_test = [data_all[ii] for ii in test_inds] - - if not os.path.isdir(args["op_dir"]): - os.makedirs(args["op_dir"]) - op_name = os.path.join(args["op_dir"], args["dataset_name"]) + if not os.path.isdir(args.op_dir): + os.makedirs(args.op_dir) + op_name = os.path.join(args.op_dir, args.dataset_name) op_name_train = op_name + "_TRAIN.json" op_name_test = op_name + "_TEST.json" - class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore) - class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore) + print("\nSplit: Train") + class_un_train = print_dataset_stats(data_train, classes_to_ignore) + + print("\nSplit: Test") + class_un_test = print_dataset_stats(data_test, classes_to_ignore) if len(data_train) > 0 and len(data_test) > 0: - if class_un_train != class_un_test: - print( - '\nError: some classes are not in both the training and test sets.\ - \nTry a different random seed "--rand_seed".' + if set(class_un_train.keys()) != set(class_un_test.keys()): + raise RuntimeError( + "Error: some classes are not in both the training and test sets." + 'Try a different random seed "--rand_seed".' ) - assert False print("\n") if len(data_train) == 0: @@ -199,3 +229,7 @@ if __name__ == "__main__": print("Saving: ", op_name_test) with open(op_name_test, "w") as da: json.dump(data_test, da, indent=2) + + +if __name__ == "__main__": + main() diff --git a/batdetect2/train/audio_dataloader.py b/batdetect2/train/audio_dataloader.py index 68d86d4..97fbc76 100644 --- a/batdetect2/train/audio_dataloader.py +++ b/batdetect2/train/audio_dataloader.py @@ -12,19 +12,24 @@ import torchaudio import batdetect2.utils.audio_utils as au from batdetect2.types import ( Annotation, - AnnotationGroup, AudioLoaderAnnotationGroup, - FileAnnotations, - HeatmapParameters, + AudioLoaderParameters, + FileAnnotation, ) def generate_gt_heatmaps( spec_op_shape: Tuple[int, int], - sampling_rate: int, - ann: AnnotationGroup, - params: HeatmapParameters, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]: + sampling_rate: float, + ann: AudioLoaderAnnotationGroup, + class_names: List[str], + fft_win_length: float, + fft_overlap: float, + max_freq: float, + min_freq: float, + resize_factor: float, + target_sigma: float, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AudioLoaderAnnotationGroup]: """Generate ground truth heatmaps from annotations. Parameters @@ -53,31 +58,31 @@ def generate_gt_heatmaps( the x and y indices of their pixel location in the input spectrogram. """ # spec may be resized on input into the network - num_classes = len(params["class_names"]) + num_classes = len(class_names) op_height = spec_op_shape[0] op_width = spec_op_shape[1] - freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height + freq_per_bin = (max_freq - min_freq) / op_height # start and end times x_pos_start = au.time_to_x_coords( ann["start_times"], sampling_rate, - params["fft_win_length"], - params["fft_overlap"], + fft_win_length, + fft_overlap, ) - x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int32) + x_pos_start = (resize_factor * x_pos_start).astype(np.int32) x_pos_end = au.time_to_x_coords( ann["end_times"], sampling_rate, - params["fft_win_length"], - params["fft_overlap"], + fft_win_length, + fft_overlap, ) - x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int32) + x_pos_end = (resize_factor * x_pos_end).astype(np.int32) # location on y axis i.e. frequency - y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin + y_pos_low = (ann["low_freqs"] - min_freq) / freq_per_bin y_pos_low = (op_height - y_pos_low).astype(np.int32) - y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin + y_pos_high = (ann["high_freqs"] - min_freq) / freq_per_bin y_pos_high = (op_height - y_pos_high).astype(np.int32) bb_widths = x_pos_end - x_pos_start bb_heights = y_pos_low - y_pos_high @@ -90,26 +95,17 @@ def generate_gt_heatmaps( & (y_pos_low < (op_height - 1)) )[0] - ann_aug: AnnotationGroup = { + ann_aug: AudioLoaderAnnotationGroup = { + **ann, "start_times": ann["start_times"][valid_inds], "end_times": ann["end_times"][valid_inds], "high_freqs": ann["high_freqs"][valid_inds], "low_freqs": ann["low_freqs"][valid_inds], "class_ids": ann["class_ids"][valid_inds], "individual_ids": ann["individual_ids"][valid_inds], + "x_inds": x_pos_start[valid_inds], + "y_inds": y_pos_low[valid_inds], } - ann_aug["x_inds"] = x_pos_start[valid_inds] - ann_aug["y_inds"] = y_pos_low[valid_inds] - # keys = [ - # "start_times", - # "end_times", - # "high_freqs", - # "low_freqs", - # "class_ids", - # "individual_ids", - # ] - # for kk in keys: - # ann_aug[kk] = ann[kk][valid_inds] # if the number of calls is only 1, then it is unique # TODO would be better if we found these unique calls at the merging stage @@ -118,6 +114,7 @@ def generate_gt_heatmaps( y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32) y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32) + # num classes and "background" class y_2d_classes: np.ndarray = np.zeros( (num_classes + 1, op_height, op_width), dtype=np.float32 @@ -128,14 +125,8 @@ def generate_gt_heatmaps( draw_gaussian( y_2d_det[0, :], (x_pos_start[ii], y_pos_low[ii]), - params["target_sigma"], + target_sigma, ) - # draw_gaussian( - # y_2d_det[0, :], - # (x_pos_start[ii], y_pos_low[ii]), - # params["target_sigma"], - # params["target_sigma"] * 2, - # ) y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii] y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii] @@ -144,14 +135,8 @@ def generate_gt_heatmaps( draw_gaussian( y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), - params["target_sigma"], + target_sigma, ) - # draw_gaussian( - # y_2d_classes[cls_id, :], - # (x_pos_start[ii], y_pos_low[ii]), - # params["target_sigma"], - # params["target_sigma"] * 2, - # ) # be careful as this will have a 1.0 places where we have event but # dont know gt class this will be masked in training anyway @@ -235,8 +220,8 @@ def pad_aray(ip_array: np.ndarray, pad_size: int) -> np.ndarray: def warp_spec_aug( spec: torch.Tensor, - ann: AnnotationGroup, - params: dict, + ann: AudioLoaderAnnotationGroup, + stretch_squeeze_delta: float, ) -> torch.Tensor: """Warp spectrogram by randomly stretching and squeezing. @@ -247,8 +232,8 @@ def warp_spec_aug( ann: AnnotationGroup Annotation group for the spectrogram. Must be provided to sync the start and stop times with the spectrogram after warping. - params: dict - Parameters for the augmentation. + stretch_squeeze_delta: float + Maximum amount to stretch or squeeze the spectrogram. Returns ------- @@ -259,11 +244,10 @@ def warp_spec_aug( ----- This function modifies the annotation group in place. """ - # This is messy # Augment spectrogram by randomly stretch and squeezing # NOTE this also changes the start and stop time in place - delta = params["stretch_squeeze_delta"] + delta = stretch_squeeze_delta op_size = (spec.shape[1], spec.shape[2]) resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0 resize_amt = int(spec.shape[2] * resize_fract_r) @@ -277,7 +261,7 @@ def warp_spec_aug( dtype=spec.dtype, ), ), - 2, + dim=2, ) else: spec_r = spec[:, :, :resize_amt] @@ -297,7 +281,10 @@ def warp_spec_aug( return spec -def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: +def mask_time_aug( + spec: torch.Tensor, + mask_max_time_perc: float, +) -> torch.Tensor: """Mask out random blocks of time. Will randomly mask out a block of time in the spectrogram. The block @@ -308,8 +295,8 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: ---------- spec: torch.Tensor Spectrogram to mask. - params: dict - Parameters for the augmentation. + mask_max_time_perc: float + Maximum percentage of time to mask out. Returns ------- @@ -324,14 +311,17 @@ def mask_time_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: Recognition """ fm = torchaudio.transforms.TimeMasking( - int(spec.shape[1] * params["mask_max_time_perc"]) + int(spec.shape[1] * mask_max_time_perc) ) for _ in range(np.random.randint(1, 4)): spec = fm(spec) return spec -def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: +def mask_freq_aug( + spec: torch.Tensor, + mask_max_freq_perc: float, +) -> torch.Tensor: """Mask out random blocks of frequency. Will randomly mask out a block of frequency in the spectrogram. The block @@ -342,8 +332,8 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: ---------- spec: torch.Tensor Spectrogram to mask. - params: dict - Parameters for the augmentation. + mask_max_freq_perc: float + Maximum percentage of frequency to mask out. Returns ------- @@ -358,41 +348,48 @@ def mask_freq_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: Recognition """ fm = torchaudio.transforms.FrequencyMasking( - int(spec.shape[1] * params["mask_max_freq_perc"]) + int(spec.shape[1] * mask_max_freq_perc) ) for _ in range(np.random.randint(1, 4)): spec = fm(spec) return spec -def scale_vol_aug(spec: torch.Tensor, params: dict) -> torch.Tensor: +def scale_vol_aug( + spec: torch.Tensor, + spec_amp_scaling: float, +) -> torch.Tensor: """Scale the volume of the spectrogram. Parameters ---------- spec: torch.Tensor Spectrogram to scale. - params: dict - Parameters for the augmentation. + spec_amp_scaling: float + Maximum scaling factor. Returns ------- torch.Tensor """ - return spec * np.random.random() * params["spec_amp_scaling"] + return spec * np.random.random() * spec_amp_scaling -def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray: +def echo_aug( + audio: np.ndarray, + sampling_rate: float, + echo_max_delay: float, +) -> np.ndarray: """Add echo to audio. Parameters ---------- audio: np.ndarray Audio to add echo to. - sampling_rate: int + sampling_rate: float Sampling rate of the audio. - params: dict - Parameters for the augmentation. + echo_max_delay: float + Maximum delay of the echo in seconds. Returns ------- @@ -400,7 +397,7 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray: Audio with echo added. """ sample_offset = ( - int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1 + int(echo_max_delay * np.random.random() * sampling_rate) + 1 ) audio[:-sample_offset] += np.random.random() * audio[sample_offset:] return audio @@ -408,9 +405,14 @@ def echo_aug(audio: np.ndarray, sampling_rate: int, params: dict) -> np.ndarray: def resample_aug( audio: np.ndarray, - sampling_rate: int, - params: dict, -) -> Tuple[np.ndarray, int, float]: + sampling_rate: float, + fft_win_length: float, + fft_overlap: float, + resize_factor: float, + spec_divide_factor: float, + spec_train_width: int, + aug_sampling_rates: List[int], +) -> Tuple[np.ndarray, float, float]: """Resample audio augmentation. Will resample the audio to a random sampling rate from the list of @@ -420,23 +422,32 @@ def resample_aug( ---------- audio: np.ndarray Audio to resample. - sampling_rate: int + sampling_rate: float Original sampling rate of the audio. - params: dict - Parameters for the augmentation. Includes the list of sampling rates - to choose from for resampling in `aug_sampling_rates`. + fft_win_length: float + Length of the FFT window in seconds. + fft_overlap: float + Amount of overlap between FFT windows. + resize_factor: float + Factor to resize the spectrogram by. + spec_divide_factor: float + Factor to divide the spectrogram by. + spec_train_width: int + Width of the spectrogram. + aug_sampling_rates: List[int] + List of sampling rates to resample to. Returns ------- audio : np.ndarray Resampled audio. - sampling_rate : int + sampling_rate : float New sampling rate. duration : float Duration of the audio in seconds. """ sampling_rate_old = sampling_rate - sampling_rate = np.random.choice(params["aug_sampling_rates"]) + sampling_rate = np.random.choice(aug_sampling_rates) audio = librosa.resample( audio, orig_sr=sampling_rate_old, @@ -447,11 +458,11 @@ def resample_aug( audio = au.pad_audio( audio, sampling_rate, - params["fft_win_length"], - params["fft_overlap"], - params["resize_factor"], - params["spec_divide_factor"], - params["spec_train_width"], + fft_win_length, + fft_overlap, + resize_factor, + spec_divide_factor, + spec_train_width, ) duration = audio.shape[0] / float(sampling_rate) return audio, sampling_rate, duration @@ -459,28 +470,28 @@ def resample_aug( def resample_audio( num_samples: int, - sampling_rate: int, + sampling_rate: float, audio2: np.ndarray, - sampling_rate2: int, -) -> Tuple[np.ndarray, int]: + sampling_rate2: float, +) -> Tuple[np.ndarray, float]: """Resample audio. Parameters ---------- num_samples: int Expected number of samples for the output audio. - sampling_rate: int + sampling_rate: float Original sampling rate of the audio. audio2: np.ndarray Audio to resample. - sampling_rate2: int + sampling_rate2: float Target sampling rate of the audio. Returns ------- audio2 : np.ndarray Resampled audio. - sampling_rate2 : int + sampling_rate2 : float New sampling rate. """ # resample to target sampling rate @@ -509,12 +520,12 @@ def resample_audio( def combine_audio_aug( audio: np.ndarray, - sampling_rate: int, - ann: AnnotationGroup, + sampling_rate: float, + ann: AudioLoaderAnnotationGroup, audio2: np.ndarray, - sampling_rate2: int, - ann2: AnnotationGroup, -) -> Tuple[np.ndarray, AnnotationGroup]: + sampling_rate2: float, + ann2: AudioLoaderAnnotationGroup, +) -> Tuple[np.ndarray, AudioLoaderAnnotationGroup]: """Combine two audio files. Will combine two audio files by resampling them to the same sampling rate @@ -570,7 +581,9 @@ def combine_audio_aug( # from different individuals if kk == "individual_ids": if (ann[kk] > -1).sum() > 0: - ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1 + ann2[kk][ann2[kk] > -1] += ( + np.max(ann[kk][ann[kk] > -1]) + 1 + ) if (kk != "class_id_file") and (kk != "annotated"): ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds] @@ -579,7 +592,8 @@ def combine_audio_aug( def _prepare_annotation( - annotation: Annotation, class_names: List[str] + annotation: Annotation, + class_names: List[str], ) -> Annotation: try: class_id = class_names.index(annotation["class"]) @@ -598,7 +612,7 @@ def _prepare_annotation( def _prepare_file_annotation( - annotation: FileAnnotations, + annotation: FileAnnotation, class_names: List[str], classes_to_ignore: List[str], ) -> AudioLoaderAnnotationGroup: @@ -626,7 +640,9 @@ def _prepare_file_annotation( "end_times": np.array([ann["end_time"] for ann in annotations]), "high_freqs": np.array([ann["high_freq"] for ann in annotations]), "low_freqs": np.array([ann["low_freq"] for ann in annotations]), - "class_ids": np.array([ann.get("class_id", -1) for ann in annotations]), + "class_ids": np.array( + [ann.get("class_id", -1) for ann in annotations] + ), "individual_ids": np.array([ann["individual"] for ann in annotations]), "class_id_file": class_id_file, } @@ -639,15 +655,15 @@ class AudioLoader(torch.utils.data.Dataset): def __init__( self, - data_anns_ip: List[FileAnnotations], - params, + data_anns_ip: List[FileAnnotation], + params: AudioLoaderParameters, dataset_name: Optional[str] = None, is_train: bool = False, + return_spec_for_viz: bool = False, ): - self.is_train: bool = is_train - self.params: dict = params - self.return_spec_for_viz: bool = False - + self.is_train = is_train + self.params = params + self.return_spec_for_viz = return_spec_for_viz self.data_anns: List[AudioLoaderAnnotationGroup] = [ _prepare_file_annotation( ann, @@ -657,61 +673,6 @@ class AudioLoader(torch.utils.data.Dataset): for ann in data_anns_ip ] - # for ii in range(len(data_anns_ip)): - # dd = copy.deepcopy(data_anns_ip[ii]) - # - # # filter out unused annotation here - # filtered_annotations = [] - # for ii, aa in enumerate(dd["annotation"]): - # if "individual" in aa.keys(): - # aa["individual"] = int(aa["individual"]) - # - # # if only one call labeled it has to be from the same - # # individual - # if len(dd["annotation"]) == 1: - # aa["individual"] = 0 - # - # # convert class name into class label - # if aa["class"] in self.params["class_names"]: - # aa["class_id"] = self.params["class_names"].index( - # aa["class"] - # ) - # else: - # aa["class_id"] = -1 - # - # if aa["class"] not in self.params["classes_to_ignore"]: - # filtered_annotations.append(aa) - # - # dd["annotation"] = filtered_annotations - # dd["start_times"] = np.array( - # [aa["start_time"] for aa in dd["annotation"]] - # ) - # dd["end_times"] = np.array( - # [aa["end_time"] for aa in dd["annotation"]] - # ) - # dd["high_freqs"] = np.array( - # [float(aa["high_freq"]) for aa in dd["annotation"]] - # ) - # dd["low_freqs"] = np.array( - # [float(aa["low_freq"]) for aa in dd["annotation"]] - # ) - # dd["class_ids"] = np.array( - # [aa["class_id"] for aa in dd["annotation"]] - # ).astype(np.int32) - # dd["individual_ids"] = np.array( - # [aa["individual"] for aa in dd["annotation"]] - # ).astype(np.int32) - # - # # file level class name - # dd["class_id_file"] = -1 - # if "class_name" in dd.keys(): - # if dd["class_name"] in self.params["class_names"]: - # dd["class_id_file"] = self.params["class_names"].index( - # dd["class_name"] - # ) - # - # self.data_anns.append(dd) - ann_cnt = [len(aa["annotation"]) for aa in self.data_anns] self.max_num_anns = 2 * np.max( ann_cnt @@ -730,7 +691,7 @@ class AudioLoader(torch.utils.data.Dataset): def get_file_and_anns( self, index: Optional[int] = None, - ) -> Tuple[np.ndarray, int, float, AudioLoaderAnnotationGroup]: + ) -> Tuple[np.ndarray, float, float, AudioLoaderAnnotationGroup]: """Get an audio file and its annotations. Parameters @@ -742,7 +703,7 @@ class AudioLoader(torch.utils.data.Dataset): ------- audio_raw : np.ndarray Loaded audio file. - sampling_rate : int + sampling_rate : float Sampling rate of the audio file. duration : float Duration of the audio file in seconds. @@ -837,7 +798,7 @@ class AudioLoader(torch.utils.data.Dataset): ( audio2, sampling_rate2, - duration2, + _, ann2, ) = self.get_file_and_anns() audio, ann = combine_audio_aug( @@ -846,7 +807,11 @@ class AudioLoader(torch.utils.data.Dataset): # simulate echo by adding delayed copy of the file if np.random.random() < self.params["aug_prob"]: - audio = echo_aug(audio, sampling_rate, self.params) + audio = echo_aug( + audio, + sampling_rate, + echo_max_delay=self.params["echo_max_delay"], + ) # resample the audio # if np.random.random() < self.params["aug_prob"]: @@ -855,11 +820,16 @@ class AudioLoader(torch.utils.data.Dataset): # ) # create spectrogram - spec, spec_for_viz = au.generate_spectrogram( + spec = au.generate_spectrogram( audio, sampling_rate, - self.params, - self.return_spec_for_viz, + fft_win_length=self.params["fft_win_length"], + fft_overlap=self.params["fft_overlap"], + max_freq=self.params["max_freq"], + min_freq=self.params["min_freq"], + spec_scale=self.params["spec_scale"], + denoise_spec_avg=self.params["denoise_spec_avg"], + max_scale_spec=self.params["max_scale_spec"], ) rsf = self.params["resize_factor"] spec_op_shape = ( @@ -879,20 +849,29 @@ class AudioLoader(torch.utils.data.Dataset): # augment spectrogram if self.is_train and self.params["augment_at_train"]: if np.random.random() < self.params["aug_prob"]: - spec = scale_vol_aug(spec, self.params) + spec = scale_vol_aug( + spec, + spec_amp_scaling=self.params["spec_amp_scaling"], + ) if np.random.random() < self.params["aug_prob"]: spec = warp_spec_aug( spec, ann, - self.params, + stretch_squeeze_delta=self.params["stretch_squeeze_delta"], ) if np.random.random() < self.params["aug_prob"]: - spec = mask_time_aug(spec, self.params) + spec = mask_time_aug( + spec, + mask_max_time_perc=self.params["mask_max_time_perc"], + ) if np.random.random() < self.params["aug_prob"]: - spec = mask_freq_aug(spec, self.params) + spec = mask_freq_aug( + spec, + mask_max_freq_perc=self.params["mask_max_freq_perc"], + ) outputs = {} outputs["spec"] = spec @@ -911,7 +890,13 @@ class AudioLoader(torch.utils.data.Dataset): spec_op_shape, sampling_rate, ann, - self.params, + class_names=self.params["class_names"], + fft_win_length=self.params["fft_win_length"], + fft_overlap=self.params["fft_overlap"], + max_freq=self.params["max_freq"], + min_freq=self.params["min_freq"], + resize_factor=self.params["resize_factor"], + target_sigma=self.params["target_sigma"], ) # hack to get around requirement that all vectors are the same length diff --git a/batdetect2/train/losses.py b/batdetect2/train/losses.py index 02bfdd6..1116c50 100644 --- a/batdetect2/train/losses.py +++ b/batdetect2/train/losses.py @@ -1,8 +1,13 @@ +from typing import Optional + import torch import torch.nn.functional as F -def bbox_size_loss(pred_size, gt_size): +def bbox_size_loss( + pred_size: torch.Tensor, + gt_size: torch.Tensor, +) -> torch.Tensor: """ Bounding box size loss. Only compute loss where there is a bounding box. """ @@ -12,7 +17,12 @@ def bbox_size_loss(pred_size, gt_size): ) -def focal_loss(pred, gt, weights=None, valid_mask=None): +def focal_loss( + pred: torch.Tensor, + gt: torch.Tensor, + weights: Optional[torch.Tensor] = None, + valid_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints pred (batch x c x h x w) @@ -52,7 +62,11 @@ def focal_loss(pred, gt, weights=None, valid_mask=None): return loss -def mse_loss(pred, gt, weights=None, valid_mask=None): +def mse_loss( + pred: torch.Tensor, + gt: torch.Tensor, + valid_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ Mean squared error loss. """ diff --git a/batdetect2/train/train_model.py b/batdetect2/train/train_model.py index e38de39..a56ea36 100644 --- a/batdetect2/train/train_model.py +++ b/batdetect2/train/train_model.py @@ -5,6 +5,7 @@ import warnings import matplotlib.pyplot as plt import numpy as np import torch +import torch.utils.data from torch.optim.lr_scheduler import CosineAnnealingLR import batdetect2.detector.post_process as pp @@ -29,7 +30,7 @@ def save_images_batch(model, data_loader, params): ind = 0 # first image in each batch with torch.no_grad(): - for batch_idx, inputs in enumerate(data_loader): + for inputs in data_loader: data = inputs["spec"].to(params["device"]) outputs = model(data) @@ -81,7 +82,12 @@ def save_image( def loss_fun( - outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq + outputs, + gt_det, + gt_size, + gt_class, + det_criterion, + params, ): # detection loss loss = params["det_loss_weight"] * det_criterion( @@ -104,7 +110,13 @@ def loss_fun( def train( - model, epoch, data_loader, det_criterion, optimizer, scheduler, params + model, + epoch, + data_loader, + det_criterion, + optimizer, + scheduler, + params, ): model.train() @@ -309,7 +321,7 @@ def select_model(params): resize_factor=params["resize_factor"], ) else: - print("No valid network specified") + raise ValueError("No valid network specified") return model @@ -319,9 +331,9 @@ def main(): params = parameters.get_params(True) if torch.cuda.is_available(): - params["device"] = "cuda" + params.device = "cuda" else: - params["device"] = "cpu" + params.device = "cpu" # setup arg parser and populate it with exiting parameters - will not work with lists parser = argparse.ArgumentParser() @@ -349,13 +361,16 @@ def main(): default="Rhinolophus ferrumequinum;Rhinolophus hipposideros", help='Will set low and high frequency the same for these classes. Separate names with ";"', ) + for key, val in params.items(): parser.add_argument("--" + key, type=type(val), default=val) params = vars(parser.parse_args()) # save notes file if params["notes"] != "": - tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"]) + tu.write_notes_file( + params["experiment"] + "notes.txt", params["notes"] + ) # load the training and test meta data - there are different splits defined train_sets, test_sets = ts.get_train_test_data( @@ -374,15 +389,11 @@ def main(): for tt in train_sets: print(tt["ann_path"]) classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] - ( - data_train, - params["class_names"], - params["class_inv_freq"], - ) = tu.load_set_of_anns( + data_train = tu.load_set_of_anns( train_sets, - classes_to_ignore, - params["events_of_interest"], - params["convert_to_genus"], + classes_to_ignore=classes_to_ignore, + events_of_interest=params["events_of_interest"], + convert_to_genus=params["convert_to_genus"], ) params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping( params["class_names"] @@ -415,11 +426,12 @@ def main(): print("\nTesting on:") for tt in test_sets: print(tt["ann_path"]) - data_test, _, _ = tu.load_set_of_anns( + + data_test = tu.load_set_of_anns( test_sets, - classes_to_ignore, - params["events_of_interest"], - params["convert_to_genus"], + classes_to_ignore=classes_to_ignore, + events_of_interest=params["events_of_interest"], + convert_to_genus=params["convert_to_genus"], ) data_train = tu.remove_dupes(data_train, data_test) test_dataset = adl.AudioLoader(data_test, params, is_train=False) @@ -447,10 +459,13 @@ def main(): scheduler = CosineAnnealingLR( optimizer, params["num_epochs"] * len(train_loader) ) + if params["train_loss"] == "mse": det_criterion = losses.mse_loss elif params["train_loss"] == "focal": det_criterion = losses.focal_loss + else: + raise ValueError("No valid loss specified") # save parameters to file with open(params["experiment"] + "params.json", "w") as da: diff --git a/batdetect2/train/train_utils.py b/batdetect2/train/train_utils.py index 62441a7..acf8f42 100644 --- a/batdetect2/train/train_utils.py +++ b/batdetect2/train/train_utils.py @@ -1,28 +1,37 @@ -import glob import json -import os -import random +from collections import Counter +from pathlib import Path +from typing import Dict, Generator, List, Optional, Tuple import numpy as np +from batdetect2 import types -def write_notes_file(file_name, text): + +def write_notes_file(file_name: str, text: str): with open(file_name, "a") as da: da.write(text + "\n") -def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path): - ddict = { +def get_blank_dataset_dict( + dataset_name: str, + is_test: bool, + ann_path: str, + wav_path: str, +) -> types.DatasetDict: + return { "dataset_name": dataset_name, "is_test": is_test, "is_binary": False, "ann_path": ann_path, "wav_path": wav_path, } - return ddict -def get_short_class_names(class_names, str_len=3): +def get_short_class_names( + class_names: List[str], + str_len: int = 3, +) -> List[str]: class_names_short = [] for cc in class_names: class_names_short.append( @@ -31,7 +40,10 @@ def get_short_class_names(class_names, str_len=3): return class_names_short -def remove_dupes(data_train, data_test): +def remove_dupes( + data_train: List[types.FileAnnotation], + data_test: List[types.FileAnnotation], +) -> List[types.FileAnnotation]: test_ids = [dd["id"] for dd in data_test] data_train_prune = [] for aa in data_train: @@ -43,14 +55,16 @@ def remove_dupes(data_train, data_test): return data_train_prune -def get_genus_mapping(class_names): +def get_genus_mapping(class_names: List[str]) -> Tuple[List[str], List[int]]: genus_names, genus_mapping = np.unique( [cc.split(" ")[0] for cc in class_names], return_inverse=True ) return genus_names.tolist(), genus_mapping.tolist() -def standardize_low_freq(data, class_of_interest): +def standardize_low_freq( + data: List[types.FileAnnotation], class_of_interest: str, +) -> List[types.FileAnnotation]: # address the issue of highly variable low frequency annotations # this often happens for contstant frequency calls # for the class of interest sets the low and high freq to be the dataset mean @@ -62,8 +76,8 @@ def standardize_low_freq(data, class_of_interest): low_freqs.append(aa["low_freq"]) high_freqs.append(aa["high_freq"]) - low_mean = np.mean(low_freqs) - high_mean = np.mean(high_freqs) + low_mean = float(np.mean(low_freqs)) + high_mean = float(np.mean(high_freqs)) assert low_mean < high_mean print("\nStandardizing low and high frequency for:") @@ -83,115 +97,148 @@ def standardize_low_freq(data, class_of_interest): return data -def load_set_of_anns( - data, - classes_to_ignore=[], - events_of_interest=None, - convert_to_genus=False, - verbose=True, - list_of_anns=False, - filter_issues=False, - name_replace=False, -): +def format_annotation( + annotation: types.FileAnnotation, + events_of_interest: Optional[List[str]] = None, + name_replace: Optional[Dict[str, str]] = None, + convert_to_genus: bool = False, + classes_to_ignore: Optional[List[str]] = None, +) -> types.FileAnnotation: + formated = [] + for aa in annotation["annotation"]: + if ( + events_of_interest is not None + and aa["event"] not in events_of_interest + ): + # Omit files with annotation issues + continue + # remove leading and trailing spaces + class_name = aa["class"].strip() + + if name_replace is not None: + # replace_names will be a dictionary mapping input name to output + class_name = name_replace.get(class_name, class_name) + + if convert_to_genus: + # convert everything to genus name + class_name = class_name.split(" ")[0] + + # NOTE: It is important to acknowledge that the class names filtering + # is done after the name replacement and the conversion to + # genus name. This allows filtering converted genus names and names + # that were replaced with a name that should be ignored. + if classes_to_ignore is not None and class_name in classes_to_ignore: + # Omit annotations with ignored classes + continue + + formated.append( + { + **aa, + "class": class_name, + } + ) + + return { + **annotation, + "annotation": formated, + } + + +def get_class_names( + data: List[types.FileAnnotation], + classes_to_ignore: Optional[List[str]] = None, +) -> Tuple[Counter[str], List[float]]: + """Extracts class names and their inverse frequencies. + + Parameters + ---------- + data + A list of file annotations, where each annotation contains a list of + sound events with associated class names. + classes_to_ignore + A list of class names to ignore. + + Returns: + -------- + class_names + A list of unique class names extracted from the annotations. + class_inv_freq + List of inverse frequencies of each class name in the provided data. + """ + if classes_to_ignore is None: + classes_to_ignore = [] + + class_names_list: List[str] = [] + for annotation in data: + for sound_event in annotation["annotation"]: + if sound_event["class"] in classes_to_ignore: + continue + + class_names_list.append(sound_event["class"]) + + counts = Counter(class_names_list) + mean_counts = float(np.mean(list(counts.values()))) + return counts, [mean_counts / counts[cc] for cc in class_names_list] + + +def report_class_counts(class_names: Counter[str]): + print("Class count:") + str_len = np.max([len(cc) for cc in class_names]) + 5 + for index, (class_name, count) in enumerate(class_names.most_common()): + print(f"{index:<5}{class_name:<{str_len}}{count}") + + +def load_set_of_anns( + data: List[types.DatasetDict], + *, + convert_to_genus: bool = False, + filter_issues: bool = False, + events_of_interest: Optional[List[str]] = None, + classes_to_ignore: Optional[List[str]] = None, + name_replace: Optional[Dict[str, str]] = None, +) -> List[types.FileAnnotation]: # load the annotations anns = [] - if list_of_anns: - # path to list of individual json files - anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"])) - else: - # dictionary of datasets - for dd in data: - anns.extend(load_anns(dd["ann_path"], dd["wav_path"])) - # discarding unannoated files - anns = [aa for aa in anns if aa["annotated"] is True] + # dictionary of datasets + for dataset in data: + for ann in load_anns(dataset["ann_path"], dataset["wav_path"]): + if not ann["annotated"]: + # Omit unannotated files + continue - # filter files that have annotation issues - is the input is a dictionary of - # datasets, this will lilely have already been done - if filter_issues: - anns = [aa for aa in anns if aa["issues"] is False] + if filter_issues and ann["issues"]: + # Omit files with annotation issues + continue - # check for some basic formatting errors with class names - for ann in anns: - for aa in ann["annotation"]: - aa["class"] = aa["class"].strip() - - # only load specified events - i.e. types of calls - if events_of_interest is not None: - for ann in anns: - filtered_events = [] - for aa in ann["annotation"]: - if aa["event"] in events_of_interest: - filtered_events.append(aa) - ann["annotation"] = filtered_events - - # change class names - # replace_names will be a dictionary mapping input name to output - if type(name_replace) is dict: - for ann in anns: - for aa in ann["annotation"]: - if aa["class"] in name_replace: - aa["class"] = name_replace[aa["class"]] - - # convert everything to genus name - if convert_to_genus: - for ann in anns: - for aa in ann["annotation"]: - aa["class"] = aa["class"].split(" ")[0] - - # get unique class names - class_names_all = [] - for ann in anns: - for aa in ann["annotation"]: - if aa["class"] not in classes_to_ignore: - class_names_all.append(aa["class"]) - - class_names, class_cnts = np.unique(class_names_all, return_counts=True) - class_inv_freq = class_cnts.sum() / ( - len(class_names) * class_cnts.astype(np.float32) - ) - - if verbose: - print("Class count:") - str_len = np.max([len(cc) for cc in class_names]) + 5 - for cc in range(len(class_names)): - print( - str(cc).ljust(5) - + class_names[cc].ljust(str_len) - + str(class_cnts[cc]) + anns.append( + format_annotation( + ann, + events_of_interest=events_of_interest, + name_replace=name_replace, + convert_to_genus=convert_to_genus, + classes_to_ignore=classes_to_ignore, + ) ) - if len(classes_to_ignore) == 0: - return anns - else: - return anns, class_names.tolist(), class_inv_freq.tolist() - - -def load_anns(ann_file_name, raw_audio_dir): - with open(ann_file_name) as da: - anns = json.load(da) - - for aa in anns: - aa["file_path"] = raw_audio_dir + aa["id"] - return anns -def load_anns_from_path(ann_file_dir, raw_audio_dir): - files = glob.glob(ann_file_dir + "*.json") - anns = [] - for ff in files: - with open(ff) as da: - ann = json.load(da) - ann["file_path"] = raw_audio_dir + ann["id"] - anns.append(ann) +def load_anns( + ann_dir: str, + raw_audio_dir: str, +) -> Generator[types.FileAnnotation, None, None]: + for path in Path(ann_dir).rglob("*.json"): + with open(path) as fp: + file_annotation = json.load(fp) - return anns + file_annotation["file_path"] = raw_audio_dir + file_annotation["id"] + yield file_annotation -class AverageMeter(object): - """Computes and stores the average and current value""" +class AverageMeter: + """Computes and stores the average and current value.""" def __init__(self): self.reset() diff --git a/batdetect2/types.py b/batdetect2/types.py index 019564a..1899665 100644 --- a/batdetect2/types.py +++ b/batdetect2/types.py @@ -1,5 +1,5 @@ """Types used in the code base.""" -from typing import List, NamedTuple, Optional, Union +from typing import Any, List, NamedTuple, Optional import numpy as np import torch @@ -26,8 +26,7 @@ __all__ = [ "Annotation", "DetectionModel", "FeatureExtractionParameters", - "FeatureExtractor", - "FileAnnotations", + "FileAnnotation", "ModelOutput", "ModelParameters", "NonMaximumSuppressionConfig", @@ -94,7 +93,10 @@ class ModelParameters(TypedDict): """Resize factor.""" class_names: List[str] - """Class names. The model is trained to detect these classes.""" + """Class names. + + The model is trained to detect these classes. + """ DictWithClass = TypedDict("DictWithClass", {"class": str}) @@ -103,8 +105,8 @@ DictWithClass = TypedDict("DictWithClass", {"class": str}) class Annotation(DictWithClass): """Format of annotations. - This is the format of a single annotation as expected by the annotation - tool. + This is the format of a single annotation as expected by the + annotation tool. """ start_time: float @@ -113,10 +115,10 @@ class Annotation(DictWithClass): end_time: float """End time in seconds.""" - low_freq: int + low_freq: float """Low frequency in Hz.""" - high_freq: int + high_freq: float """High frequency in Hz.""" class_prob: float @@ -135,7 +137,7 @@ class Annotation(DictWithClass): """Numeric ID for the class of the annotation.""" -class FileAnnotations(TypedDict): +class FileAnnotation(TypedDict): """Format of results. This is the format of the results expected by the annotation tool. @@ -157,7 +159,7 @@ class FileAnnotations(TypedDict): """Time expansion factor.""" class_name: str - """Class predicted at file level""" + """Class predicted at file level.""" notes: str """Notes of file.""" @@ -169,7 +171,7 @@ class FileAnnotations(TypedDict): class RunResults(TypedDict): """Run results.""" - pred_dict: FileAnnotations + pred_dict: FileAnnotation """Predictions in the format expected by the annotation tool.""" spec_feats: NotRequired[List[np.ndarray]] @@ -394,9 +396,9 @@ class PredictionResults(TypedDict): class DetectionModel(Protocol): """Protocol for detection models. - This protocol is used to define the interface for the detection models. - This allows us to use the same code for training and inference, even - though the models are different. + This protocol is used to define the interface for the detection + models. This allows us to use the same code for training and + inference, even though the models are different. """ num_classes: int @@ -416,16 +418,14 @@ class DetectionModel(Protocol): def forward( self, - ip: torch.Tensor, - return_feats: bool = False, + spec: torch.Tensor, ) -> ModelOutput: """Forward pass of the model.""" ... def __call__( self, - ip: torch.Tensor, - return_feats: bool = False, + spec: torch.Tensor, ) -> ModelOutput: """Forward pass of the model.""" ... @@ -490,8 +490,10 @@ class HeatmapParameters(TypedDict): """Maximum frequency to consider in Hz.""" target_sigma: float - """Sigma for the Gaussian kernel. Controls the width of the points in - the heatmap.""" + """Sigma for the Gaussian kernel. + + Controls the width of the points in the heatmap. + """ class AnnotationGroup(TypedDict): @@ -522,10 +524,10 @@ class AnnotationGroup(TypedDict): annotated: NotRequired[bool] """Wether the annotation group is complete or not. - Usually annotation groups are associated to a single - audio clip. If the annotation group is complete, it means that all - relevant sound events have been annotated. If it is not complete, it - means that some sound events might not have been annotated. + Usually annotation groups are associated to a single audio clip. If + the annotation group is complete, it means that all relevant sound + events have been annotated. If it is not complete, it means that + some sound events might not have been annotated. """ x_inds: NotRequired[np.ndarray] @@ -535,12 +537,88 @@ class AnnotationGroup(TypedDict): """Y coordinate of the annotations in the spectrogram.""" -class AudioLoaderAnnotationGroup(AnnotationGroup, FileAnnotations): +class AudioLoaderAnnotationGroup(TypedDict): """Group of annotation items for the training audio loader. This class is used to store the annotations for the training audio loader. It inherits from `AnnotationGroup` and `FileAnnotations`. """ + id: str + duration: float + issues: bool + file_path: str + time_exp: float + class_name: str + notes: str + start_times: np.ndarray + end_times: np.ndarray + low_freqs: np.ndarray + high_freqs: np.ndarray + class_ids: np.ndarray + individual_ids: np.ndarray + x_inds: np.ndarray + y_inds: np.ndarray + annotation: List[Annotation] + annotated: bool class_id_file: int """ID of the class of the file.""" + + +class AudioLoaderParameters(TypedDict): + class_names: List[str] + classes_to_ignore: List[str] + target_samp_rate: int + scale_raw_audio: bool + fft_win_length: float + fft_overlap: float + spec_train_width: int + resize_factor: float + spec_divide_factor: int + augment_at_train: bool + augment_at_train_combine: bool + aug_prob: float + spec_height: int + echo_max_delay: float + spec_amp_scaling: float + stretch_squeeze_delta: float + mask_max_time_perc: float + mask_max_freq_perc: float + max_freq: float + min_freq: float + spec_scale: str + denoise_spec_avg: bool + max_scale_spec: bool + target_sigma: float + + +class FeatureExtractor(Protocol): + def __call__( + self, + prediction: Prediction, + **kwargs: Any, + ) -> float: + ... + + +class DatasetDict(TypedDict): + """Dataset dictionary. + + This is the format of the dictionary that contains the dataset + information. + """ + + dataset_name: str + """Name of the dataset.""" + + is_test: bool + """Whether the dataset is a test set.""" + + is_binary: bool + """Whether the dataset is binary.""" + + ann_path: str + """Path to the annotations.""" + + wav_path: str + """Path to the audio files.""" diff --git a/batdetect2/utils/audio_utils.py b/batdetect2/utils/audio_utils.py index 908d971..09e76c8 100644 --- a/batdetect2/utils/audio_utils.py +++ b/batdetect2/utils/audio_utils.py @@ -1,13 +1,11 @@ import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, overload import librosa import librosa.core.spectrum import numpy as np import torch -from . import wavfile - __all__ = [ "load_audio", "generate_spectrogram", @@ -15,113 +13,171 @@ __all__ = [ ] -def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): - nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor +@overload +def time_to_x_coords( + time_in_file: np.ndarray, + sampling_rate: float, + fft_win_length: float, + fft_overlap: float, +) -> np.ndarray: + ... + + +@overload +def time_to_x_coords( + time_in_file: float, + sampling_rate: float, + fft_win_length: float, + fft_overlap: float, +) -> float: + ... + + +def time_to_x_coords( + time_in_file: Union[float, np.ndarray], + sampling_rate: float, + fft_win_length: float, + fft_overlap: float, +) -> Union[float, np.ndarray]: + nfft = np.floor(fft_win_length * sampling_rate) noverlap = np.floor(fft_overlap * nfft) return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap) # NOTE this is also defined in post_process -def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): +def x_coords_to_time( + x_pos: float, + sampling_rate: int, + fft_win_length: float, + fft_overlap: float, +) -> float: nfft = np.floor(fft_win_length * sampling_rate) noverlap = np.floor(fft_overlap * nfft) return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate - # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window + + # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for + # center of temporal window def generate_spectrogram( - audio, - sampling_rate, - params, - return_spec_for_viz=False, - check_spec_size=True, -): + audio: np.ndarray, + sampling_rate: float, + fft_win_length: float, + fft_overlap: float, + max_freq: float, + min_freq: float, + spec_scale: str, + denoise_spec_avg: bool = False, + max_scale_spec: bool = False, +) -> np.ndarray: # generate spectrogram spec = gen_mag_spectrogram( audio, sampling_rate, - params["fft_win_length"], - params["fft_overlap"], + window_len=fft_win_length, + overlap_perc=fft_overlap, + ) + spec = crop_spectrogram( + spec, + fft_win_length=fft_win_length, + max_freq=max_freq, + min_freq=min_freq, + ) + spec = scale_spectrogram( + spec, + sampling_rate, + spec_scale=spec_scale, + fft_win_length=fft_win_length, ) + if denoise_spec_avg: + spec = denoise_spectrogram(spec) + + if max_scale_spec: + spec = max_scale_spectrogram(spec) + + return spec + + +def crop_spectrogram( + spec: np.ndarray, + fft_win_length: float, + max_freq: float, + min_freq: float, +) -> np.ndarray: # crop to min/max freq - max_freq = round(params["max_freq"] * params["fft_win_length"]) - min_freq = round(params["min_freq"] * params["fft_win_length"]) + max_freq = round(max_freq * fft_win_length) + min_freq = round(min_freq * fft_win_length) if spec.shape[0] < max_freq: freq_pad = max_freq - spec.shape[0] spec = np.vstack( (np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec) ) - spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :] + return spec[-max_freq : spec.shape[0] - min_freq, :] - if params["spec_scale"] == "log": - log_scaling = ( - 2.0 - * (1.0 / sampling_rate) - * ( - 1.0 - / ( - np.abs( - np.hanning( - int(params["fft_win_length"] * sampling_rate) - ) - ) - ** 2 - ).sum() - ) + +def denoise_spectrogram(spec: np.ndarray) -> np.ndarray: + spec = spec - np.mean(spec, 1)[:, np.newaxis] + return spec.clip(min=0) + + +def max_scale_spectrogram(spec: np.ndarray) -> np.ndarray: + return spec / (spec.max() + 10e-6) + + +def log_scale( + spec: np.ndarray, + sampling_rate: float, + fft_win_length: float, +) -> np.ndarray: + log_scaling = ( + 2.0 + * (1.0 / sampling_rate) + * ( + 1.0 + / ( + np.abs(np.hanning(int(fft_win_length * sampling_rate))) ** 2 + ).sum() ) - # log_scaling = (1.0 / sampling_rate)*0.1 - # log_scaling = (1.0 / sampling_rate)*10e4 - spec = np.log1p(log_scaling * spec_cropped) - elif params["spec_scale"] == "pcen": - spec = pcen(spec_cropped, sampling_rate) + ) + return np.log1p(log_scaling * spec) - elif params["spec_scale"] == "none": - pass - if params["denoise_spec_avg"]: - spec = spec - np.mean(spec, 1)[:, np.newaxis] - spec.clip(min=0, out=spec) +def scale_spectrogram( + spec: np.ndarray, + sampling_rate: float, + spec_scale: str, + fft_win_length: float, +) -> np.ndarray: + if spec_scale == "log": + return log_scale(spec, sampling_rate, fft_win_length) - if params["max_scale_spec"]: - spec = spec / (spec.max() + 10e-6) + if spec_scale == "pcen": + return pcen(spec, sampling_rate) - # needs to be divisible by specific factor - if not it should have been padded - # if check_spec_size: - # assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0) - # assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0) + return spec + +def prepare_spec_for_viz( + spec: np.ndarray, + sampling_rate: int, + fft_win_length: float, +) -> np.ndarray: # for visualization purposes - use log scaled spectrogram - if return_spec_for_viz: - log_scaling = ( - 2.0 - * (1.0 / sampling_rate) - * ( - 1.0 - / ( - np.abs( - np.hanning( - int(params["fft_win_length"] * sampling_rate) - ) - ) - ** 2 - ).sum() - ) - ) - spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32) - else: - spec_for_viz = None - - return spec, spec_for_viz + return log_scale( + spec, + sampling_rate, + fft_win_length=fft_win_length, + ).astype(np.float32) def load_audio( audio_file: str, time_exp_fact: float, - target_samp_rate: int, + target_sampling_rate: int, scale: bool = False, max_duration: Optional[float] = None, -) -> Tuple[int, np.ndarray]: +) -> Tuple[float, np.ndarray]: """Load an audio file and resample it to the target sampling rate. The audio is also scaled to [-1, 1] and clipped to the maximum duration. @@ -152,63 +208,82 @@ def load_audio( """ with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=wavfile.WavFileWarning) - # sampling_rate, audio_raw = wavfile.read(audio_file) - audio_raw, sampling_rate = librosa.load( + audio, sampling_rate = librosa.load( audio_file, sr=None, dtype=np.float32, ) - if len(audio_raw.shape) > 1: + if len(audio.shape) > 1: raise ValueError("Currently does not handle stereo files") sampling_rate = sampling_rate * time_exp_fact # resample - need to do this after correcting for time expansion - sampling_rate_old = sampling_rate - sampling_rate = target_samp_rate - if sampling_rate_old != sampling_rate: - audio_raw = librosa.resample( - audio_raw, - orig_sr=sampling_rate_old, - target_sr=sampling_rate, - res_type="polyphase", - ) + audio = resample_audio(audio, sampling_rate, target_sampling_rate) - # clipping maximum duration if max_duration is not None: - max_duration = int( - np.minimum( - int(sampling_rate * max_duration), - audio_raw.shape[0], - ) - ) - audio_raw = audio_raw[:max_duration] + audio = clip_audio(audio, target_sampling_rate, max_duration) # scale to [-1, 1] if scale: - audio_raw = audio_raw - audio_raw.mean() - audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6) + audio = scale_audio(audio) - return sampling_rate, audio_raw + return target_sampling_rate, audio + + +def resample_audio( + audio: np.ndarray, + sr_orig: float, + sr_target: float, +) -> np.ndarray: + if sr_orig != sr_target: + return librosa.resample( + audio, + orig_sr=sr_orig, + target_sr=sr_target, + res_type="polyphase", + ) + + return audio + + +def clip_audio( + audio: np.ndarray, + sampling_rate: float, + max_duration: float, +) -> np.ndarray: + max_duration = int( + np.minimum( + int(sampling_rate * max_duration), + audio.shape[0], + ) + ) + return audio[:max_duration] + + +def scale_audio( + audio: np.ndarray, + eps: float = 10e-6, +) -> np.ndarray: + return (audio - audio.mean()) / (np.abs(audio).max() + eps) def pad_audio( - audio_raw, - fs, - ms, - overlap_perc, - resize_factor, - divide_factor, - fixed_width=None, -): + audio_raw: np.ndarray, + sampling_rate: float, + window_len: float, + overlap_perc: float, + resize_factor: float, + divide_factor: float, + fixed_width: Optional[int] = None, +) -> np.ndarray: # Adds zeros to the end of the raw data so that the generated sepctrogram # will be evenly divisible by `divide_factor` # Also deals with very short audio clips and fixed_width during training # This code could be clearer, clean up - nfft = int(ms * fs) + nfft = int(window_len * sampling_rate) noverlap = int(overlap_perc * nfft) step = nfft - noverlap min_size = int(divide_factor * (1.0 / resize_factor)) @@ -245,19 +320,24 @@ def pad_audio( return audio_raw -def gen_mag_spectrogram(x, fs, ms, overlap_perc): +def gen_mag_spectrogram( + audio: np.ndarray, + sampling_rate: float, + window_len: float, + overlap_perc: float, +) -> np.ndarray: # Computes magnitude spectrogram by specifying time. - - x = x.astype(np.float32) - nfft = int(ms * fs) + audio = audio.astype(np.float32) + nfft = int(window_len * sampling_rate) noverlap = int(overlap_perc * nfft) - # window data - step = nfft - noverlap - # compute spec spec, _ = librosa.core.spectrum._spectrogram( - y=x, power=1, n_fft=nfft, hop_length=step, center=False + y=audio, + power=1, + n_fft=nfft, + hop_length=nfft - noverlap, + center=False, ) # remove DC component and flip vertical orientation @@ -266,24 +346,25 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc): return spec.astype(np.float32) -def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc): - nfft = int(ms * fs) +def gen_mag_spectrogram_pt( + audio: torch.Tensor, + sampling_rate: float, + window_len: float, + overlap_perc: float, +) -> torch.Tensor: + nfft = int(window_len * sampling_rate) nstep = round((1.0 - overlap_perc) * nfft) + han_win = torch.hann_window(nfft, periodic=False).to(audio.device) - han_win = torch.hann_window(nfft, periodic=False).to(x.device) - - complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False) + complex_spec = torch.stft(audio, nfft, nstep, window=han_win, center=False) spec = complex_spec.pow(2.0).sum(-1) # remove DC component and flip vertically - spec = torch.flipud(spec[0, 1:, :]) - - return spec + return torch.flipud(spec[0, 1:, :]) -def pcen(spec_cropped, sampling_rate): +def pcen(spec: np.ndarray, sampling_rate: float) -> np.ndarray: # TODO should be passing hop_length too i.e. step - spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype( + return librosa.pcen(spec * (2**31), sr=sampling_rate / 10).astype( np.float32 ) - return spec diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 8d6ca7f..a6eadfd 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -16,7 +16,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH from batdetect2.types import ( Annotation, DetectionModel, - FileAnnotations, + FileAnnotation, ModelOutput, ModelParameters, PredictionResults, @@ -79,7 +79,7 @@ def list_audio_files(ip_dir: str) -> List[str]: def load_model( model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True, - device: Optional[torch.device] = None, + device: Union[torch.device, str, None] = None, ) -> Tuple[DetectionModel, ModelParameters]: """Load model from file. @@ -222,7 +222,7 @@ def format_single_result( duration: float, predictions: PredictionResults, class_names: List[str], -) -> FileAnnotations: +) -> FileAnnotation: """Format results into the format expected by the annotation tool. Args: @@ -399,11 +399,10 @@ def save_results_to_file(results, op_path: str) -> None: def compute_spectrogram( audio: np.ndarray, - sampling_rate: int, + sampling_rate: float, params: SpectrogramParameters, device: torch.device, - return_np: bool = False, -) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]: +) -> Tuple[float, torch.Tensor]: """Compute a spectrogram from an audio array. Will pad the audio array so that it is evenly divisible by the @@ -412,24 +411,16 @@ def compute_spectrogram( Parameters ---------- audio : np.ndarray - sampling_rate : int - params : SpectrogramParameters The parameters to use for generating the spectrogram. - return_np : bool, optional - Whether to return the spectrogram as a numpy array as well as a - torch tensor. The default is False. - Returns ------- duration : float The duration of the spectrgram in seconds. - spec : torch.Tensor The spectrogram as a torch tensor. - spec_np : np.ndarray, optional The spectrogram as a numpy array. Only returned if `return_np` is True, otherwise None. @@ -446,7 +437,7 @@ def compute_spectrogram( ) # generate spectrogram - spec, _ = au.generate_spectrogram(audio, sampling_rate, params) + spec = au.generate_spectrogram(audio, sampling_rate, params) # convert to pytorch spec = torch.from_numpy(spec).to(device) @@ -466,18 +457,12 @@ def compute_spectrogram( mode="bilinear", align_corners=False, ) - - if return_np: - spec_np = spec[0, 0, :].cpu().data.numpy() - else: - spec_np = None - - return duration, spec, spec_np + return duration, spec def iterate_over_chunks( audio: np.ndarray, - samplerate: int, + samplerate: float, chunk_size: float, ) -> Iterator[Tuple[float, np.ndarray]]: """Iterate over audio in chunks of size chunk_size. @@ -510,7 +495,7 @@ def iterate_over_chunks( def _process_spectrogram( spec: torch.Tensor, - samplerate: int, + samplerate: float, model: DetectionModel, config: ProcessingConfiguration, ) -> Tuple[PredictionResults, np.ndarray]: @@ -632,13 +617,13 @@ def process_spectrogram( def _process_audio_array( audio: np.ndarray, - sampling_rate: int, + sampling_rate: float, model: DetectionModel, config: ProcessingConfiguration, device: torch.device, ) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]: # load audio file and compute spectrogram - _, spec, _ = compute_spectrogram( + _, spec = compute_spectrogram( audio, sampling_rate, { @@ -654,7 +639,6 @@ def _process_audio_array( "max_scale_spec": config["max_scale_spec"], }, device, - return_np=False, ) # process spectrogram with model @@ -754,13 +738,15 @@ def process_file( # Get original sampling rate file_samp_rate = librosa.get_samplerate(audio_file) - orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1 + orig_samp_rate = file_samp_rate * float( + config.get("time_expansion", 1.0) or 1.0 + ) # load audio file sampling_rate, audio_full = au.load_audio( audio_file, time_exp_fact=config.get("time_expansion", 1) or 1, - target_samp_rate=config["target_samp_rate"], + target_sampling_rate=config["target_samp_rate"], scale=config["scale_raw_audio"], max_duration=config.get("max_duration"), ) @@ -802,7 +788,6 @@ def process_file( cnn_feats.append(features[0]) if config["spec_slices"]: - # FIX: This is not currently working. Returns empty slices spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms)) # Merge results from chunks diff --git a/tests/test_features.py b/tests/test_features.py index 1271fda..e34d499 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -152,7 +152,7 @@ def test_compute_max_power_bb(max_power: int): target_samp_rate=samplerate, ) - spec, _ = au.generate_spectrogram( + spec = au.generate_spectrogram( audio, samplerate, params, @@ -240,7 +240,7 @@ def test_compute_max_power(): target_samp_rate=samplerate, ) - spec, _ = au.generate_spectrogram( + spec = au.generate_spectrogram( audio, samplerate, params,