batdetect2/scripts/gen_dataset_summary_image.py
2025-02-25 11:25:16 +00:00

97 lines
2.8 KiB
Python

"""
Loads a set of annotations corresponding to a dataset and saves an image which
is the mean spectrogram for each class.
"""
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
import viz_helpers as vz
import batdetect2.detector.parameters as parameters
import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu
import batdetect2.utils.audio_utils as au
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"audio_path", type=str, help="Input directory for audio"
)
parser.add_argument(
"op_dir",
type=str,
help="Path to where single annotation json file is stored",
)
parser.add_argument(
"--ann_file",
type=str,
help="Path to where single annotation json file is stored",
)
parser.add_argument(
"--uk_split", type=str, default="", help="Set as: diff or same"
)
parser.add_argument(
"--file_type",
type=str,
default="png",
help="Type of image to save png or pdf",
)
args = vars(parser.parse_args())
if not os.path.isdir(args["op_dir"]):
os.makedirs(args["op_dir"])
params = parameters.get_params(False)
params["smooth_spec"] = False
params["spec_width"] = 48
params["norm_type"] = "log" # log, pcen
params["aud_pad"] = 0.005
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
# load train annotations
if args["uk_split"] == "":
print("\nLoading:", args["ann_file"], "\n")
dataset_name = os.path.basename(args["ann_file"]).replace(".json", "")
datasets = []
datasets.append(
tu.get_blank_dataset_dict(
dataset_name, False, args["ann_file"], args["audio_path"]
)
)
else:
# load uk data - special case
print("\nLoading:", args["uk_split"], "\n")
dataset_name = (
"uk_" + args["uk_split"]
) # should be uk_diff, or uk_same
datasets, _ = ts.get_train_test_data(
args["ann_file"],
args["audio_path"],
args["uk_split"],
load_extra=False,
)
anns, class_names, _ = tu.load_set_of_anns(
datasets, classes_to_ignore, params["events_of_interest"]
)
class_names_order = range(len(class_names))
x_train, y_train = vz.load_data(
anns,
params,
class_names,
smooth_spec=params["smooth_spec"],
norm_type=params["norm_type"],
)
op_file_name = os.path.join(
args["op_dir"], dataset_name + "." + args["file_type"]
)
vz.save_summary_image(
x_train, y_train, class_names, params, op_file_name, class_names_order
)
print("\nImage saved to:", op_file_name)