From 8da98b5258d088ba6d8f6d2c6676e9a0c2ec515a Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Wed, 22 Feb 2023 20:24:43 +0000 Subject: [PATCH] Refactored detector_utils module --- app.py | 13 +- bat_detect/command.py | 25 +- bat_detect/detector/post_process.py | 37 +- bat_detect/evaluate/evaluate_models.py | 14 +- bat_detect/utils/audio_utils.py | 65 ++- bat_detect/utils/detector_utils.py | 734 ++++++++++++++++++++----- run_batdetect.py | 1 - scripts/gen_spec_image.py | 7 +- scripts/gen_spec_video.py | 7 +- 9 files changed, 716 insertions(+), 187 deletions(-) diff --git a/app.py b/app.py index 1c884f0..5e11be2 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,3 @@ -import os - import gradio as gr import matplotlib.pyplot as plt import numpy as np @@ -37,7 +35,6 @@ examples = [ def make_prediction(file_name=None, detection_threshold=0.3): - if file_name is not None: audio_file = file_name else: @@ -46,9 +43,17 @@ def make_prediction(file_name=None, detection_threshold=0.3): if detection_threshold is not None and detection_threshold != "": args["detection_threshold"] = float(detection_threshold) + run_config = { + **params, + **args, + "max_duration": max_duration, + } + # process the file to generate predictions results = du.process_file( - audio_file, model, params, args, max_duration=max_duration + audio_file, + model, + run_config, ) anns = [ann for ann in results["pred_dict"]["annotation"]] diff --git a/bat_detect/command.py b/bat_detect/command.py index 9996832..35ea257 100644 --- a/bat_detect/command.py +++ b/bat_detect/command.py @@ -1,3 +1,9 @@ +"""Main script for running BatDetect2 on audio files. + +Example usage: + python command.py /path/to/audio/ /path/to/ann/ 0.1 + +""" import argparse import os @@ -7,7 +13,6 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) def parse_args(): - info_str = ( "\nBatDetect2 - Detection and Classification\n" + " Assumes audio files are mono, not stereo.\n" @@ -88,14 +93,22 @@ def main(): print("\nInput directory: " + args["audio_dir"]) files = du.get_audio_files(args["audio_dir"]) - print("Number of audio files: {}".format(len(files))) + + print(f"Number of audio files: {len(files)}") print("\nSaving results to: " + args["ann_dir"]) + # set up run config + run_config = { + **args, + **params, + } + # process files error_files = [] - for ii, audio_file in enumerate(files): + for audio_file in files: try: - results = du.process_file(audio_file, model, params, args) + results = du.process_file(audio_file, model, run_config) + if args["save_preds_if_empty"] or ( len(results["pred_dict"]["annotation"]) > 0 ): @@ -103,9 +116,9 @@ def main(): args["audio_dir"], args["ann_dir"] ) du.save_results_to_file(results, results_path) - except: + except (RuntimeError, ValueError, LookupError) as err: error_files.append(audio_file) - print("Error processing file!") + print(f"Error processing file!: {err}") print("\nResults saved to: " + args["ann_dir"]) diff --git a/bat_detect/detector/post_process.py b/bat_detect/detector/post_process.py index 2745cdf..05fabfc 100644 --- a/bat_detect/detector/post_process.py +++ b/bat_detect/detector/post_process.py @@ -1,7 +1,11 @@ import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict np.seterr(divide="ignore", invalid="ignore") @@ -18,7 +22,33 @@ def overall_class_pred(det_prob, class_prob): return weighted_pred / weighted_pred.sum() -def run_nms(outputs, params, sampling_rate): +class NonMaximumSuppressionConfig(TypedDict): + """Configuration for non-maximum suppression.""" + + nms_kernel_size: int + """Size of the kernel for non-maximum suppression.""" + + max_freq: float + """Maximum frequency to consider in Hz.""" + + min_freq: float + """Minimum frequency to consider in Hz.""" + + fft_win_length: float + """Length of the FFT window in seconds.""" + + fft_overlap: float + """Overlap of the FFT windows in seconds.""" + + nms_top_k_per_sec: float + """Number of top detections to keep per second.""" + + detection_threshold: float + """Threshold for detection probability.""" + + +def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int): + """Run non-maximum suppression on the output of the model.""" pred_det = outputs["pred_det"] # probability of box pred_size = outputs["pred_size"] # box size @@ -92,6 +122,7 @@ def run_nms(outputs, params, sampling_rate): # convert to numpy for kk in pred.keys(): pred[kk] = pred[kk].cpu().numpy().astype(np.float32) + preds.append(pred) return preds, feats diff --git a/bat_detect/evaluate/evaluate_models.py b/bat_detect/evaluate/evaluate_models.py index e7ce249..8ee3282 100644 --- a/bat_detect/evaluate/evaluate_models.py +++ b/bat_detect/evaluate/evaluate_models.py @@ -6,14 +6,12 @@ import argparse import copy import json import os -import sys import numpy as np import pandas as pd from sklearn.ensemble import RandomForestClassifier -sys.path.append("../../") -import bat_detect.detector.parameters as parameters +from bat_detect.detector import parameters import bat_detect.train.evaluate as evl import bat_detect.train.train_utils as tu import bat_detect.utils.detector_utils as du @@ -749,14 +747,18 @@ if __name__ == "__main__": print("Warning: Class names are not the same as the trained model") assert False + run_config = { + **bd_args, + **params_bd, + "return_raw_preds": True, + } + preds_bd = [] for ii, gg in enumerate(gt_test): pred = du.process_file( gg["file_path"], model, - params_bd, - bd_args, - return_raw_preds=True, + run_config, ) preds_bd.append(pred) diff --git a/bat_detect/utils/audio_utils.py b/bat_detect/utils/audio_utils.py index 318b790..1a62f0c 100644 --- a/bat_detect/utils/audio_utils.py +++ b/bat_detect/utils/audio_utils.py @@ -1,4 +1,5 @@ import warnings +from typing import Optional, Tuple import librosa import numpy as np @@ -7,6 +8,11 @@ import torch from . import wavfile +__all__ = [ + "load_audio_file", +] + + def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor noverlap = np.floor(fft_overlap * nfft) @@ -105,40 +111,65 @@ def generate_spectrogram( def load_audio_file( - audio_file, - time_exp_fact, - target_samp_rate, - scale=False, - max_duration=False, + audio_file: str, + time_exp_fact: float, + target_samp_rate: int, + scale: bool = False, + max_duration: Optional[float] = None, ): + """Load an audio file and resample it to the target sampling rate. + + The audio is also scaled to [-1, 1] and clipped to the maximum duration. + Only mono files are supported. + + Args: + audio_file (str): Path to the audio file. + target_samp_rate (int): Target sampling rate. + scale (bool): Whether to scale the audio to [-1, 1]. + max_duration (float): Maximum duration of the audio in seconds. + + Returns: + sampling_rate: The sampling rate of the audio. + audio_raw: The audio signal in a numpy array. + + Raises: + ValueError: If the audio file is stereo. + + """ with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=wavfile.WavFileWarning) # sampling_rate, audio_raw = wavfile.read(audio_file) - audio_raw, sampling_rate = librosa.load(audio_file, sr=None) + audio_raw, sampling_rate = librosa.load( + audio_file, + sr=None, + dtype=np.float32, + ) if len(audio_raw.shape) > 1: - raise Exception("Currently does not handle stereo files") + raise ValueError("Currently does not handle stereo files") + sampling_rate = sampling_rate * time_exp_fact # resample - need to do this after correcting for time expansion sampling_rate_old = sampling_rate sampling_rate = target_samp_rate - audio_raw = librosa.resample( - audio_raw, - orig_sr=sampling_rate_old, - target_sr=sampling_rate, - res_type="polyphase", - ) + if sampling_rate_old != sampling_rate: + audio_raw = librosa.resample( + audio_raw, + orig_sr=sampling_rate_old, + target_sr=sampling_rate, + res_type="polyphase", + ) # clipping maximum duration - if max_duration is not False: + if max_duration is not None: max_duration = np.minimum( - int(sampling_rate * max_duration), audio_raw.shape[0] + int(sampling_rate * max_duration), + audio_raw.shape[0], ) audio_raw = audio_raw[:max_duration] - # convert to float32 and scale - audio_raw = audio_raw.astype(np.float32) + # scale to [-1, 1] if scale: audio_raw = audio_raw - audio_raw.mean() audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6) diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index 3186ab6..aec742c 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -1,6 +1,6 @@ import json import os -from typing import List, Tuple +from typing import Any, Iterator, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -24,7 +24,7 @@ DEFAULT_MODEL_PATH = os.path.join( "model.pth", ) -__all__ = ["load_model", "DEFAULT_MODEL_PATH"] +__all__ = ["load_model", "get_audio_files", "DEFAULT_MODEL_PATH"] def get_default_bd_args(): @@ -69,16 +69,29 @@ class ModelParameters(TypedDict): """Model parameters.""" model_name: str + """Model name.""" + num_filters: int + """Number of filters.""" + emb_dim: int + """Embedding dimension.""" + ip_height: int + """Input height in pixels.""" + resize_factor: int + """Resize factor.""" + class_names: List[str] + """Class names. The model is trained to detect these classes.""" + device: torch.device def load_model( - model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True + model_path: str = DEFAULT_MODEL_PATH, + load_weights: bool = True, ) -> Tuple[torch.nn.Module, ModelParameters]: """Load model from file. @@ -141,8 +154,7 @@ def load_model( return model, params -def merge_results(predictions, spec_feats, cnn_feats, spec_slices): - +def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): predictions_m = {} num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) @@ -157,82 +169,255 @@ def merge_results(predictions, spec_feats, cnn_feats, spec_slices): if len(spec_feats) > 0: spec_feats = np.vstack(spec_feats) + if len(cnn_feats) > 0: cnn_feats = np.vstack(cnn_feats) + return predictions_m, spec_feats, cnn_feats, spec_slices +class Annotation(TypedDict("WithClass", {"class": str})): + """Format of annotations. + + This is the format of a single annotation as expected by the annotation + tool. + """ + + start_time: float + """Start time in seconds.""" + + end_time: float + """End time in seconds.""" + + low_freq: int + """Low frequency in Hz.""" + + high_freq: int + """High frequency in Hz.""" + + class_prob: float + """Probability of class assignment.""" + + det_prob: float + """Probability of detection.""" + + individual: str + """Individual ID.""" + + event: str + """Type of detected event.""" + + +class FileAnnotations(TypedDict): + """Format of results. + + This is the format of the results expected by the annotation tool. + """ + + file_id: str + """File ID.""" + + annotated: bool + """Whether file has been annotated.""" + + duration: float + """Duration of audio file.""" + + issues: bool + """Whether file has issues.""" + + time_exp: float + """Time expansion factor.""" + + class_name: str + """Class predicted at file level""" + + annotation: List[Annotation] + + +class Results(TypedDict): + pred_dict: FileAnnotations + """Predictions in the format expected by the annotation tool.""" + + spec_feats: Optional[np.ndarray] + """Spectrogram features.""" + + spec_feat_names: Optional[List[str]] + """Spectrogram feature names.""" + + cnn_feats: Optional[np.ndarray] + """CNN features.""" + + cnn_feat_names: Optional[List[str]] + """CNN feature names.""" + + spec_slices: Optional[np.ndarray] + """Spectrogram slices.""" + + +class ResultParams(TypedDict): + """Result parameters.""" + + class_names: List[str] + """Class names.""" + + +def format_results( + file_id: str, + time_exp: float, + duration: float, + predictions, + class_names: List[str], +) -> FileAnnotations: + """Format results into the format expected by the annotation tool. + + Args: + file_id (str): File ID. + time_exp (float): Time expansion factor. + duration (float): Duration of audio file. + predictions (dict): Predictions. + + Returns: + dict: Results in the format expected by the annotation tool. + """ + # Get a single class prediction for the file + class_overall = pp.overall_class_pred( + predictions["det_probs"], + predictions["class_probs"], + ) + + # Get the best class prediction probability and index for each detection + class_prob_best = predictions["class_probs"].max(0) + class_ind_best = predictions["class_probs"].argmax(0) + + # Pack the results into a list of dictionaries + annotations: List[Annotation] = [ + { + "start_time": round(float(start_time), 4), + "end_time": round(end_time, 4), + "low_freq": int(low_freq), + "high_freq": int(high_freq), + "class": str(class_names[class_index]), + "class_prob": round(float(class_prob), 3), + "det_prob": round(float(det_prob), 3), + "individual": "-1", + "event": "Echolocation", + } + for ( + start_time, + end_time, + low_freq, + high_freq, + class_index, + class_prob, + det_prob, + ) in zip( + predictions["start_time"], + predictions["end_times"], + predictions["low_freqs"], + predictions["high_freqs"], + class_ind_best, + class_prob_best, + predictions["det_probs"], + ) + ] + + return { + "id": file_id, + "annotated": False, + "issues": False, + "notes": "Automatically generated.", + "time_exp": time_exp, + "duration": round(duration, 4), + "annotation": annotations, + "class_name": class_names[np.argmax(class_overall)], + } + + def convert_results( - file_id, - time_exp, - duration, - params, + file_id: str, + time_exp: float, + duration: float, + params: ResultParams, predictions, spec_feats, cnn_feats, spec_slices, -): +) -> Results: + """Convert results to dictionary as expected by the annotation tool. - # create a single dictionary - this is the format used by the annotation tool - pred_dict = {} - pred_dict["id"] = file_id - pred_dict["annotated"] = False - pred_dict["issues"] = False - pred_dict["notes"] = "Automatically generated." - pred_dict["time_exp"] = time_exp - pred_dict["duration"] = round(duration, 4) - pred_dict["annotation"] = [] + Args: + file_id (str): File ID. + time_exp (float): Time expansion factor. + duration (float): Duration of audio file. + params (dict): Model parameters. + predictions (dict): Predictions. + spec_feats (np.ndarray): Spectral features. + cnn_feats (np.ndarray): CNN features. + spec_slices (list): Spectrogram slices. - class_prob_best = predictions["class_probs"].max(0) - class_ind_best = predictions["class_probs"].argmax(0) - class_overall = pp.overall_class_pred( - predictions["det_probs"], predictions["class_probs"] + Returns: + dict: Dictionary with results. + + """ + pred_dict = format_results( + file_id, + time_exp, + duration, + predictions, + params["class_names"], ) - pred_dict["class_name"] = params["class_names"][np.argmax(class_overall)] - - for ii in range(predictions["det_probs"].shape[0]): - res = {} - res["start_time"] = round(float(predictions["start_times"][ii]), 4) - res["end_time"] = round(float(predictions["end_times"][ii]), 4) - res["low_freq"] = int(predictions["low_freqs"][ii]) - res["high_freq"] = int(predictions["high_freqs"][ii]) - res["class"] = str(params["class_names"][int(class_ind_best[ii])]) - res["class_prob"] = round(float(class_prob_best[ii]), 3) - res["det_prob"] = round(float(predictions["det_probs"][ii]), 3) - res["individual"] = "-1" - res["event"] = "Echolocation" - pred_dict["annotation"].append(res) # combine into final results dictionary results = {} results["pred_dict"] = pred_dict + + # add spectrogram features if they exist if len(spec_feats) > 0: results["spec_feats"] = spec_feats results["spec_feat_names"] = feats.get_feature_names() + + # add CNN features if they exist if len(cnn_feats) > 0: results["cnn_feats"] = cnn_feats results["cnn_feat_names"] = [ str(ii) for ii in range(cnn_feats.shape[1]) ] + + # add spectrogram slices if they exist if len(spec_slices) > 0: results["spec_slices"] = spec_slices return results -def save_results_to_file(results, op_path): +def save_results_to_file(results, op_path: str) -> None: + """Save results to file. + + Args: + results (dict): Results. + op_path (str): Output path. + + """ # make directory if it does not exist if not os.path.isdir(os.path.dirname(op_path)): os.makedirs(os.path.dirname(op_path)) # save csv file - if there are predictions - result_list = [res for res in results["pred_dict"]["annotation"]] - df = pd.DataFrame(result_list) - df["file_name"] = [results["pred_dict"]["id"]] * len(result_list) - df.index.name = "id" - if "class_prob" in df.columns: - df = df[ + result_list = results["pred_dict"]["annotation"] + + results_df = pd.DataFrame(result_list) + + # add file name as a column + results_df["file_name"] = results["pred_dict"]["id"] + + # rename index column + results_df.index.name = "id" + + # create a csv file with predicted events + if "class_prob" in results_df.columns: + preds_df = results_df[ [ "det_prob", "start_time", @@ -243,14 +428,14 @@ def save_results_to_file(results, op_path): "class_prob", ] ] - df.to_csv(op_path + ".csv", sep=",") + preds_df.to_csv(op_path + ".csv", sep=",") - # save features if "spec_feats" in results.keys(): - df = pd.DataFrame( + # create csv file with spectrogram features + spec_feats_df = pd.DataFrame( results["spec_feats"], columns=results["spec_feat_names"] ) - df.to_csv( + spec_feats_df.to_csv( op_path + "_spec_features.csv", sep=",", index=False, @@ -258,10 +443,12 @@ def save_results_to_file(results, op_path): ) if "cnn_feats" in results.keys(): - df = pd.DataFrame( - results["cnn_feats"], columns=results["cnn_feat_names"] + # create csv file with cnn extracted features + cnn_feats_df = pd.DataFrame( + results["cnn_feats"], + columns=results["cnn_feat_names"], ) - df.to_csv( + cnn_feats_df.to_csv( op_path + "_cnn_features.csv", sep=",", index=False, @@ -269,11 +456,71 @@ def save_results_to_file(results, op_path): ) # save json file - with open(op_path + ".json", "w") as da: - json.dump(results["pred_dict"], da, indent=2, sort_keys=True) + with open(op_path + ".json", "w", encoding="utf-8") as jsonfile: + json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True) -def compute_spectrogram(audio, sampling_rate, params, return_np=False): +class SpectrogramParameters(TypedDict): + """Parameters for generating spectrograms.""" + + fft_win_length: int + """Length of the FFT window in samples.""" + + fft_overlap: int + """Number of samples to overlap between FFT windows.""" + + spec_height: int + """Height of the spectrogram in pixels.""" + + spec_width: int + """Width of the spectrogram in pixels.""" + + resize_factor: int + """Factor to resize the spectrogram by.""" + + spec_divide_factor: int + """Factor to divide the spectrogram by.""" + + device: torch.device + """Device to store the spectrogram on.""" + + +def compute_spectrogram( + audio: np.ndarray, + sampling_rate: int, + params: SpectrogramParameters, + return_np: bool = False, +) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]: + """Compute a spectrogram from an audio array. + + Will pad the audio array so that it is evenly divisible by the + downsampling factors. + + Parameters + ---------- + audio : np.ndarray + + sampling_rate : int + + params : SpectrogramParameters + The parameters to use for generating the spectrogram. + + return_np : bool, optional + Whether to return the spectrogram as a numpy array as well as a + torch tensor. The default is False. + + Returns + ------- + duration : float + The duration of the spectrgram in seconds. + + spec : torch.Tensor + The spectrogram as a torch tensor. + + spec_np : np.ndarray, optional + The spectrogram as a numpy array. Only returned if `return_np` is + True, otherwise None. + """ # pad audio so it is evenly divisible by downsampling factors duration = audio.shape[0] / float(sampling_rate) audio = au.pad_audio( @@ -290,13 +537,21 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False): # convert to pytorch spec = torch.from_numpy(spec).to(params["device"]) + + # add batch and channel dimensions spec = spec.unsqueeze(0).unsqueeze(0) # resize the spec - rs = params["resize_factor"] - spec_op_shape = (int(params["spec_height"] * rs), int(spec.shape[-1] * rs)) + resize_factor = params["resize_factor"] + spec_op_shape = ( + int(params["spec_height"] * resize_factor), + int(spec.shape[-1] * resize_factor), + ) spec = F.interpolate( - spec, size=spec_op_shape, mode="bilinear", align_corners=False + spec, + size=spec_op_shape, + mode="bilinear", + align_corners=False, ) if return_np: @@ -307,135 +562,318 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False): return duration, spec, spec_np -def process_file( - audio_file, - model, - params, - args, - time_exp=None, - top_n=5, - return_raw_preds=False, - max_duration=False, -): +def iterate_over_chunks( + audio: np.ndarray, + samplerate: int, + chunk_size: float, +) -> Iterator[Tuple[float, np.ndarray]]: + """Iterate over audio in chunks of size chunk_size. + Parameters + ---------- + audio : np.ndarray + + samplerate : int + + chunk_size : float + Size of chunks in seconds. + + Yields + ------ + chunk_start : float + Start time of chunk in seconds. + chunk : np.ndarray + + """ + nsamples = audio.shape[0] + duration_full = nsamples / samplerate + num_chunks = int(np.ceil(duration_full / chunk_size)) + for chunk_id in range(num_chunks): + chunk_start = chunk_size * chunk_id + chunk_length = int(samplerate * chunk_size) + start_sample = chunk_id * chunk_length + end_sample = np.minimum((chunk_id + 1) * chunk_length, nsamples) + yield chunk_start, audio[start_sample:end_sample] + + +class ProcessingConfiguration(TypedDict): + """Parameters for processing audio files.""" + + # audio parameters + target_samp_rate: int + """Target sampling rate of the audio.""" + + fft_win_length: float + """Length of the FFT window in seconds.""" + fft_overlap: float + """Length of the FFT window in samples.""" + + resize_factor: float + """Factor to resize the spectrogram by.""" + + spec_divide_factor: float + """Factor to divide the spectrogram by.""" + + spec_height: int + """Height of the spectrogram in pixels.""" + + scale_raw_audio: bool + """Whether to scale the raw audio to be between -1 and 1.""" + + device: torch.device + """Device to run the model on.""" + + class_names: List[str] + """Names of the classes the model can detect.""" + + detection_threshold: float + """Threshold for detection probability.""" + + time_expansion: Optional[float] + """Time expansion factor of the processed recordings.""" + + top_n: int + """Number of top detections to keep.""" + + return_raw_preds: bool + """Whether to return raw predictions.""" + + max_duration: Optional[float] + """Maximum duration of audio file to process in seconds.""" + + nms_kernel_size: int + """Size of the kernel for non-maximum suppression.""" + + max_freq: float + """Maximum frequency to consider in Hz.""" + + min_freq: float + """Minimum frequency to consider in Hz.""" + + nms_top_k_per_sec: float + """Number of top detections to keep per second.""" + + detection_threshold: float + """Threshold for detection probability.""" + + quiet: bool + """Whether to suppress output.""" + + +def process_spectrogram( + spec: torch.Tensor, + samplerate: int, + model: torch.nn.Module, + config: pp.NonMaximumSuppressionConfig, +): + """Process a spectrogram with detection model. + + Will run non-maximum suppression on the output of the model. + + Parameters + ---------- + spec : torch.Tensor + + samplerate : int + + model : torch.nn.Module + Detection model. + + config : pp.NonMaximumSuppressionConfig + Parameters for non-maximum suppression. + + Returns + ------- + pred_nms : Dict[str, np.ndarray] + features : Dict[str, np.ndarray] + """ + # evaluate model + with torch.no_grad(): + outputs = model(spec, return_feats=config["cnn_features"]) + + # run non-max suppression + pred_nms, features = pp.run_nms( + outputs, + config, + np.array([float(samplerate)]), + ) + + pred_nms = pred_nms[0] + + # if we have a background class + if pred_nms["class_probs"].shape[0] > len(config["class_names"]): + pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :] + + return pred_nms, features + + +def process_audio_array( + audio: np.ndarray, + sampling_rate: int, + model: torch.nn.Module, + config: ProcessingConfiguration, +): + """Process a single audio array with detection model. + + Parameters + ---------- + audio : np.ndarray + + sampling_rate : int + + model : torch.nn.Module + Detection model. + + config : ProcessingConfiguration + Configuration for processing. + + Returns + ------- + pred_nms : Dict[str, np.ndarray] + features : Dict[str, np.ndarray] + spec_np : np.ndarray + """ + # load audio file and compute spectrogram + _, spec, spec_np = compute_spectrogram( + audio, + sampling_rate, + config, + return_np=config["spec_features"] or config["spec_slices"], + ) + + # process spectrogram with model + pred_nms, features = process_spectrogram( + spec, + sampling_rate, + model, + config, + ) + + return pred_nms, features, spec_np + + +def process_file( + audio_file: str, + model: torch.nn.Module, + config: ProcessingConfiguration, +) -> Union[Results, Any]: + """Process a single audio file with detection model. + + Will split the audio file into chunks if it is too long and + process each chunk separately. + + Parameters + ---------- + audio_file : str + Path to audio file. + + model : torch.nn.Module + Detection model. + + config : ProcessingConfiguration + Configuration for processing. + + Returns + ------- + results : Results or Any + Results of processing audio file with the given detection model. + Will be a dictionary if `config["return_raw_preds"]` is `True`, + """ # store temporary results here predictions = [] spec_feats = [] cnn_feats = [] spec_slices = [] - # get time expansion factor - if time_exp is None: - time_exp = args["time_expansion_factor"] - - params["detection_threshold"] = args["detection_threshold"] - # load audio file sampling_rate, audio_full = au.load_audio_file( audio_file, - time_exp, - params["target_samp_rate"], - params["scale_raw_audio"], + time_exp_fact=config["time_expansion"], + target_samp_rate=config["target_samp_rate"], + scale=config["scale_raw_audio"], + max_duration=config["max_duration"], ) - # clipping maximum duration - if max_duration is not False: - max_duration = np.minimum( - int(sampling_rate * max_duration), audio_full.shape[0] - ) - audio_full = audio_full[:max_duration] - - duration_full = audio_full.shape[0] / float(sampling_rate) - - return_np_spec = args["spec_features"] or args["spec_slices"] - # loop through larger file and split into chunks - # TODO fix so that it overlaps correctly and takes care of duplicate detections at borders - num_chunks = int(np.ceil(duration_full / args["chunk_size"])) - for chunk_id in range(num_chunks): - - # chunk - chunk_time = args["chunk_size"] * chunk_id - chunk_length = int(sampling_rate * args["chunk_size"]) - start_sample = chunk_id * chunk_length - end_sample = np.minimum( - (chunk_id + 1) * chunk_length, audio_full.shape[0] - ) - audio = audio_full[start_sample:end_sample] - - # load audio file and compute spectrogram - duration, spec, spec_np = compute_spectrogram( - audio, sampling_rate, params, return_np_spec + # TODO fix so that it overlaps correctly and takes care of + # duplicate detections at borders + for chunk_time, audio in iterate_over_chunks( + audio_full, + sampling_rate, + config["chunk_size"], + ): + # Run detection model on chunk + pred_nms, features, spec_np = process_audio_array( + audio, + sampling_rate, + model, + config, ) - # evaluate model - with torch.no_grad(): - outputs = model(spec, return_feats=args["cnn_features"]) - - # run non-max suppression - pred_nms, features = pp.run_nms( - outputs, params, np.array([float(sampling_rate)]) - ) - pred_nms = pred_nms[0] + # add chunk time to start and end times pred_nms["start_times"] += chunk_time pred_nms["end_times"] += chunk_time - # if we have a background class - if pred_nms["class_probs"].shape[0] > len(params["class_names"]): - pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :] - predictions.append(pred_nms) # extract features - if there are any calls detected if pred_nms["det_probs"].shape[0] > 0: - if args["spec_features"]: - spec_feats.append(feats.get_feats(spec_np, pred_nms, params)) + if config["spec_features"]: + spec_feats.append(feats.get_feats(spec_np, pred_nms, config)) - if args["cnn_features"]: + if config["cnn_features"]: cnn_feats.append(features[0]) - if args["spec_slices"]: + if config["spec_slices"]: spec_slices.extend( - feats.extract_spec_slices(spec_np, pred_nms, params) + feats.extract_spec_slices(spec_np, pred_nms, config) ) - # convert the predictions into output dictionary - file_id = os.path.basename(audio_file) - predictions, spec_feats, cnn_feats, spec_slices = merge_results( - predictions, spec_feats, cnn_feats, spec_slices - ) - results = convert_results( - file_id, - time_exp, - duration_full, - params, + # Merge results from chunks + predictions, spec_feats, cnn_feats, spec_slices = _merge_results( predictions, spec_feats, cnn_feats, spec_slices, ) + # convert results to a dictionary in the right format + results = convert_results( + file_id=os.path.basename(audio_file), + time_exp=config["time_expansion"], + duration=audio_full.shape[0] / float(sampling_rate), + params=config, + predictions=predictions, + spec_feats=spec_feats, + cnn_feats=cnn_feats, + spec_slices=spec_slices, + ) + # summarize results - if not args["quiet"]: - num_detections = len(results["pred_dict"]["annotation"]) - print( - "{}".format(num_detections) - + " call(s) detected above the threshold." - ) + if not config["quiet"]: + summarize_results(results, predictions, config) + + if config["return_raw_preds"]: + return predictions + + return results + + +def summarize_results(results, predictions, config): + """Print summary of results.""" + num_detections = len(results["pred_dict"]["annotation"]) + print(f"{num_detections} call(s) detected above the threshold.") # print results for top n classes - if not args["quiet"] and (num_detections > 0): + if num_detections > 0: class_overall = pp.overall_class_pred( - predictions["det_probs"], predictions["class_probs"] + predictions["det_probs"], + predictions["class_probs"], ) print("species name".ljust(30) + "probablity present") - for cc in np.argsort(class_overall)[::-1][:top_n]: - print( - params["class_names"][cc].ljust(30) - + str(round(class_overall[cc], 3)) - ) - if return_raw_preds: - return predictions - else: - return results + for class_index in np.argsort(class_overall)[::-1][: config["top_n"]]: + print( + config["class_names"][class_index].ljust(30) + + str(round(class_overall[class_index], 3)) + ) diff --git a/run_batdetect.py b/run_batdetect.py index ec3e535..e9e06da 100644 --- a/run_batdetect.py +++ b/run_batdetect.py @@ -5,7 +5,6 @@ import bat_detect.utils.detector_utils as du def main(args): - print("Loading model: " + args["model_path"]) model, params = du.load_model(args["model_path"]) diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py index ba69481..8269f78 100644 --- a/scripts/gen_spec_image.py +++ b/scripts/gen_spec_image.py @@ -136,8 +136,13 @@ if __name__ == "__main__": audio, sampling_rate, params_bd, True, False ) + run_config = { + **params_bd, + **bd_args, + } + # run model and filter detections so only keep ones in relevant time range - results = du.process_file(args_cmd["audio_file"], model, params_bd, bd_args) + results = du.process_file(args_cmd["audio_file"], model, run_config) pred_anns = filter_anns( results["pred_dict"]["annotation"], args_cmd["start_time"], diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py index 17354f6..cbb443c 100644 --- a/scripts/gen_spec_video.py +++ b/scripts/gen_spec_video.py @@ -122,7 +122,12 @@ if __name__ == "__main__": print(" Loading model and running detector on entire file ...") model, det_params = du.load_model(args_cmd["model_path"]) det_params["detection_threshold"] = args["detection_threshold"] - results = du.process_file(audio_file, model, det_params, args) + + run_config = { + **det_params, + **args, + } + results = du.process_file(audio_file, model, run_config) print(" Processing detections and plotting ...") detections = []