batdetect2/bat_detect/cli.py
2023-02-26 18:40:35 +00:00

142 lines
3.6 KiB
Python

"""BatDetect2 command line interface."""
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import click # noqa: E402
from bat_detect import api # noqa: E402
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402
from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
INFO_STR = """
BatDetect2 - Detection and Classification
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
"""
@click.group()
def cli():
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
@cli.command()
@click.argument(
"audio_dir",
type=click.Path(exists=True),
)
@click.argument(
"ann_dir",
type=click.Path(exists=False),
)
@click.argument(
"detection_threshold",
type=float,
)
@click.option(
"--cnn_features",
is_flag=True,
default=False,
help="Extracts CNN call features",
)
@click.option(
"--spec_features",
is_flag=True,
default=False,
help="Extracts low level call features",
)
@click.option(
"--time_expansion_factor",
type=int,
default=1,
help="The time expansion factor used for all files (default is 1)",
)
@click.option(
"--quiet",
is_flag=True,
default=False,
help="Minimize output printing",
)
@click.option(
"--save_preds_if_empty",
is_flag=True,
default=False,
help="Save empty annotation file if no detections made.",
)
@click.option(
"--model_path",
type=str,
default=DEFAULT_MODEL_PATH,
help="Path to trained BatDetect2 model",
)
def detect(
audio_dir: str,
ann_dir: str,
detection_threshold: float,
**args,
):
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
DETECTION_THRESHOLD is the detection threshold. All predictions with a
score below this threshold will be discarded. Values between 0 and 1.
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
"""
click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"])
click.echo(f"\nInput directory: {audio_dir}")
files = api.list_audio_files(audio_dir)
click.echo(f"Number of audio files: {len(files)}")
click.echo(f"\nSaving results to: {ann_dir}")
config = api.get_config(
**{
**params,
**args,
"spec_slices": False,
"chunk_size": 2,
"detection_threshold": detection_threshold,
}
)
# process files
error_files = []
for audio_file in files:
try:
results = api.process_file(audio_file, model, config=config)
if args["save_preds_if_empty"] or (
len(results["pred_dict"]["annotation"]) > 0
):
results_path = audio_file.replace(audio_dir, ann_dir)
save_results_to_file(results, results_path)
except (RuntimeError, ValueError, LookupError) as err:
# TODO: Check what other errors can be thrown
error_files.append(audio_file)
click.echo(f"Error processing file!: {err}")
raise err
click.echo(f"\nResults saved to: {ann_dir}")
if len(error_files) > 0:
click.echo("\nUnable to process the follow files:")
for err in error_files:
click.echo(f" {err}")
if __name__ == "__main__":
cli()