mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-09 09:40:19 +02:00
236 lines
6.9 KiB
Python
236 lines
6.9 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
from collections import Counter
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
from sklearn.model_selection import StratifiedGroupKFold
|
|
|
|
import batdetect2.train.train_utils as tu
|
|
from batdetect2 import types
|
|
|
|
|
|
def print_dataset_stats(
|
|
data: List[types.FileAnnotation],
|
|
classes_to_ignore: Optional[List[str]] = None,
|
|
) -> Counter[str]:
|
|
print("Num files:", len(data))
|
|
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
|
if len(counts) > 0:
|
|
tu.report_class_counts(counts)
|
|
return counts
|
|
|
|
|
|
def load_file_names(file_name: str) -> List[str]:
|
|
if not os.path.isfile(file_name):
|
|
raise FileNotFoundError(f"Input file not found - {file_name}")
|
|
|
|
with open(file_name) as da:
|
|
files = [line.rstrip() for line in da.readlines()]
|
|
|
|
for path in files:
|
|
if path.lower()[-3:] != "wav":
|
|
raise ValueError(
|
|
f"Invalid file name - {path}. Must be a .wav file"
|
|
)
|
|
|
|
return files
|
|
|
|
|
|
def parse_args():
|
|
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
|
|
print(info_str)
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"dataset_name", type=str, help="Name to call your dataset"
|
|
)
|
|
parser.add_argument(
|
|
"audio_dir", type=str, help="Input directory for audio"
|
|
)
|
|
parser.add_argument(
|
|
"ann_dir",
|
|
type=str,
|
|
help="Input directory for where the audio annotations are stored",
|
|
)
|
|
parser.add_argument(
|
|
"op_dir",
|
|
type=str,
|
|
help="Path where the train and test splits will be stored",
|
|
)
|
|
parser.add_argument(
|
|
"--percent_val",
|
|
type=float,
|
|
default=0.20,
|
|
help="Hold out this much data for validation. Should be number between 0 and 1",
|
|
)
|
|
parser.add_argument(
|
|
"--rand_seed",
|
|
type=int,
|
|
default=2001,
|
|
help="Random seed used for creating the validation split",
|
|
)
|
|
parser.add_argument(
|
|
"--train_file",
|
|
type=str,
|
|
default="",
|
|
help="Text file where each line is a wav file in train split",
|
|
)
|
|
parser.add_argument(
|
|
"--test_file",
|
|
type=str,
|
|
default="",
|
|
help="Text file where each line is a wav file in test split",
|
|
)
|
|
parser.add_argument(
|
|
"--input_class_names",
|
|
type=str,
|
|
default="",
|
|
help='Specify names of classes that you want to change. Separate with ";"',
|
|
)
|
|
parser.add_argument(
|
|
"--output_class_names",
|
|
type=str,
|
|
default="",
|
|
help='New class names to use instead. One to one mapping with "--input_class_names". \
|
|
Separate with ";"',
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def split_data(
|
|
data: List[types.FileAnnotation],
|
|
train_file: str,
|
|
test_file: str,
|
|
n_splits: int = 5,
|
|
random_state: int = 0,
|
|
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
|
|
if train_file != "" and test_file != "":
|
|
# user has specifed the train / test split
|
|
mapping = {
|
|
file_annotation["id"]: file_annotation for file_annotation in data
|
|
}
|
|
train_files = load_file_names(train_file)
|
|
test_files = load_file_names(test_file)
|
|
data_train = [
|
|
mapping[file_id] for file_id in train_files if file_id in mapping
|
|
]
|
|
data_test = [
|
|
mapping[file_id] for file_id in test_files if file_id in mapping
|
|
]
|
|
return data_train, data_test
|
|
|
|
# NOTE: Using StratifiedGroupKFold to ensure that the same file does not
|
|
# appear in both the training and test sets and trying to keep the
|
|
# distribution of classes the same in both sets.
|
|
splitter = StratifiedGroupKFold(
|
|
n_splits=n_splits,
|
|
shuffle=True,
|
|
random_state=random_state,
|
|
)
|
|
anns = np.array(
|
|
[
|
|
[dd["id"], ann["class"], ann["event"]]
|
|
for dd in data
|
|
for ann in dd["annotation"]
|
|
]
|
|
)
|
|
y = anns[:, 1]
|
|
group = anns[:, 0]
|
|
|
|
train_idx, test_idx = next(splitter.split(X=anns, y=y, groups=group))
|
|
train_ids = set(anns[train_idx, 0])
|
|
test_ids = set(anns[test_idx, 0])
|
|
|
|
assert not (train_ids & test_ids)
|
|
data_train = [dd for dd in data if dd["id"] in train_ids]
|
|
data_test = [dd for dd in data if dd["id"] in test_ids]
|
|
return data_train, data_test
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
np.random.seed(args.rand_seed)
|
|
|
|
classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
|
|
events_of_interest = ["Echolocation"]
|
|
|
|
name_dict = None
|
|
if args.input_class_names != "" and args.output_class_names != "":
|
|
# change the names of the classes
|
|
ip_names = args.input_class_names.split(";")
|
|
op_names = args.output_class_names.split(";")
|
|
name_dict = dict(zip(ip_names, op_names))
|
|
|
|
# load annotations
|
|
data_all = tu.load_set_of_anns(
|
|
[
|
|
{
|
|
"dataset_name": args.dataset_name,
|
|
"ann_path": args.ann_dir,
|
|
"wav_path": args.audio_dir,
|
|
"is_test": False,
|
|
"is_binary": False,
|
|
}
|
|
],
|
|
classes_to_ignore=classes_to_ignore,
|
|
events_of_interest=events_of_interest,
|
|
convert_to_genus=False,
|
|
filter_issues=True,
|
|
name_replace=name_dict,
|
|
)
|
|
|
|
print("Dataset name: " + args.dataset_name)
|
|
print("Audio directory: " + args.audio_dir)
|
|
print("Annotation directory: " + args.ann_dir)
|
|
print("Ouput directory: " + args.op_dir)
|
|
print("Num annotated files: " + str(len(data_all)))
|
|
|
|
data_train, data_test = split_data(
|
|
data=data_all,
|
|
train_file=args.train_file,
|
|
test_file=args.test_file,
|
|
n_splits=5,
|
|
random_state=args.rand_seed,
|
|
)
|
|
|
|
if not os.path.isdir(args.op_dir):
|
|
os.makedirs(args.op_dir)
|
|
op_name = os.path.join(args.op_dir, args.dataset_name)
|
|
op_name_train = op_name + "_TRAIN.json"
|
|
op_name_test = op_name + "_TEST.json"
|
|
|
|
print("\nSplit: Train")
|
|
class_un_train = print_dataset_stats(data_train, classes_to_ignore)
|
|
|
|
print("\nSplit: Test")
|
|
class_un_test = print_dataset_stats(data_test, classes_to_ignore)
|
|
|
|
if len(data_train) > 0 and len(data_test) > 0:
|
|
if set(class_un_train.keys()) != set(class_un_test.keys()):
|
|
raise RuntimeError(
|
|
"Error: some classes are not in both the training and test sets."
|
|
'Try a different random seed "--rand_seed".'
|
|
)
|
|
|
|
print("\n")
|
|
if len(data_train) == 0:
|
|
print("No train annotations to save")
|
|
else:
|
|
print("Saving: ", op_name_train)
|
|
with open(op_name_train, "w") as da:
|
|
json.dump(data_train, da, indent=2)
|
|
|
|
if len(data_test) == 0:
|
|
print("No test annotations to save")
|
|
else:
|
|
print("Saving: ", op_name_test)
|
|
with open(op_name_test, "w") as da:
|
|
json.dump(data_test, da, indent=2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|