Updated gradio app to new api

This commit is contained in:
Santiago Martinez 2023-04-07 15:20:27 -06:00
parent b8bbfe8ad4
commit 0a7ad18193

136
app.py
View File

@ -3,21 +3,10 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import batdetect2.utils.audio_utils as au from batdetect2 import api, plot
import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as viz
# setup the arguments
args = {}
args = du.get_default_run_config()
args["detection_threshold"] = 0.3
args["time_expansion_factor"] = 1
args["model_path"] = "models/Net2DFast_UK_same.pth.tar"
max_duration = 2.0
# load the model
model, params = du.load_model(args["model_path"])
MAX_DURATION = 2
DETECTION_THRESHOLD = 0.3
df = gr.Dataframe( df = gr.Dataframe(
headers=["species", "time", "detection_prob", "species_prob"], headers=["species", "time", "detection_prob", "species_prob"],
@ -28,97 +17,71 @@ df = gr.Dataframe(
) )
examples = [ examples = [
["example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 0.3], [
["example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 0.3], "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav",
["example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 0.3], DETECTION_THRESHOLD,
],
[
"example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav",
DETECTION_THRESHOLD,
],
[
"example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav",
DETECTION_THRESHOLD,
],
] ]
def make_prediction(file_name=None, detection_threshold=0.3): def make_prediction(file_name, detection_threshold=DETECTION_THRESHOLD):
if file_name is not None: # configure the model run
audio_file = file_name run_config = api.get_config(
else: detection_threshold=detection_threshold,
return "You must provide an input audio file." max_duration=MAX_DURATION,
if detection_threshold is not None and detection_threshold != "":
args["detection_threshold"] = float(detection_threshold)
run_config = {
**params,
**args,
"max_duration": max_duration,
}
# process the file to generate predictions
results = du.process_file(
audio_file,
model,
run_config,
) )
anns = [ann for ann in results["pred_dict"]["annotation"]] # process the file to generate predictions
clss = [aa["class"] for aa in anns] results = api.process_file(file_name, config=run_config)
st_time = [aa["start_time"] for aa in anns]
cls_prob = [aa["class_prob"] for aa in anns]
det_prob = [aa["det_prob"] for aa in anns]
data = {
"species": clss,
"time": st_time,
"detection_prob": det_prob,
"species_prob": cls_prob,
}
df = pd.DataFrame(data=data) # extract the detections
im = generate_results_image(audio_file, anns) detections = results["pred_dict"]["annotation"]
# create a dataframe of the predictions
df = pd.DataFrame(
[
{
"species": pred["class"],
"time": pred["start_time"],
"detection_prob": pred["class_prob"],
"species_prob": pred["class_prob"],
}
for pred in detections
]
)
im = generate_results_image(file_name, detections, run_config)
return [df, im] return [df, im]
def generate_results_image(audio_file, anns): def generate_results_image(file_name, detections, config):
audio = api.load_audio(
# load audio file_name,
sampling_rate, audio = au.load_audio( max_duration=config["max_duration"],
audio_file, time_exp_fact=config["time_expansion"],
args["time_expansion_factor"], target_samp_rate=config["target_samp_rate"],
params["target_samp_rate"],
params["scale_raw_audio"],
max_duration=max_duration,
) )
duration = audio.shape[0] / sampling_rate
# generate spec spec = api.generate_spectrogram(audio, config=config)
spec, spec_viz = au.generate_spectrogram(
audio, sampling_rate, params, True, False
)
# create fig # create fig
plt.close("all") plt.close("all")
fig = plt.figure( fig = plt.figure(
1, 1,
figsize=(spec.shape[1] / 100, spec.shape[0] / 100), figsize=(15, 4),
dpi=100, dpi=100,
frameon=False, frameon=False,
) )
spec_duration = au.x_coords_to_time( ax = fig.add_subplot(111)
spec.shape[1], plot.spectrogram_with_detections(spec, detections, ax=ax)
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
viz.create_box_image(
spec,
fig,
anns,
0,
spec_duration,
spec_duration,
params,
spec.max() * 1.1,
False,
True,
)
plt.ylabel("Freq - kHz")
plt.xlabel("Time - secs")
plt.tight_layout() plt.tight_layout()
# convert fig to image # convert fig to image
@ -126,7 +89,6 @@ def generate_results_image(audio_file, anns):
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
w, h = fig.canvas.get_width_height() w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1)) im = data.reshape((int(h), int(w), -1))
return im return im
@ -140,7 +102,7 @@ descr_txt = (
gr.Interface( gr.Interface(
fn=make_prediction, fn=make_prediction,
inputs=[ inputs=[
gr.Audio(source="upload", type="filepath", optional=True), gr.Audio(source="upload", type="filepath"),
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
], ],
outputs=[df, gr.Image(label="Visualisation")], outputs=[df, gr.Image(label="Visualisation")],