import numpy as np import matplotlib.pyplot as plt from scipy import ndimage import os import sys sys.path.append(os.path.join('..')) import bat_detect.utils.audio_utils as au def generate_spectrogram_data(audio, sampling_rate, params, norm_type='log', smooth_spec=False): max_freq = round(params['max_freq']*params['fft_win_length']) min_freq = round(params['min_freq']*params['fft_win_length']) # create spectrogram - numpy spec = au.gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']) #spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy() if spec.shape[0] < max_freq: freq_pad = max_freq - spec.shape[0] spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec)) spec = spec[-max_freq:spec.shape[0]-min_freq, :] if norm_type == 'log': log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum()) ##log_scaling = 0.01 spec = np.log(1.0 + log_scaling*spec).astype(np.float32) elif norm_type == 'pcen': spec = au.pcen(spec, sampling_rate) else: pass if smooth_spec: spec = ndimage.gaussian_filter(spec, 1) return spec def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', extract_bg=False): specs = [] labels = [] coords = [] audios = [] sampling_rates = [] file_names = [] for cur_file in anns: sampling_rate, audio_orig = au.load_audio_file(cur_file['file_path'], cur_file['time_exp'], params['target_samp_rate'], params['scale_raw_audio']) for ann in cur_file['annotation']: if ann['class'] not in params['classes_to_ignore'] and ann['class'] in class_names: # clip out of bounds if ann['low_freq'] < params['min_freq']: ann['low_freq'] = params['min_freq'] if ann['high_freq'] > params['max_freq']: ann['high_freq'] = params['max_freq'] # load cropped audio start_samp_diff = int(sampling_rate*ann['start_time']) - int(sampling_rate*params['aud_pad']) start_samp = np.maximum(0, start_samp_diff) end_samp = np.minimum(audio_orig.shape[0], int(sampling_rate*ann['end_time'])*2 + int(sampling_rate*params['aud_pad'])) audio = audio_orig[start_samp:end_samp] if start_samp_diff < 0: # need to pad at start if the call is at the very begining audio = np.hstack((np.zeros(-start_samp_diff, dtype=np.float32), audio)) nfft = int(params['fft_win_length']*sampling_rate) noverlap = int(params['fft_overlap']*nfft) max_samps = params['spec_width']*(nfft - noverlap) + noverlap if max_samps > audio.shape[0]: audio = np.hstack((audio, np.zeros(max_samps - audio.shape[0]))) audio = audio[:max_samps].astype(np.float32) audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], params['fft_overlap'], params['resize_factor'], params['spec_divide_factor']) # generate spectrogram spec = generate_spectrogram_data(audio, sampling_rate, params, norm_type, smooth_spec)[:, :params['spec_width']] specs.append(spec[np.newaxis, ...]) labels.append(ann['class']) audios.append(audio) sampling_rates.append(sampling_rate) file_names.append(cur_file['file_path']) # position in crop x1 = int(au.time_to_x_coords(np.array(params['aud_pad']), sampling_rate, params['fft_win_length'], params['fft_overlap'])) y1 = (ann['low_freq'] - params['min_freq']) * params['fft_win_length'] coords.append((y1, x1)) _, file_ids = np.unique(file_names, return_inverse=True) labels = np.array([class_names.index(ll) for ll in labels]) #return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names return np.vstack(specs), labels def save_summary_image(specs, labels, species_names, params, op_file_name='plots/all_species.png', order=None): # takes the mean for each class and plots it on a grid mean_specs = [] max_band = [] for ii in range(len(species_names)): inds = np.where(labels==ii)[0] mu = specs[inds, :].mean(0) max_band.append(np.argmax(mu.sum(1))) mean_specs.append(mu) # control the order in which classes are printed if order is None: order = np.arange(len(species_names)) max_cols = 6 nrows = int(np.ceil(len(species_names)/max_cols)) ncols = np.minimum(len(species_names), max_cols) fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.3, nrows*6), gridspec_kw = {'wspace':0, 'hspace':0.2}) spec_min_max = (0, mean_specs[0].shape[1], params['min_freq']/1000, params['max_freq']/1000) ii = 0 for row in ax: if type(row) != np.ndarray: row = np.array([row]) for col in row: if ii >= len(species_names): col.axis('off') else: inds = np.where(labels==order[ii])[0] col.imshow(mean_specs[order[ii]], extent=spec_min_max, cmap='plasma', aspect='equal') col.grid(color='w', alpha=0.3, linewidth=0.3) col.set_xticks([]) col.title.set_text(str(ii+1) + ' ' + species_names[order[ii]]) col.tick_params(axis='both', which='major', labelsize=7) ii += 1 #plt.tight_layout() #plt.show() plt.savefig(op_file_name) plt.close('all')