From 995935675067f3531a355f93aaf15b1d9fe3be3f Mon Sep 17 00:00:00 2001 From: Oisin Mac Aodha Date: Tue, 20 Dec 2022 16:25:43 +0000 Subject: [PATCH] limiting file length in demos --- app.py | 60 +++++++++++++++++++++++++-------- bat_detect/utils/audio_utils.py | 7 +++- batdetect2_notebook.ipynb | 7 ++-- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/app.py b/app.py index 2244048..ca13f8e 100644 --- a/app.py +++ b/app.py @@ -15,19 +15,19 @@ args = du.get_default_bd_args() args['detection_threshold'] = 0.3 args['time_expansion_factor'] = 1 args['model_path'] = 'models/Net2DFast_UK_same.pth.tar' +max_duration = 2.0 # load the model model, params = du.load_model(args['model_path']) df = gr.Dataframe( - headers=["species", "time_in_file", "species_prob"], - datatype=["str", "str", "str"], + headers=["species", "time", "detection_prob", "species_prob"], + datatype=["str", "str", "str", "str"], row_count=1, - col_count=(3, "fixed"), + col_count=(4, "fixed"), ) - examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3], ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3], ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]] @@ -40,31 +40,63 @@ def make_prediction(file_name=None, detection_threshold=0.3): else: return "You must provide an input audio file." - if detection_threshold != '': + if detection_threshold is not None and detection_threshold != '': args['detection_threshold'] = float(detection_threshold) - results = du.process_file(audio_file, model, params, args, max_duration=5.0) + # process the file to generate predictions + results = du.process_file(audio_file, model, params, args, max_duration=max_duration) + + anns = [ann for ann in results['pred_dict']['annotation']] + clss = [aa['class'] for aa in anns] + st_time = [aa['start_time'] for aa in anns] + cls_prob = [aa['class_prob'] for aa in anns] + det_prob = [aa['det_prob'] for aa in anns] + data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob} - clss = [aa['class'] for aa in results['pred_dict']['annotation']] - st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']] - cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']] - - data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob} df = pd.DataFrame(data=data) + im = generate_results_image(audio_file, anns) + + return [df, im] - return df + +def generate_results_image(audio_file, anns): + + # load audio + sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], + params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration) + duration = audio.shape[0] / sampling_rate + + # generate spec + spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False) + + # create fig + plt.close('all') + fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False) + spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) + viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True) + plt.ylabel('Freq - kHz') + plt.xlabel('Time - secs') + plt.tight_layout() + + # convert fig to image + fig.canvas.draw() + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + w, h = fig.canvas.get_width_height() + im = data.reshape((int(h), int(w), -1)) + + return im descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \ "
This model is only trained on bat species from the UK. If the input " \ - "file is longer than 5 seconds, only the first 5 seconds will be processed." \ + "file is longer than 2 seconds, only the first 2 seconds will be processed." \ "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)." gr.Interface( fn = make_prediction, inputs = [gr.Audio(source="upload", type="filepath", optional=True), gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])], - outputs = df, + outputs = [df, "image"], theme = "huggingface", title = "BatDetect2 Demo", description = descr_txt, diff --git a/bat_detect/utils/audio_utils.py b/bat_detect/utils/audio_utils.py index bbbc2a3..4a18d74 100644 --- a/bat_detect/utils/audio_utils.py +++ b/bat_detect/utils/audio_utils.py @@ -64,7 +64,7 @@ def generate_spectrogram(audio, sampling_rate, params, return_spec_for_viz=False return spec, spec_for_viz -def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False): +def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, max_duration=False): with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=wavfile.WavFileWarning) #sampling_rate, audio_raw = wavfile.read(audio_file) @@ -79,6 +79,11 @@ def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False): sampling_rate = target_samp_rate audio_raw = librosa.resample(audio_raw, orig_sr=sampling_rate_old, target_sr=sampling_rate, res_type='polyphase') + # clipping maximum duration + if max_duration is not False: + max_duration = np.minimum(int(sampling_rate*max_duration), audio_raw.shape[0]) + audio_raw = audio_raw[:max_duration] + # convert to float32 and scale audio_raw = audio_raw.astype(np.float32) if scale: diff --git a/batdetect2_notebook.ipynb b/batdetect2_notebook.ipynb index 2ec28f5..035affd 100644 --- a/batdetect2_notebook.ipynb +++ b/batdetect2_notebook.ipynb @@ -58,7 +58,8 @@ "args = du.get_default_bd_args()\n", "args['detection_threshold'] = 0.3\n", "args['time_expansion_factor'] = 1\n", - "args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'" + "args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'\n", + "max_duration = 2.0" ] }, { @@ -101,7 +102,7 @@ "outputs": [], "source": [ "# run the model\n", - "results = du.process_file(audio_file, model, params, args, max_duration=5.0)" + "results = du.process_file(audio_file, model, params, args, max_duration=max_duration)" ] }, { @@ -174,7 +175,7 @@ ], "source": [ "# read the audio file \n", - "sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio'])\n", + "sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)\n", "duration = audio.shape[0] / sampling_rate\n", "print('File duration: {} seconds'.format(duration))" ]