mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
fixed missing get_default_db_config function
This commit is contained in:
parent
178fa518c3
commit
c865b53c17
@ -7,6 +7,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
@ -739,7 +740,7 @@ if __name__ == "__main__":
|
|||||||
#
|
#
|
||||||
if args["bd_model_path"] != "":
|
if args["bd_model_path"] != "":
|
||||||
# load model
|
# load model
|
||||||
bd_args = du.get_default_run_config()
|
bd_args = du.get_default_bd_args()
|
||||||
model, params_bd = du.load_model(args["bd_model_path"])
|
model, params_bd = du.load_model(args["bd_model_path"])
|
||||||
|
|
||||||
# check if the class names are the same
|
# check if the class names are the same
|
||||||
@ -754,11 +755,13 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
|
|
||||||
preds_bd = []
|
preds_bd = []
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
for ii, gg in enumerate(gt_test):
|
for ii, gg in enumerate(gt_test):
|
||||||
pred = du.process_file(
|
pred = du.process_file(
|
||||||
gg["file_path"],
|
gg["file_path"],
|
||||||
model,
|
model,
|
||||||
run_config,
|
run_config,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
preds_bd.append(pred)
|
preds_bd.append(pred)
|
||||||
|
|
||||||
|
@ -34,9 +34,26 @@ __all__ = [
|
|||||||
"process_spectrogram",
|
"process_spectrogram",
|
||||||
"process_audio_array",
|
"process_audio_array",
|
||||||
"process_file",
|
"process_file",
|
||||||
|
"get_default_bd_args",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_bd_args():
|
||||||
|
args = {}
|
||||||
|
args["detection_threshold"] = 0.001
|
||||||
|
args["time_expansion_factor"] = 1
|
||||||
|
args["audio_dir"] = ""
|
||||||
|
args["ann_dir"] = ""
|
||||||
|
args["spec_slices"] = False
|
||||||
|
args["chunk_size"] = 3
|
||||||
|
args["spec_features"] = False
|
||||||
|
args["cnn_features"] = False
|
||||||
|
args["quiet"] = True
|
||||||
|
args["save_preds_if_empty"] = True
|
||||||
|
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
def list_audio_files(ip_dir: str) -> List[str]:
|
def list_audio_files(ip_dir: str) -> List[str]:
|
||||||
"""Get all audio files in directory.
|
"""Get all audio files in directory.
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -85,7 +86,7 @@ if __name__ == "__main__":
|
|||||||
args_cmd = vars(parser.parse_args())
|
args_cmd = vars(parser.parse_args())
|
||||||
|
|
||||||
# load the model
|
# load the model
|
||||||
bd_args = du.get_default_run_config()
|
bd_args = du.get_default_bd_args()
|
||||||
model, params_bd = du.load_model(args_cmd["model_path"])
|
model, params_bd = du.load_model(args_cmd["model_path"])
|
||||||
bd_args["detection_threshold"] = args_cmd["detection_threshold"]
|
bd_args["detection_threshold"] = args_cmd["detection_threshold"]
|
||||||
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
||||||
@ -141,7 +142,8 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
|
|
||||||
# run model and filter detections so only keep ones in relevant time range
|
# run model and filter detections so only keep ones in relevant time range
|
||||||
results = du.process_file(args_cmd["audio_file"], model, run_config)
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
results = du.process_file(args_cmd["audio_file"], model, run_config, device)
|
||||||
pred_anns = filter_anns(
|
pred_anns = filter_anns(
|
||||||
results["pred_dict"]["annotation"],
|
results["pred_dict"]["annotation"],
|
||||||
args_cmd["start_time"],
|
args_cmd["start_time"],
|
||||||
|
@ -15,6 +15,7 @@ import sys
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|
||||||
import batdetect2.detector.parameters as parameters
|
import batdetect2.detector.parameters as parameters
|
||||||
@ -23,7 +24,6 @@ import batdetect2.utils.detector_utils as du
|
|||||||
import batdetect2.utils.plot_utils as viz
|
import batdetect2.utils.plot_utils as viz
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
parser.add_argument("audio_file", type=str, help="Path to input audio file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -72,7 +72,7 @@ if __name__ == "__main__":
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
if not os.path.isfile(args_cmd["model_path"]):
|
if not os.path.isfile(args_cmd["model_path"]):
|
||||||
print("Model not found: ", model_path)
|
print("Model not found: ", args_cmd["model_path"])
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
@ -88,7 +88,7 @@ if __name__ == "__main__":
|
|||||||
os.makedirs(op_dir)
|
os.makedirs(op_dir)
|
||||||
|
|
||||||
params = parameters.get_params(False)
|
params = parameters.get_params(False)
|
||||||
args = du.get_default_run_config()
|
args = du.get_default_bd_args()
|
||||||
args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
||||||
args["detection_threshold"] = args_cmd["detection_threshold"]
|
args["detection_threshold"] = args_cmd["detection_threshold"]
|
||||||
|
|
||||||
@ -118,6 +118,8 @@ if __name__ == "__main__":
|
|||||||
max_val = spec.max() * 1.1
|
max_val = spec.max() * 1.1
|
||||||
|
|
||||||
if not args_cmd["no_detector"]:
|
if not args_cmd["no_detector"]:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
print(" Loading model and running detector on entire file ...")
|
print(" Loading model and running detector on entire file ...")
|
||||||
model, det_params = du.load_model(args_cmd["model_path"])
|
model, det_params = du.load_model(args_cmd["model_path"])
|
||||||
det_params["detection_threshold"] = args["detection_threshold"]
|
det_params["detection_threshold"] = args["detection_threshold"]
|
||||||
@ -126,7 +128,12 @@ if __name__ == "__main__":
|
|||||||
**det_params,
|
**det_params,
|
||||||
**args,
|
**args,
|
||||||
}
|
}
|
||||||
results = du.process_file(audio_file, model, run_config)
|
results = du.process_file(
|
||||||
|
audio_file,
|
||||||
|
model,
|
||||||
|
run_config,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
print(" Processing detections and plotting ...")
|
print(" Processing detections and plotting ...")
|
||||||
detections = []
|
detections = []
|
||||||
|
Loading…
Reference in New Issue
Block a user