mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
270 lines
10 KiB
Python
270 lines
10 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import os
|
|
import numpy as np
|
|
import pandas as pd
|
|
import json
|
|
import sys
|
|
|
|
from bat_detect.detector import models
|
|
import bat_detect.detector.compute_features as feats
|
|
import bat_detect.detector.post_process as pp
|
|
import bat_detect.utils.audio_utils as au
|
|
|
|
|
|
def get_audio_files(ip_dir):
|
|
|
|
matches = []
|
|
for root, dirnames, filenames in os.walk(ip_dir):
|
|
for filename in filenames:
|
|
if filename.lower().endswith('.wav'):
|
|
matches.append(os.path.join(root, filename))
|
|
return matches
|
|
|
|
|
|
def load_model(model_path, load_weights=True):
|
|
|
|
# load model
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
if os.path.isfile(model_path):
|
|
net_params = torch.load(model_path, map_location=device)
|
|
else:
|
|
print('Error: model not found.')
|
|
sys.exit(1)
|
|
|
|
params = net_params['params']
|
|
params['device'] = device
|
|
|
|
if params['model_name'] == 'Net2DFast':
|
|
model = models.Net2DFast(params['num_filters'], num_classes=len(params['class_names']),
|
|
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
|
resize_factor=params['resize_factor'])
|
|
elif params['model_name'] == 'Net2DFastNoAttn':
|
|
model = models.Net2DFastNoAttn(params['num_filters'], num_classes=len(params['class_names']),
|
|
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
|
resize_factor=params['resize_factor'])
|
|
elif params['model_name'] == 'Net2DFastNoCoordConv':
|
|
model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=len(params['class_names']),
|
|
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
|
resize_factor=params['resize_factor'])
|
|
else:
|
|
print('Error: unknown model.')
|
|
|
|
if load_weights:
|
|
model.load_state_dict(net_params['state_dict'])
|
|
|
|
model = model.to(params['device'])
|
|
model.eval()
|
|
|
|
return model, params
|
|
|
|
|
|
def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
|
|
|
predictions_m = {}
|
|
num_preds = np.sum([len(pp['det_probs']) for pp in predictions])
|
|
|
|
if num_preds > 0:
|
|
for kk in predictions[0].keys():
|
|
predictions_m[kk] = np.hstack([pp[kk] for pp in predictions if pp['det_probs'].shape[0] > 0])
|
|
else:
|
|
# hack in case where no detected calls as we need some of the key names in dict
|
|
predictions_m = predictions[0]
|
|
|
|
if len(spec_feats) > 0:
|
|
spec_feats = np.vstack(spec_feats)
|
|
if len(cnn_feats) > 0:
|
|
cnn_feats = np.vstack(cnn_feats)
|
|
return predictions_m, spec_feats, cnn_feats, spec_slices
|
|
|
|
|
|
def convert_results(file_id, time_exp, duration, params, predictions, spec_feats, cnn_feats, spec_slices):
|
|
|
|
# create a single dictionary - this is the format used by the annotation tool
|
|
pred_dict = {}
|
|
pred_dict['id'] = file_id
|
|
pred_dict['annotated'] = False
|
|
pred_dict['issues'] = False
|
|
pred_dict['notes'] = 'Automatically generated.'
|
|
pred_dict['time_exp'] = time_exp
|
|
pred_dict['duration'] = round(duration, 4)
|
|
pred_dict['annotation'] = []
|
|
|
|
class_prob_best = predictions['class_probs'].max(0)
|
|
class_ind_best = predictions['class_probs'].argmax(0)
|
|
class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
|
|
pred_dict['class_name'] = params['class_names'][np.argmax(class_overall)]
|
|
|
|
for ii in range(predictions['det_probs'].shape[0]):
|
|
res = {}
|
|
res['start_time'] = round(float(predictions['start_times'][ii]), 4)
|
|
res['end_time'] = round(float(predictions['end_times'][ii]), 4)
|
|
res['low_freq'] = int(predictions['low_freqs'][ii])
|
|
res['high_freq'] = int(predictions['high_freqs'][ii])
|
|
res['class'] = str(params['class_names'][int(class_ind_best[ii])])
|
|
res['class_prob'] = round(float(class_prob_best[ii]), 3)
|
|
res['det_prob'] = round(float(predictions['det_probs'][ii]), 3)
|
|
res['individual'] = '-1'
|
|
res['event'] = 'Echolocation'
|
|
pred_dict['annotation'].append(res)
|
|
|
|
# combine into final results dictionary
|
|
results = {}
|
|
results['pred_dict'] = pred_dict
|
|
if len(spec_feats) > 0:
|
|
results['spec_feats'] = spec_feats
|
|
results['spec_feat_names'] = feats.get_feature_names()
|
|
if len(cnn_feats) > 0:
|
|
results['cnn_feats'] = cnn_feats
|
|
results['cnn_feat_names'] = [str(ii) for ii in range(cnn_feats.shape[1])]
|
|
if len(spec_slices) > 0:
|
|
results['spec_slices'] = spec_slices
|
|
|
|
return results
|
|
|
|
|
|
def save_results_to_file(results, op_path):
|
|
|
|
# make directory if it does not exist
|
|
if not os.path.isdir(os.path.dirname(op_path)):
|
|
os.makedirs(os.path.dirname(op_path))
|
|
|
|
# save csv file - if there are predictions
|
|
result_list = [res for res in results['pred_dict']['annotation']]
|
|
df = pd.DataFrame(result_list)
|
|
df['file_name'] = [results['pred_dict']['id']]*len(result_list)
|
|
df.index.name = 'id'
|
|
if 'class_prob' in df.columns:
|
|
df = df[['det_prob', 'start_time', 'end_time', 'high_freq',
|
|
'low_freq', 'class', 'class_prob']]
|
|
df.to_csv(op_path + '.csv', sep=',')
|
|
|
|
# save features
|
|
if 'spec_feats' in results.keys():
|
|
df = pd.DataFrame(results['spec_feats'], columns=results['spec_feat_names'])
|
|
df.to_csv(op_path + '_spec_features.csv', sep=',', index=False, float_format='%.5f')
|
|
|
|
if 'cnn_feats' in results.keys():
|
|
df = pd.DataFrame(results['cnn_feats'], columns=results['cnn_feat_names'])
|
|
df.to_csv(op_path + '_cnn_features.csv', sep=',', index=False, float_format='%.5f')
|
|
|
|
# save json file
|
|
with open(op_path + '.json', 'w') as da:
|
|
json.dump(results['pred_dict'], da, indent=2, sort_keys=True)
|
|
|
|
|
|
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
|
|
|
# pad audio so it is evenly divisible by downsampling factors
|
|
duration = audio.shape[0] / float(sampling_rate)
|
|
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
|
|
params['fft_overlap'], params['resize_factor'],
|
|
params['spec_divide_factor'])
|
|
|
|
# generate spectrogram
|
|
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
|
|
|
# convert to pytorch
|
|
spec = torch.from_numpy(spec).to(params['device'])
|
|
spec = spec.unsqueeze(0).unsqueeze(0)
|
|
|
|
# resize the spec
|
|
rs = params['resize_factor']
|
|
spec_op_shape = (int(params['spec_height']*rs), int(spec.shape[-1]*rs))
|
|
spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False)
|
|
|
|
if return_np:
|
|
spec_np = spec[0,0,:].cpu().data.numpy()
|
|
else:
|
|
spec_np = None
|
|
|
|
return duration, spec, spec_np
|
|
|
|
|
|
def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False):
|
|
|
|
# store temporary results here
|
|
predictions = []
|
|
spec_feats = []
|
|
cnn_feats = []
|
|
spec_slices = []
|
|
|
|
# get time expansion factor
|
|
if time_exp is None:
|
|
time_exp = args['time_expansion_factor']
|
|
|
|
params['detection_threshold'] = args['detection_threshold']
|
|
|
|
# load audio file
|
|
sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp,
|
|
params['target_samp_rate'], params['scale_raw_audio'])
|
|
duration_full = audio_full.shape[0] / float(sampling_rate)
|
|
|
|
return_np_spec = args['spec_features'] or args['spec_slices']
|
|
|
|
# loop through larger file and split into chunks
|
|
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
|
|
num_chunks = int(np.ceil(duration_full/args['chunk_size']))
|
|
for chunk_id in range(num_chunks):
|
|
|
|
# chunk
|
|
chunk_time = args['chunk_size']*chunk_id
|
|
chunk_length = int(sampling_rate*args['chunk_size'])
|
|
start_sample = chunk_id*chunk_length
|
|
end_sample = np.minimum((chunk_id+1)*chunk_length, audio_full.shape[0])
|
|
audio = audio_full[start_sample:end_sample]
|
|
|
|
# load audio file and compute spectrogram
|
|
duration, spec, spec_np = compute_spectrogram(audio, sampling_rate, params, return_np_spec)
|
|
|
|
# evaluate model
|
|
with torch.no_grad():
|
|
outputs = model(spec, return_feats=args['cnn_features'])
|
|
|
|
# run non-max suppression
|
|
pred_nms, features = pp.run_nms(outputs, params, np.array([float(sampling_rate)]))
|
|
pred_nms = pred_nms[0]
|
|
pred_nms['start_times'] += chunk_time
|
|
pred_nms['end_times'] += chunk_time
|
|
|
|
# if we have a background class
|
|
if pred_nms['class_probs'].shape[0] > len(params['class_names']):
|
|
pred_nms['class_probs'] = pred_nms['class_probs'][:-1, :]
|
|
|
|
predictions.append(pred_nms)
|
|
|
|
# extract features - if there are any calls detected
|
|
if (pred_nms['det_probs'].shape[0] > 0):
|
|
if args['spec_features']:
|
|
spec_feats.append(feats.get_feats(spec_np, pred_nms, params))
|
|
|
|
if args['cnn_features']:
|
|
cnn_feats.append(features[0])
|
|
|
|
if args['spec_slices']:
|
|
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms, params))
|
|
|
|
# convert the predictions into output dictionary
|
|
file_id = os.path.basename(audio_file)
|
|
predictions, spec_feats, cnn_feats, spec_slices =\
|
|
merge_results(predictions, spec_feats, cnn_feats, spec_slices)
|
|
results = convert_results(file_id, time_exp, duration_full, params,
|
|
predictions, spec_feats, cnn_feats, spec_slices)
|
|
|
|
# summarize results
|
|
if not args['quiet']:
|
|
num_detections = len(results['pred_dict']['annotation'])
|
|
print('{}'.format(num_detections) + ' call(s) detected above the threshold.')
|
|
|
|
# print results for top n classes
|
|
if not args['quiet'] and (num_detections > 0):
|
|
class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
|
|
print('species name'.ljust(30) + 'probablity present')
|
|
for cc in np.argsort(class_overall)[::-1][:top_n]:
|
|
print(params['class_names'][cc].ljust(30) + str(round(class_overall[cc], 3)))
|
|
|
|
if return_raw_preds:
|
|
return predictions
|
|
else:
|
|
return results
|