diff --git a/README.md b/README.md index 5627ab9..a1858c6 100644 --- a/README.md +++ b/README.md @@ -21,11 +21,13 @@ You can also run this notebook locally. ### Running the model on your own data After following the above steps to install the code you can run the model on your own data by opening the command line where the code is located and typing: `python run_batdetect.py AUDIO_DIR ANN_DIR DETECTION_THRESHOLD` +e.g. +`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3` + `AUDIO_DIR` is the path on your computer to the audio wav files of interest. `ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file. -`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes: -`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3` +`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistake. There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files: `python run_batdetect.py example_data/audio/ example_data/anns/ 0.3 --spec_features` @@ -34,13 +36,14 @@ You can also specify which model to use by setting the `--model_path` argument. ### Data and annotations -The raw audio data and annotations used to train the models in the paper will be added soon. +The raw audio data and annotations used to train the models in the paper will be added soon. +The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI). ### Warning -Note the models developed and shared as part of this repository should be used with caution. -While they have been evaluated on held out audio data, great care should be taken when using the models for any form of biodiversity assessment. -Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted. +The models developed and shared as part of this repository should be used with caution. +While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment. +Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted. ### FAQ diff --git a/bat_detect/detector/parameters.py b/bat_detect/detector/parameters.py index 7b1be46..10276eb 100644 --- a/bat_detect/detector/parameters.py +++ b/bat_detect/detector/parameters.py @@ -6,8 +6,8 @@ import datetime def mk_dir(path): if not os.path.isdir(path): os.makedirs(path) - - + + def get_params(make_dirs=False, exps_dir='../../experiments/'): params = {} diff --git a/bat_detect/evaluate/evaluate_models.py b/bat_detect/evaluate/evaluate_models.py index 876c76b..0fc8ae9 100644 --- a/bat_detect/evaluate/evaluate_models.py +++ b/bat_detect/evaluate/evaluate_models.py @@ -44,22 +44,6 @@ def get_blank_annotation(ip_str): return copy.deepcopy(res), copy.deepcopy(ann) -def get_default_bd_args(): - args = {} - args['detection_threshold'] = 0.001 - args['time_expansion_factor'] = 1 - args['audio_dir'] = '' - args['ann_dir'] = '' - args['spec_slices'] = False - args['chunk_size'] = 3 - args['spec_features'] = False - args['cnn_features'] = False - args['quiet'] = True - args['save_preds_if_empty'] = True - args['ann_dir'] = os.path.join(args['ann_dir'], '') - return args - - def create_genus_mapping(gt_test, preds, class_names): # rolls the per class predictions and ground truth back up to genus level class_names_genus, cls_to_genus = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True) @@ -555,7 +539,7 @@ if __name__ == "__main__": # if args['bd_model_path'] != '': # load model - bd_args = get_default_bd_args() + bd_args = du.get_default_bd_args() model, params_bd = du.load_model(args['bd_model_path']) # check if the class names are the same diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index a4e63ad..ebfedd2 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -12,6 +12,22 @@ import bat_detect.detector.post_process as pp import bat_detect.utils.audio_utils as au +def get_default_bd_args(): + args = {} + args['detection_threshold'] = 0.001 + args['time_expansion_factor'] = 1 + args['audio_dir'] = '' + args['ann_dir'] = '' + args['spec_slices'] = False + args['chunk_size'] = 3 + args['spec_features'] = False + args['cnn_features'] = False + args['quiet'] = True + args['save_preds_if_empty'] = True + args['ann_dir'] = os.path.join(args['ann_dir'], '') + return args + + def get_audio_files(ip_dir): matches = [] diff --git a/batdetect2_notebook.ipynb b/batdetect2_notebook.ipynb index b8ae9cc..e7bf4de 100644 --- a/batdetect2_notebook.ipynb +++ b/batdetect2_notebook.ipynb @@ -56,17 +56,10 @@ "outputs": [], "source": [ "# setup the arguments\n", - "args = {}\n", + "args = du.get_default_bd_args()\n", "args['detection_threshold'] = 0.3\n", "args['time_expansion_factor'] = 1\n", - "args['model_path'] = os.path.join('models', os.path.basename(config.MODEL_PATH))\n", - "\n", - "args['cnn_features'] = False\n", - "args['spec_features'] = False\n", - "args['quiet'] = True\n", - "args['save_preds_if_empty'] = False\n", - "args['spec_slices'] = False\n", - "args['chunk_size'] = 3" + "args['model_path'] = os.path.join('models', os.path.basename(config.MODEL_PATH))" ] }, { diff --git a/requirements.txt b/requirements.txt index 427233c..5bb8e16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ matplotlib==3.6.2 numpy==1.23.4 pandas==1.5.2 scikit_learn==1.2.0 +scipy==1.9.3 torch==1.13.0 torchaudio==0.13.0 torchvision==0.14.0 diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..bcc4692 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,17 @@ +This directory contains some scripts for visualizing the raw data and model outputs. + + +`gen_spec_image.py`: saves the model predictions on a spectrogram of the input audio file. +e.g. +`python gen_spec_image.py ../example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav ../models/Net2DFast_UK_same.pth.tar` + + +`gen_spec_video.py`: generates a video showing the model predictions for a file. +e.g. +`python gen_spec_video.py ../example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav ../models/Net2DFast_UK_same.pth.tar` + + + +`gen_dataset_summary_image.py`: generates an image displaying the mean spectrogram for each class in a specified dataset. +e.g. +`python gen_dataset_summary_image.py --ann_file PATH_TO_ANN/australia_TRAIN.json PATH_TO_AUDIO/audio/ ../plots/australia/` diff --git a/scripts/gen_dataset_summary_image.py b/scripts/gen_dataset_summary_image.py new file mode 100644 index 0000000..b789584 --- /dev/null +++ b/scripts/gen_dataset_summary_image.py @@ -0,0 +1,64 @@ +""" +Loads a set of annotations corresponding to a dataset and saves an image which +is the mean spectrogram for each class. +""" + +import matplotlib.pyplot as plt +import numpy as np +import os +import argparse +import sys +import viz_helpers as vz + +sys.path.append(os.path.join('..')) +import bat_detect.train.train_utils as tu +import bat_detect.detector.parameters as parameters +import bat_detect.utils.audio_utils as au +import bat_detect.train.train_split as ts + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('audio_path', type=str, help='Input directory for audio') + parser.add_argument('op_dir', type=str, + help='Path to where single annotation json file is stored') + parser.add_argument('--ann_file', type=str, + help='Path to where single annotation json file is stored') + parser.add_argument('--uk_split', type=str, default='', + help='Set as: diff or same') + parser.add_argument('--file_type', type=str, default='png', + help='Type of image to save png or pdf') + args = vars(parser.parse_args()) + + if not os.path.isdir(args['op_dir']): + os.makedirs(args['op_dir']) + + params = parameters.get_params(False) + params['smooth_spec'] = False + params['spec_width'] = 48 + params['norm_type'] = 'log' # log, pcen + params['aud_pad'] = 0.005 + classes_to_ignore = params['classes_to_ignore'] + params['generic_class'] + + + # load train annotations + if args['uk_split'] == '': + print('\nLoading:', args['ann_file'], '\n') + dataset_name = os.path.basename(args['ann_file']).replace('.json', '') + datasets = [] + datasets.append(tu.get_blank_dataset_dict(dataset_name, False, args['ann_file'], args['audio_path'])) + else: + # load uk data - special case + print('\nLoading:', args['uk_split'], '\n') + dataset_name = 'uk_' + args['uk_split'] # should be uk_diff, or uk_same + datasets, _ = ts.get_train_test_data(args['ann_file'], args['audio_path'], args['uk_split'], load_extra=False) + + anns, class_names, _ = tu.load_set_of_anns(datasets, classes_to_ignore, params['events_of_interest']) + class_names_order = range(len(class_names)) + + x_train, y_train = vz.load_data(anns, params, class_names, smooth_spec=params['smooth_spec'], norm_type=params['norm_type']) + + op_file_name = os.path.join(args['op_dir'], dataset_name + '.' + args['file_type']) + vz.save_summary_image(x_train, y_train, class_names, params, op_file_name, class_names_order) + print('\nImage saved to:', op_file_name) diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py new file mode 100644 index 0000000..182f5bf --- /dev/null +++ b/scripts/gen_spec_image.py @@ -0,0 +1,116 @@ +""" +Visualize predctions on top of spectrogram. + +Will save images with: +1) raw spectrogram +2) spectrogram with GT boxes +3) spectrogram with predicted boxes +""" + +import numpy as np +import sys +import os +import argparse +import matplotlib.pyplot as plt +import json + +sys.path.append(os.path.join('..')) +import bat_detect.evaluate.evaluate_models as evlm +import bat_detect.utils.detector_utils as du +import bat_detect.utils.plot_utils as viz +import bat_detect.utils.audio_utils as au + + +def filter_anns(anns, start_time, stop_time): + anns_op = [] + for aa in anns: + if (aa['start_time'] >= start_time) and (aa['start_time'] < stop_time-0.02): + anns_op.append(aa) + return anns_op + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('audio_file', type=str, help='Path to audio file') + parser.add_argument('model_path', type=str, help='Path to BatDetect model') + parser.add_argument('--ann_file', type=str, default='', help='Path to annotation file') + parser.add_argument('--op_dir', type=str, default='plots/', + help='Output directory for plots') + parser.add_argument('--file_type', type=str, default='png', + help='Type of image to save png or pdf') + parser.add_argument('--title_text', type=str, default='', + help='Text to add as title of plots') + parser.add_argument('--detection_threshold', type=float, default=0.2, + help='Threshold for output detections') + parser.add_argument('--start_time', type=float, default=0.0, + help='Start time for cropped file') + parser.add_argument('--stop_time', type=float, default=0.5, + help='End time for cropped file') + parser.add_argument('--time_exp', type=int, default=1, + help='Time expansion factor') + + args_cmd = vars(parser.parse_args()) + + # load the model + bd_args = du.get_default_bd_args() + model, params_bd = du.load_model(args_cmd['model_path']) + bd_args['detection_threshold'] = args_cmd['detection_threshold'] + bd_args['time_expansion_factor'] = args_cmd['time_exp'] + + # load the annotation if it exists + gt_present = False + if args_cmd['ann_file'] != '': + if os.path.isfile(args_cmd['ann_file']): + with open(args_cmd['ann_file']) as da: + gt_anns = json.load(da) + gt_anns = filter_anns(gt_anns['annotation'], args_cmd['start_time'], args_cmd['stop_time']) + gt_present = True + else: + print('Annotation file not found: ', args_cmd['ann_file']) + + # load the audio file + if not os.path.isfile(args_cmd['audio_file']): + print('Audio file not found: ', args_cmd['audio_file']) + sys.exit() + + # load audio and crop + print('\nProcessing: ' + os.path.basename(args_cmd['audio_file'])) + print('\nOutput directory: ' + args_cmd['op_dir']) + sampling_rate, audio = au.load_audio_file(args_cmd['audio_file'], args_cmd['time_exp'], + params_bd['target_samp_rate'], params_bd['scale_raw_audio']) + st_samp = int(sampling_rate*args_cmd['start_time']) + en_samp = int(sampling_rate*args_cmd['stop_time']) + if en_samp > audio.shape[0]: + audio = np.hstack((audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype))) + audio = audio[st_samp:en_samp] + + duration = audio.shape[0] / sampling_rate + print('File duration: {} seconds'.format(duration)) + + # create spec for viz + spec, _ = au.generate_spectrogram(audio, sampling_rate, params_bd, True, False) + + # run model and filter detections so only keep ones in relevant time range + results = du.process_file(args_cmd['audio_file'], model, params_bd, bd_args) + pred_anns = filter_anns(results['pred_dict']['annotation'], args_cmd['start_time'], args_cmd['stop_time']) + print(len(pred_anns), 'Detections') + + # save output + if not os.path.isdir(args_cmd['op_dir']): + os.makedirs(args_cmd['op_dir']) + + # create output file names + op_path_clean = os.path.basename(args_cmd['audio_file'])[:-4] + '_clean.' + args_cmd['file_type'] + op_path_clean = os.path.join(args_cmd['op_dir'], op_path_clean) + op_path_pred = os.path.basename(args_cmd['audio_file'])[:-4] + '_pred.' + args_cmd['file_type'] + op_path_pred = os.path.join(args_cmd['op_dir'], op_path_pred) + + # create and save iamges + viz.save_ann_spec(op_path_clean, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', None) + viz.save_ann_spec(op_path_pred, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', pred_anns) + + if gt_present: + op_path_gt = os.path.basename(args_cmd['audio_file'])[:-4] + '_gt.' + args_cmd['file_type'] + op_path_gt = os.path.join(args_cmd['op_dir'], op_path_gt) + viz.save_ann_spec(op_path_gt, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', gt_anns) diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py new file mode 100644 index 0000000..25e7319 --- /dev/null +++ b/scripts/gen_spec_video.py @@ -0,0 +1,171 @@ +""" +This script takes an audio file as input, runs the detector, and makes a video output + +Notes: + It needs ffmpeg installed to make the videos + Sometimes conda can overwrite the default ffmpeg path set this to use system one. + Check which one is being used with `which ffmpeg`. If conda version, can thow an error. + Best to use system one - see ffmpeg_path. +""" + +from scipy.io import wavfile +import os +import shutil +import matplotlib.pyplot as plt +import numpy as np +import argparse + +import sys +sys.path.append(os.path.join('..')) +import bat_detect.detector.parameters as parameters +import bat_detect.utils.audio_utils as au +import bat_detect.utils.plot_utils as viz +import bat_detect.utils.detector_utils as du +import config + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('audio_file', type=str, help='Path to input audio file') + parser.add_argument('model_path', type=str, help='Path to trained BatDetect model') + parser.add_argument('--op_dir', type=str, default='generated_vids/', help='Path to output directory') + parser.add_argument('--no_detector', action='store_true', help='Do not run detector') + parser.add_argument('--plot_class_names_off', action='store_true', help='Do not plot class names') + parser.add_argument('--disable_axis', action='store_true', help='Do not plot axis') + parser.add_argument('--detection_threshold', type=float, default=0.2, help='Cut-off probability for detector') + parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor', + help='The time expansion factor used for all files (default is 1)') + args_cmd = vars(parser.parse_args()) + + # file of interest + audio_file = args_cmd['audio_file'] + op_dir = args_cmd['op_dir'] + op_str = '_output' + ffmpeg_path = '/usr/bin/' + + if not os.path.isfile(audio_file): + print('Audio file not found: ', audio_file) + sys.exit() + + if not os.path.isfile(args_cmd['model_path']): + print('Model not found: ', model_path) + sys.exit() + + + start_time = 0.0 + duration = 0.5 + reveal_boxes = True # makes the boxes appear one at a time + fps = 24 + dpi = 100 + + op_dir_tmp = os.path.join(op_dir, 'op_tmp_vids', '') + if not os.path.isdir(op_dir_tmp): + os.makedirs(op_dir_tmp) + if not os.path.isdir(op_dir): + os.makedirs(op_dir) + + params = parameters.get_params(False) + args = du.get_default_bd_args() + args['time_expansion_factor'] = args_cmd['time_expansion_factor'] + args['detection_threshold'] = args_cmd['detection_threshold'] + + + # load audio file + print('\nProcessing: ' + os.path.basename(audio_file)) + print('\nOutput directory: ' + op_dir) + sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate']) + audio = audio[int(sampling_rate*start_time):int(sampling_rate*start_time + sampling_rate*duration)] + audio_orig = audio.copy() + audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], + params['fft_overlap'], params['resize_factor'], + params['spec_divide_factor']) + + # generate spectrogram + spec, _ = au.generate_spectrogram(audio, sampling_rate, params, True) + max_val = spec.max()*1.1 + + + if not args_cmd['no_detector']: + print(' Loading model and running detector on entire file ...') + model, det_params = du.load_model(args_cmd['model_path']) + det_params['detection_threshold'] = args['detection_threshold'] + results = du.process_file(audio_file, model, det_params, args) + + print(' Processing detections and plotting ...') + detections = [] + for bb in results['pred_dict']['annotation']: + if (bb['start_time'] >= start_time) and (bb['end_time'] < start_time+duration): + detections.append(bb) + + # plot boxes + fig = plt.figure(1, figsize=(spec.shape[1]/dpi, spec.shape[0]/dpi), dpi=dpi) + duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) + viz.create_box_image(spec, fig, detections, start_time, start_time+duration, duration, params, max_val, + plot_class_names=not args_cmd['plot_class_names_off']) + op_im_file_boxes = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '_boxes.png') + fig.savefig(op_im_file_boxes, dpi=dpi) + plt.close(1) + spec_with_boxes = plt.imread(op_im_file_boxes) + + + print(' Saving audio file ...') + if args['time_expansion_factor']==1: + sampling_rate_op = int(sampling_rate/10.0) + else: + sampling_rate_op = sampling_rate + op_audio_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.wav') + wavfile.write(op_audio_file, sampling_rate_op, audio_orig) + + + print(' Saving image ...') + op_im_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.png') + plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap='plasma') + spec_blank = plt.imread(op_im_file) + + # create figure + freq_scale = 1000 # turn Hz to kHz + min_freq = params['min_freq']//freq_scale + max_freq = params['max_freq']//freq_scale + y_extent = [0, duration, min_freq, max_freq] + + print(' Saving video frames ...') + # save images that will be combined into video + # will either plot with or without boxes + for ii, col in enumerate(np.linspace(0, spec.shape[1]-1, int(fps*duration*10))): + if not args_cmd['no_detector']: + spec_op = spec_with_boxes.copy() + if ii > 0: + spec_op[:, int(col), :] = 1.0 + if reveal_boxes: + spec_op[:, int(col)+1:, :] = spec_blank[:, int(col)+1:, :] + elif ii == 0 and reveal_boxes: + spec_op = spec_blank + + if not args_cmd['disable_axis']: + plt.close('all') + fig = plt.figure(ii, figsize=(1.2*(spec_op.shape[1]/dpi), 1.5*(spec_op.shape[0]/dpi)), dpi=dpi) + plt.xlabel('Time - seconds') + plt.ylabel('Frequency - kHz') + plt.imshow(spec_op, vmin=0, vmax=1.0, cmap='plasma', extent=y_extent, aspect='auto') + plt.tight_layout() + fig.savefig(op_dir_tmp + str(ii).zfill(4) + '.png', dpi=dpi) + else: + plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=1.0, cmap='plasma') + else: + spec_op = spec.copy() + if ii > 0: + spec_op[:, int(col)] = max_val + plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=max_val, cmap='plasma') + + + print(' Creating video ...') + op_vid_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.avi') + ffmpeg_cmd = 'ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 ' \ + '-crf 25 -pix_fmt yuv420p -acodec copy {}'.format(fps, spec.shape[1], spec.shape[0], op_dir_tmp, op_audio_file, op_vid_file) + ffmpeg_cmd = ffmpeg_path + ffmpeg_cmd + os.system(ffmpeg_cmd) + + print(' Deleting temporary files ...') + if os.path.isdir(op_dir_tmp): + shutil.rmtree(op_dir_tmp) diff --git a/scripts/viz_helpers.py b/scripts/viz_helpers.py new file mode 100644 index 0000000..2f55836 --- /dev/null +++ b/scripts/viz_helpers.py @@ -0,0 +1,142 @@ +import numpy as np +import matplotlib.pyplot as plt +from scipy import ndimage +import os +import sys +sys.path.append(os.path.join('..')) + +import bat_detect.utils.audio_utils as au + + +def generate_spectrogram_data(audio, sampling_rate, params, norm_type='log', smooth_spec=False): + max_freq = round(params['max_freq']*params['fft_win_length']) + min_freq = round(params['min_freq']*params['fft_win_length']) + + # create spectrogram - numpy + spec = au.gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']) + #spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy() + if spec.shape[0] < max_freq: + freq_pad = max_freq - spec.shape[0] + spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec)) + spec = spec[-max_freq:spec.shape[0]-min_freq, :] + + if norm_type == 'log': + log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum()) + ##log_scaling = 0.01 + spec = np.log(1.0 + log_scaling*spec).astype(np.float32) + elif norm_type == 'pcen': + spec = au.pcen(spec, sampling_rate) + else: + pass + + if smooth_spec: + spec = ndimage.gaussian_filter(spec, 1) + + return spec + + +def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', extract_bg=False): + specs = [] + labels = [] + coords = [] + audios = [] + sampling_rates = [] + file_names = [] + for cur_file in anns: + sampling_rate, audio_orig = au.load_audio_file(cur_file['file_path'], cur_file['time_exp'], + params['target_samp_rate'], params['scale_raw_audio']) + + for ann in cur_file['annotation']: + if ann['class'] not in params['classes_to_ignore'] and ann['class'] in class_names: + # clip out of bounds + if ann['low_freq'] < params['min_freq']: + ann['low_freq'] = params['min_freq'] + if ann['high_freq'] > params['max_freq']: + ann['high_freq'] = params['max_freq'] + + # load cropped audio + start_samp_diff = int(sampling_rate*ann['start_time']) - int(sampling_rate*params['aud_pad']) + start_samp = np.maximum(0, start_samp_diff) + end_samp = np.minimum(audio_orig.shape[0], int(sampling_rate*ann['end_time'])*2 + int(sampling_rate*params['aud_pad'])) + audio = audio_orig[start_samp:end_samp] + if start_samp_diff < 0: + # need to pad at start if the call is at the very begining + audio = np.hstack((np.zeros(-start_samp_diff, dtype=np.float32), audio)) + + nfft = int(params['fft_win_length']*sampling_rate) + noverlap = int(params['fft_overlap']*nfft) + max_samps = params['spec_width']*(nfft - noverlap) + noverlap + + if max_samps > audio.shape[0]: + audio = np.hstack((audio, np.zeros(max_samps - audio.shape[0]))) + audio = audio[:max_samps].astype(np.float32) + + audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], + params['fft_overlap'], params['resize_factor'], + params['spec_divide_factor']) + + # generate spectrogram + spec = generate_spectrogram_data(audio, sampling_rate, params, norm_type, smooth_spec)[:, :params['spec_width']] + + specs.append(spec[np.newaxis, ...]) + labels.append(ann['class']) + + audios.append(audio) + sampling_rates.append(sampling_rate) + file_names.append(cur_file['file_path']) + + # position in crop + x1 = int(au.time_to_x_coords(np.array(params['aud_pad']), sampling_rate, params['fft_win_length'], params['fft_overlap'])) + y1 = (ann['low_freq'] - params['min_freq']) * params['fft_win_length'] + coords.append((y1, x1)) + + + _, file_ids = np.unique(file_names, return_inverse=True) + labels = np.array([class_names.index(ll) for ll in labels]) + + #return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names + return np.vstack(specs), labels + + +def save_summary_image(specs, labels, species_names, params, op_file_name='plots/all_species.png', order=None): + # takes the mean for each class and plots it on a grid + mean_specs = [] + max_band = [] + for ii in range(len(species_names)): + inds = np.where(labels==ii)[0] + mu = specs[inds, :].mean(0) + max_band.append(np.argmax(mu.sum(1))) + mean_specs.append(mu) + + # control the order in which classes are printed + if order is None: + order = np.arange(len(species_names)) + + max_cols = 6 + nrows = int(np.ceil(len(species_names)/max_cols)) + ncols = np.minimum(len(species_names), max_cols) + + fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.3, nrows*6), gridspec_kw = {'wspace':0, 'hspace':0.2}) + spec_min_max = (0, mean_specs[0].shape[1], params['min_freq']/1000, params['max_freq']/1000) + ii = 0 + for row in ax: + + if type(row) != np.ndarray: + row = np.array([row]) + + for col in row: + if ii >= len(species_names): + col.axis('off') + else: + inds = np.where(labels==order[ii])[0] + col.imshow(mean_specs[order[ii]], extent=spec_min_max, cmap='plasma', aspect='equal') + col.grid(color='w', alpha=0.3, linewidth=0.3) + col.set_xticks([]) + col.title.set_text(str(ii+1) + ' ' + species_names[order[ii]]) + col.tick_params(axis='both', which='major', labelsize=7) + ii += 1 + + #plt.tight_layout() + #plt.show() + plt.savefig(op_file_name) + plt.close('all')