mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
import gradio as gr
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
import bat_detect.utils.detector_utils as du
|
|
import bat_detect.utils.audio_utils as au
|
|
import bat_detect.utils.plot_utils as viz
|
|
|
|
|
|
# setup the arguments
|
|
args = {}
|
|
args = du.get_default_bd_args()
|
|
args['detection_threshold'] = 0.3
|
|
args['time_expansion_factor'] = 1
|
|
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
|
|
|
|
# load the model
|
|
model, params = du.load_model(args['model_path'])
|
|
|
|
|
|
df = gr.Dataframe(
|
|
headers=["species", "time_in_file", "species_prob"],
|
|
datatype=["str", "str", "str"],
|
|
row_count=1,
|
|
col_count=(3, "fixed"),
|
|
)
|
|
|
|
|
|
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
|
|
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
|
|
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
|
|
|
|
|
|
def make_prediction(file_name=None, detection_threshold=0.3):
|
|
|
|
if file_name is not None:
|
|
audio_file = file_name
|
|
else:
|
|
return "You must provide an input audio file."
|
|
|
|
if detection_threshold != '':
|
|
args['detection_threshold'] = float(detection_threshold)
|
|
|
|
results = du.process_file(audio_file, model, params, args, max_duration=5.0)
|
|
|
|
clss = [aa['class'] for aa in results['pred_dict']['annotation']]
|
|
st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']]
|
|
cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']]
|
|
|
|
data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob}
|
|
df = pd.DataFrame(data=data)
|
|
|
|
return df
|
|
|
|
|
|
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
|
|
"<br>This model is only trained on bat species from the UK. If the input " \
|
|
"file is longer than 5 seconds, only the first 5 seconds will be processed." \
|
|
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
|
|
|
|
gr.Interface(
|
|
fn = make_prediction,
|
|
inputs = [gr.Audio(source="upload", type="filepath", optional=True),
|
|
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
|
|
outputs = df,
|
|
theme = "huggingface",
|
|
title = "BatDetect2 Demo",
|
|
description = descr_txt,
|
|
examples = examples,
|
|
allow_flagging = 'never',
|
|
).launch()
|
|
|
|
|