fixed missing get_default_db_config function

This commit is contained in:
Santiago Martinez 2023-04-07 16:26:01 -06:00
parent 178fa518c3
commit c865b53c17
4 changed files with 36 additions and 7 deletions

View File

@ -7,6 +7,7 @@ import copy
import json
import os
import torch
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
@ -739,7 +740,7 @@ if __name__ == "__main__":
#
if args["bd_model_path"] != "":
# load model
bd_args = du.get_default_run_config()
bd_args = du.get_default_bd_args()
model, params_bd = du.load_model(args["bd_model_path"])
# check if the class names are the same
@ -754,11 +755,13 @@ if __name__ == "__main__":
}
preds_bd = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for ii, gg in enumerate(gt_test):
pred = du.process_file(
gg["file_path"],
model,
run_config,
device,
)
preds_bd.append(pred)

View File

@ -34,9 +34,26 @@ __all__ = [
"process_spectrogram",
"process_audio_array",
"process_file",
"get_default_bd_args",
]
def get_default_bd_args():
args = {}
args["detection_threshold"] = 0.001
args["time_expansion_factor"] = 1
args["audio_dir"] = ""
args["ann_dir"] = ""
args["spec_slices"] = False
args["chunk_size"] = 3
args["spec_features"] = False
args["cnn_features"] = False
args["quiet"] = True
args["save_preds_if_empty"] = True
args["ann_dir"] = os.path.join(args["ann_dir"], "")
return args
def list_audio_files(ip_dir: str) -> List[str]:
"""Get all audio files in directory.

View File

@ -12,6 +12,7 @@ import json
import os
import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
@ -85,7 +86,7 @@ if __name__ == "__main__":
args_cmd = vars(parser.parse_args())
# load the model
bd_args = du.get_default_run_config()
bd_args = du.get_default_bd_args()
model, params_bd = du.load_model(args_cmd["model_path"])
bd_args["detection_threshold"] = args_cmd["detection_threshold"]
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
@ -141,7 +142,8 @@ if __name__ == "__main__":
}
# run model and filter detections so only keep ones in relevant time range
results = du.process_file(args_cmd["audio_file"], model, run_config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
pred_anns = filter_anns(
results["pred_dict"]["annotation"],
args_cmd["start_time"],

View File

@ -15,6 +15,7 @@ import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.io import wavfile
import batdetect2.detector.parameters as parameters
@ -23,7 +24,6 @@ import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as viz
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("audio_file", type=str, help="Path to input audio file")
parser.add_argument(
@ -72,7 +72,7 @@ if __name__ == "__main__":
sys.exit()
if not os.path.isfile(args_cmd["model_path"]):
print("Model not found: ", model_path)
print("Model not found: ", args_cmd["model_path"])
sys.exit()
start_time = 0.0
@ -88,7 +88,7 @@ if __name__ == "__main__":
os.makedirs(op_dir)
params = parameters.get_params(False)
args = du.get_default_run_config()
args = du.get_default_bd_args()
args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
args["detection_threshold"] = args_cmd["detection_threshold"]
@ -118,6 +118,8 @@ if __name__ == "__main__":
max_val = spec.max() * 1.1
if not args_cmd["no_detector"]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Loading model and running detector on entire file ...")
model, det_params = du.load_model(args_cmd["model_path"])
det_params["detection_threshold"] = args["detection_threshold"]
@ -126,7 +128,12 @@ if __name__ == "__main__":
**det_params,
**args,
}
results = du.process_file(audio_file, model, run_config)
results = du.process_file(
audio_file,
model,
run_config,
device,
)
print(" Processing detections and plotting ...")
detections = []