mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Fixed errors with command
This commit is contained in:
parent
e6a6ad4696
commit
40222d8233
@ -71,7 +71,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path",
|
"--model_path",
|
||||||
type=str,
|
type=str,
|
||||||
default=os.path.join(CURRENT_DIR, "models/Net2DFast_UK_same.pth.tar"),
|
default=du.DEFAULT_MODEL_PATH,
|
||||||
help="Path to trained BatDetect2 model",
|
help="Path to trained BatDetect2 model",
|
||||||
)
|
)
|
||||||
args = vars(parser.parse_args())
|
args = vars(parser.parse_args())
|
||||||
@ -97,8 +97,11 @@ def main():
|
|||||||
print(f"Number of audio files: {len(files)}")
|
print(f"Number of audio files: {len(files)}")
|
||||||
print("\nSaving results to: " + args["ann_dir"])
|
print("\nSaving results to: " + args["ann_dir"])
|
||||||
|
|
||||||
|
default_config = du.get_default_config()
|
||||||
|
|
||||||
# set up run config
|
# set up run config
|
||||||
run_config = {
|
run_config = {
|
||||||
|
**default_config,
|
||||||
**args,
|
**args,
|
||||||
**params,
|
**params,
|
||||||
}
|
}
|
||||||
@ -119,6 +122,7 @@ def main():
|
|||||||
except (RuntimeError, ValueError, LookupError) as err:
|
except (RuntimeError, ValueError, LookupError) as err:
|
||||||
error_files.append(audio_file)
|
error_files.append(audio_file)
|
||||||
print(f"Error processing file!: {err}")
|
print(f"Error processing file!: {err}")
|
||||||
|
raise err
|
||||||
|
|
||||||
print("\nResults saved to: " + args["ann_dir"])
|
print("\nResults saved to: " + args["ann_dir"])
|
||||||
|
|
||||||
|
@ -13,6 +13,9 @@ SCALE_RAW_AUDIO = False
|
|||||||
DETECTION_THRESHOLD = 0.01
|
DETECTION_THRESHOLD = 0.01
|
||||||
NMS_KERNEL_SIZE = 9
|
NMS_KERNEL_SIZE = 9
|
||||||
NMS_TOP_K_PER_SEC = 200
|
NMS_TOP_K_PER_SEC = 200
|
||||||
|
SPEC_SCALE = "pcen"
|
||||||
|
DENOISE_SPEC_AVG = True
|
||||||
|
MAX_SCALE_SPEC = False
|
||||||
|
|
||||||
|
|
||||||
def mk_dir(path):
|
def mk_dir(path):
|
||||||
@ -70,14 +73,14 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
|||||||
# spec processing params
|
# spec processing params
|
||||||
params[
|
params[
|
||||||
"denoise_spec_avg"
|
"denoise_spec_avg"
|
||||||
] = True # removes the mean for each frequency band
|
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
|
||||||
params[
|
params[
|
||||||
"scale_raw_audio"
|
"scale_raw_audio"
|
||||||
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
|
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
|
||||||
params[
|
params[
|
||||||
"max_scale_spec"
|
"max_scale_spec"
|
||||||
] = False # scales the spectrogram so that it is max 1
|
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
|
||||||
params["spec_scale"] = "pcen" # 'log', 'pcen', 'none'
|
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
|
||||||
|
|
||||||
# detection params
|
# detection params
|
||||||
params[
|
params[
|
||||||
|
@ -12,10 +12,12 @@ import bat_detect.detector.post_process as pp
|
|||||||
import bat_detect.utils.audio_utils as au
|
import bat_detect.utils.audio_utils as au
|
||||||
from bat_detect.detector import models
|
from bat_detect.detector import models
|
||||||
from bat_detect.detector.parameters import (
|
from bat_detect.detector.parameters import (
|
||||||
|
DENOISE_SPEC_AVG,
|
||||||
DETECTION_THRESHOLD,
|
DETECTION_THRESHOLD,
|
||||||
FFT_OVERLAP,
|
FFT_OVERLAP,
|
||||||
FFT_WIN_LENGTH_S,
|
FFT_WIN_LENGTH_S,
|
||||||
MAX_FREQ_HZ,
|
MAX_FREQ_HZ,
|
||||||
|
MAX_SCALE_SPEC,
|
||||||
MIN_FREQ_HZ,
|
MIN_FREQ_HZ,
|
||||||
NMS_KERNEL_SIZE,
|
NMS_KERNEL_SIZE,
|
||||||
NMS_TOP_K_PER_SEC,
|
NMS_TOP_K_PER_SEC,
|
||||||
@ -23,6 +25,7 @@ from bat_detect.detector.parameters import (
|
|||||||
SCALE_RAW_AUDIO,
|
SCALE_RAW_AUDIO,
|
||||||
SPEC_DIVIDE_FACTOR,
|
SPEC_DIVIDE_FACTOR,
|
||||||
SPEC_HEIGHT,
|
SPEC_HEIGHT,
|
||||||
|
SPEC_SCALE,
|
||||||
TARGET_SAMPLERATE_HZ,
|
TARGET_SAMPLERATE_HZ,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -35,12 +38,13 @@ except ImportError:
|
|||||||
DEFAULT_MODEL_PATH = os.path.join(
|
DEFAULT_MODEL_PATH = os.path.join(
|
||||||
os.path.dirname(os.path.dirname(__file__)),
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
"models",
|
"models",
|
||||||
"model.pth",
|
"Net2DFast_UK_same.pth.tar",
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
"get_audio_files",
|
"get_audio_files",
|
||||||
|
"get_default_config",
|
||||||
"format_results",
|
"format_results",
|
||||||
"save_results_to_file",
|
"save_results_to_file",
|
||||||
"iterate_over_chunks",
|
"iterate_over_chunks",
|
||||||
@ -313,7 +317,7 @@ def format_results(
|
|||||||
annotations: List[Annotation] = [
|
annotations: List[Annotation] = [
|
||||||
{
|
{
|
||||||
"start_time": round(float(start_time), 4),
|
"start_time": round(float(start_time), 4),
|
||||||
"end_time": round(end_time, 4),
|
"end_time": round(float(end_time), 4),
|
||||||
"low_freq": int(low_freq),
|
"low_freq": int(low_freq),
|
||||||
"high_freq": int(high_freq),
|
"high_freq": int(high_freq),
|
||||||
"class": str(class_names[class_index]),
|
"class": str(class_names[class_index]),
|
||||||
@ -331,7 +335,7 @@ def format_results(
|
|||||||
class_prob,
|
class_prob,
|
||||||
det_prob,
|
det_prob,
|
||||||
) in zip(
|
) in zip(
|
||||||
predictions["start_time"],
|
predictions["start_times"],
|
||||||
predictions["end_times"],
|
predictions["end_times"],
|
||||||
predictions["low_freqs"],
|
predictions["low_freqs"],
|
||||||
predictions["high_freqs"],
|
predictions["high_freqs"],
|
||||||
@ -347,7 +351,7 @@ def format_results(
|
|||||||
"issues": False,
|
"issues": False,
|
||||||
"notes": "Automatically generated.",
|
"notes": "Automatically generated.",
|
||||||
"time_exp": time_exp,
|
"time_exp": time_exp,
|
||||||
"duration": round(duration, 4),
|
"duration": round(float(duration), 4),
|
||||||
"annotation": annotations,
|
"annotation": annotations,
|
||||||
"class_name": class_names[np.argmax(class_overall)],
|
"class_name": class_names[np.argmax(class_overall)],
|
||||||
}
|
}
|
||||||
@ -458,7 +462,8 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
if "spec_feats" in results.keys():
|
if "spec_feats" in results.keys():
|
||||||
# create csv file with spectrogram features
|
# create csv file with spectrogram features
|
||||||
spec_feats_df = pd.DataFrame(
|
spec_feats_df = pd.DataFrame(
|
||||||
results["spec_feats"], columns=results["spec_feat_names"]
|
results["spec_feats"],
|
||||||
|
columns=results["spec_feat_names"],
|
||||||
)
|
)
|
||||||
spec_feats_df.to_csv(
|
spec_feats_df.to_csv(
|
||||||
op_path + "_spec_features.csv",
|
op_path + "_spec_features.csv",
|
||||||
@ -506,6 +511,21 @@ class SpectrogramParameters(TypedDict):
|
|||||||
device: torch.device
|
device: torch.device
|
||||||
"""Device to store the spectrogram on."""
|
"""Device to store the spectrogram on."""
|
||||||
|
|
||||||
|
max_freq: int
|
||||||
|
"""Maximum frequency to display in the spectrogram."""
|
||||||
|
|
||||||
|
min_freq: int
|
||||||
|
"""Minimum frequency to display in the spectrogram."""
|
||||||
|
|
||||||
|
spec_scale: str
|
||||||
|
"""Scale to use for the spectrogram."""
|
||||||
|
|
||||||
|
denoise_spec_avg: bool
|
||||||
|
"""Whether to denoise the spectrogram by averaging."""
|
||||||
|
|
||||||
|
max_scale_spec: bool
|
||||||
|
"""Whether to scale the spectrogram so that its max is 1."""
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
@ -640,6 +660,15 @@ class ProcessingConfiguration(TypedDict):
|
|||||||
spec_height: int
|
spec_height: int
|
||||||
"""Height of the spectrogram in pixels."""
|
"""Height of the spectrogram in pixels."""
|
||||||
|
|
||||||
|
spec_scale: str
|
||||||
|
"""Scale to use for the spectrogram."""
|
||||||
|
|
||||||
|
denoise_spec_avg: bool
|
||||||
|
"""Whether to denoise the spectrogram by averaging."""
|
||||||
|
|
||||||
|
max_scale_spec: bool
|
||||||
|
"""Whether to scale the spectrogram so that its max is 1."""
|
||||||
|
|
||||||
scale_raw_audio: bool
|
scale_raw_audio: bool
|
||||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||||
|
|
||||||
@ -735,6 +764,7 @@ def process_spectrogram(
|
|||||||
"resize_factor": config["resize_factor"],
|
"resize_factor": config["resize_factor"],
|
||||||
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
||||||
"detection_threshold": config["detection_threshold"],
|
"detection_threshold": config["detection_threshold"],
|
||||||
|
"max_scale_spec": config["max_scale_spec"],
|
||||||
},
|
},
|
||||||
np.array([float(samplerate)]),
|
np.array([float(samplerate)]),
|
||||||
)
|
)
|
||||||
@ -788,6 +818,11 @@ def process_audio_array(
|
|||||||
"resize_factor": config["resize_factor"],
|
"resize_factor": config["resize_factor"],
|
||||||
"spec_divide_factor": config["spec_divide_factor"],
|
"spec_divide_factor": config["spec_divide_factor"],
|
||||||
"device": config["device"],
|
"device": config["device"],
|
||||||
|
"max_freq": config["max_freq"],
|
||||||
|
"min_freq": config["min_freq"],
|
||||||
|
"spec_scale": config["spec_scale"],
|
||||||
|
"denoise_spec_avg": config["denoise_spec_avg"],
|
||||||
|
"max_scale_spec": config["max_scale_spec"],
|
||||||
},
|
},
|
||||||
return_np=config["spec_features"] or config["spec_slices"],
|
return_np=config["spec_features"] or config["spec_slices"],
|
||||||
)
|
)
|
||||||
@ -842,7 +877,7 @@ def process_file(
|
|||||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||||
target_samp_rate=config["target_samp_rate"],
|
target_samp_rate=config["target_samp_rate"],
|
||||||
scale=config["scale_raw_audio"],
|
scale=config["scale_raw_audio"],
|
||||||
max_duration=config["max_duration"],
|
max_duration=config.get("max_duration"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# loop through larger file and split into chunks
|
# loop through larger file and split into chunks
|
||||||
@ -930,7 +965,7 @@ def summarize_results(results, predictions, config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_default_run_config(**kwargs) -> ProcessingConfiguration:
|
def get_default_config(**kwargs) -> ProcessingConfiguration:
|
||||||
"""Get default configuration for running detection model."""
|
"""Get default configuration for running detection model."""
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
@ -958,6 +993,9 @@ def get_default_run_config(**kwargs) -> ProcessingConfiguration:
|
|||||||
"max_freq": MAX_FREQ_HZ,
|
"max_freq": MAX_FREQ_HZ,
|
||||||
"min_freq": MIN_FREQ_HZ,
|
"min_freq": MIN_FREQ_HZ,
|
||||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||||
|
"spec_scale": SPEC_SCALE,
|
||||||
|
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||||
|
"max_scale_spec": MAX_SCALE_SPEC,
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
**args,
|
**args,
|
||||||
|
113
run_batdetect.py
113
run_batdetect.py
@ -1,112 +1,5 @@
|
|||||||
import argparse
|
"""Run bat_detect.command.main() from the command line."""
|
||||||
import os
|
from bat_detect.command import main
|
||||||
|
|
||||||
import bat_detect.utils.detector_utils as du
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
print("Loading model: " + args["model_path"])
|
|
||||||
model, params = du.load_model(args["model_path"])
|
|
||||||
|
|
||||||
print("\nInput directory: " + args["audio_dir"])
|
|
||||||
files = du.get_audio_files(args["audio_dir"])
|
|
||||||
print("Number of audio files: {}".format(len(files)))
|
|
||||||
print("\nSaving results to: " + args["ann_dir"])
|
|
||||||
|
|
||||||
# process files
|
|
||||||
error_files = []
|
|
||||||
for ii, audio_file in enumerate(files):
|
|
||||||
print("\n" + str(ii).ljust(6) + os.path.basename(audio_file))
|
|
||||||
try:
|
|
||||||
results = du.process_file(audio_file, model, params, args)
|
|
||||||
if args["save_preds_if_empty"] or (
|
|
||||||
len(results["pred_dict"]["annotation"]) > 0
|
|
||||||
):
|
|
||||||
results_path = audio_file.replace(
|
|
||||||
args["audio_dir"], args["ann_dir"]
|
|
||||||
)
|
|
||||||
du.save_results_to_file(results, results_path)
|
|
||||||
except:
|
|
||||||
error_files.append(audio_file)
|
|
||||||
print("Error processing file!")
|
|
||||||
|
|
||||||
print("\nResults saved to: " + args["ann_dir"])
|
|
||||||
|
|
||||||
if len(error_files) > 0:
|
|
||||||
print("\nUnable to process the follow files:")
|
|
||||||
for err in error_files:
|
|
||||||
print(" " + err)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
info_str = (
|
|
||||||
"\nBatDetect2 - Detection and Classification\n"
|
|
||||||
+ " Assumes audio files are mono, not stereo.\n"
|
|
||||||
+ ' Spaces in the input paths will throw an error. Wrap in quotes "".\n'
|
|
||||||
+ " Input files should be short in duration e.g. < 30 seconds.\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(info_str)
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
|
||||||
parser.add_argument(
|
|
||||||
"ann_dir",
|
|
||||||
type=str,
|
|
||||||
help="Output directory for where the predictions will be stored",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"detection_threshold",
|
|
||||||
type=float,
|
|
||||||
help="Cut-off probability for detector e.g. 0.1",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cnn_features",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
dest="cnn_features",
|
|
||||||
help="Extracts CNN call features",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--spec_features",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
dest="spec_features",
|
|
||||||
help="Extracts low level call features",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--time_expansion_factor",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
dest="time_expansion_factor",
|
|
||||||
help="The time expansion factor used for all files (default is 1)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--quiet",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
dest="quiet",
|
|
||||||
help="Minimize output printing",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save_preds_if_empty",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
dest="save_preds_if_empty",
|
|
||||||
help="Save empty annotation file if no detections made.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
default="models/Net2DFast_UK_same.pth.tar",
|
|
||||||
help="Path to trained BatDetect2 model",
|
|
||||||
)
|
|
||||||
args = vars(parser.parse_args())
|
|
||||||
|
|
||||||
args["spec_slices"] = False # used for visualization
|
|
||||||
args[
|
|
||||||
"chunk_size"
|
|
||||||
] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks
|
|
||||||
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
|
||||||
|
|
||||||
main(args)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user