mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Fixed errors with command
This commit is contained in:
parent
e6a6ad4696
commit
40222d8233
@ -71,7 +71,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default=os.path.join(CURRENT_DIR, "models/Net2DFast_UK_same.pth.tar"),
|
||||
default=du.DEFAULT_MODEL_PATH,
|
||||
help="Path to trained BatDetect2 model",
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
@ -97,8 +97,11 @@ def main():
|
||||
print(f"Number of audio files: {len(files)}")
|
||||
print("\nSaving results to: " + args["ann_dir"])
|
||||
|
||||
default_config = du.get_default_config()
|
||||
|
||||
# set up run config
|
||||
run_config = {
|
||||
**default_config,
|
||||
**args,
|
||||
**params,
|
||||
}
|
||||
@ -119,6 +122,7 @@ def main():
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
error_files.append(audio_file)
|
||||
print(f"Error processing file!: {err}")
|
||||
raise err
|
||||
|
||||
print("\nResults saved to: " + args["ann_dir"])
|
||||
|
||||
|
@ -13,6 +13,9 @@ SCALE_RAW_AUDIO = False
|
||||
DETECTION_THRESHOLD = 0.01
|
||||
NMS_KERNEL_SIZE = 9
|
||||
NMS_TOP_K_PER_SEC = 200
|
||||
SPEC_SCALE = "pcen"
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
def mk_dir(path):
|
||||
@ -70,14 +73,14 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
||||
# spec processing params
|
||||
params[
|
||||
"denoise_spec_avg"
|
||||
] = True # removes the mean for each frequency band
|
||||
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
|
||||
params[
|
||||
"scale_raw_audio"
|
||||
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
|
||||
params[
|
||||
"max_scale_spec"
|
||||
] = False # scales the spectrogram so that it is max 1
|
||||
params["spec_scale"] = "pcen" # 'log', 'pcen', 'none'
|
||||
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
|
||||
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
|
||||
|
||||
# detection params
|
||||
params[
|
||||
|
@ -12,10 +12,12 @@ import bat_detect.detector.post_process as pp
|
||||
import bat_detect.utils.audio_utils as au
|
||||
from bat_detect.detector import models
|
||||
from bat_detect.detector.parameters import (
|
||||
DENOISE_SPEC_AVG,
|
||||
DETECTION_THRESHOLD,
|
||||
FFT_OVERLAP,
|
||||
FFT_WIN_LENGTH_S,
|
||||
MAX_FREQ_HZ,
|
||||
MAX_SCALE_SPEC,
|
||||
MIN_FREQ_HZ,
|
||||
NMS_KERNEL_SIZE,
|
||||
NMS_TOP_K_PER_SEC,
|
||||
@ -23,6 +25,7 @@ from bat_detect.detector.parameters import (
|
||||
SCALE_RAW_AUDIO,
|
||||
SPEC_DIVIDE_FACTOR,
|
||||
SPEC_HEIGHT,
|
||||
SPEC_SCALE,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
)
|
||||
|
||||
@ -35,12 +38,13 @@ except ImportError:
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"model.pth",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"get_audio_files",
|
||||
"get_default_config",
|
||||
"format_results",
|
||||
"save_results_to_file",
|
||||
"iterate_over_chunks",
|
||||
@ -313,7 +317,7 @@ def format_results(
|
||||
annotations: List[Annotation] = [
|
||||
{
|
||||
"start_time": round(float(start_time), 4),
|
||||
"end_time": round(end_time, 4),
|
||||
"end_time": round(float(end_time), 4),
|
||||
"low_freq": int(low_freq),
|
||||
"high_freq": int(high_freq),
|
||||
"class": str(class_names[class_index]),
|
||||
@ -331,7 +335,7 @@ def format_results(
|
||||
class_prob,
|
||||
det_prob,
|
||||
) in zip(
|
||||
predictions["start_time"],
|
||||
predictions["start_times"],
|
||||
predictions["end_times"],
|
||||
predictions["low_freqs"],
|
||||
predictions["high_freqs"],
|
||||
@ -347,7 +351,7 @@ def format_results(
|
||||
"issues": False,
|
||||
"notes": "Automatically generated.",
|
||||
"time_exp": time_exp,
|
||||
"duration": round(duration, 4),
|
||||
"duration": round(float(duration), 4),
|
||||
"annotation": annotations,
|
||||
"class_name": class_names[np.argmax(class_overall)],
|
||||
}
|
||||
@ -458,7 +462,8 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
if "spec_feats" in results.keys():
|
||||
# create csv file with spectrogram features
|
||||
spec_feats_df = pd.DataFrame(
|
||||
results["spec_feats"], columns=results["spec_feat_names"]
|
||||
results["spec_feats"],
|
||||
columns=results["spec_feat_names"],
|
||||
)
|
||||
spec_feats_df.to_csv(
|
||||
op_path + "_spec_features.csv",
|
||||
@ -506,6 +511,21 @@ class SpectrogramParameters(TypedDict):
|
||||
device: torch.device
|
||||
"""Device to store the spectrogram on."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to display in the spectrogram."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to display in the spectrogram."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
audio: np.ndarray,
|
||||
@ -640,6 +660,15 @@ class ProcessingConfiguration(TypedDict):
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
scale_raw_audio: bool
|
||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||
|
||||
@ -735,6 +764,7 @@ def process_spectrogram(
|
||||
"resize_factor": config["resize_factor"],
|
||||
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
||||
"detection_threshold": config["detection_threshold"],
|
||||
"max_scale_spec": config["max_scale_spec"],
|
||||
},
|
||||
np.array([float(samplerate)]),
|
||||
)
|
||||
@ -788,6 +818,11 @@ def process_audio_array(
|
||||
"resize_factor": config["resize_factor"],
|
||||
"spec_divide_factor": config["spec_divide_factor"],
|
||||
"device": config["device"],
|
||||
"max_freq": config["max_freq"],
|
||||
"min_freq": config["min_freq"],
|
||||
"spec_scale": config["spec_scale"],
|
||||
"denoise_spec_avg": config["denoise_spec_avg"],
|
||||
"max_scale_spec": config["max_scale_spec"],
|
||||
},
|
||||
return_np=config["spec_features"] or config["spec_slices"],
|
||||
)
|
||||
@ -842,7 +877,7 @@ def process_file(
|
||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
scale=config["scale_raw_audio"],
|
||||
max_duration=config["max_duration"],
|
||||
max_duration=config.get("max_duration"),
|
||||
)
|
||||
|
||||
# loop through larger file and split into chunks
|
||||
@ -930,7 +965,7 @@ def summarize_results(results, predictions, config):
|
||||
)
|
||||
|
||||
|
||||
def get_default_run_config(**kwargs) -> ProcessingConfiguration:
|
||||
def get_default_config(**kwargs) -> ProcessingConfiguration:
|
||||
"""Get default configuration for running detection model."""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@ -958,6 +993,9 @@ def get_default_run_config(**kwargs) -> ProcessingConfiguration:
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
return {
|
||||
**args,
|
||||
|
113
run_batdetect.py
113
run_batdetect.py
@ -1,112 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import bat_detect.utils.detector_utils as du
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Loading model: " + args["model_path"])
|
||||
model, params = du.load_model(args["model_path"])
|
||||
|
||||
print("\nInput directory: " + args["audio_dir"])
|
||||
files = du.get_audio_files(args["audio_dir"])
|
||||
print("Number of audio files: {}".format(len(files)))
|
||||
print("\nSaving results to: " + args["ann_dir"])
|
||||
|
||||
# process files
|
||||
error_files = []
|
||||
for ii, audio_file in enumerate(files):
|
||||
print("\n" + str(ii).ljust(6) + os.path.basename(audio_file))
|
||||
try:
|
||||
results = du.process_file(audio_file, model, params, args)
|
||||
if args["save_preds_if_empty"] or (
|
||||
len(results["pred_dict"]["annotation"]) > 0
|
||||
):
|
||||
results_path = audio_file.replace(
|
||||
args["audio_dir"], args["ann_dir"]
|
||||
)
|
||||
du.save_results_to_file(results, results_path)
|
||||
except:
|
||||
error_files.append(audio_file)
|
||||
print("Error processing file!")
|
||||
|
||||
print("\nResults saved to: " + args["ann_dir"])
|
||||
|
||||
if len(error_files) > 0:
|
||||
print("\nUnable to process the follow files:")
|
||||
for err in error_files:
|
||||
print(" " + err)
|
||||
|
||||
"""Run bat_detect.command.main() from the command line."""
|
||||
from bat_detect.command import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
info_str = (
|
||||
"\nBatDetect2 - Detection and Classification\n"
|
||||
+ " Assumes audio files are mono, not stereo.\n"
|
||||
+ ' Spaces in the input paths will throw an error. Wrap in quotes "".\n'
|
||||
+ " Input files should be short in duration e.g. < 30 seconds.\n"
|
||||
)
|
||||
|
||||
print(info_str)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
||||
parser.add_argument(
|
||||
"ann_dir",
|
||||
type=str,
|
||||
help="Output directory for where the predictions will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"detection_threshold",
|
||||
type=float,
|
||||
help="Cut-off probability for detector e.g. 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cnn_features",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="cnn_features",
|
||||
help="Extracts CNN call features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spec_features",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="spec_features",
|
||||
help="Extracts low level call features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_expansion_factor",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="time_expansion_factor",
|
||||
help="The time expansion factor used for all files (default is 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="quiet",
|
||||
help="Minimize output printing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_preds_if_empty",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="save_preds_if_empty",
|
||||
help="Save empty annotation file if no detections made.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="models/Net2DFast_UK_same.pth.tar",
|
||||
help="Path to trained BatDetect2 model",
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["spec_slices"] = False # used for visualization
|
||||
args[
|
||||
"chunk_size"
|
||||
] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks
|
||||
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
||||
|
||||
main(args)
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user