fixed missing get_default_db_config function

This commit is contained in:
Santiago Martinez 2023-04-07 16:26:01 -06:00
parent 178fa518c3
commit c865b53c17
4 changed files with 36 additions and 7 deletions

View File

@ -7,6 +7,7 @@ import copy
import json import json
import os import os
import torch
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
@ -739,7 +740,7 @@ if __name__ == "__main__":
# #
if args["bd_model_path"] != "": if args["bd_model_path"] != "":
# load model # load model
bd_args = du.get_default_run_config() bd_args = du.get_default_bd_args()
model, params_bd = du.load_model(args["bd_model_path"]) model, params_bd = du.load_model(args["bd_model_path"])
# check if the class names are the same # check if the class names are the same
@ -754,11 +755,13 @@ if __name__ == "__main__":
} }
preds_bd = [] preds_bd = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for ii, gg in enumerate(gt_test): for ii, gg in enumerate(gt_test):
pred = du.process_file( pred = du.process_file(
gg["file_path"], gg["file_path"],
model, model,
run_config, run_config,
device,
) )
preds_bd.append(pred) preds_bd.append(pred)

View File

@ -34,9 +34,26 @@ __all__ = [
"process_spectrogram", "process_spectrogram",
"process_audio_array", "process_audio_array",
"process_file", "process_file",
"get_default_bd_args",
] ]
def get_default_bd_args():
args = {}
args["detection_threshold"] = 0.001
args["time_expansion_factor"] = 1
args["audio_dir"] = ""
args["ann_dir"] = ""
args["spec_slices"] = False
args["chunk_size"] = 3
args["spec_features"] = False
args["cnn_features"] = False
args["quiet"] = True
args["save_preds_if_empty"] = True
args["ann_dir"] = os.path.join(args["ann_dir"], "")
return args
def list_audio_files(ip_dir: str) -> List[str]: def list_audio_files(ip_dir: str) -> List[str]:
"""Get all audio files in directory. """Get all audio files in directory.

View File

@ -12,6 +12,7 @@ import json
import os import os
import sys import sys
import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -85,7 +86,7 @@ if __name__ == "__main__":
args_cmd = vars(parser.parse_args()) args_cmd = vars(parser.parse_args())
# load the model # load the model
bd_args = du.get_default_run_config() bd_args = du.get_default_bd_args()
model, params_bd = du.load_model(args_cmd["model_path"]) model, params_bd = du.load_model(args_cmd["model_path"])
bd_args["detection_threshold"] = args_cmd["detection_threshold"] bd_args["detection_threshold"] = args_cmd["detection_threshold"]
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"] bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
@ -141,7 +142,8 @@ if __name__ == "__main__":
} }
# run model and filter detections so only keep ones in relevant time range # run model and filter detections so only keep ones in relevant time range
results = du.process_file(args_cmd["audio_file"], model, run_config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
pred_anns = filter_anns( pred_anns = filter_anns(
results["pred_dict"]["annotation"], results["pred_dict"]["annotation"],
args_cmd["start_time"], args_cmd["start_time"],

View File

@ -15,6 +15,7 @@ import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch
from scipy.io import wavfile from scipy.io import wavfile
import batdetect2.detector.parameters as parameters import batdetect2.detector.parameters as parameters
@ -23,7 +24,6 @@ import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as viz import batdetect2.utils.plot_utils as viz
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Path to input audio file") parser.add_argument("audio_file", type=str, help="Path to input audio file")
parser.add_argument( parser.add_argument(
@ -72,7 +72,7 @@ if __name__ == "__main__":
sys.exit() sys.exit()
if not os.path.isfile(args_cmd["model_path"]): if not os.path.isfile(args_cmd["model_path"]):
print("Model not found: ", model_path) print("Model not found: ", args_cmd["model_path"])
sys.exit() sys.exit()
start_time = 0.0 start_time = 0.0
@ -88,7 +88,7 @@ if __name__ == "__main__":
os.makedirs(op_dir) os.makedirs(op_dir)
params = parameters.get_params(False) params = parameters.get_params(False)
args = du.get_default_run_config() args = du.get_default_bd_args()
args["time_expansion_factor"] = args_cmd["time_expansion_factor"] args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
args["detection_threshold"] = args_cmd["detection_threshold"] args["detection_threshold"] = args_cmd["detection_threshold"]
@ -118,6 +118,8 @@ if __name__ == "__main__":
max_val = spec.max() * 1.1 max_val = spec.max() * 1.1
if not args_cmd["no_detector"]: if not args_cmd["no_detector"]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Loading model and running detector on entire file ...") print(" Loading model and running detector on entire file ...")
model, det_params = du.load_model(args_cmd["model_path"]) model, det_params = du.load_model(args_cmd["model_path"])
det_params["detection_threshold"] = args["detection_threshold"] det_params["detection_threshold"] = args["detection_threshold"]
@ -126,7 +128,12 @@ if __name__ == "__main__":
**det_params, **det_params,
**args, **args,
} }
results = du.process_file(audio_file, model, run_config) results = du.process_file(
audio_file,
model,
run_config,
device,
)
print(" Processing detections and plotting ...") print(" Processing detections and plotting ...")
detections = [] detections = []