diff --git a/app.py b/app.py new file mode 100644 index 0000000..d3af4c7 --- /dev/null +++ b/app.py @@ -0,0 +1,96 @@ +import gradio as gr +import os +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np + +import bat_detect.utils.detector_utils as du +import bat_detect.utils.audio_utils as au +import bat_detect.utils.plot_utils as viz + +# setup the arguments +args = {} +args = du.get_default_bd_args() +args['detection_threshold'] = 0.3 +args['time_expansion_factor'] = 1 +args['model_path'] = 'models/Net2DFast_UK_same.pth.tar' + +# load the model +model, params = du.load_model(args['model_path']) + +""" +# read the audio file +sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio']) +duration = audio.shape[0] / sampling_rate +print('File duration: {} seconds'.format(duration)) + +# generate spectrogram for visualization +spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False) + + +# display the detections on top of the spectrogram +# note, if the audio file is very long, this image will be very large - best to crop the audio first +start_time = 0.0 +detections = [ann for ann in results['pred_dict']['annotation']] +fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False) +spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) +viz.create_box_image(spec, fig, detections, start_time, start_time+spec_duration, spec_duration, params, spec.max()*1.1, False) +plt.ylabel('Freq - kHz') +plt.xlabel('Time - secs') +plt.title(os.path.basename(audio_file)) +plt.show() +""" + +df = gr.Dataframe( + headers=["species", "time_in_file", "species_prob"], + datatype=["str", "str", "str"], + row_count=1, + col_count=(3, "fixed"), + ) + + +examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3], + ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3], + ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]] + + +def make_prediction(file_name=None, detection_threshold=0.3): + + if file_name is not None: + audio_file = file_name + else: + return "You must provide an input audio file." + + if detection_threshold != '': + args['detection_threshold'] = float(detection_threshold) + + results = du.process_file(audio_file, model, params, args, max_duration=5.0) + + clss = [aa['class'] for aa in results['pred_dict']['annotation']] + st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']] + cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']] + + data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob} + df = pd.DataFrame(data=data) + + return df + + +descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \ + "
This model is only trained on bat species from the UK. If the input " \ + "file is longer than 5 seconds, only the first 5 seconds will be processed." \ + "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)." + +gr.Interface( + fn = make_prediction, + inputs = [gr.Audio(source="upload", type="filepath", optional=True), + gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])], + outputs = df, + theme = "huggingface", + title = "BatDetect2 Demo", + description = descr_txt, + examples = examples, + allow_flagging = 'never', +).launch() + + diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index ebfedd2..fef9828 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -197,7 +197,7 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False): return duration, spec, spec_np -def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False): +def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False, max_duration=False): # store temporary results here predictions = [] @@ -214,6 +214,12 @@ def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return # load audio file sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp, params['target_samp_rate'], params['scale_raw_audio']) + + # clipping maximum duration + if max_duration is not False: + max_duration = np.minimum(int(sampling_rate*max_duration), audio_full.shape[0]) + audio_full = audio_full[:max_duration] + duration_full = audio_full.shape[0] / float(sampling_rate) return_np_spec = args['spec_features'] or args['spec_slices'] diff --git a/batdetect2_notebook.ipynb b/batdetect2_notebook.ipynb index 37580ec..2ec28f5 100644 --- a/batdetect2_notebook.ipynb +++ b/batdetect2_notebook.ipynb @@ -101,7 +101,7 @@ "outputs": [], "source": [ "# run the model\n", - "results = du.process_file(audio_file, model, params, args)" + "results = du.process_file(audio_file, model, params, args, max_duration=5.0)" ] }, { @@ -133,7 +133,7 @@ "0.2195\t0.503\t48671\tPipistrellus pipistrellus\n", "0.2315\t0.672\t27187\tMyotis mystacinus\n", "0.2995\t0.65\t48671\tPipistrellus pipistrellus\n", - "0.3245\t0.688\t27187\tMyotis mystacinus\n", + "0.3245\t0.687\t27187\tMyotis mystacinus\n", "0.3705\t0.547\t34062\tMyotis mystacinus\n", "0.4125\t0.664\t28906\tMyotis mystacinus\n", "0.4365\t0.544\t36640\tMyotis mystacinus\n", @@ -236,7 +236,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.13" } }, "nbformat": 4,