mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Updated gradio app to new api
This commit is contained in:
parent
b8bbfe8ad4
commit
0a7ad18193
136
app.py
136
app.py
@ -3,21 +3,10 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import batdetect2.utils.audio_utils as au
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as viz
|
||||
|
||||
# setup the arguments
|
||||
args = {}
|
||||
args = du.get_default_run_config()
|
||||
args["detection_threshold"] = 0.3
|
||||
args["time_expansion_factor"] = 1
|
||||
args["model_path"] = "models/Net2DFast_UK_same.pth.tar"
|
||||
max_duration = 2.0
|
||||
|
||||
# load the model
|
||||
model, params = du.load_model(args["model_path"])
|
||||
from batdetect2 import api, plot
|
||||
|
||||
MAX_DURATION = 2
|
||||
DETECTION_THRESHOLD = 0.3
|
||||
|
||||
df = gr.Dataframe(
|
||||
headers=["species", "time", "detection_prob", "species_prob"],
|
||||
@ -28,97 +17,71 @@ df = gr.Dataframe(
|
||||
)
|
||||
|
||||
examples = [
|
||||
["example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 0.3],
|
||||
["example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 0.3],
|
||||
["example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 0.3],
|
||||
[
|
||||
"example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav",
|
||||
DETECTION_THRESHOLD,
|
||||
],
|
||||
[
|
||||
"example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav",
|
||||
DETECTION_THRESHOLD,
|
||||
],
|
||||
[
|
||||
"example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav",
|
||||
DETECTION_THRESHOLD,
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def make_prediction(file_name=None, detection_threshold=0.3):
|
||||
if file_name is not None:
|
||||
audio_file = file_name
|
||||
else:
|
||||
return "You must provide an input audio file."
|
||||
|
||||
if detection_threshold is not None and detection_threshold != "":
|
||||
args["detection_threshold"] = float(detection_threshold)
|
||||
|
||||
run_config = {
|
||||
**params,
|
||||
**args,
|
||||
"max_duration": max_duration,
|
||||
}
|
||||
|
||||
# process the file to generate predictions
|
||||
results = du.process_file(
|
||||
audio_file,
|
||||
model,
|
||||
run_config,
|
||||
def make_prediction(file_name, detection_threshold=DETECTION_THRESHOLD):
|
||||
# configure the model run
|
||||
run_config = api.get_config(
|
||||
detection_threshold=detection_threshold,
|
||||
max_duration=MAX_DURATION,
|
||||
)
|
||||
|
||||
anns = [ann for ann in results["pred_dict"]["annotation"]]
|
||||
clss = [aa["class"] for aa in anns]
|
||||
st_time = [aa["start_time"] for aa in anns]
|
||||
cls_prob = [aa["class_prob"] for aa in anns]
|
||||
det_prob = [aa["det_prob"] for aa in anns]
|
||||
data = {
|
||||
"species": clss,
|
||||
"time": st_time,
|
||||
"detection_prob": det_prob,
|
||||
"species_prob": cls_prob,
|
||||
}
|
||||
# process the file to generate predictions
|
||||
results = api.process_file(file_name, config=run_config)
|
||||
|
||||
df = pd.DataFrame(data=data)
|
||||
im = generate_results_image(audio_file, anns)
|
||||
# extract the detections
|
||||
detections = results["pred_dict"]["annotation"]
|
||||
|
||||
# create a dataframe of the predictions
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"species": pred["class"],
|
||||
"time": pred["start_time"],
|
||||
"detection_prob": pred["class_prob"],
|
||||
"species_prob": pred["class_prob"],
|
||||
}
|
||||
for pred in detections
|
||||
]
|
||||
)
|
||||
im = generate_results_image(file_name, detections, run_config)
|
||||
|
||||
return [df, im]
|
||||
|
||||
|
||||
def generate_results_image(audio_file, anns):
|
||||
|
||||
# load audio
|
||||
sampling_rate, audio = au.load_audio(
|
||||
audio_file,
|
||||
args["time_expansion_factor"],
|
||||
params["target_samp_rate"],
|
||||
params["scale_raw_audio"],
|
||||
max_duration=max_duration,
|
||||
def generate_results_image(file_name, detections, config):
|
||||
audio = api.load_audio(
|
||||
file_name,
|
||||
max_duration=config["max_duration"],
|
||||
time_exp_fact=config["time_expansion"],
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
)
|
||||
duration = audio.shape[0] / sampling_rate
|
||||
|
||||
# generate spec
|
||||
spec, spec_viz = au.generate_spectrogram(
|
||||
audio, sampling_rate, params, True, False
|
||||
)
|
||||
spec = api.generate_spectrogram(audio, config=config)
|
||||
|
||||
# create fig
|
||||
plt.close("all")
|
||||
fig = plt.figure(
|
||||
1,
|
||||
figsize=(spec.shape[1] / 100, spec.shape[0] / 100),
|
||||
figsize=(15, 4),
|
||||
dpi=100,
|
||||
frameon=False,
|
||||
)
|
||||
spec_duration = au.x_coords_to_time(
|
||||
spec.shape[1],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
viz.create_box_image(
|
||||
spec,
|
||||
fig,
|
||||
anns,
|
||||
0,
|
||||
spec_duration,
|
||||
spec_duration,
|
||||
params,
|
||||
spec.max() * 1.1,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
plt.ylabel("Freq - kHz")
|
||||
plt.xlabel("Time - secs")
|
||||
ax = fig.add_subplot(111)
|
||||
plot.spectrogram_with_detections(spec, detections, ax=ax)
|
||||
plt.tight_layout()
|
||||
|
||||
# convert fig to image
|
||||
@ -126,7 +89,6 @@ def generate_results_image(audio_file, anns):
|
||||
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||
w, h = fig.canvas.get_width_height()
|
||||
im = data.reshape((int(h), int(w), -1))
|
||||
|
||||
return im
|
||||
|
||||
|
||||
@ -140,7 +102,7 @@ descr_txt = (
|
||||
gr.Interface(
|
||||
fn=make_prediction,
|
||||
inputs=[
|
||||
gr.Audio(source="upload", type="filepath", optional=True),
|
||||
gr.Audio(source="upload", type="filepath"),
|
||||
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
|
||||
],
|
||||
outputs=[df, gr.Image(label="Visualisation")],
|
||||
|
Loading…
Reference in New Issue
Block a user