diff --git a/app.py b/app.py
new file mode 100644
index 0000000..d3af4c7
--- /dev/null
+++ b/app.py
@@ -0,0 +1,96 @@
+import gradio as gr
+import os
+import matplotlib.pyplot as plt
+import pandas as pd
+import numpy as np
+
+import bat_detect.utils.detector_utils as du
+import bat_detect.utils.audio_utils as au
+import bat_detect.utils.plot_utils as viz
+
+# setup the arguments
+args = {}
+args = du.get_default_bd_args()
+args['detection_threshold'] = 0.3
+args['time_expansion_factor'] = 1
+args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
+
+# load the model
+model, params = du.load_model(args['model_path'])
+
+"""
+# read the audio file
+sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio'])
+duration = audio.shape[0] / sampling_rate
+print('File duration: {} seconds'.format(duration))
+
+# generate spectrogram for visualization
+spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)
+
+
+# display the detections on top of the spectrogram
+# note, if the audio file is very long, this image will be very large - best to crop the audio first
+start_time = 0.0
+detections = [ann for ann in results['pred_dict']['annotation']]
+fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
+spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
+viz.create_box_image(spec, fig, detections, start_time, start_time+spec_duration, spec_duration, params, spec.max()*1.1, False)
+plt.ylabel('Freq - kHz')
+plt.xlabel('Time - secs')
+plt.title(os.path.basename(audio_file))
+plt.show()
+"""
+
+df = gr.Dataframe(
+ headers=["species", "time_in_file", "species_prob"],
+ datatype=["str", "str", "str"],
+ row_count=1,
+ col_count=(3, "fixed"),
+ )
+
+
+examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
+ ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
+ ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
+
+
+def make_prediction(file_name=None, detection_threshold=0.3):
+
+ if file_name is not None:
+ audio_file = file_name
+ else:
+ return "You must provide an input audio file."
+
+ if detection_threshold != '':
+ args['detection_threshold'] = float(detection_threshold)
+
+ results = du.process_file(audio_file, model, params, args, max_duration=5.0)
+
+ clss = [aa['class'] for aa in results['pred_dict']['annotation']]
+ st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']]
+ cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']]
+
+ data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob}
+ df = pd.DataFrame(data=data)
+
+ return df
+
+
+descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
+ "
This model is only trained on bat species from the UK. If the input " \
+ "file is longer than 5 seconds, only the first 5 seconds will be processed." \
+ "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
+
+gr.Interface(
+ fn = make_prediction,
+ inputs = [gr.Audio(source="upload", type="filepath", optional=True),
+ gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
+ outputs = df,
+ theme = "huggingface",
+ title = "BatDetect2 Demo",
+ description = descr_txt,
+ examples = examples,
+ allow_flagging = 'never',
+).launch()
+
+
diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py
index ebfedd2..fef9828 100644
--- a/bat_detect/utils/detector_utils.py
+++ b/bat_detect/utils/detector_utils.py
@@ -197,7 +197,7 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False):
return duration, spec, spec_np
-def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False):
+def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False, max_duration=False):
# store temporary results here
predictions = []
@@ -214,6 +214,12 @@ def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return
# load audio file
sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp,
params['target_samp_rate'], params['scale_raw_audio'])
+
+ # clipping maximum duration
+ if max_duration is not False:
+ max_duration = np.minimum(int(sampling_rate*max_duration), audio_full.shape[0])
+ audio_full = audio_full[:max_duration]
+
duration_full = audio_full.shape[0] / float(sampling_rate)
return_np_spec = args['spec_features'] or args['spec_slices']
diff --git a/batdetect2_notebook.ipynb b/batdetect2_notebook.ipynb
index 37580ec..2ec28f5 100644
--- a/batdetect2_notebook.ipynb
+++ b/batdetect2_notebook.ipynb
@@ -101,7 +101,7 @@
"outputs": [],
"source": [
"# run the model\n",
- "results = du.process_file(audio_file, model, params, args)"
+ "results = du.process_file(audio_file, model, params, args, max_duration=5.0)"
]
},
{
@@ -133,7 +133,7 @@
"0.2195\t0.503\t48671\tPipistrellus pipistrellus\n",
"0.2315\t0.672\t27187\tMyotis mystacinus\n",
"0.2995\t0.65\t48671\tPipistrellus pipistrellus\n",
- "0.3245\t0.688\t27187\tMyotis mystacinus\n",
+ "0.3245\t0.687\t27187\tMyotis mystacinus\n",
"0.3705\t0.547\t34062\tMyotis mystacinus\n",
"0.4125\t0.664\t28906\tMyotis mystacinus\n",
"0.4365\t0.544\t36640\tMyotis mystacinus\n",
@@ -236,7 +236,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.8"
+ "version": "3.9.13"
}
},
"nbformat": 4,