batdetect2/bat_detect/utils/detector_utils.py
2023-01-25 19:17:38 +00:00

390 lines
12 KiB
Python

import json
import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import bat_detect.detector.compute_features as feats
import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au
from bat_detect.detector import models
def get_default_bd_args():
args = {}
args["detection_threshold"] = 0.001
args["time_expansion_factor"] = 1
args["audio_dir"] = ""
args["ann_dir"] = ""
args["spec_slices"] = False
args["chunk_size"] = 3
args["spec_features"] = False
args["cnn_features"] = False
args["quiet"] = True
args["save_preds_if_empty"] = True
args["ann_dir"] = os.path.join(args["ann_dir"], "")
return args
def get_audio_files(ip_dir):
matches = []
for root, dirnames, filenames in os.walk(ip_dir):
for filename in filenames:
if filename.lower().endswith(".wav"):
matches.append(os.path.join(root, filename))
return matches
def load_model(model_path, load_weights=True):
# load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.path.isfile(model_path):
net_params = torch.load(model_path, map_location=device)
else:
print("Error: model not found.")
sys.exit(1)
params = net_params["params"]
params["device"] = device
if params["model_name"] == "Net2DFast":
model = models.Net2DFast(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
elif params["model_name"] == "Net2DFastNoAttn":
model = models.Net2DFastNoAttn(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
elif params["model_name"] == "Net2DFastNoCoordConv":
model = models.Net2DFastNoCoordConv(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
else:
print("Error: unknown model.")
if load_weights:
model.load_state_dict(net_params["state_dict"])
model = model.to(params["device"])
model.eval()
return model, params
def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {}
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
if num_preds > 0:
for kk in predictions[0].keys():
predictions_m[kk] = np.hstack(
[pp[kk] for pp in predictions if pp["det_probs"].shape[0] > 0]
)
else:
# hack in case where no detected calls as we need some of the key names in dict
predictions_m = predictions[0]
if len(spec_feats) > 0:
spec_feats = np.vstack(spec_feats)
if len(cnn_feats) > 0:
cnn_feats = np.vstack(cnn_feats)
return predictions_m, spec_feats, cnn_feats, spec_slices
def convert_results(
file_id,
time_exp,
duration,
params,
predictions,
spec_feats,
cnn_feats,
spec_slices,
):
# create a single dictionary - this is the format used by the annotation tool
pred_dict = {}
pred_dict["id"] = file_id
pred_dict["annotated"] = False
pred_dict["issues"] = False
pred_dict["notes"] = "Automatically generated."
pred_dict["time_exp"] = time_exp
pred_dict["duration"] = round(duration, 4)
pred_dict["annotation"] = []
class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0)
class_overall = pp.overall_class_pred(
predictions["det_probs"], predictions["class_probs"]
)
pred_dict["class_name"] = params["class_names"][np.argmax(class_overall)]
for ii in range(predictions["det_probs"].shape[0]):
res = {}
res["start_time"] = round(float(predictions["start_times"][ii]), 4)
res["end_time"] = round(float(predictions["end_times"][ii]), 4)
res["low_freq"] = int(predictions["low_freqs"][ii])
res["high_freq"] = int(predictions["high_freqs"][ii])
res["class"] = str(params["class_names"][int(class_ind_best[ii])])
res["class_prob"] = round(float(class_prob_best[ii]), 3)
res["det_prob"] = round(float(predictions["det_probs"][ii]), 3)
res["individual"] = "-1"
res["event"] = "Echolocation"
pred_dict["annotation"].append(res)
# combine into final results dictionary
results = {}
results["pred_dict"] = pred_dict
if len(spec_feats) > 0:
results["spec_feats"] = spec_feats
results["spec_feat_names"] = feats.get_feature_names()
if len(cnn_feats) > 0:
results["cnn_feats"] = cnn_feats
results["cnn_feat_names"] = [
str(ii) for ii in range(cnn_feats.shape[1])
]
if len(spec_slices) > 0:
results["spec_slices"] = spec_slices
return results
def save_results_to_file(results, op_path):
# make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)):
os.makedirs(os.path.dirname(op_path))
# save csv file - if there are predictions
result_list = [res for res in results["pred_dict"]["annotation"]]
df = pd.DataFrame(result_list)
df["file_name"] = [results["pred_dict"]["id"]] * len(result_list)
df.index.name = "id"
if "class_prob" in df.columns:
df = df[
[
"det_prob",
"start_time",
"end_time",
"high_freq",
"low_freq",
"class",
"class_prob",
]
]
df.to_csv(op_path + ".csv", sep=",")
# save features
if "spec_feats" in results.keys():
df = pd.DataFrame(
results["spec_feats"], columns=results["spec_feat_names"]
)
df.to_csv(
op_path + "_spec_features.csv",
sep=",",
index=False,
float_format="%.5f",
)
if "cnn_feats" in results.keys():
df = pd.DataFrame(
results["cnn_feats"], columns=results["cnn_feat_names"]
)
df.to_csv(
op_path + "_cnn_features.csv",
sep=",",
index=False,
float_format="%.5f",
)
# save json file
with open(op_path + ".json", "w") as da:
json.dump(results["pred_dict"], da, indent=2, sort_keys=True)
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
# pad audio so it is evenly divisible by downsampling factors
duration = audio.shape[0] / float(sampling_rate)
audio = au.pad_audio(
audio,
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
params["resize_factor"],
params["spec_divide_factor"],
)
# generate spectrogram
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
# convert to pytorch
spec = torch.from_numpy(spec).to(params["device"])
spec = spec.unsqueeze(0).unsqueeze(0)
# resize the spec
rs = params["resize_factor"]
spec_op_shape = (int(params["spec_height"] * rs), int(spec.shape[-1] * rs))
spec = F.interpolate(
spec, size=spec_op_shape, mode="bilinear", align_corners=False
)
if return_np:
spec_np = spec[0, 0, :].cpu().data.numpy()
else:
spec_np = None
return duration, spec, spec_np
def process_file(
audio_file,
model,
params,
args,
time_exp=None,
top_n=5,
return_raw_preds=False,
max_duration=False,
):
# store temporary results here
predictions = []
spec_feats = []
cnn_feats = []
spec_slices = []
# get time expansion factor
if time_exp is None:
time_exp = args["time_expansion_factor"]
params["detection_threshold"] = args["detection_threshold"]
# load audio file
sampling_rate, audio_full = au.load_audio_file(
audio_file,
time_exp,
params["target_samp_rate"],
params["scale_raw_audio"],
)
# clipping maximum duration
if max_duration is not False:
max_duration = np.minimum(
int(sampling_rate * max_duration), audio_full.shape[0]
)
audio_full = audio_full[:max_duration]
duration_full = audio_full.shape[0] / float(sampling_rate)
return_np_spec = args["spec_features"] or args["spec_slices"]
# loop through larger file and split into chunks
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
num_chunks = int(np.ceil(duration_full / args["chunk_size"]))
for chunk_id in range(num_chunks):
# chunk
chunk_time = args["chunk_size"] * chunk_id
chunk_length = int(sampling_rate * args["chunk_size"])
start_sample = chunk_id * chunk_length
end_sample = np.minimum(
(chunk_id + 1) * chunk_length, audio_full.shape[0]
)
audio = audio_full[start_sample:end_sample]
# load audio file and compute spectrogram
duration, spec, spec_np = compute_spectrogram(
audio, sampling_rate, params, return_np_spec
)
# evaluate model
with torch.no_grad():
outputs = model(spec, return_feats=args["cnn_features"])
# run non-max suppression
pred_nms, features = pp.run_nms(
outputs, params, np.array([float(sampling_rate)])
)
pred_nms = pred_nms[0]
pred_nms["start_times"] += chunk_time
pred_nms["end_times"] += chunk_time
# if we have a background class
if pred_nms["class_probs"].shape[0] > len(params["class_names"]):
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
predictions.append(pred_nms)
# extract features - if there are any calls detected
if pred_nms["det_probs"].shape[0] > 0:
if args["spec_features"]:
spec_feats.append(feats.get_feats(spec_np, pred_nms, params))
if args["cnn_features"]:
cnn_feats.append(features[0])
if args["spec_slices"]:
spec_slices.extend(
feats.extract_spec_slices(spec_np, pred_nms, params)
)
# convert the predictions into output dictionary
file_id = os.path.basename(audio_file)
predictions, spec_feats, cnn_feats, spec_slices = merge_results(
predictions, spec_feats, cnn_feats, spec_slices
)
results = convert_results(
file_id,
time_exp,
duration_full,
params,
predictions,
spec_feats,
cnn_feats,
spec_slices,
)
# summarize results
if not args["quiet"]:
num_detections = len(results["pred_dict"]["annotation"])
print(
"{}".format(num_detections)
+ " call(s) detected above the threshold."
)
# print results for top n classes
if not args["quiet"] and (num_detections > 0):
class_overall = pp.overall_class_pred(
predictions["det_probs"], predictions["class_probs"]
)
print("species name".ljust(30) + "probablity present")
for cc in np.argsort(class_overall)[::-1][:top_n]:
print(
params["class_names"][cc].ljust(30)
+ str(round(class_overall[cc], 3))
)
if return_raw_preds:
return predictions
else:
return results