mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
390 lines
12 KiB
Python
390 lines
12 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import bat_detect.detector.compute_features as feats
|
|
import bat_detect.detector.post_process as pp
|
|
import bat_detect.utils.audio_utils as au
|
|
from bat_detect.detector import models
|
|
|
|
|
|
def get_default_bd_args():
|
|
args = {}
|
|
args["detection_threshold"] = 0.001
|
|
args["time_expansion_factor"] = 1
|
|
args["audio_dir"] = ""
|
|
args["ann_dir"] = ""
|
|
args["spec_slices"] = False
|
|
args["chunk_size"] = 3
|
|
args["spec_features"] = False
|
|
args["cnn_features"] = False
|
|
args["quiet"] = True
|
|
args["save_preds_if_empty"] = True
|
|
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
|
return args
|
|
|
|
|
|
def get_audio_files(ip_dir):
|
|
|
|
matches = []
|
|
for root, dirnames, filenames in os.walk(ip_dir):
|
|
for filename in filenames:
|
|
if filename.lower().endswith(".wav"):
|
|
matches.append(os.path.join(root, filename))
|
|
return matches
|
|
|
|
|
|
def load_model(model_path, load_weights=True):
|
|
|
|
# load model
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
if os.path.isfile(model_path):
|
|
net_params = torch.load(model_path, map_location=device)
|
|
else:
|
|
print("Error: model not found.")
|
|
sys.exit(1)
|
|
|
|
params = net_params["params"]
|
|
params["device"] = device
|
|
|
|
if params["model_name"] == "Net2DFast":
|
|
model = models.Net2DFast(
|
|
params["num_filters"],
|
|
num_classes=len(params["class_names"]),
|
|
emb_dim=params["emb_dim"],
|
|
ip_height=params["ip_height"],
|
|
resize_factor=params["resize_factor"],
|
|
)
|
|
elif params["model_name"] == "Net2DFastNoAttn":
|
|
model = models.Net2DFastNoAttn(
|
|
params["num_filters"],
|
|
num_classes=len(params["class_names"]),
|
|
emb_dim=params["emb_dim"],
|
|
ip_height=params["ip_height"],
|
|
resize_factor=params["resize_factor"],
|
|
)
|
|
elif params["model_name"] == "Net2DFastNoCoordConv":
|
|
model = models.Net2DFastNoCoordConv(
|
|
params["num_filters"],
|
|
num_classes=len(params["class_names"]),
|
|
emb_dim=params["emb_dim"],
|
|
ip_height=params["ip_height"],
|
|
resize_factor=params["resize_factor"],
|
|
)
|
|
else:
|
|
print("Error: unknown model.")
|
|
|
|
if load_weights:
|
|
model.load_state_dict(net_params["state_dict"])
|
|
|
|
model = model.to(params["device"])
|
|
model.eval()
|
|
|
|
return model, params
|
|
|
|
|
|
def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
|
|
|
predictions_m = {}
|
|
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
|
|
|
if num_preds > 0:
|
|
for kk in predictions[0].keys():
|
|
predictions_m[kk] = np.hstack(
|
|
[pp[kk] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
|
)
|
|
else:
|
|
# hack in case where no detected calls as we need some of the key names in dict
|
|
predictions_m = predictions[0]
|
|
|
|
if len(spec_feats) > 0:
|
|
spec_feats = np.vstack(spec_feats)
|
|
if len(cnn_feats) > 0:
|
|
cnn_feats = np.vstack(cnn_feats)
|
|
return predictions_m, spec_feats, cnn_feats, spec_slices
|
|
|
|
|
|
def convert_results(
|
|
file_id,
|
|
time_exp,
|
|
duration,
|
|
params,
|
|
predictions,
|
|
spec_feats,
|
|
cnn_feats,
|
|
spec_slices,
|
|
):
|
|
|
|
# create a single dictionary - this is the format used by the annotation tool
|
|
pred_dict = {}
|
|
pred_dict["id"] = file_id
|
|
pred_dict["annotated"] = False
|
|
pred_dict["issues"] = False
|
|
pred_dict["notes"] = "Automatically generated."
|
|
pred_dict["time_exp"] = time_exp
|
|
pred_dict["duration"] = round(duration, 4)
|
|
pred_dict["annotation"] = []
|
|
|
|
class_prob_best = predictions["class_probs"].max(0)
|
|
class_ind_best = predictions["class_probs"].argmax(0)
|
|
class_overall = pp.overall_class_pred(
|
|
predictions["det_probs"], predictions["class_probs"]
|
|
)
|
|
pred_dict["class_name"] = params["class_names"][np.argmax(class_overall)]
|
|
|
|
for ii in range(predictions["det_probs"].shape[0]):
|
|
res = {}
|
|
res["start_time"] = round(float(predictions["start_times"][ii]), 4)
|
|
res["end_time"] = round(float(predictions["end_times"][ii]), 4)
|
|
res["low_freq"] = int(predictions["low_freqs"][ii])
|
|
res["high_freq"] = int(predictions["high_freqs"][ii])
|
|
res["class"] = str(params["class_names"][int(class_ind_best[ii])])
|
|
res["class_prob"] = round(float(class_prob_best[ii]), 3)
|
|
res["det_prob"] = round(float(predictions["det_probs"][ii]), 3)
|
|
res["individual"] = "-1"
|
|
res["event"] = "Echolocation"
|
|
pred_dict["annotation"].append(res)
|
|
|
|
# combine into final results dictionary
|
|
results = {}
|
|
results["pred_dict"] = pred_dict
|
|
if len(spec_feats) > 0:
|
|
results["spec_feats"] = spec_feats
|
|
results["spec_feat_names"] = feats.get_feature_names()
|
|
if len(cnn_feats) > 0:
|
|
results["cnn_feats"] = cnn_feats
|
|
results["cnn_feat_names"] = [
|
|
str(ii) for ii in range(cnn_feats.shape[1])
|
|
]
|
|
if len(spec_slices) > 0:
|
|
results["spec_slices"] = spec_slices
|
|
|
|
return results
|
|
|
|
|
|
def save_results_to_file(results, op_path):
|
|
|
|
# make directory if it does not exist
|
|
if not os.path.isdir(os.path.dirname(op_path)):
|
|
os.makedirs(os.path.dirname(op_path))
|
|
|
|
# save csv file - if there are predictions
|
|
result_list = [res for res in results["pred_dict"]["annotation"]]
|
|
df = pd.DataFrame(result_list)
|
|
df["file_name"] = [results["pred_dict"]["id"]] * len(result_list)
|
|
df.index.name = "id"
|
|
if "class_prob" in df.columns:
|
|
df = df[
|
|
[
|
|
"det_prob",
|
|
"start_time",
|
|
"end_time",
|
|
"high_freq",
|
|
"low_freq",
|
|
"class",
|
|
"class_prob",
|
|
]
|
|
]
|
|
df.to_csv(op_path + ".csv", sep=",")
|
|
|
|
# save features
|
|
if "spec_feats" in results.keys():
|
|
df = pd.DataFrame(
|
|
results["spec_feats"], columns=results["spec_feat_names"]
|
|
)
|
|
df.to_csv(
|
|
op_path + "_spec_features.csv",
|
|
sep=",",
|
|
index=False,
|
|
float_format="%.5f",
|
|
)
|
|
|
|
if "cnn_feats" in results.keys():
|
|
df = pd.DataFrame(
|
|
results["cnn_feats"], columns=results["cnn_feat_names"]
|
|
)
|
|
df.to_csv(
|
|
op_path + "_cnn_features.csv",
|
|
sep=",",
|
|
index=False,
|
|
float_format="%.5f",
|
|
)
|
|
|
|
# save json file
|
|
with open(op_path + ".json", "w") as da:
|
|
json.dump(results["pred_dict"], da, indent=2, sort_keys=True)
|
|
|
|
|
|
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
|
|
|
# pad audio so it is evenly divisible by downsampling factors
|
|
duration = audio.shape[0] / float(sampling_rate)
|
|
audio = au.pad_audio(
|
|
audio,
|
|
sampling_rate,
|
|
params["fft_win_length"],
|
|
params["fft_overlap"],
|
|
params["resize_factor"],
|
|
params["spec_divide_factor"],
|
|
)
|
|
|
|
# generate spectrogram
|
|
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
|
|
|
# convert to pytorch
|
|
spec = torch.from_numpy(spec).to(params["device"])
|
|
spec = spec.unsqueeze(0).unsqueeze(0)
|
|
|
|
# resize the spec
|
|
rs = params["resize_factor"]
|
|
spec_op_shape = (int(params["spec_height"] * rs), int(spec.shape[-1] * rs))
|
|
spec = F.interpolate(
|
|
spec, size=spec_op_shape, mode="bilinear", align_corners=False
|
|
)
|
|
|
|
if return_np:
|
|
spec_np = spec[0, 0, :].cpu().data.numpy()
|
|
else:
|
|
spec_np = None
|
|
|
|
return duration, spec, spec_np
|
|
|
|
|
|
def process_file(
|
|
audio_file,
|
|
model,
|
|
params,
|
|
args,
|
|
time_exp=None,
|
|
top_n=5,
|
|
return_raw_preds=False,
|
|
max_duration=False,
|
|
):
|
|
|
|
# store temporary results here
|
|
predictions = []
|
|
spec_feats = []
|
|
cnn_feats = []
|
|
spec_slices = []
|
|
|
|
# get time expansion factor
|
|
if time_exp is None:
|
|
time_exp = args["time_expansion_factor"]
|
|
|
|
params["detection_threshold"] = args["detection_threshold"]
|
|
|
|
# load audio file
|
|
sampling_rate, audio_full = au.load_audio_file(
|
|
audio_file,
|
|
time_exp,
|
|
params["target_samp_rate"],
|
|
params["scale_raw_audio"],
|
|
)
|
|
|
|
# clipping maximum duration
|
|
if max_duration is not False:
|
|
max_duration = np.minimum(
|
|
int(sampling_rate * max_duration), audio_full.shape[0]
|
|
)
|
|
audio_full = audio_full[:max_duration]
|
|
|
|
duration_full = audio_full.shape[0] / float(sampling_rate)
|
|
|
|
return_np_spec = args["spec_features"] or args["spec_slices"]
|
|
|
|
# loop through larger file and split into chunks
|
|
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
|
|
num_chunks = int(np.ceil(duration_full / args["chunk_size"]))
|
|
for chunk_id in range(num_chunks):
|
|
|
|
# chunk
|
|
chunk_time = args["chunk_size"] * chunk_id
|
|
chunk_length = int(sampling_rate * args["chunk_size"])
|
|
start_sample = chunk_id * chunk_length
|
|
end_sample = np.minimum(
|
|
(chunk_id + 1) * chunk_length, audio_full.shape[0]
|
|
)
|
|
audio = audio_full[start_sample:end_sample]
|
|
|
|
# load audio file and compute spectrogram
|
|
duration, spec, spec_np = compute_spectrogram(
|
|
audio, sampling_rate, params, return_np_spec
|
|
)
|
|
|
|
# evaluate model
|
|
with torch.no_grad():
|
|
outputs = model(spec, return_feats=args["cnn_features"])
|
|
|
|
# run non-max suppression
|
|
pred_nms, features = pp.run_nms(
|
|
outputs, params, np.array([float(sampling_rate)])
|
|
)
|
|
pred_nms = pred_nms[0]
|
|
pred_nms["start_times"] += chunk_time
|
|
pred_nms["end_times"] += chunk_time
|
|
|
|
# if we have a background class
|
|
if pred_nms["class_probs"].shape[0] > len(params["class_names"]):
|
|
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
|
|
|
|
predictions.append(pred_nms)
|
|
|
|
# extract features - if there are any calls detected
|
|
if pred_nms["det_probs"].shape[0] > 0:
|
|
if args["spec_features"]:
|
|
spec_feats.append(feats.get_feats(spec_np, pred_nms, params))
|
|
|
|
if args["cnn_features"]:
|
|
cnn_feats.append(features[0])
|
|
|
|
if args["spec_slices"]:
|
|
spec_slices.extend(
|
|
feats.extract_spec_slices(spec_np, pred_nms, params)
|
|
)
|
|
|
|
# convert the predictions into output dictionary
|
|
file_id = os.path.basename(audio_file)
|
|
predictions, spec_feats, cnn_feats, spec_slices = merge_results(
|
|
predictions, spec_feats, cnn_feats, spec_slices
|
|
)
|
|
results = convert_results(
|
|
file_id,
|
|
time_exp,
|
|
duration_full,
|
|
params,
|
|
predictions,
|
|
spec_feats,
|
|
cnn_feats,
|
|
spec_slices,
|
|
)
|
|
|
|
# summarize results
|
|
if not args["quiet"]:
|
|
num_detections = len(results["pred_dict"]["annotation"])
|
|
print(
|
|
"{}".format(num_detections)
|
|
+ " call(s) detected above the threshold."
|
|
)
|
|
|
|
# print results for top n classes
|
|
if not args["quiet"] and (num_detections > 0):
|
|
class_overall = pp.overall_class_pred(
|
|
predictions["det_probs"], predictions["class_probs"]
|
|
)
|
|
print("species name".ljust(30) + "probablity present")
|
|
for cc in np.argsort(class_overall)[::-1][:top_n]:
|
|
print(
|
|
params["class_names"][cc].ljust(30)
|
|
+ str(round(class_overall[cc], 3))
|
|
)
|
|
|
|
if return_raw_preds:
|
|
return predictions
|
|
else:
|
|
return results
|