diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..b20ceed --- /dev/null +++ b/.pylintrc @@ -0,0 +1,5 @@ +[TYPECHECK] + +# List of members which are set dynamically and missed by Pylint inference +# system, and so shouldn't trigger E1101 when accessed. +generated-members=torch.* diff --git a/app.py b/app.py index 5e11be2..dd35265 100644 --- a/app.py +++ b/app.py @@ -9,7 +9,7 @@ import bat_detect.utils.plot_utils as viz # setup the arguments args = {} -args = du.get_default_bd_args() +args = du.get_default_run_config() args["detection_threshold"] = 0.3 args["time_expansion_factor"] = 1 args["model_path"] = "models/Net2DFast_UK_same.pth.tar" diff --git a/bat_detect/detector/parameters.py b/bat_detect/detector/parameters.py index a1fe9c7..bb705dd 100644 --- a/bat_detect/detector/parameters.py +++ b/bat_detect/detector/parameters.py @@ -1,7 +1,18 @@ import datetime import os -import numpy as np +TARGET_SAMPLERATE_HZ = 256000 +FFT_WIN_LENGTH_S = 512 / 256000.0 +FFT_OVERLAP = 0.75 +MAX_FREQ_HZ = 120000 +MIN_FREQ_HZ = 10000 +RESIZE_FACTOR = 0.5 +SPEC_DIVIDE_FACTOR = 32 +SPEC_HEIGHT = 256 +SCALE_RAW_AUDIO = False +DETECTION_THRESHOLD = 0.01 +NMS_KERNEL_SIZE = 9 +NMS_TOP_K_PER_SEC = 200 def mk_dir(path): @@ -30,35 +41,39 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"): # spec parameters params[ "target_samp_rate" - ] = 256000 # resamples all audio so that it is at this rate - params["fft_win_length"] = ( - 512 / 256000.0 - ) # in milliseconds, amount of time per stft time step - params["fft_overlap"] = 0.75 # stft window overlap + ] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate + params[ + "fft_win_length" + ] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step + params["fft_overlap"] = FFT_OVERLAP # stft window overlap params[ "max_freq" - ] = 120000 # in Hz, everything above this will be discarded - params["min_freq"] = 10000 # in Hz, everything below this will be discarded + ] = MAX_FREQ_HZ # in Hz, everything above this will be discarded + params[ + "min_freq" + ] = MIN_FREQ_HZ # in Hz, everything below this will be discarded params[ "resize_factor" - ] = 0.5 # resize so the spectrogram at the input of the network + ] = RESIZE_FACTOR # resize so the spectrogram at the input of the network params[ "spec_height" - ] = 256 # units are number of frequency bins (before resizing is performed) + ] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed) params[ "spec_train_width" ] = 512 # units are number of time steps (before resizing is performed) params[ "spec_divide_factor" - ] = 32 # spectrogram should be divisible by this amount in width and height + ] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height # spec processing params params[ "denoise_spec_avg" ] = True # removes the mean for each frequency band - params["scale_raw_audio"] = False # scales the raw audio to [-1, 1] + params[ + "scale_raw_audio" + ] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1] params[ "max_scale_spec" ] = False # scales the spectrogram so that it is max 1 @@ -73,11 +88,13 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"): ] = 0.01 # if start of GT calls are within this time from the start/end of file ignore params[ "detection_threshold" - ] = 0.01 # the smaller this is the better the recall will be - params["nms_kernel_size"] = 9 + ] = DETECTION_THRESHOLD # the smaller this is the better the recall will be + params[ + "nms_kernel_size" + ] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression params[ "nms_top_k_per_sec" - ] = 200 # keep top K highest predictions per second of audio + ] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio params["target_sigma"] = 2.0 # augmentation params diff --git a/bat_detect/detector/post_process.py b/bat_detect/detector/post_process.py index 05fabfc..fbbd410 100644 --- a/bat_detect/detector/post_process.py +++ b/bat_detect/detector/post_process.py @@ -1,3 +1,6 @@ +"""Post-processing of the output of the model.""" +from typing import List, Optional, Tuple + import numpy as np import torch from torch import nn @@ -10,11 +13,26 @@ except ImportError: np.seterr(divide="ignore", invalid="ignore") -def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): +def x_coords_to_time( + x_pos: float, + sampling_rate: int, + fft_win_length: float, + fft_overlap: float, +) -> float: + """Convert x coordinates of spectrogram to time in seconds. + + Args: + x_pos: X position of the detection in pixels. + sampling_rate: Sampling rate of the audio in Hz. + fft_win_length: Length of the FFT window in seconds. + fft_overlap: Overlap of the FFT windows in seconds. + + Returns: + Time in seconds. + """ nfft = int(fft_win_length * sampling_rate) noverlap = int(fft_overlap * nfft) return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate - # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window def overall_class_pred(det_prob, class_prob): @@ -28,10 +46,10 @@ class NonMaximumSuppressionConfig(TypedDict): nms_kernel_size: int """Size of the kernel for non-maximum suppression.""" - max_freq: float + max_freq: int """Maximum frequency to consider in Hz.""" - min_freq: float + min_freq: int """Minimum frequency to consider in Hz.""" fft_win_length: float @@ -40,6 +58,9 @@ class NonMaximumSuppressionConfig(TypedDict): fft_overlap: float """Overlap of the FFT windows in seconds.""" + resize_factor: float + """Factor by which the input was resized.""" + nms_top_k_per_sec: float """Number of top detections to keep per second.""" @@ -47,8 +68,73 @@ class NonMaximumSuppressionConfig(TypedDict): """Threshold for detection probability.""" -def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int): - """Run non-maximum suppression on the output of the model.""" +class PredictionResults(TypedDict): + """Results of the prediction. + + Each key is a list of length `num_detections` containing the + corresponding values for each detection. + """ + + det_probs: np.ndarray + """Detection probabilities.""" + + x_pos: np.ndarray + """X position of the detection in pixels.""" + + y_pos: np.ndarray + """Y position of the detection in pixels.""" + + bb_width: np.ndarray + """Width of the detection in pixels.""" + + bb_height: np.ndarray + """Height of the detection in pixels.""" + + start_times: np.ndarray + """Start times of the detections in seconds.""" + + end_times: np.ndarray + """End times of the detections in seconds.""" + + low_freqs: np.ndarray + """Low frequencies of the detections in Hz.""" + + high_freqs: np.ndarray + """High frequencies of the detections in Hz.""" + + class_probs: Optional[np.ndarray] + """Class probabilities.""" + + +class ModelOutputs(TypedDict): + """Outputs of the model.""" + + pred_det: torch.Tensor + """Detection probabilities.""" + + pred_size: torch.Tensor + """Box sizes.""" + + pred_class: Optional[torch.Tensor] + """Class probabilities.""" + + features: Optional[torch.Tensor] + """Features extracted by the model.""" + + +def run_nms( + outputs: ModelOutputs, + params: NonMaximumSuppressionConfig, + sampling_rate: np.ndarray, +) -> Tuple[List[PredictionResults], List[np.ndarray]]: + """Run non-maximum suppression on the output of the model. + + Model outputs processed are expected to have a batch dimension. + Each element of the batch is processed independently. The + result is a pair of lists, one for the predictions and one for + the features. Each element of the lists corresponds to one + element of the batch. + """ pred_det = outputs["pred_det"] # probability of box pred_size = outputs["pred_size"] # box size @@ -62,7 +148,7 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int): # as we are choosing the same sampling rate for the entire batch duration = x_coords_to_time( pred_det.shape[-1], - sampling_rate[0].item(), + int(sampling_rate[0].item()), params["fft_win_length"], params["fft_overlap"], ) @@ -70,58 +156,72 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int): scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) # loop over batch to save outputs - preds = [] - feats = [] - for ii in range(pred_det_nms.shape[0]): + preds: List[PredictionResults] = [] + feats: List[np.ndarray] = [] + for num_detection in range(pred_det_nms.shape[0]): # get valid indices - inds_ord = torch.argsort(x_pos[ii, :]) - valid_inds = scores[ii, inds_ord] > params["detection_threshold"] + inds_ord = torch.argsort(x_pos[num_detection, :]) + valid_inds = ( + scores[num_detection, inds_ord] > params["detection_threshold"] + ) valid_inds = inds_ord[valid_inds] # create result dictionary pred = {} - pred["det_probs"] = scores[ii, valid_inds] - pred["x_pos"] = x_pos[ii, valid_inds] - pred["y_pos"] = y_pos[ii, valid_inds] - pred["bb_width"] = pred_size[ii, 0, pred["y_pos"], pred["x_pos"]] - pred["bb_height"] = pred_size[ii, 1, pred["y_pos"], pred["x_pos"]] + pred["det_probs"] = scores[num_detection, valid_inds] + pred["x_pos"] = x_pos[num_detection, valid_inds] + pred["y_pos"] = y_pos[num_detection, valid_inds] + pred["bb_width"] = pred_size[ + num_detection, 0, pred["y_pos"], pred["x_pos"] + ] + pred["bb_height"] = pred_size[ + num_detection, 1, pred["y_pos"], pred["x_pos"] + ] pred["start_times"] = x_coords_to_time( pred["x_pos"].float() / params["resize_factor"], - sampling_rate[ii].item(), + int(sampling_rate[num_detection].item()), params["fft_win_length"], params["fft_overlap"], ) pred["end_times"] = x_coords_to_time( (pred["x_pos"].float() + pred["bb_width"]) / params["resize_factor"], - sampling_rate[ii].item(), + int(sampling_rate[num_detection].item()), params["fft_win_length"], params["fft_overlap"], ) pred["low_freqs"] = ( - pred_size[ii].shape[1] - pred["y_pos"].float() + pred_size[num_detection].shape[1] - pred["y_pos"].float() ) * freq_rescale + params["min_freq"] pred["high_freqs"] = ( pred["low_freqs"] + pred["bb_height"] * freq_rescale ) # extract the per class votes - if "pred_class" in outputs: - pred["class_probs"] = outputs["pred_class"][ - ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds] + pred_class = outputs.get("pred_class") + if pred_class is not None: + pred["class_probs"] = pred_class[ + num_detection, + :, + y_pos[num_detection, valid_inds], + x_pos[num_detection, valid_inds], ] # extract the model features - if "features" in outputs: - feat = outputs["features"][ - ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds] + features = outputs.get("features") + if features is not None: + feat = features[ + num_detection, + :, + y_pos[num_detection, valid_inds], + x_pos[num_detection, valid_inds], ].transpose(0, 1) feat = feat.cpu().numpy().astype(np.float32) feats.append(feat) # convert to numpy - for kk in pred.keys(): - pred[kk] = pred[kk].cpu().numpy().astype(np.float32) + for key, value in pred.items(): + pred[key] = value.cpu().numpy().astype(np.float32) preds.append(pred) @@ -130,7 +230,7 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int): def non_max_suppression(heat, kernel_size): # kernel can be an int or list/tuple - if type(kernel_size) is int: + if isinstance(kernel_size, int): kernel_size_h = kernel_size kernel_size_w = kernel_size diff --git a/bat_detect/evaluate/evaluate_models.py b/bat_detect/evaluate/evaluate_models.py index 8ee3282..bf70f15 100644 --- a/bat_detect/evaluate/evaluate_models.py +++ b/bat_detect/evaluate/evaluate_models.py @@ -739,7 +739,7 @@ if __name__ == "__main__": # if args["bd_model_path"] != "": # load model - bd_args = du.get_default_bd_args() + bd_args = du.get_default_run_config() model, params_bd = du.load_model(args["bd_model_path"]) # check if the class names are the same diff --git a/bat_detect/train/audio_dataloader.py b/bat_detect/train/audio_dataloader.py index 70ba5b8..f7790a6 100644 --- a/bat_detect/train/audio_dataloader.py +++ b/bat_detect/train/audio_dataloader.py @@ -1,7 +1,4 @@ import copy -import os -import random -import sys import librosa import numpy as np @@ -9,7 +6,6 @@ import torch import torch.nn.functional as F import torchaudio -sys.path.append(os.path.join("..", "..")) import bat_detect.utils.audio_utils as au @@ -218,7 +214,10 @@ def resample_aug(audio, sampling_rate, params): sampling_rate_old = sampling_rate sampling_rate = np.random.choice(params["aug_sampling_rates"]) audio = librosa.resample( - audio, sampling_rate_old, sampling_rate, res_type="polyphase" + audio, + orig_sr=sampling_rate_old, + target_sr=sampling_rate, + res_type="polyphase", ) audio = au.pad_audio( @@ -237,7 +236,10 @@ def resample_aug(audio, sampling_rate, params): def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2): if sampling_rate != sampling_rate2: audio2 = librosa.resample( - audio2, sampling_rate2, sampling_rate, res_type="polyphase" + audio2, + orig_sr=sampling_rate2, + target_sr=sampling_rate, + res_type="polyphase", ) sampling_rate2 = sampling_rate if audio2.shape[0] < num_samples: diff --git a/bat_detect/train/train_model.py b/bat_detect/train/train_model.py index f7504b0..3619576 100644 --- a/bat_detect/train/train_model.py +++ b/bat_detect/train/train_model.py @@ -553,5 +553,6 @@ if __name__ == "__main__": torch.save(op_state, params["model_file_name"]) # save an image with associated prediction for each batch in the test set - if not args["do_not_save_images"]: - save_images_batch(model, test_loader, params) + # TODO: args variable does not exist + # if not args["do_not_save_images"]: + # save_images_batch(model, test_loader, params) diff --git a/bat_detect/utils/audio_utils.py b/bat_detect/utils/audio_utils.py index 1a62f0c..23bfc2c 100644 --- a/bat_detect/utils/audio_utils.py +++ b/bat_detect/utils/audio_utils.py @@ -7,7 +7,6 @@ import torch from . import wavfile - __all__ = [ "load_audio_file", ] @@ -163,9 +162,11 @@ def load_audio_file( # clipping maximum duration if max_duration is not None: - max_duration = np.minimum( - int(sampling_rate * max_duration), - audio_raw.shape[0], + max_duration = int( + np.minimum( + int(sampling_rate * max_duration), + audio_raw.shape[0], + ) ) audio_raw = audio_raw[:max_duration] diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index aec742c..448cba2 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -11,6 +11,20 @@ import bat_detect.detector.compute_features as feats import bat_detect.detector.post_process as pp import bat_detect.utils.audio_utils as au from bat_detect.detector import models +from bat_detect.detector.parameters import ( + DETECTION_THRESHOLD, + FFT_OVERLAP, + FFT_WIN_LENGTH_S, + MAX_FREQ_HZ, + MIN_FREQ_HZ, + NMS_KERNEL_SIZE, + NMS_TOP_K_PER_SEC, + RESIZE_FACTOR, + SCALE_RAW_AUDIO, + SPEC_DIVIDE_FACTOR, + SPEC_HEIGHT, + TARGET_SAMPLERATE_HZ, +) try: from typing import TypedDict @@ -24,23 +38,17 @@ DEFAULT_MODEL_PATH = os.path.join( "model.pth", ) -__all__ = ["load_model", "get_audio_files", "DEFAULT_MODEL_PATH"] - - -def get_default_bd_args(): - args = {} - args["detection_threshold"] = 0.001 - args["time_expansion_factor"] = 1 - args["audio_dir"] = "" - args["ann_dir"] = "" - args["spec_slices"] = False - args["chunk_size"] = 3 - args["spec_features"] = False - args["cnn_features"] = False - args["quiet"] = True - args["save_preds_if_empty"] = True - args["ann_dir"] = os.path.join(args["ann_dir"], "") - return args +__all__ = [ + "load_model", + "get_audio_files", + "format_results", + "save_results_to_file", + "iterate_over_chunks", + "process_spectrogram", + "process_audio_array", + "process_file", + "DEFAULT_MODEL_PATH", +] def get_audio_files(ip_dir: str) -> List[str]: @@ -80,7 +88,7 @@ class ModelParameters(TypedDict): ip_height: int """Input height in pixels.""" - resize_factor: int + resize_factor: float """Resize factor.""" class_names: List[str] @@ -118,6 +126,8 @@ def load_model( params = net_params["params"] params["device"] = device + model: torch.nn.Module + if params["model_name"] == "Net2DFast": model = models.Net2DFast( params["num_filters"], @@ -159,9 +169,9 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) if num_preds > 0: - for kk in predictions[0].keys(): - predictions_m[kk] = np.hstack( - [pp[kk] for pp in predictions if pp["det_probs"].shape[0] > 0] + for key in predictions[0].keys(): + predictions_m[key] = np.hstack( + [pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0] ) else: # hack in case where no detected calls as we need some of the key names in dict @@ -176,7 +186,10 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): return predictions_m, spec_feats, cnn_feats, spec_slices -class Annotation(TypedDict("WithClass", {"class": str})): +DictWithClass = TypedDict("DictWithClass", {"class": str}) + + +class Annotation(DictWithClass): """Format of annotations. This is the format of a single annotation as expected by the annotation @@ -214,7 +227,7 @@ class FileAnnotations(TypedDict): This is the format of the results expected by the annotation tool. """ - file_id: str + id: str """File ID.""" annotated: bool @@ -232,26 +245,32 @@ class FileAnnotations(TypedDict): class_name: str """Class predicted at file level""" + notes: str + """Notes of file.""" + annotation: List[Annotation] + """List of annotations.""" -class Results(TypedDict): +class RunResults(TypedDict): + """Run results.""" + pred_dict: FileAnnotations """Predictions in the format expected by the annotation tool.""" - spec_feats: Optional[np.ndarray] + spec_feats: Optional[List[np.ndarray]] """Spectrogram features.""" spec_feat_names: Optional[List[str]] """Spectrogram feature names.""" - cnn_feats: Optional[np.ndarray] + cnn_feats: Optional[List[np.ndarray]] """CNN features.""" cnn_feat_names: Optional[List[str]] """CNN feature names.""" - spec_slices: Optional[np.ndarray] + spec_slices: Optional[List[np.ndarray]] """Spectrogram slices.""" @@ -343,7 +362,7 @@ def convert_results( spec_feats, cnn_feats, spec_slices, -) -> Results: +) -> RunResults: """Convert results to dictionary as expected by the annotation tool. Args: @@ -369,8 +388,14 @@ def convert_results( ) # combine into final results dictionary - results = {} - results["pred_dict"] = pred_dict + results: RunResults = { + "pred_dict": pred_dict, + "spec_feats": None, + "spec_feat_names": None, + "cnn_feats": None, + "cnn_feat_names": None, + "spec_slices": None, + } # add spectrogram features if they exist if len(spec_feats) > 0: @@ -463,19 +488,16 @@ def save_results_to_file(results, op_path: str) -> None: class SpectrogramParameters(TypedDict): """Parameters for generating spectrograms.""" - fft_win_length: int - """Length of the FFT window in samples.""" + fft_win_length: float + """Length of the FFT window in seconds.""" - fft_overlap: int - """Number of samples to overlap between FFT windows.""" + fft_overlap: float + """Percentage of overlap between FFT windows.""" spec_height: int """Height of the spectrogram in pixels.""" - spec_width: int - """Width of the spectrogram in pixels.""" - - resize_factor: int + resize_factor: float """Factor to resize the spectrogram by.""" spec_divide_factor: int @@ -605,13 +627,14 @@ class ProcessingConfiguration(TypedDict): fft_win_length: float """Length of the FFT window in seconds.""" + fft_overlap: float """Length of the FFT window in samples.""" resize_factor: float """Factor to resize the spectrogram by.""" - spec_divide_factor: float + spec_divide_factor: int """Factor to divide the spectrogram by.""" spec_height: int @@ -644,27 +667,36 @@ class ProcessingConfiguration(TypedDict): nms_kernel_size: int """Size of the kernel for non-maximum suppression.""" - max_freq: float + max_freq: int """Maximum frequency to consider in Hz.""" - min_freq: float + min_freq: int """Minimum frequency to consider in Hz.""" nms_top_k_per_sec: float """Number of top detections to keep per second.""" - detection_threshold: float - """Threshold for detection probability.""" - quiet: bool """Whether to suppress output.""" + chunk_size: float + """Size of chunks to process in seconds.""" + + cnn_features: bool + """Whether to return CNN features.""" + + spec_features: bool + """Whether to return spectrogram features.""" + + spec_slices: bool + """Whether to return spectrogram slices.""" + def process_spectrogram( spec: torch.Tensor, samplerate: int, model: torch.nn.Module, - config: pp.NonMaximumSuppressionConfig, + config: ProcessingConfiguration, ): """Process a spectrogram with detection model. @@ -692,17 +724,29 @@ def process_spectrogram( outputs = model(spec, return_feats=config["cnn_features"]) # run non-max suppression - pred_nms, features = pp.run_nms( + pred_nms_list, features = pp.run_nms( outputs, - config, + { + "nms_kernel_size": config["nms_kernel_size"], + "max_freq": config["max_freq"], + "min_freq": config["min_freq"], + "fft_win_length": config["fft_win_length"], + "fft_overlap": config["fft_overlap"], + "resize_factor": config["resize_factor"], + "nms_top_k_per_sec": config["nms_top_k_per_sec"], + "detection_threshold": config["detection_threshold"], + }, np.array([float(samplerate)]), ) - pred_nms = pred_nms[0] + pred_nms = pred_nms_list[0] # if we have a background class - if pred_nms["class_probs"].shape[0] > len(config["class_names"]): - pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :] + class_probs = pred_nms.get("class_probs") + if (class_probs is not None) and ( + class_probs.shape[0] > len(config["class_names"]) + ): + pred_nms["class_probs"] = class_probs[:-1, :] return pred_nms, features @@ -737,7 +781,14 @@ def process_audio_array( _, spec, spec_np = compute_spectrogram( audio, sampling_rate, - config, + { + "fft_win_length": config["fft_win_length"], + "fft_overlap": config["fft_overlap"], + "spec_height": config["spec_height"], + "resize_factor": config["resize_factor"], + "spec_divide_factor": config["spec_divide_factor"], + "device": config["device"], + }, return_np=config["spec_features"] or config["spec_slices"], ) @@ -756,7 +807,7 @@ def process_file( audio_file: str, model: torch.nn.Module, config: ProcessingConfiguration, -) -> Union[Results, Any]: +) -> Union[RunResults, Any]: """Process a single audio file with detection model. Will split the audio file into chunks if it is too long and @@ -788,7 +839,7 @@ def process_file( # load audio file sampling_rate, audio_full = au.load_audio_file( audio_file, - time_exp_fact=config["time_expansion"], + time_exp_fact=config.get("time_expansion", 1) or 1, target_samp_rate=config["target_samp_rate"], scale=config["scale_raw_audio"], max_duration=config["max_duration"], @@ -840,7 +891,7 @@ def process_file( # convert results to a dictionary in the right format results = convert_results( file_id=os.path.basename(audio_file), - time_exp=config["time_expansion"], + time_exp=config.get("time_expansion", 1) or 1, duration=audio_full.shape[0] / float(sampling_rate), params=config, predictions=predictions, @@ -877,3 +928,38 @@ def summarize_results(results, predictions, config): config["class_names"][class_index].ljust(30) + str(round(class_overall[class_index], 3)) ) + + +def get_default_run_config(**kwargs) -> ProcessingConfiguration: + """Get default configuration for running detection model.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + args: ProcessingConfiguration = { + "detection_threshold": DETECTION_THRESHOLD, + "spec_slices": False, + "chunk_size": 3, + "spec_features": False, + "cnn_features": False, + "quiet": True, + "target_samp_rate": TARGET_SAMPLERATE_HZ, + "fft_win_length": FFT_WIN_LENGTH_S, + "fft_overlap": FFT_OVERLAP, + "resize_factor": RESIZE_FACTOR, + "spec_divide_factor": SPEC_DIVIDE_FACTOR, + "spec_height": SPEC_HEIGHT, + "scale_raw_audio": SCALE_RAW_AUDIO, + "device": device, + "class_names": [], + "time_expansion": 1, + "top_n": 3, + "return_raw_preds": False, + "max_duration": None, + "nms_kernel_size": NMS_KERNEL_SIZE, + "max_freq": MAX_FREQ_HZ, + "min_freq": MIN_FREQ_HZ, + "nms_top_k_per_sec": NMS_TOP_K_PER_SEC, + } + return { + **args, + **kwargs, + } diff --git a/bat_detect/utils/plot_utils.py b/bat_detect/utils/plot_utils.py index afbbc5f..6d732ec 100644 --- a/bat_detect/utils/plot_utils.py +++ b/bat_detect/utils/plot_utils.py @@ -523,7 +523,7 @@ class LossPlotter(object): def save_confusion_matrix(self, gt, pred): plt.figure(0) cm = confusion_matrix( - gt, pred, np.arange(len(self.class_names)) + gt, pred, labels=np.arange(len(self.class_names)) ).astype(np.float32) cm_norm = cm.sum(1) valid_inds = np.where(cm_norm > 0)[0] diff --git a/pyproject.toml b/pyproject.toml index 2c58647..3be8e8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,3 +49,10 @@ batdetect2 = "bat_detect.command:main" [tool.black] line-length = 80 + +[[tool.mypy.overrides]] +module = [ + "librosa", + "pandas", +] +ignore_missing_imports = true diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py index 8269f78..a7cb0a6 100644 --- a/scripts/gen_spec_image.py +++ b/scripts/gen_spec_image.py @@ -86,7 +86,7 @@ if __name__ == "__main__": args_cmd = vars(parser.parse_args()) # load the model - bd_args = du.get_default_bd_args() + bd_args = du.get_default_run_config() model, params_bd = du.load_model(args_cmd["model_path"]) bd_args["detection_threshold"] = args_cmd["detection_threshold"] bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"] diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py index cbb443c..f7185a0 100644 --- a/scripts/gen_spec_video.py +++ b/scripts/gen_spec_video.py @@ -89,7 +89,7 @@ if __name__ == "__main__": os.makedirs(op_dir) params = parameters.get_params(False) - args = du.get_default_bd_args() + args = du.get_default_run_config() args["time_expansion_factor"] = args_cmd["time_expansion_factor"] args["detection_threshold"] = args_cmd["detection_threshold"]