batdetect2/scripts/viz_helpers.py
2023-02-02 19:13:00 +00:00

216 lines
6.8 KiB
Python

import os
import sys
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage
sys.path.append(os.path.join(".."))
import bat_detect.utils.audio_utils as au
def generate_spectrogram_data(
audio, sampling_rate, params, norm_type="log", smooth_spec=False
):
max_freq = round(params["max_freq"] * params["fft_win_length"])
min_freq = round(params["min_freq"] * params["fft_win_length"])
# create spectrogram - numpy
spec = au.gen_mag_spectrogram(
audio, sampling_rate, params["fft_win_length"], params["fft_overlap"]
)
# spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy()
if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0]
spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec))
spec = spec[-max_freq : spec.shape[0] - min_freq, :]
if norm_type == "log":
log_scaling = (
2.0
* (1.0 / sampling_rate)
* (
1.0
/ (
np.abs(np.hanning(int(params["fft_win_length"] * sampling_rate)))
** 2
).sum()
)
)
##log_scaling = 0.01
spec = np.log(1.0 + log_scaling * spec).astype(np.float32)
elif norm_type == "pcen":
spec = au.pcen(spec, sampling_rate)
else:
pass
if smooth_spec:
spec = ndimage.gaussian_filter(spec, 1)
return spec
def load_data(
anns,
params,
class_names,
smooth_spec=False,
norm_type="log",
extract_bg=False,
):
specs = []
labels = []
coords = []
audios = []
sampling_rates = []
file_names = []
for cur_file in anns:
sampling_rate, audio_orig = au.load_audio_file(
cur_file["file_path"],
cur_file["time_exp"],
params["target_samp_rate"],
params["scale_raw_audio"],
)
for ann in cur_file["annotation"]:
if (
ann["class"] not in params["classes_to_ignore"]
and ann["class"] in class_names
):
# clip out of bounds
if ann["low_freq"] < params["min_freq"]:
ann["low_freq"] = params["min_freq"]
if ann["high_freq"] > params["max_freq"]:
ann["high_freq"] = params["max_freq"]
# load cropped audio
start_samp_diff = int(sampling_rate * ann["start_time"]) - int(
sampling_rate * params["aud_pad"]
)
start_samp = np.maximum(0, start_samp_diff)
end_samp = np.minimum(
audio_orig.shape[0],
int(sampling_rate * ann["end_time"]) * 2
+ int(sampling_rate * params["aud_pad"]),
)
audio = audio_orig[start_samp:end_samp]
if start_samp_diff < 0:
# need to pad at start if the call is at the very begining
audio = np.hstack(
(np.zeros(-start_samp_diff, dtype=np.float32), audio)
)
nfft = int(params["fft_win_length"] * sampling_rate)
noverlap = int(params["fft_overlap"] * nfft)
max_samps = params["spec_width"] * (nfft - noverlap) + noverlap
if max_samps > audio.shape[0]:
audio = np.hstack((audio, np.zeros(max_samps - audio.shape[0])))
audio = audio[:max_samps].astype(np.float32)
audio = au.pad_audio(
audio,
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
params["resize_factor"],
params["spec_divide_factor"],
)
# generate spectrogram
spec = generate_spectrogram_data(
audio, sampling_rate, params, norm_type, smooth_spec
)[:, : params["spec_width"]]
specs.append(spec[np.newaxis, ...])
labels.append(ann["class"])
audios.append(audio)
sampling_rates.append(sampling_rate)
file_names.append(cur_file["file_path"])
# position in crop
x1 = int(
au.time_to_x_coords(
np.array(params["aud_pad"]),
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
)
y1 = (ann["low_freq"] - params["min_freq"]) * params["fft_win_length"]
coords.append((y1, x1))
_, file_ids = np.unique(file_names, return_inverse=True)
labels = np.array([class_names.index(ll) for ll in labels])
# return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names
return np.vstack(specs), labels
def save_summary_image(
specs,
labels,
species_names,
params,
op_file_name="plots/all_species.png",
order=None,
):
# takes the mean for each class and plots it on a grid
mean_specs = []
max_band = []
for ii in range(len(species_names)):
inds = np.where(labels == ii)[0]
mu = specs[inds, :].mean(0)
max_band.append(np.argmax(mu.sum(1)))
mean_specs.append(mu)
# control the order in which classes are printed
if order is None:
order = np.arange(len(species_names))
max_cols = 6
nrows = int(np.ceil(len(species_names) / max_cols))
ncols = np.minimum(len(species_names), max_cols)
fig, ax = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(ncols * 3.3, nrows * 6),
gridspec_kw={"wspace": 0, "hspace": 0.2},
)
spec_min_max = (
0,
mean_specs[0].shape[1],
params["min_freq"] / 1000,
params["max_freq"] / 1000,
)
ii = 0
for row in ax:
if type(row) != np.ndarray:
row = np.array([row])
for col in row:
if ii >= len(species_names):
col.axis("off")
else:
inds = np.where(labels == order[ii])[0]
col.imshow(
mean_specs[order[ii]],
extent=spec_min_max,
cmap="plasma",
aspect="equal",
)
col.grid(color="w", alpha=0.3, linewidth=0.3)
col.set_xticks([])
col.title.set_text(str(ii + 1) + " " + species_names[order[ii]])
col.tick_params(axis="both", which="major", labelsize=7)
ii += 1
# plt.tight_layout()
# plt.show()
plt.savefig(op_file_name)
plt.close("all")