batdetect2/app.py
2023-04-07 15:20:27 -06:00

115 lines
3.0 KiB
Python

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from batdetect2 import api, plot
MAX_DURATION = 2
DETECTION_THRESHOLD = 0.3
df = gr.Dataframe(
headers=["species", "time", "detection_prob", "species_prob"],
datatype=["str", "str", "str", "str"],
row_count=1,
col_count=(4, "fixed"),
label="Predictions",
)
examples = [
[
"example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav",
DETECTION_THRESHOLD,
],
[
"example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav",
DETECTION_THRESHOLD,
],
[
"example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav",
DETECTION_THRESHOLD,
],
]
def make_prediction(file_name, detection_threshold=DETECTION_THRESHOLD):
# configure the model run
run_config = api.get_config(
detection_threshold=detection_threshold,
max_duration=MAX_DURATION,
)
# process the file to generate predictions
results = api.process_file(file_name, config=run_config)
# extract the detections
detections = results["pred_dict"]["annotation"]
# create a dataframe of the predictions
df = pd.DataFrame(
[
{
"species": pred["class"],
"time": pred["start_time"],
"detection_prob": pred["class_prob"],
"species_prob": pred["class_prob"],
}
for pred in detections
]
)
im = generate_results_image(file_name, detections, run_config)
return [df, im]
def generate_results_image(file_name, detections, config):
audio = api.load_audio(
file_name,
max_duration=config["max_duration"],
time_exp_fact=config["time_expansion"],
target_samp_rate=config["target_samp_rate"],
)
spec = api.generate_spectrogram(audio, config=config)
# create fig
plt.close("all")
fig = plt.figure(
1,
figsize=(15, 4),
dpi=100,
frameon=False,
)
ax = fig.add_subplot(111)
plot.spectrogram_with_detections(spec, detections, ax=ax)
plt.tight_layout()
# convert fig to image
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im
descr_txt = (
"Demo of BatDetect2 deep learning-based bat echolocation call detection. "
"<br>This model is only trained on bat species from the UK. If the input "
"file is longer than 2 seconds, only the first 2 seconds will be processed."
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
)
gr.Interface(
fn=make_prediction,
inputs=[
gr.Audio(source="upload", type="filepath"),
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
],
outputs=[df, gr.Image(label="Visualisation")],
theme="huggingface",
title="BatDetect2 Demo",
description=descr_txt,
examples=examples,
allow_flagging="never",
).launch()