diff --git a/app.py b/app.py
index 2244048..ca13f8e 100644
--- a/app.py
+++ b/app.py
@@ -15,19 +15,19 @@ args = du.get_default_bd_args()
args['detection_threshold'] = 0.3
args['time_expansion_factor'] = 1
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
+max_duration = 2.0
# load the model
model, params = du.load_model(args['model_path'])
df = gr.Dataframe(
- headers=["species", "time_in_file", "species_prob"],
- datatype=["str", "str", "str"],
+ headers=["species", "time", "detection_prob", "species_prob"],
+ datatype=["str", "str", "str", "str"],
row_count=1,
- col_count=(3, "fixed"),
+ col_count=(4, "fixed"),
)
-
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
@@ -40,31 +40,63 @@ def make_prediction(file_name=None, detection_threshold=0.3):
else:
return "You must provide an input audio file."
- if detection_threshold != '':
+ if detection_threshold is not None and detection_threshold != '':
args['detection_threshold'] = float(detection_threshold)
- results = du.process_file(audio_file, model, params, args, max_duration=5.0)
+ # process the file to generate predictions
+ results = du.process_file(audio_file, model, params, args, max_duration=max_duration)
+
+ anns = [ann for ann in results['pred_dict']['annotation']]
+ clss = [aa['class'] for aa in anns]
+ st_time = [aa['start_time'] for aa in anns]
+ cls_prob = [aa['class_prob'] for aa in anns]
+ det_prob = [aa['det_prob'] for aa in anns]
+ data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob}
- clss = [aa['class'] for aa in results['pred_dict']['annotation']]
- st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']]
- cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']]
-
- data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob}
df = pd.DataFrame(data=data)
+ im = generate_results_image(audio_file, anns)
+
+ return [df, im]
- return df
+
+def generate_results_image(audio_file, anns):
+
+ # load audio
+ sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'],
+ params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)
+ duration = audio.shape[0] / sampling_rate
+
+ # generate spec
+ spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)
+
+ # create fig
+ plt.close('all')
+ fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
+ spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
+ viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True)
+ plt.ylabel('Freq - kHz')
+ plt.xlabel('Time - secs')
+ plt.tight_layout()
+
+ # convert fig to image
+ fig.canvas.draw()
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ w, h = fig.canvas.get_width_height()
+ im = data.reshape((int(h), int(w), -1))
+
+ return im
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
"
This model is only trained on bat species from the UK. If the input " \
- "file is longer than 5 seconds, only the first 5 seconds will be processed." \
+ "file is longer than 2 seconds, only the first 2 seconds will be processed." \
"
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
gr.Interface(
fn = make_prediction,
inputs = [gr.Audio(source="upload", type="filepath", optional=True),
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
- outputs = df,
+ outputs = [df, "image"],
theme = "huggingface",
title = "BatDetect2 Demo",
description = descr_txt,
diff --git a/bat_detect/utils/audio_utils.py b/bat_detect/utils/audio_utils.py
index bbbc2a3..4a18d74 100644
--- a/bat_detect/utils/audio_utils.py
+++ b/bat_detect/utils/audio_utils.py
@@ -64,7 +64,7 @@ def generate_spectrogram(audio, sampling_rate, params, return_spec_for_viz=False
return spec, spec_for_viz
-def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False):
+def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, max_duration=False):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=wavfile.WavFileWarning)
#sampling_rate, audio_raw = wavfile.read(audio_file)
@@ -79,6 +79,11 @@ def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False):
sampling_rate = target_samp_rate
audio_raw = librosa.resample(audio_raw, orig_sr=sampling_rate_old, target_sr=sampling_rate, res_type='polyphase')
+ # clipping maximum duration
+ if max_duration is not False:
+ max_duration = np.minimum(int(sampling_rate*max_duration), audio_raw.shape[0])
+ audio_raw = audio_raw[:max_duration]
+
# convert to float32 and scale
audio_raw = audio_raw.astype(np.float32)
if scale:
diff --git a/batdetect2_notebook.ipynb b/batdetect2_notebook.ipynb
index 2ec28f5..035affd 100644
--- a/batdetect2_notebook.ipynb
+++ b/batdetect2_notebook.ipynb
@@ -58,7 +58,8 @@
"args = du.get_default_bd_args()\n",
"args['detection_threshold'] = 0.3\n",
"args['time_expansion_factor'] = 1\n",
- "args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'"
+ "args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'\n",
+ "max_duration = 2.0"
]
},
{
@@ -101,7 +102,7 @@
"outputs": [],
"source": [
"# run the model\n",
- "results = du.process_file(audio_file, model, params, args, max_duration=5.0)"
+ "results = du.process_file(audio_file, model, params, args, max_duration=max_duration)"
]
},
{
@@ -174,7 +175,7 @@
],
"source": [
"# read the audio file \n",
- "sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio'])\n",
+ "sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)\n",
"duration = audio.shape[0] / sampling_rate\n",
"print('File duration: {} seconds'.format(duration))"
]