Fixed errors with command

This commit is contained in:
Santiago Martinez 2023-02-22 23:20:51 +00:00
parent e6a6ad4696
commit 40222d8233
4 changed files with 59 additions and 121 deletions

View File

@ -71,7 +71,7 @@ def parse_args():
parser.add_argument(
"--model_path",
type=str,
default=os.path.join(CURRENT_DIR, "models/Net2DFast_UK_same.pth.tar"),
default=du.DEFAULT_MODEL_PATH,
help="Path to trained BatDetect2 model",
)
args = vars(parser.parse_args())
@ -97,8 +97,11 @@ def main():
print(f"Number of audio files: {len(files)}")
print("\nSaving results to: " + args["ann_dir"])
default_config = du.get_default_config()
# set up run config
run_config = {
**default_config,
**args,
**params,
}
@ -119,6 +122,7 @@ def main():
except (RuntimeError, ValueError, LookupError) as err:
error_files.append(audio_file)
print(f"Error processing file!: {err}")
raise err
print("\nResults saved to: " + args["ann_dir"])

View File

@ -13,6 +13,9 @@ SCALE_RAW_AUDIO = False
DETECTION_THRESHOLD = 0.01
NMS_KERNEL_SIZE = 9
NMS_TOP_K_PER_SEC = 200
SPEC_SCALE = "pcen"
DENOISE_SPEC_AVG = True
MAX_SCALE_SPEC = False
def mk_dir(path):
@ -70,14 +73,14 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
# spec processing params
params[
"denoise_spec_avg"
] = True # removes the mean for each frequency band
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
params[
"scale_raw_audio"
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
params[
"max_scale_spec"
] = False # scales the spectrogram so that it is max 1
params["spec_scale"] = "pcen" # 'log', 'pcen', 'none'
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
# detection params
params[

View File

@ -12,10 +12,12 @@ import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au
from bat_detect.detector import models
from bat_detect.detector.parameters import (
DENOISE_SPEC_AVG,
DETECTION_THRESHOLD,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MAX_SCALE_SPEC,
MIN_FREQ_HZ,
NMS_KERNEL_SIZE,
NMS_TOP_K_PER_SEC,
@ -23,6 +25,7 @@ from bat_detect.detector.parameters import (
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
SPEC_SCALE,
TARGET_SAMPLERATE_HZ,
)
@ -35,12 +38,13 @@ except ImportError:
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"model.pth",
"Net2DFast_UK_same.pth.tar",
)
__all__ = [
"load_model",
"get_audio_files",
"get_default_config",
"format_results",
"save_results_to_file",
"iterate_over_chunks",
@ -313,7 +317,7 @@ def format_results(
annotations: List[Annotation] = [
{
"start_time": round(float(start_time), 4),
"end_time": round(end_time, 4),
"end_time": round(float(end_time), 4),
"low_freq": int(low_freq),
"high_freq": int(high_freq),
"class": str(class_names[class_index]),
@ -331,7 +335,7 @@ def format_results(
class_prob,
det_prob,
) in zip(
predictions["start_time"],
predictions["start_times"],
predictions["end_times"],
predictions["low_freqs"],
predictions["high_freqs"],
@ -347,7 +351,7 @@ def format_results(
"issues": False,
"notes": "Automatically generated.",
"time_exp": time_exp,
"duration": round(duration, 4),
"duration": round(float(duration), 4),
"annotation": annotations,
"class_name": class_names[np.argmax(class_overall)],
}
@ -458,7 +462,8 @@ def save_results_to_file(results, op_path: str) -> None:
if "spec_feats" in results.keys():
# create csv file with spectrogram features
spec_feats_df = pd.DataFrame(
results["spec_feats"], columns=results["spec_feat_names"]
results["spec_feats"],
columns=results["spec_feat_names"],
)
spec_feats_df.to_csv(
op_path + "_spec_features.csv",
@ -506,6 +511,21 @@ class SpectrogramParameters(TypedDict):
device: torch.device
"""Device to store the spectrogram on."""
max_freq: int
"""Maximum frequency to display in the spectrogram."""
min_freq: int
"""Minimum frequency to display in the spectrogram."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
def compute_spectrogram(
audio: np.ndarray,
@ -640,6 +660,15 @@ class ProcessingConfiguration(TypedDict):
spec_height: int
"""Height of the spectrogram in pixels."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1."""
@ -735,6 +764,7 @@ def process_spectrogram(
"resize_factor": config["resize_factor"],
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
"detection_threshold": config["detection_threshold"],
"max_scale_spec": config["max_scale_spec"],
},
np.array([float(samplerate)]),
)
@ -788,6 +818,11 @@ def process_audio_array(
"resize_factor": config["resize_factor"],
"spec_divide_factor": config["spec_divide_factor"],
"device": config["device"],
"max_freq": config["max_freq"],
"min_freq": config["min_freq"],
"spec_scale": config["spec_scale"],
"denoise_spec_avg": config["denoise_spec_avg"],
"max_scale_spec": config["max_scale_spec"],
},
return_np=config["spec_features"] or config["spec_slices"],
)
@ -842,7 +877,7 @@ def process_file(
time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"],
max_duration=config["max_duration"],
max_duration=config.get("max_duration"),
)
# loop through larger file and split into chunks
@ -930,7 +965,7 @@ def summarize_results(results, predictions, config):
)
def get_default_run_config(**kwargs) -> ProcessingConfiguration:
def get_default_config(**kwargs) -> ProcessingConfiguration:
"""Get default configuration for running detection model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -958,6 +993,9 @@ def get_default_run_config(**kwargs) -> ProcessingConfiguration:
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
return {
**args,

View File

@ -1,112 +1,5 @@
import argparse
import os
import bat_detect.utils.detector_utils as du
def main(args):
print("Loading model: " + args["model_path"])
model, params = du.load_model(args["model_path"])
print("\nInput directory: " + args["audio_dir"])
files = du.get_audio_files(args["audio_dir"])
print("Number of audio files: {}".format(len(files)))
print("\nSaving results to: " + args["ann_dir"])
# process files
error_files = []
for ii, audio_file in enumerate(files):
print("\n" + str(ii).ljust(6) + os.path.basename(audio_file))
try:
results = du.process_file(audio_file, model, params, args)
if args["save_preds_if_empty"] or (
len(results["pred_dict"]["annotation"]) > 0
):
results_path = audio_file.replace(
args["audio_dir"], args["ann_dir"]
)
du.save_results_to_file(results, results_path)
except:
error_files.append(audio_file)
print("Error processing file!")
print("\nResults saved to: " + args["ann_dir"])
if len(error_files) > 0:
print("\nUnable to process the follow files:")
for err in error_files:
print(" " + err)
"""Run bat_detect.command.main() from the command line."""
from bat_detect.command import main
if __name__ == "__main__":
info_str = (
"\nBatDetect2 - Detection and Classification\n"
+ " Assumes audio files are mono, not stereo.\n"
+ ' Spaces in the input paths will throw an error. Wrap in quotes "".\n'
+ " Input files should be short in duration e.g. < 30 seconds.\n"
)
print(info_str)
parser = argparse.ArgumentParser()
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
parser.add_argument(
"ann_dir",
type=str,
help="Output directory for where the predictions will be stored",
)
parser.add_argument(
"detection_threshold",
type=float,
help="Cut-off probability for detector e.g. 0.1",
)
parser.add_argument(
"--cnn_features",
action="store_true",
default=False,
dest="cnn_features",
help="Extracts CNN call features",
)
parser.add_argument(
"--spec_features",
action="store_true",
default=False,
dest="spec_features",
help="Extracts low level call features",
)
parser.add_argument(
"--time_expansion_factor",
type=int,
default=1,
dest="time_expansion_factor",
help="The time expansion factor used for all files (default is 1)",
)
parser.add_argument(
"--quiet",
action="store_true",
default=False,
dest="quiet",
help="Minimize output printing",
)
parser.add_argument(
"--save_preds_if_empty",
action="store_true",
default=False,
dest="save_preds_if_empty",
help="Save empty annotation file if no detections made.",
)
parser.add_argument(
"--model_path",
type=str,
default="models/Net2DFast_UK_same.pth.tar",
help="Path to trained BatDetect2 model",
)
args = vars(parser.parse_args())
args["spec_slices"] = False # used for visualization
args[
"chunk_size"
] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks
args["ann_dir"] = os.path.join(args["ann_dir"], "")
main(args)
main()