diff --git a/bat_detect/command.py b/bat_detect/command.py index 35ea257..b680b5d 100644 --- a/bat_detect/command.py +++ b/bat_detect/command.py @@ -71,7 +71,7 @@ def parse_args(): parser.add_argument( "--model_path", type=str, - default=os.path.join(CURRENT_DIR, "models/Net2DFast_UK_same.pth.tar"), + default=du.DEFAULT_MODEL_PATH, help="Path to trained BatDetect2 model", ) args = vars(parser.parse_args()) @@ -97,8 +97,11 @@ def main(): print(f"Number of audio files: {len(files)}") print("\nSaving results to: " + args["ann_dir"]) + default_config = du.get_default_config() + # set up run config run_config = { + **default_config, **args, **params, } @@ -119,6 +122,7 @@ def main(): except (RuntimeError, ValueError, LookupError) as err: error_files.append(audio_file) print(f"Error processing file!: {err}") + raise err print("\nResults saved to: " + args["ann_dir"]) diff --git a/bat_detect/detector/parameters.py b/bat_detect/detector/parameters.py index bb705dd..b7d9244 100644 --- a/bat_detect/detector/parameters.py +++ b/bat_detect/detector/parameters.py @@ -13,6 +13,9 @@ SCALE_RAW_AUDIO = False DETECTION_THRESHOLD = 0.01 NMS_KERNEL_SIZE = 9 NMS_TOP_K_PER_SEC = 200 +SPEC_SCALE = "pcen" +DENOISE_SPEC_AVG = True +MAX_SCALE_SPEC = False def mk_dir(path): @@ -70,14 +73,14 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"): # spec processing params params[ "denoise_spec_avg" - ] = True # removes the mean for each frequency band + ] = DENOISE_SPEC_AVG # removes the mean for each frequency band params[ "scale_raw_audio" ] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1] params[ "max_scale_spec" - ] = False # scales the spectrogram so that it is max 1 - params["spec_scale"] = "pcen" # 'log', 'pcen', 'none' + ] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1 + params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none' # detection params params[ diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index 448cba2..f748396 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -12,10 +12,12 @@ import bat_detect.detector.post_process as pp import bat_detect.utils.audio_utils as au from bat_detect.detector import models from bat_detect.detector.parameters import ( + DENOISE_SPEC_AVG, DETECTION_THRESHOLD, FFT_OVERLAP, FFT_WIN_LENGTH_S, MAX_FREQ_HZ, + MAX_SCALE_SPEC, MIN_FREQ_HZ, NMS_KERNEL_SIZE, NMS_TOP_K_PER_SEC, @@ -23,6 +25,7 @@ from bat_detect.detector.parameters import ( SCALE_RAW_AUDIO, SPEC_DIVIDE_FACTOR, SPEC_HEIGHT, + SPEC_SCALE, TARGET_SAMPLERATE_HZ, ) @@ -35,12 +38,13 @@ except ImportError: DEFAULT_MODEL_PATH = os.path.join( os.path.dirname(os.path.dirname(__file__)), "models", - "model.pth", + "Net2DFast_UK_same.pth.tar", ) __all__ = [ "load_model", "get_audio_files", + "get_default_config", "format_results", "save_results_to_file", "iterate_over_chunks", @@ -313,7 +317,7 @@ def format_results( annotations: List[Annotation] = [ { "start_time": round(float(start_time), 4), - "end_time": round(end_time, 4), + "end_time": round(float(end_time), 4), "low_freq": int(low_freq), "high_freq": int(high_freq), "class": str(class_names[class_index]), @@ -331,7 +335,7 @@ def format_results( class_prob, det_prob, ) in zip( - predictions["start_time"], + predictions["start_times"], predictions["end_times"], predictions["low_freqs"], predictions["high_freqs"], @@ -347,7 +351,7 @@ def format_results( "issues": False, "notes": "Automatically generated.", "time_exp": time_exp, - "duration": round(duration, 4), + "duration": round(float(duration), 4), "annotation": annotations, "class_name": class_names[np.argmax(class_overall)], } @@ -458,7 +462,8 @@ def save_results_to_file(results, op_path: str) -> None: if "spec_feats" in results.keys(): # create csv file with spectrogram features spec_feats_df = pd.DataFrame( - results["spec_feats"], columns=results["spec_feat_names"] + results["spec_feats"], + columns=results["spec_feat_names"], ) spec_feats_df.to_csv( op_path + "_spec_features.csv", @@ -506,6 +511,21 @@ class SpectrogramParameters(TypedDict): device: torch.device """Device to store the spectrogram on.""" + max_freq: int + """Maximum frequency to display in the spectrogram.""" + + min_freq: int + """Minimum frequency to display in the spectrogram.""" + + spec_scale: str + """Scale to use for the spectrogram.""" + + denoise_spec_avg: bool + """Whether to denoise the spectrogram by averaging.""" + + max_scale_spec: bool + """Whether to scale the spectrogram so that its max is 1.""" + def compute_spectrogram( audio: np.ndarray, @@ -640,6 +660,15 @@ class ProcessingConfiguration(TypedDict): spec_height: int """Height of the spectrogram in pixels.""" + spec_scale: str + """Scale to use for the spectrogram.""" + + denoise_spec_avg: bool + """Whether to denoise the spectrogram by averaging.""" + + max_scale_spec: bool + """Whether to scale the spectrogram so that its max is 1.""" + scale_raw_audio: bool """Whether to scale the raw audio to be between -1 and 1.""" @@ -735,6 +764,7 @@ def process_spectrogram( "resize_factor": config["resize_factor"], "nms_top_k_per_sec": config["nms_top_k_per_sec"], "detection_threshold": config["detection_threshold"], + "max_scale_spec": config["max_scale_spec"], }, np.array([float(samplerate)]), ) @@ -788,6 +818,11 @@ def process_audio_array( "resize_factor": config["resize_factor"], "spec_divide_factor": config["spec_divide_factor"], "device": config["device"], + "max_freq": config["max_freq"], + "min_freq": config["min_freq"], + "spec_scale": config["spec_scale"], + "denoise_spec_avg": config["denoise_spec_avg"], + "max_scale_spec": config["max_scale_spec"], }, return_np=config["spec_features"] or config["spec_slices"], ) @@ -842,7 +877,7 @@ def process_file( time_exp_fact=config.get("time_expansion", 1) or 1, target_samp_rate=config["target_samp_rate"], scale=config["scale_raw_audio"], - max_duration=config["max_duration"], + max_duration=config.get("max_duration"), ) # loop through larger file and split into chunks @@ -930,7 +965,7 @@ def summarize_results(results, predictions, config): ) -def get_default_run_config(**kwargs) -> ProcessingConfiguration: +def get_default_config(**kwargs) -> ProcessingConfiguration: """Get default configuration for running detection model.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -958,6 +993,9 @@ def get_default_run_config(**kwargs) -> ProcessingConfiguration: "max_freq": MAX_FREQ_HZ, "min_freq": MIN_FREQ_HZ, "nms_top_k_per_sec": NMS_TOP_K_PER_SEC, + "spec_scale": SPEC_SCALE, + "denoise_spec_avg": DENOISE_SPEC_AVG, + "max_scale_spec": MAX_SCALE_SPEC, } return { **args, diff --git a/run_batdetect.py b/run_batdetect.py index e9e06da..b2c5230 100644 --- a/run_batdetect.py +++ b/run_batdetect.py @@ -1,112 +1,5 @@ -import argparse -import os - -import bat_detect.utils.detector_utils as du - - -def main(args): - print("Loading model: " + args["model_path"]) - model, params = du.load_model(args["model_path"]) - - print("\nInput directory: " + args["audio_dir"]) - files = du.get_audio_files(args["audio_dir"]) - print("Number of audio files: {}".format(len(files))) - print("\nSaving results to: " + args["ann_dir"]) - - # process files - error_files = [] - for ii, audio_file in enumerate(files): - print("\n" + str(ii).ljust(6) + os.path.basename(audio_file)) - try: - results = du.process_file(audio_file, model, params, args) - if args["save_preds_if_empty"] or ( - len(results["pred_dict"]["annotation"]) > 0 - ): - results_path = audio_file.replace( - args["audio_dir"], args["ann_dir"] - ) - du.save_results_to_file(results, results_path) - except: - error_files.append(audio_file) - print("Error processing file!") - - print("\nResults saved to: " + args["ann_dir"]) - - if len(error_files) > 0: - print("\nUnable to process the follow files:") - for err in error_files: - print(" " + err) - +"""Run bat_detect.command.main() from the command line.""" +from bat_detect.command import main if __name__ == "__main__": - - info_str = ( - "\nBatDetect2 - Detection and Classification\n" - + " Assumes audio files are mono, not stereo.\n" - + ' Spaces in the input paths will throw an error. Wrap in quotes "".\n' - + " Input files should be short in duration e.g. < 30 seconds.\n" - ) - - print(info_str) - parser = argparse.ArgumentParser() - parser.add_argument("audio_dir", type=str, help="Input directory for audio") - parser.add_argument( - "ann_dir", - type=str, - help="Output directory for where the predictions will be stored", - ) - parser.add_argument( - "detection_threshold", - type=float, - help="Cut-off probability for detector e.g. 0.1", - ) - parser.add_argument( - "--cnn_features", - action="store_true", - default=False, - dest="cnn_features", - help="Extracts CNN call features", - ) - parser.add_argument( - "--spec_features", - action="store_true", - default=False, - dest="spec_features", - help="Extracts low level call features", - ) - parser.add_argument( - "--time_expansion_factor", - type=int, - default=1, - dest="time_expansion_factor", - help="The time expansion factor used for all files (default is 1)", - ) - parser.add_argument( - "--quiet", - action="store_true", - default=False, - dest="quiet", - help="Minimize output printing", - ) - parser.add_argument( - "--save_preds_if_empty", - action="store_true", - default=False, - dest="save_preds_if_empty", - help="Save empty annotation file if no detections made.", - ) - parser.add_argument( - "--model_path", - type=str, - default="models/Net2DFast_UK_same.pth.tar", - help="Path to trained BatDetect2 model", - ) - args = vars(parser.parse_args()) - - args["spec_slices"] = False # used for visualization - args[ - "chunk_size" - ] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks - args["ann_dir"] = os.path.join(args["ann_dir"], "") - - main(args) + main()