"""BatDetect2 command line interface.""" import os import warnings warnings.filterwarnings("ignore", category=UserWarning) import click # noqa: E402 from bat_detect import api # noqa: E402 from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402 from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402 CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) INFO_STR = """ BatDetect2 - Detection and Classification Assumes audio files are mono, not stereo. Spaces in the input paths will throw an error. Wrap in quotes. Input files should be short in duration e.g. < 30 seconds. """ @click.group() def cli(): """BatDetect2 - Bat Call Detection and Classification.""" click.echo(INFO_STR) @cli.command() @click.argument( "audio_dir", type=click.Path(exists=True), ) @click.argument( "ann_dir", type=click.Path(exists=False), ) @click.argument( "detection_threshold", type=float, ) @click.option( "--cnn_features", is_flag=True, default=False, help="Extracts CNN call features", ) @click.option( "--spec_features", is_flag=True, default=False, help="Extracts low level call features", ) @click.option( "--time_expansion_factor", type=int, default=1, help="The time expansion factor used for all files (default is 1)", ) @click.option( "--quiet", is_flag=True, default=False, help="Minimize output printing", ) @click.option( "--save_preds_if_empty", is_flag=True, default=False, help="Save empty annotation file if no detections made.", ) @click.option( "--model_path", type=str, default=DEFAULT_MODEL_PATH, help="Path to trained BatDetect2 model", ) def detect( audio_dir: str, ann_dir: str, detection_threshold: float, **args, ): """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR. DETECTION_THRESHOLD is the detection threshold. All predictions with a score below this threshold will be discarded. Values between 0 and 1. Assumes audio files are mono, not stereo. Spaces in the input paths will throw an error. Wrap in quotes. Input files should be short in duration e.g. < 30 seconds. """ click.echo(f"Loading model: {args['model_path']}") model, params = api.load_model(args["model_path"]) click.echo(f"\nInput directory: {audio_dir}") files = api.list_audio_files(audio_dir) click.echo(f"Number of audio files: {len(files)}") click.echo(f"\nSaving results to: {ann_dir}") config = api.get_config( **{ **params, **args, "spec_slices": False, "chunk_size": 2, "detection_threshold": detection_threshold, } ) # process files error_files = [] for audio_file in files: try: results = api.process_file(audio_file, model, config=config) if args["save_preds_if_empty"] or ( len(results["pred_dict"]["annotation"]) > 0 ): results_path = audio_file.replace(audio_dir, ann_dir) save_results_to_file(results, results_path) except (RuntimeError, ValueError, LookupError) as err: # TODO: Check what other errors can be thrown error_files.append(audio_file) click.echo(f"Error processing file!: {err}") raise err click.echo(f"\nResults saved to: {ann_dir}") if len(error_files) > 0: click.echo("\nUnable to process the follow files:") for err in error_files: click.echo(f" {err}") if __name__ == "__main__": cli()