mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
fixed missing get_default_db_config function
This commit is contained in:
parent
178fa518c3
commit
c865b53c17
@ -7,6 +7,7 @@ import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
@ -739,7 +740,7 @@ if __name__ == "__main__":
|
||||
#
|
||||
if args["bd_model_path"] != "":
|
||||
# load model
|
||||
bd_args = du.get_default_run_config()
|
||||
bd_args = du.get_default_bd_args()
|
||||
model, params_bd = du.load_model(args["bd_model_path"])
|
||||
|
||||
# check if the class names are the same
|
||||
@ -754,11 +755,13 @@ if __name__ == "__main__":
|
||||
}
|
||||
|
||||
preds_bd = []
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for ii, gg in enumerate(gt_test):
|
||||
pred = du.process_file(
|
||||
gg["file_path"],
|
||||
model,
|
||||
run_config,
|
||||
device,
|
||||
)
|
||||
preds_bd.append(pred)
|
||||
|
||||
|
@ -34,9 +34,26 @@ __all__ = [
|
||||
"process_spectrogram",
|
||||
"process_audio_array",
|
||||
"process_file",
|
||||
"get_default_bd_args",
|
||||
]
|
||||
|
||||
|
||||
def get_default_bd_args():
|
||||
args = {}
|
||||
args["detection_threshold"] = 0.001
|
||||
args["time_expansion_factor"] = 1
|
||||
args["audio_dir"] = ""
|
||||
args["ann_dir"] = ""
|
||||
args["spec_slices"] = False
|
||||
args["chunk_size"] = 3
|
||||
args["spec_features"] = False
|
||||
args["cnn_features"] = False
|
||||
args["quiet"] = True
|
||||
args["save_preds_if_empty"] = True
|
||||
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
||||
return args
|
||||
|
||||
|
||||
def list_audio_files(ip_dir: str) -> List[str]:
|
||||
"""Get all audio files in directory.
|
||||
|
||||
|
@ -12,6 +12,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
@ -85,7 +86,7 @@ if __name__ == "__main__":
|
||||
args_cmd = vars(parser.parse_args())
|
||||
|
||||
# load the model
|
||||
bd_args = du.get_default_run_config()
|
||||
bd_args = du.get_default_bd_args()
|
||||
model, params_bd = du.load_model(args_cmd["model_path"])
|
||||
bd_args["detection_threshold"] = args_cmd["detection_threshold"]
|
||||
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
||||
@ -141,7 +142,8 @@ if __name__ == "__main__":
|
||||
}
|
||||
|
||||
# run model and filter detections so only keep ones in relevant time range
|
||||
results = du.process_file(args_cmd["audio_file"], model, run_config)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
|
||||
pred_anns = filter_anns(
|
||||
results["pred_dict"]["annotation"],
|
||||
args_cmd["start_time"],
|
||||
|
@ -15,6 +15,7 @@ import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.io import wavfile
|
||||
|
||||
import batdetect2.detector.parameters as parameters
|
||||
@ -23,7 +24,6 @@ import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as viz
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
||||
parser.add_argument(
|
||||
@ -72,7 +72,7 @@ if __name__ == "__main__":
|
||||
sys.exit()
|
||||
|
||||
if not os.path.isfile(args_cmd["model_path"]):
|
||||
print("Model not found: ", model_path)
|
||||
print("Model not found: ", args_cmd["model_path"])
|
||||
sys.exit()
|
||||
|
||||
start_time = 0.0
|
||||
@ -88,7 +88,7 @@ if __name__ == "__main__":
|
||||
os.makedirs(op_dir)
|
||||
|
||||
params = parameters.get_params(False)
|
||||
args = du.get_default_run_config()
|
||||
args = du.get_default_bd_args()
|
||||
args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
||||
args["detection_threshold"] = args_cmd["detection_threshold"]
|
||||
|
||||
@ -118,6 +118,8 @@ if __name__ == "__main__":
|
||||
max_val = spec.max() * 1.1
|
||||
|
||||
if not args_cmd["no_detector"]:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
print(" Loading model and running detector on entire file ...")
|
||||
model, det_params = du.load_model(args_cmd["model_path"])
|
||||
det_params["detection_threshold"] = args["detection_threshold"]
|
||||
@ -126,7 +128,12 @@ if __name__ == "__main__":
|
||||
**det_params,
|
||||
**args,
|
||||
}
|
||||
results = du.process_file(audio_file, model, run_config)
|
||||
results = du.process_file(
|
||||
audio_file,
|
||||
model,
|
||||
run_config,
|
||||
device,
|
||||
)
|
||||
|
||||
print(" Processing detections and plotting ...")
|
||||
detections = []
|
||||
|
Loading…
Reference in New Issue
Block a user