From c865b53c17970f72c2993855353cb5e384c4f2ef Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Fri, 7 Apr 2023 16:26:01 -0600 Subject: [PATCH] fixed missing get_default_db_config function --- batdetect2/evaluate/evaluate_models.py | 5 ++++- batdetect2/utils/detector_utils.py | 17 +++++++++++++++++ scripts/gen_spec_image.py | 6 ++++-- scripts/gen_spec_video.py | 15 +++++++++++---- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/batdetect2/evaluate/evaluate_models.py b/batdetect2/evaluate/evaluate_models.py index 97c1bd1..3303c92 100644 --- a/batdetect2/evaluate/evaluate_models.py +++ b/batdetect2/evaluate/evaluate_models.py @@ -7,6 +7,7 @@ import copy import json import os +import torch import numpy as np import pandas as pd from sklearn.ensemble import RandomForestClassifier @@ -739,7 +740,7 @@ if __name__ == "__main__": # if args["bd_model_path"] != "": # load model - bd_args = du.get_default_run_config() + bd_args = du.get_default_bd_args() model, params_bd = du.load_model(args["bd_model_path"]) # check if the class names are the same @@ -754,11 +755,13 @@ if __name__ == "__main__": } preds_bd = [] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for ii, gg in enumerate(gt_test): pred = du.process_file( gg["file_path"], model, run_config, + device, ) preds_bd.append(pred) diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index c5dd4d0..d6d2b13 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -34,9 +34,26 @@ __all__ = [ "process_spectrogram", "process_audio_array", "process_file", + "get_default_bd_args", ] +def get_default_bd_args(): + args = {} + args["detection_threshold"] = 0.001 + args["time_expansion_factor"] = 1 + args["audio_dir"] = "" + args["ann_dir"] = "" + args["spec_slices"] = False + args["chunk_size"] = 3 + args["spec_features"] = False + args["cnn_features"] = False + args["quiet"] = True + args["save_preds_if_empty"] = True + args["ann_dir"] = os.path.join(args["ann_dir"], "") + return args + + def list_audio_files(ip_dir: str) -> List[str]: """Get all audio files in directory. diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py index e296979..2249b58 100644 --- a/scripts/gen_spec_image.py +++ b/scripts/gen_spec_image.py @@ -12,6 +12,7 @@ import json import os import sys +import torch import matplotlib.pyplot as plt import numpy as np @@ -85,7 +86,7 @@ if __name__ == "__main__": args_cmd = vars(parser.parse_args()) # load the model - bd_args = du.get_default_run_config() + bd_args = du.get_default_bd_args() model, params_bd = du.load_model(args_cmd["model_path"]) bd_args["detection_threshold"] = args_cmd["detection_threshold"] bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"] @@ -141,7 +142,8 @@ if __name__ == "__main__": } # run model and filter detections so only keep ones in relevant time range - results = du.process_file(args_cmd["audio_file"], model, run_config) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + results = du.process_file(args_cmd["audio_file"], model, run_config, device) pred_anns = filter_anns( results["pred_dict"]["annotation"], args_cmd["start_time"], diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py index e7ffc06..9636cae 100644 --- a/scripts/gen_spec_video.py +++ b/scripts/gen_spec_video.py @@ -15,6 +15,7 @@ import sys import matplotlib.pyplot as plt import numpy as np +import torch from scipy.io import wavfile import batdetect2.detector.parameters as parameters @@ -23,7 +24,6 @@ import batdetect2.utils.detector_utils as du import batdetect2.utils.plot_utils as viz if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument("audio_file", type=str, help="Path to input audio file") parser.add_argument( @@ -72,7 +72,7 @@ if __name__ == "__main__": sys.exit() if not os.path.isfile(args_cmd["model_path"]): - print("Model not found: ", model_path) + print("Model not found: ", args_cmd["model_path"]) sys.exit() start_time = 0.0 @@ -88,7 +88,7 @@ if __name__ == "__main__": os.makedirs(op_dir) params = parameters.get_params(False) - args = du.get_default_run_config() + args = du.get_default_bd_args() args["time_expansion_factor"] = args_cmd["time_expansion_factor"] args["detection_threshold"] = args_cmd["detection_threshold"] @@ -118,6 +118,8 @@ if __name__ == "__main__": max_val = spec.max() * 1.1 if not args_cmd["no_detector"]: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(" Loading model and running detector on entire file ...") model, det_params = du.load_model(args_cmd["model_path"]) det_params["detection_threshold"] = args["detection_threshold"] @@ -126,7 +128,12 @@ if __name__ == "__main__": **det_params, **args, } - results = du.process_file(audio_file, model, run_config) + results = du.process_file( + audio_file, + model, + run_config, + device, + ) print(" Processing detections and plotting ...") detections = []