mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added command module
This commit is contained in:
parent
7050f858aa
commit
9fd3412d5b
117
bat_detect/command.py
Normal file
117
bat_detect/command.py
Normal file
@ -0,0 +1,117 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import bat_detect.utils.detector_utils as du
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
info_str = (
|
||||
"\nBatDetect2 - Detection and Classification\n"
|
||||
+ " Assumes audio files are mono, not stereo.\n"
|
||||
+ ' Spaces in the input paths will throw an error. Wrap in quotes "".\n'
|
||||
+ " Input files should be short in duration e.g. < 30 seconds.\n"
|
||||
)
|
||||
|
||||
print(info_str)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
||||
parser.add_argument(
|
||||
"ann_dir",
|
||||
type=str,
|
||||
help="Output directory for where the predictions will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"detection_threshold",
|
||||
type=float,
|
||||
help="Cut-off probability for detector e.g. 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cnn_features",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="cnn_features",
|
||||
help="Extracts CNN call features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spec_features",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="spec_features",
|
||||
help="Extracts low level call features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_expansion_factor",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="time_expansion_factor",
|
||||
help="The time expansion factor used for all files (default is 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="quiet",
|
||||
help="Minimize output printing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_preds_if_empty",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="save_preds_if_empty",
|
||||
help="Save empty annotation file if no detections made.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default=os.path.join(CURRENT_DIR, "models/Net2DFast_UK_same.pth.tar"),
|
||||
help="Path to trained BatDetect2 model",
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["spec_slices"] = False # used for visualization
|
||||
# if files greater than this amount (seconds) they will be broken down into small chunks
|
||||
args["chunk_size"] = 2
|
||||
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
||||
args["quiet"] = True
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
print("Loading model: " + args["model_path"])
|
||||
model, params = du.load_model(args["model_path"])
|
||||
|
||||
print("\nInput directory: " + args["audio_dir"])
|
||||
files = du.get_audio_files(args["audio_dir"])
|
||||
print("Number of audio files: {}".format(len(files)))
|
||||
print("\nSaving results to: " + args["ann_dir"])
|
||||
|
||||
# process files
|
||||
error_files = []
|
||||
for ii, audio_file in enumerate(files):
|
||||
try:
|
||||
results = du.process_file(audio_file, model, params, args)
|
||||
if args["save_preds_if_empty"] or (
|
||||
len(results["pred_dict"]["annotation"]) > 0
|
||||
):
|
||||
results_path = audio_file.replace(args["audio_dir"], args["ann_dir"])
|
||||
du.save_results_to_file(results, results_path)
|
||||
except:
|
||||
error_files.append(audio_file)
|
||||
print("Error processing file!")
|
||||
|
||||
print("\nResults saved to: " + args["ann_dir"])
|
||||
|
||||
if len(error_files) > 0:
|
||||
print("\nUnable to process the follow files:")
|
||||
for err in error_files:
|
||||
print(" " + err)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user