added viz scripts

This commit is contained in:
macaodha 2022-12-14 18:19:52 +00:00
parent 20218e023c
commit 7b5b2be08f
11 changed files with 541 additions and 34 deletions

View File

@ -21,11 +21,13 @@ You can also run this notebook locally.
### Running the model on your own data
After following the above steps to install the code you can run the model on your own data by opening the command line where the code is located and typing:
`python run_batdetect.py AUDIO_DIR ANN_DIR DETECTION_THRESHOLD`
e.g.
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3`
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes:
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3`
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistake.
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3 --spec_features`
@ -35,11 +37,12 @@ You can also specify which model to use by setting the `--model_path` argument.
### Data and annotations
The raw audio data and annotations used to train the models in the paper will be added soon.
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
### Warning
Note the models developed and shared as part of this repository should be used with caution.
While they have been evaluated on held out audio data, great care should be taken when using the models for any form of biodiversity assessment.
The models developed and shared as part of this repository should be used with caution.
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.

View File

@ -44,22 +44,6 @@ def get_blank_annotation(ip_str):
return copy.deepcopy(res), copy.deepcopy(ann)
def get_default_bd_args():
args = {}
args['detection_threshold'] = 0.001
args['time_expansion_factor'] = 1
args['audio_dir'] = ''
args['ann_dir'] = ''
args['spec_slices'] = False
args['chunk_size'] = 3
args['spec_features'] = False
args['cnn_features'] = False
args['quiet'] = True
args['save_preds_if_empty'] = True
args['ann_dir'] = os.path.join(args['ann_dir'], '')
return args
def create_genus_mapping(gt_test, preds, class_names):
# rolls the per class predictions and ground truth back up to genus level
class_names_genus, cls_to_genus = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True)
@ -555,7 +539,7 @@ if __name__ == "__main__":
#
if args['bd_model_path'] != '':
# load model
bd_args = get_default_bd_args()
bd_args = du.get_default_bd_args()
model, params_bd = du.load_model(args['bd_model_path'])
# check if the class names are the same

View File

@ -12,6 +12,22 @@ import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au
def get_default_bd_args():
args = {}
args['detection_threshold'] = 0.001
args['time_expansion_factor'] = 1
args['audio_dir'] = ''
args['ann_dir'] = ''
args['spec_slices'] = False
args['chunk_size'] = 3
args['spec_features'] = False
args['cnn_features'] = False
args['quiet'] = True
args['save_preds_if_empty'] = True
args['ann_dir'] = os.path.join(args['ann_dir'], '')
return args
def get_audio_files(ip_dir):
matches = []

View File

@ -56,17 +56,10 @@
"outputs": [],
"source": [
"# setup the arguments\n",
"args = {}\n",
"args = du.get_default_bd_args()\n",
"args['detection_threshold'] = 0.3\n",
"args['time_expansion_factor'] = 1\n",
"args['model_path'] = os.path.join('models', os.path.basename(config.MODEL_PATH))\n",
"\n",
"args['cnn_features'] = False\n",
"args['spec_features'] = False\n",
"args['quiet'] = True\n",
"args['save_preds_if_empty'] = False\n",
"args['spec_slices'] = False\n",
"args['chunk_size'] = 3"
"args['model_path'] = os.path.join('models', os.path.basename(config.MODEL_PATH))"
]
},
{

View File

@ -3,6 +3,7 @@ matplotlib==3.6.2
numpy==1.23.4
pandas==1.5.2
scikit_learn==1.2.0
scipy==1.9.3
torch==1.13.0
torchaudio==0.13.0
torchvision==0.14.0

17
scripts/README.md Normal file
View File

@ -0,0 +1,17 @@
This directory contains some scripts for visualizing the raw data and model outputs.
`gen_spec_image.py`: saves the model predictions on a spectrogram of the input audio file.
e.g.
`python gen_spec_image.py ../example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav ../models/Net2DFast_UK_same.pth.tar`
`gen_spec_video.py`: generates a video showing the model predictions for a file.
e.g.
`python gen_spec_video.py ../example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav ../models/Net2DFast_UK_same.pth.tar`
`gen_dataset_summary_image.py`: generates an image displaying the mean spectrogram for each class in a specified dataset.
e.g.
`python gen_dataset_summary_image.py --ann_file PATH_TO_ANN/australia_TRAIN.json PATH_TO_AUDIO/audio/ ../plots/australia/`

View File

@ -0,0 +1,64 @@
"""
Loads a set of annotations corresponding to a dataset and saves an image which
is the mean spectrogram for each class.
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import sys
import viz_helpers as vz
sys.path.append(os.path.join('..'))
import bat_detect.train.train_utils as tu
import bat_detect.detector.parameters as parameters
import bat_detect.utils.audio_utils as au
import bat_detect.train.train_split as ts
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('audio_path', type=str, help='Input directory for audio')
parser.add_argument('op_dir', type=str,
help='Path to where single annotation json file is stored')
parser.add_argument('--ann_file', type=str,
help='Path to where single annotation json file is stored')
parser.add_argument('--uk_split', type=str, default='',
help='Set as: diff or same')
parser.add_argument('--file_type', type=str, default='png',
help='Type of image to save png or pdf')
args = vars(parser.parse_args())
if not os.path.isdir(args['op_dir']):
os.makedirs(args['op_dir'])
params = parameters.get_params(False)
params['smooth_spec'] = False
params['spec_width'] = 48
params['norm_type'] = 'log' # log, pcen
params['aud_pad'] = 0.005
classes_to_ignore = params['classes_to_ignore'] + params['generic_class']
# load train annotations
if args['uk_split'] == '':
print('\nLoading:', args['ann_file'], '\n')
dataset_name = os.path.basename(args['ann_file']).replace('.json', '')
datasets = []
datasets.append(tu.get_blank_dataset_dict(dataset_name, False, args['ann_file'], args['audio_path']))
else:
# load uk data - special case
print('\nLoading:', args['uk_split'], '\n')
dataset_name = 'uk_' + args['uk_split'] # should be uk_diff, or uk_same
datasets, _ = ts.get_train_test_data(args['ann_file'], args['audio_path'], args['uk_split'], load_extra=False)
anns, class_names, _ = tu.load_set_of_anns(datasets, classes_to_ignore, params['events_of_interest'])
class_names_order = range(len(class_names))
x_train, y_train = vz.load_data(anns, params, class_names, smooth_spec=params['smooth_spec'], norm_type=params['norm_type'])
op_file_name = os.path.join(args['op_dir'], dataset_name + '.' + args['file_type'])
vz.save_summary_image(x_train, y_train, class_names, params, op_file_name, class_names_order)
print('\nImage saved to:', op_file_name)

116
scripts/gen_spec_image.py Normal file
View File

@ -0,0 +1,116 @@
"""
Visualize predctions on top of spectrogram.
Will save images with:
1) raw spectrogram
2) spectrogram with GT boxes
3) spectrogram with predicted boxes
"""
import numpy as np
import sys
import os
import argparse
import matplotlib.pyplot as plt
import json
sys.path.append(os.path.join('..'))
import bat_detect.evaluate.evaluate_models as evlm
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz
import bat_detect.utils.audio_utils as au
def filter_anns(anns, start_time, stop_time):
anns_op = []
for aa in anns:
if (aa['start_time'] >= start_time) and (aa['start_time'] < stop_time-0.02):
anns_op.append(aa)
return anns_op
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('audio_file', type=str, help='Path to audio file')
parser.add_argument('model_path', type=str, help='Path to BatDetect model')
parser.add_argument('--ann_file', type=str, default='', help='Path to annotation file')
parser.add_argument('--op_dir', type=str, default='plots/',
help='Output directory for plots')
parser.add_argument('--file_type', type=str, default='png',
help='Type of image to save png or pdf')
parser.add_argument('--title_text', type=str, default='',
help='Text to add as title of plots')
parser.add_argument('--detection_threshold', type=float, default=0.2,
help='Threshold for output detections')
parser.add_argument('--start_time', type=float, default=0.0,
help='Start time for cropped file')
parser.add_argument('--stop_time', type=float, default=0.5,
help='End time for cropped file')
parser.add_argument('--time_exp', type=int, default=1,
help='Time expansion factor')
args_cmd = vars(parser.parse_args())
# load the model
bd_args = du.get_default_bd_args()
model, params_bd = du.load_model(args_cmd['model_path'])
bd_args['detection_threshold'] = args_cmd['detection_threshold']
bd_args['time_expansion_factor'] = args_cmd['time_exp']
# load the annotation if it exists
gt_present = False
if args_cmd['ann_file'] != '':
if os.path.isfile(args_cmd['ann_file']):
with open(args_cmd['ann_file']) as da:
gt_anns = json.load(da)
gt_anns = filter_anns(gt_anns['annotation'], args_cmd['start_time'], args_cmd['stop_time'])
gt_present = True
else:
print('Annotation file not found: ', args_cmd['ann_file'])
# load the audio file
if not os.path.isfile(args_cmd['audio_file']):
print('Audio file not found: ', args_cmd['audio_file'])
sys.exit()
# load audio and crop
print('\nProcessing: ' + os.path.basename(args_cmd['audio_file']))
print('\nOutput directory: ' + args_cmd['op_dir'])
sampling_rate, audio = au.load_audio_file(args_cmd['audio_file'], args_cmd['time_exp'],
params_bd['target_samp_rate'], params_bd['scale_raw_audio'])
st_samp = int(sampling_rate*args_cmd['start_time'])
en_samp = int(sampling_rate*args_cmd['stop_time'])
if en_samp > audio.shape[0]:
audio = np.hstack((audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype)))
audio = audio[st_samp:en_samp]
duration = audio.shape[0] / sampling_rate
print('File duration: {} seconds'.format(duration))
# create spec for viz
spec, _ = au.generate_spectrogram(audio, sampling_rate, params_bd, True, False)
# run model and filter detections so only keep ones in relevant time range
results = du.process_file(args_cmd['audio_file'], model, params_bd, bd_args)
pred_anns = filter_anns(results['pred_dict']['annotation'], args_cmd['start_time'], args_cmd['stop_time'])
print(len(pred_anns), 'Detections')
# save output
if not os.path.isdir(args_cmd['op_dir']):
os.makedirs(args_cmd['op_dir'])
# create output file names
op_path_clean = os.path.basename(args_cmd['audio_file'])[:-4] + '_clean.' + args_cmd['file_type']
op_path_clean = os.path.join(args_cmd['op_dir'], op_path_clean)
op_path_pred = os.path.basename(args_cmd['audio_file'])[:-4] + '_pred.' + args_cmd['file_type']
op_path_pred = os.path.join(args_cmd['op_dir'], op_path_pred)
# create and save iamges
viz.save_ann_spec(op_path_clean, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', None)
viz.save_ann_spec(op_path_pred, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', pred_anns)
if gt_present:
op_path_gt = os.path.basename(args_cmd['audio_file'])[:-4] + '_gt.' + args_cmd['file_type']
op_path_gt = os.path.join(args_cmd['op_dir'], op_path_gt)
viz.save_ann_spec(op_path_gt, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', gt_anns)

171
scripts/gen_spec_video.py Normal file
View File

@ -0,0 +1,171 @@
"""
This script takes an audio file as input, runs the detector, and makes a video output
Notes:
It needs ffmpeg installed to make the videos
Sometimes conda can overwrite the default ffmpeg path set this to use system one.
Check which one is being used with `which ffmpeg`. If conda version, can thow an error.
Best to use system one - see ffmpeg_path.
"""
from scipy.io import wavfile
import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
import argparse
import sys
sys.path.append(os.path.join('..'))
import bat_detect.detector.parameters as parameters
import bat_detect.utils.audio_utils as au
import bat_detect.utils.plot_utils as viz
import bat_detect.utils.detector_utils as du
import config
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('audio_file', type=str, help='Path to input audio file')
parser.add_argument('model_path', type=str, help='Path to trained BatDetect model')
parser.add_argument('--op_dir', type=str, default='generated_vids/', help='Path to output directory')
parser.add_argument('--no_detector', action='store_true', help='Do not run detector')
parser.add_argument('--plot_class_names_off', action='store_true', help='Do not plot class names')
parser.add_argument('--disable_axis', action='store_true', help='Do not plot axis')
parser.add_argument('--detection_threshold', type=float, default=0.2, help='Cut-off probability for detector')
parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor',
help='The time expansion factor used for all files (default is 1)')
args_cmd = vars(parser.parse_args())
# file of interest
audio_file = args_cmd['audio_file']
op_dir = args_cmd['op_dir']
op_str = '_output'
ffmpeg_path = '/usr/bin/'
if not os.path.isfile(audio_file):
print('Audio file not found: ', audio_file)
sys.exit()
if not os.path.isfile(args_cmd['model_path']):
print('Model not found: ', model_path)
sys.exit()
start_time = 0.0
duration = 0.5
reveal_boxes = True # makes the boxes appear one at a time
fps = 24
dpi = 100
op_dir_tmp = os.path.join(op_dir, 'op_tmp_vids', '')
if not os.path.isdir(op_dir_tmp):
os.makedirs(op_dir_tmp)
if not os.path.isdir(op_dir):
os.makedirs(op_dir)
params = parameters.get_params(False)
args = du.get_default_bd_args()
args['time_expansion_factor'] = args_cmd['time_expansion_factor']
args['detection_threshold'] = args_cmd['detection_threshold']
# load audio file
print('\nProcessing: ' + os.path.basename(audio_file))
print('\nOutput directory: ' + op_dir)
sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'])
audio = audio[int(sampling_rate*start_time):int(sampling_rate*start_time + sampling_rate*duration)]
audio_orig = audio.copy()
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
params['fft_overlap'], params['resize_factor'],
params['spec_divide_factor'])
# generate spectrogram
spec, _ = au.generate_spectrogram(audio, sampling_rate, params, True)
max_val = spec.max()*1.1
if not args_cmd['no_detector']:
print(' Loading model and running detector on entire file ...')
model, det_params = du.load_model(args_cmd['model_path'])
det_params['detection_threshold'] = args['detection_threshold']
results = du.process_file(audio_file, model, det_params, args)
print(' Processing detections and plotting ...')
detections = []
for bb in results['pred_dict']['annotation']:
if (bb['start_time'] >= start_time) and (bb['end_time'] < start_time+duration):
detections.append(bb)
# plot boxes
fig = plt.figure(1, figsize=(spec.shape[1]/dpi, spec.shape[0]/dpi), dpi=dpi)
duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
viz.create_box_image(spec, fig, detections, start_time, start_time+duration, duration, params, max_val,
plot_class_names=not args_cmd['plot_class_names_off'])
op_im_file_boxes = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '_boxes.png')
fig.savefig(op_im_file_boxes, dpi=dpi)
plt.close(1)
spec_with_boxes = plt.imread(op_im_file_boxes)
print(' Saving audio file ...')
if args['time_expansion_factor']==1:
sampling_rate_op = int(sampling_rate/10.0)
else:
sampling_rate_op = sampling_rate
op_audio_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.wav')
wavfile.write(op_audio_file, sampling_rate_op, audio_orig)
print(' Saving image ...')
op_im_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.png')
plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap='plasma')
spec_blank = plt.imread(op_im_file)
# create figure
freq_scale = 1000 # turn Hz to kHz
min_freq = params['min_freq']//freq_scale
max_freq = params['max_freq']//freq_scale
y_extent = [0, duration, min_freq, max_freq]
print(' Saving video frames ...')
# save images that will be combined into video
# will either plot with or without boxes
for ii, col in enumerate(np.linspace(0, spec.shape[1]-1, int(fps*duration*10))):
if not args_cmd['no_detector']:
spec_op = spec_with_boxes.copy()
if ii > 0:
spec_op[:, int(col), :] = 1.0
if reveal_boxes:
spec_op[:, int(col)+1:, :] = spec_blank[:, int(col)+1:, :]
elif ii == 0 and reveal_boxes:
spec_op = spec_blank
if not args_cmd['disable_axis']:
plt.close('all')
fig = plt.figure(ii, figsize=(1.2*(spec_op.shape[1]/dpi), 1.5*(spec_op.shape[0]/dpi)), dpi=dpi)
plt.xlabel('Time - seconds')
plt.ylabel('Frequency - kHz')
plt.imshow(spec_op, vmin=0, vmax=1.0, cmap='plasma', extent=y_extent, aspect='auto')
plt.tight_layout()
fig.savefig(op_dir_tmp + str(ii).zfill(4) + '.png', dpi=dpi)
else:
plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=1.0, cmap='plasma')
else:
spec_op = spec.copy()
if ii > 0:
spec_op[:, int(col)] = max_val
plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=max_val, cmap='plasma')
print(' Creating video ...')
op_vid_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.avi')
ffmpeg_cmd = 'ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 ' \
'-crf 25 -pix_fmt yuv420p -acodec copy {}'.format(fps, spec.shape[1], spec.shape[0], op_dir_tmp, op_audio_file, op_vid_file)
ffmpeg_cmd = ffmpeg_path + ffmpeg_cmd
os.system(ffmpeg_cmd)
print(' Deleting temporary files ...')
if os.path.isdir(op_dir_tmp):
shutil.rmtree(op_dir_tmp)

142
scripts/viz_helpers.py Normal file
View File

@ -0,0 +1,142 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import os
import sys
sys.path.append(os.path.join('..'))
import bat_detect.utils.audio_utils as au
def generate_spectrogram_data(audio, sampling_rate, params, norm_type='log', smooth_spec=False):
max_freq = round(params['max_freq']*params['fft_win_length'])
min_freq = round(params['min_freq']*params['fft_win_length'])
# create spectrogram - numpy
spec = au.gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap'])
#spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy()
if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0]
spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec))
spec = spec[-max_freq:spec.shape[0]-min_freq, :]
if norm_type == 'log':
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
##log_scaling = 0.01
spec = np.log(1.0 + log_scaling*spec).astype(np.float32)
elif norm_type == 'pcen':
spec = au.pcen(spec, sampling_rate)
else:
pass
if smooth_spec:
spec = ndimage.gaussian_filter(spec, 1)
return spec
def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', extract_bg=False):
specs = []
labels = []
coords = []
audios = []
sampling_rates = []
file_names = []
for cur_file in anns:
sampling_rate, audio_orig = au.load_audio_file(cur_file['file_path'], cur_file['time_exp'],
params['target_samp_rate'], params['scale_raw_audio'])
for ann in cur_file['annotation']:
if ann['class'] not in params['classes_to_ignore'] and ann['class'] in class_names:
# clip out of bounds
if ann['low_freq'] < params['min_freq']:
ann['low_freq'] = params['min_freq']
if ann['high_freq'] > params['max_freq']:
ann['high_freq'] = params['max_freq']
# load cropped audio
start_samp_diff = int(sampling_rate*ann['start_time']) - int(sampling_rate*params['aud_pad'])
start_samp = np.maximum(0, start_samp_diff)
end_samp = np.minimum(audio_orig.shape[0], int(sampling_rate*ann['end_time'])*2 + int(sampling_rate*params['aud_pad']))
audio = audio_orig[start_samp:end_samp]
if start_samp_diff < 0:
# need to pad at start if the call is at the very begining
audio = np.hstack((np.zeros(-start_samp_diff, dtype=np.float32), audio))
nfft = int(params['fft_win_length']*sampling_rate)
noverlap = int(params['fft_overlap']*nfft)
max_samps = params['spec_width']*(nfft - noverlap) + noverlap
if max_samps > audio.shape[0]:
audio = np.hstack((audio, np.zeros(max_samps - audio.shape[0])))
audio = audio[:max_samps].astype(np.float32)
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
params['fft_overlap'], params['resize_factor'],
params['spec_divide_factor'])
# generate spectrogram
spec = generate_spectrogram_data(audio, sampling_rate, params, norm_type, smooth_spec)[:, :params['spec_width']]
specs.append(spec[np.newaxis, ...])
labels.append(ann['class'])
audios.append(audio)
sampling_rates.append(sampling_rate)
file_names.append(cur_file['file_path'])
# position in crop
x1 = int(au.time_to_x_coords(np.array(params['aud_pad']), sampling_rate, params['fft_win_length'], params['fft_overlap']))
y1 = (ann['low_freq'] - params['min_freq']) * params['fft_win_length']
coords.append((y1, x1))
_, file_ids = np.unique(file_names, return_inverse=True)
labels = np.array([class_names.index(ll) for ll in labels])
#return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names
return np.vstack(specs), labels
def save_summary_image(specs, labels, species_names, params, op_file_name='plots/all_species.png', order=None):
# takes the mean for each class and plots it on a grid
mean_specs = []
max_band = []
for ii in range(len(species_names)):
inds = np.where(labels==ii)[0]
mu = specs[inds, :].mean(0)
max_band.append(np.argmax(mu.sum(1)))
mean_specs.append(mu)
# control the order in which classes are printed
if order is None:
order = np.arange(len(species_names))
max_cols = 6
nrows = int(np.ceil(len(species_names)/max_cols))
ncols = np.minimum(len(species_names), max_cols)
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.3, nrows*6), gridspec_kw = {'wspace':0, 'hspace':0.2})
spec_min_max = (0, mean_specs[0].shape[1], params['min_freq']/1000, params['max_freq']/1000)
ii = 0
for row in ax:
if type(row) != np.ndarray:
row = np.array([row])
for col in row:
if ii >= len(species_names):
col.axis('off')
else:
inds = np.where(labels==order[ii])[0]
col.imshow(mean_specs[order[ii]], extent=spec_min_max, cmap='plasma', aspect='equal')
col.grid(color='w', alpha=0.3, linewidth=0.3)
col.set_xticks([])
col.title.set_text(str(ii+1) + ' ' + species_names[order[ii]])
col.tick_params(axis='both', which='major', labelsize=7)
ii += 1
#plt.tight_layout()
#plt.show()
plt.savefig(op_file_name)
plt.close('all')