mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Updated gradio app to new api
This commit is contained in:
parent
b8bbfe8ad4
commit
0a7ad18193
136
app.py
136
app.py
@ -3,21 +3,10 @@ import matplotlib.pyplot as plt
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
import batdetect2.utils.audio_utils as au
|
from batdetect2 import api, plot
|
||||||
import batdetect2.utils.detector_utils as du
|
|
||||||
import batdetect2.utils.plot_utils as viz
|
|
||||||
|
|
||||||
# setup the arguments
|
|
||||||
args = {}
|
|
||||||
args = du.get_default_run_config()
|
|
||||||
args["detection_threshold"] = 0.3
|
|
||||||
args["time_expansion_factor"] = 1
|
|
||||||
args["model_path"] = "models/Net2DFast_UK_same.pth.tar"
|
|
||||||
max_duration = 2.0
|
|
||||||
|
|
||||||
# load the model
|
|
||||||
model, params = du.load_model(args["model_path"])
|
|
||||||
|
|
||||||
|
MAX_DURATION = 2
|
||||||
|
DETECTION_THRESHOLD = 0.3
|
||||||
|
|
||||||
df = gr.Dataframe(
|
df = gr.Dataframe(
|
||||||
headers=["species", "time", "detection_prob", "species_prob"],
|
headers=["species", "time", "detection_prob", "species_prob"],
|
||||||
@ -28,97 +17,71 @@ df = gr.Dataframe(
|
|||||||
)
|
)
|
||||||
|
|
||||||
examples = [
|
examples = [
|
||||||
["example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 0.3],
|
[
|
||||||
["example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 0.3],
|
"example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav",
|
||||||
["example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 0.3],
|
DETECTION_THRESHOLD,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav",
|
||||||
|
DETECTION_THRESHOLD,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav",
|
||||||
|
DETECTION_THRESHOLD,
|
||||||
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_prediction(file_name=None, detection_threshold=0.3):
|
def make_prediction(file_name, detection_threshold=DETECTION_THRESHOLD):
|
||||||
if file_name is not None:
|
# configure the model run
|
||||||
audio_file = file_name
|
run_config = api.get_config(
|
||||||
else:
|
detection_threshold=detection_threshold,
|
||||||
return "You must provide an input audio file."
|
max_duration=MAX_DURATION,
|
||||||
|
|
||||||
if detection_threshold is not None and detection_threshold != "":
|
|
||||||
args["detection_threshold"] = float(detection_threshold)
|
|
||||||
|
|
||||||
run_config = {
|
|
||||||
**params,
|
|
||||||
**args,
|
|
||||||
"max_duration": max_duration,
|
|
||||||
}
|
|
||||||
|
|
||||||
# process the file to generate predictions
|
|
||||||
results = du.process_file(
|
|
||||||
audio_file,
|
|
||||||
model,
|
|
||||||
run_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
anns = [ann for ann in results["pred_dict"]["annotation"]]
|
# process the file to generate predictions
|
||||||
clss = [aa["class"] for aa in anns]
|
results = api.process_file(file_name, config=run_config)
|
||||||
st_time = [aa["start_time"] for aa in anns]
|
|
||||||
cls_prob = [aa["class_prob"] for aa in anns]
|
|
||||||
det_prob = [aa["det_prob"] for aa in anns]
|
|
||||||
data = {
|
|
||||||
"species": clss,
|
|
||||||
"time": st_time,
|
|
||||||
"detection_prob": det_prob,
|
|
||||||
"species_prob": cls_prob,
|
|
||||||
}
|
|
||||||
|
|
||||||
df = pd.DataFrame(data=data)
|
# extract the detections
|
||||||
im = generate_results_image(audio_file, anns)
|
detections = results["pred_dict"]["annotation"]
|
||||||
|
|
||||||
|
# create a dataframe of the predictions
|
||||||
|
df = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"species": pred["class"],
|
||||||
|
"time": pred["start_time"],
|
||||||
|
"detection_prob": pred["class_prob"],
|
||||||
|
"species_prob": pred["class_prob"],
|
||||||
|
}
|
||||||
|
for pred in detections
|
||||||
|
]
|
||||||
|
)
|
||||||
|
im = generate_results_image(file_name, detections, run_config)
|
||||||
|
|
||||||
return [df, im]
|
return [df, im]
|
||||||
|
|
||||||
|
|
||||||
def generate_results_image(audio_file, anns):
|
def generate_results_image(file_name, detections, config):
|
||||||
|
audio = api.load_audio(
|
||||||
# load audio
|
file_name,
|
||||||
sampling_rate, audio = au.load_audio(
|
max_duration=config["max_duration"],
|
||||||
audio_file,
|
time_exp_fact=config["time_expansion"],
|
||||||
args["time_expansion_factor"],
|
target_samp_rate=config["target_samp_rate"],
|
||||||
params["target_samp_rate"],
|
|
||||||
params["scale_raw_audio"],
|
|
||||||
max_duration=max_duration,
|
|
||||||
)
|
)
|
||||||
duration = audio.shape[0] / sampling_rate
|
|
||||||
|
|
||||||
# generate spec
|
spec = api.generate_spectrogram(audio, config=config)
|
||||||
spec, spec_viz = au.generate_spectrogram(
|
|
||||||
audio, sampling_rate, params, True, False
|
|
||||||
)
|
|
||||||
|
|
||||||
# create fig
|
# create fig
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
fig = plt.figure(
|
fig = plt.figure(
|
||||||
1,
|
1,
|
||||||
figsize=(spec.shape[1] / 100, spec.shape[0] / 100),
|
figsize=(15, 4),
|
||||||
dpi=100,
|
dpi=100,
|
||||||
frameon=False,
|
frameon=False,
|
||||||
)
|
)
|
||||||
spec_duration = au.x_coords_to_time(
|
ax = fig.add_subplot(111)
|
||||||
spec.shape[1],
|
plot.spectrogram_with_detections(spec, detections, ax=ax)
|
||||||
sampling_rate,
|
|
||||||
params["fft_win_length"],
|
|
||||||
params["fft_overlap"],
|
|
||||||
)
|
|
||||||
viz.create_box_image(
|
|
||||||
spec,
|
|
||||||
fig,
|
|
||||||
anns,
|
|
||||||
0,
|
|
||||||
spec_duration,
|
|
||||||
spec_duration,
|
|
||||||
params,
|
|
||||||
spec.max() * 1.1,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
plt.ylabel("Freq - kHz")
|
|
||||||
plt.xlabel("Time - secs")
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
# convert fig to image
|
# convert fig to image
|
||||||
@ -126,7 +89,6 @@ def generate_results_image(audio_file, anns):
|
|||||||
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||||
w, h = fig.canvas.get_width_height()
|
w, h = fig.canvas.get_width_height()
|
||||||
im = data.reshape((int(h), int(w), -1))
|
im = data.reshape((int(h), int(w), -1))
|
||||||
|
|
||||||
return im
|
return im
|
||||||
|
|
||||||
|
|
||||||
@ -140,7 +102,7 @@ descr_txt = (
|
|||||||
gr.Interface(
|
gr.Interface(
|
||||||
fn=make_prediction,
|
fn=make_prediction,
|
||||||
inputs=[
|
inputs=[
|
||||||
gr.Audio(source="upload", type="filepath", optional=True),
|
gr.Audio(source="upload", type="filepath"),
|
||||||
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
|
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
|
||||||
],
|
],
|
||||||
outputs=[df, gr.Image(label="Visualisation")],
|
outputs=[df, gr.Image(label="Visualisation")],
|
||||||
|
Loading…
Reference in New Issue
Block a user