From 0a7ad1819318109dba42c44825d363f2b7e1bdf7 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Fri, 7 Apr 2023 15:20:27 -0600 Subject: [PATCH] Updated gradio app to new api --- app.py | 136 +++++++++++++++++++++------------------------------------ 1 file changed, 49 insertions(+), 87 deletions(-) diff --git a/app.py b/app.py index 9c82b01..c1f6a60 100644 --- a/app.py +++ b/app.py @@ -3,21 +3,10 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -import batdetect2.utils.audio_utils as au -import batdetect2.utils.detector_utils as du -import batdetect2.utils.plot_utils as viz - -# setup the arguments -args = {} -args = du.get_default_run_config() -args["detection_threshold"] = 0.3 -args["time_expansion_factor"] = 1 -args["model_path"] = "models/Net2DFast_UK_same.pth.tar" -max_duration = 2.0 - -# load the model -model, params = du.load_model(args["model_path"]) +from batdetect2 import api, plot +MAX_DURATION = 2 +DETECTION_THRESHOLD = 0.3 df = gr.Dataframe( headers=["species", "time", "detection_prob", "species_prob"], @@ -28,97 +17,71 @@ df = gr.Dataframe( ) examples = [ - ["example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 0.3], - ["example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 0.3], - ["example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 0.3], + [ + "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", + DETECTION_THRESHOLD, + ], + [ + "example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", + DETECTION_THRESHOLD, + ], + [ + "example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", + DETECTION_THRESHOLD, + ], ] -def make_prediction(file_name=None, detection_threshold=0.3): - if file_name is not None: - audio_file = file_name - else: - return "You must provide an input audio file." - - if detection_threshold is not None and detection_threshold != "": - args["detection_threshold"] = float(detection_threshold) - - run_config = { - **params, - **args, - "max_duration": max_duration, - } - - # process the file to generate predictions - results = du.process_file( - audio_file, - model, - run_config, +def make_prediction(file_name, detection_threshold=DETECTION_THRESHOLD): + # configure the model run + run_config = api.get_config( + detection_threshold=detection_threshold, + max_duration=MAX_DURATION, ) - anns = [ann for ann in results["pred_dict"]["annotation"]] - clss = [aa["class"] for aa in anns] - st_time = [aa["start_time"] for aa in anns] - cls_prob = [aa["class_prob"] for aa in anns] - det_prob = [aa["det_prob"] for aa in anns] - data = { - "species": clss, - "time": st_time, - "detection_prob": det_prob, - "species_prob": cls_prob, - } + # process the file to generate predictions + results = api.process_file(file_name, config=run_config) - df = pd.DataFrame(data=data) - im = generate_results_image(audio_file, anns) + # extract the detections + detections = results["pred_dict"]["annotation"] + + # create a dataframe of the predictions + df = pd.DataFrame( + [ + { + "species": pred["class"], + "time": pred["start_time"], + "detection_prob": pred["class_prob"], + "species_prob": pred["class_prob"], + } + for pred in detections + ] + ) + im = generate_results_image(file_name, detections, run_config) return [df, im] -def generate_results_image(audio_file, anns): - - # load audio - sampling_rate, audio = au.load_audio( - audio_file, - args["time_expansion_factor"], - params["target_samp_rate"], - params["scale_raw_audio"], - max_duration=max_duration, +def generate_results_image(file_name, detections, config): + audio = api.load_audio( + file_name, + max_duration=config["max_duration"], + time_exp_fact=config["time_expansion"], + target_samp_rate=config["target_samp_rate"], ) - duration = audio.shape[0] / sampling_rate - # generate spec - spec, spec_viz = au.generate_spectrogram( - audio, sampling_rate, params, True, False - ) + spec = api.generate_spectrogram(audio, config=config) # create fig plt.close("all") fig = plt.figure( 1, - figsize=(spec.shape[1] / 100, spec.shape[0] / 100), + figsize=(15, 4), dpi=100, frameon=False, ) - spec_duration = au.x_coords_to_time( - spec.shape[1], - sampling_rate, - params["fft_win_length"], - params["fft_overlap"], - ) - viz.create_box_image( - spec, - fig, - anns, - 0, - spec_duration, - spec_duration, - params, - spec.max() * 1.1, - False, - True, - ) - plt.ylabel("Freq - kHz") - plt.xlabel("Time - secs") + ax = fig.add_subplot(111) + plot.spectrogram_with_detections(spec, detections, ax=ax) plt.tight_layout() # convert fig to image @@ -126,7 +89,6 @@ def generate_results_image(audio_file, anns): data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) w, h = fig.canvas.get_width_height() im = data.reshape((int(h), int(w), -1)) - return im @@ -140,7 +102,7 @@ descr_txt = ( gr.Interface( fn=make_prediction, inputs=[ - gr.Audio(source="upload", type="filepath", optional=True), + gr.Audio(source="upload", type="filepath"), gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), ], outputs=[df, gr.Image(label="Visualisation")],