Finished refactoring detector_utils

This commit is contained in:
Santiago Martinez 2023-02-22 22:45:26 +00:00
parent 8da98b5258
commit e6a6ad4696
13 changed files with 334 additions and 115 deletions

5
.pylintrc Normal file
View File

@ -0,0 +1,5 @@
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=torch.*

2
app.py
View File

@ -9,7 +9,7 @@ import bat_detect.utils.plot_utils as viz
# setup the arguments
args = {}
args = du.get_default_bd_args()
args = du.get_default_run_config()
args["detection_threshold"] = 0.3
args["time_expansion_factor"] = 1
args["model_path"] = "models/Net2DFast_UK_same.pth.tar"

View File

@ -1,7 +1,18 @@
import datetime
import os
import numpy as np
TARGET_SAMPLERATE_HZ = 256000
FFT_WIN_LENGTH_S = 512 / 256000.0
FFT_OVERLAP = 0.75
MAX_FREQ_HZ = 120000
MIN_FREQ_HZ = 10000
RESIZE_FACTOR = 0.5
SPEC_DIVIDE_FACTOR = 32
SPEC_HEIGHT = 256
SCALE_RAW_AUDIO = False
DETECTION_THRESHOLD = 0.01
NMS_KERNEL_SIZE = 9
NMS_TOP_K_PER_SEC = 200
def mk_dir(path):
@ -30,35 +41,39 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
# spec parameters
params[
"target_samp_rate"
] = 256000 # resamples all audio so that it is at this rate
params["fft_win_length"] = (
512 / 256000.0
) # in milliseconds, amount of time per stft time step
params["fft_overlap"] = 0.75 # stft window overlap
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
params[
"fft_win_length"
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
params[
"max_freq"
] = 120000 # in Hz, everything above this will be discarded
params["min_freq"] = 10000 # in Hz, everything below this will be discarded
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
params[
"min_freq"
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
params[
"resize_factor"
] = 0.5 # resize so the spectrogram at the input of the network
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
params[
"spec_height"
] = 256 # units are number of frequency bins (before resizing is performed)
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
params[
"spec_train_width"
] = 512 # units are number of time steps (before resizing is performed)
params[
"spec_divide_factor"
] = 32 # spectrogram should be divisible by this amount in width and height
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
# spec processing params
params[
"denoise_spec_avg"
] = True # removes the mean for each frequency band
params["scale_raw_audio"] = False # scales the raw audio to [-1, 1]
params[
"scale_raw_audio"
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
params[
"max_scale_spec"
] = False # scales the spectrogram so that it is max 1
@ -73,11 +88,13 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
params[
"detection_threshold"
] = 0.01 # the smaller this is the better the recall will be
params["nms_kernel_size"] = 9
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
params[
"nms_kernel_size"
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
params[
"nms_top_k_per_sec"
] = 200 # keep top K highest predictions per second of audio
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
params["target_sigma"] = 2.0
# augmentation params

View File

@ -1,3 +1,6 @@
"""Post-processing of the output of the model."""
from typing import List, Optional, Tuple
import numpy as np
import torch
from torch import nn
@ -10,11 +13,26 @@ except ImportError:
np.seterr(divide="ignore", invalid="ignore")
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
def x_coords_to_time(
x_pos: float,
sampling_rate: int,
fft_win_length: float,
fft_overlap: float,
) -> float:
"""Convert x coordinates of spectrogram to time in seconds.
Args:
x_pos: X position of the detection in pixels.
sampling_rate: Sampling rate of the audio in Hz.
fft_win_length: Length of the FFT window in seconds.
fft_overlap: Overlap of the FFT windows in seconds.
Returns:
Time in seconds.
"""
nfft = int(fft_win_length * sampling_rate)
noverlap = int(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def overall_class_pred(det_prob, class_prob):
@ -28,10 +46,10 @@ class NonMaximumSuppressionConfig(TypedDict):
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: float
max_freq: int
"""Maximum frequency to consider in Hz."""
min_freq: float
min_freq: int
"""Minimum frequency to consider in Hz."""
fft_win_length: float
@ -40,6 +58,9 @@ class NonMaximumSuppressionConfig(TypedDict):
fft_overlap: float
"""Overlap of the FFT windows in seconds."""
resize_factor: float
"""Factor by which the input was resized."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
@ -47,8 +68,73 @@ class NonMaximumSuppressionConfig(TypedDict):
"""Threshold for detection probability."""
def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
"""Run non-maximum suppression on the output of the model."""
class PredictionResults(TypedDict):
"""Results of the prediction.
Each key is a list of length `num_detections` containing the
corresponding values for each detection.
"""
det_probs: np.ndarray
"""Detection probabilities."""
x_pos: np.ndarray
"""X position of the detection in pixels."""
y_pos: np.ndarray
"""Y position of the detection in pixels."""
bb_width: np.ndarray
"""Width of the detection in pixels."""
bb_height: np.ndarray
"""Height of the detection in pixels."""
start_times: np.ndarray
"""Start times of the detections in seconds."""
end_times: np.ndarray
"""End times of the detections in seconds."""
low_freqs: np.ndarray
"""Low frequencies of the detections in Hz."""
high_freqs: np.ndarray
"""High frequencies of the detections in Hz."""
class_probs: Optional[np.ndarray]
"""Class probabilities."""
class ModelOutputs(TypedDict):
"""Outputs of the model."""
pred_det: torch.Tensor
"""Detection probabilities."""
pred_size: torch.Tensor
"""Box sizes."""
pred_class: Optional[torch.Tensor]
"""Class probabilities."""
features: Optional[torch.Tensor]
"""Features extracted by the model."""
def run_nms(
outputs: ModelOutputs,
params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
"""Run non-maximum suppression on the output of the model.
Model outputs processed are expected to have a batch dimension.
Each element of the batch is processed independently. The
result is a pair of lists, one for the predictions and one for
the features. Each element of the lists corresponds to one
element of the batch.
"""
pred_det = outputs["pred_det"] # probability of box
pred_size = outputs["pred_size"] # box size
@ -62,7 +148,7 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
# as we are choosing the same sampling rate for the entire batch
duration = x_coords_to_time(
pred_det.shape[-1],
sampling_rate[0].item(),
int(sampling_rate[0].item()),
params["fft_win_length"],
params["fft_overlap"],
)
@ -70,58 +156,72 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs
preds = []
feats = []
for ii in range(pred_det_nms.shape[0]):
preds: List[PredictionResults] = []
feats: List[np.ndarray] = []
for num_detection in range(pred_det_nms.shape[0]):
# get valid indices
inds_ord = torch.argsort(x_pos[ii, :])
valid_inds = scores[ii, inds_ord] > params["detection_threshold"]
inds_ord = torch.argsort(x_pos[num_detection, :])
valid_inds = (
scores[num_detection, inds_ord] > params["detection_threshold"]
)
valid_inds = inds_ord[valid_inds]
# create result dictionary
pred = {}
pred["det_probs"] = scores[ii, valid_inds]
pred["x_pos"] = x_pos[ii, valid_inds]
pred["y_pos"] = y_pos[ii, valid_inds]
pred["bb_width"] = pred_size[ii, 0, pred["y_pos"], pred["x_pos"]]
pred["bb_height"] = pred_size[ii, 1, pred["y_pos"], pred["x_pos"]]
pred["det_probs"] = scores[num_detection, valid_inds]
pred["x_pos"] = x_pos[num_detection, valid_inds]
pred["y_pos"] = y_pos[num_detection, valid_inds]
pred["bb_width"] = pred_size[
num_detection, 0, pred["y_pos"], pred["x_pos"]
]
pred["bb_height"] = pred_size[
num_detection, 1, pred["y_pos"], pred["x_pos"]
]
pred["start_times"] = x_coords_to_time(
pred["x_pos"].float() / params["resize_factor"],
sampling_rate[ii].item(),
int(sampling_rate[num_detection].item()),
params["fft_win_length"],
params["fft_overlap"],
)
pred["end_times"] = x_coords_to_time(
(pred["x_pos"].float() + pred["bb_width"])
/ params["resize_factor"],
sampling_rate[ii].item(),
int(sampling_rate[num_detection].item()),
params["fft_win_length"],
params["fft_overlap"],
)
pred["low_freqs"] = (
pred_size[ii].shape[1] - pred["y_pos"].float()
pred_size[num_detection].shape[1] - pred["y_pos"].float()
) * freq_rescale + params["min_freq"]
pred["high_freqs"] = (
pred["low_freqs"] + pred["bb_height"] * freq_rescale
)
# extract the per class votes
if "pred_class" in outputs:
pred["class_probs"] = outputs["pred_class"][
ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]
pred_class = outputs.get("pred_class")
if pred_class is not None:
pred["class_probs"] = pred_class[
num_detection,
:,
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
]
# extract the model features
if "features" in outputs:
feat = outputs["features"][
ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]
features = outputs.get("features")
if features is not None:
feat = features[
num_detection,
:,
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
].transpose(0, 1)
feat = feat.cpu().numpy().astype(np.float32)
feats.append(feat)
# convert to numpy
for kk in pred.keys():
pred[kk] = pred[kk].cpu().numpy().astype(np.float32)
for key, value in pred.items():
pred[key] = value.cpu().numpy().astype(np.float32)
preds.append(pred)
@ -130,7 +230,7 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
def non_max_suppression(heat, kernel_size):
# kernel can be an int or list/tuple
if type(kernel_size) is int:
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size

View File

@ -739,7 +739,7 @@ if __name__ == "__main__":
#
if args["bd_model_path"] != "":
# load model
bd_args = du.get_default_bd_args()
bd_args = du.get_default_run_config()
model, params_bd = du.load_model(args["bd_model_path"])
# check if the class names are the same

View File

@ -1,7 +1,4 @@
import copy
import os
import random
import sys
import librosa
import numpy as np
@ -9,7 +6,6 @@ import torch
import torch.nn.functional as F
import torchaudio
sys.path.append(os.path.join("..", ".."))
import bat_detect.utils.audio_utils as au
@ -218,7 +214,10 @@ def resample_aug(audio, sampling_rate, params):
sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(params["aug_sampling_rates"])
audio = librosa.resample(
audio, sampling_rate_old, sampling_rate, res_type="polyphase"
audio,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
audio = au.pad_audio(
@ -237,7 +236,10 @@ def resample_aug(audio, sampling_rate, params):
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
if sampling_rate != sampling_rate2:
audio2 = librosa.resample(
audio2, sampling_rate2, sampling_rate, res_type="polyphase"
audio2,
orig_sr=sampling_rate2,
target_sr=sampling_rate,
res_type="polyphase",
)
sampling_rate2 = sampling_rate
if audio2.shape[0] < num_samples:

View File

@ -553,5 +553,6 @@ if __name__ == "__main__":
torch.save(op_state, params["model_file_name"])
# save an image with associated prediction for each batch in the test set
if not args["do_not_save_images"]:
save_images_batch(model, test_loader, params)
# TODO: args variable does not exist
# if not args["do_not_save_images"]:
# save_images_batch(model, test_loader, params)

View File

@ -7,7 +7,6 @@ import torch
from . import wavfile
__all__ = [
"load_audio_file",
]
@ -163,9 +162,11 @@ def load_audio_file(
# clipping maximum duration
if max_duration is not None:
max_duration = np.minimum(
int(sampling_rate * max_duration),
audio_raw.shape[0],
max_duration = int(
np.minimum(
int(sampling_rate * max_duration),
audio_raw.shape[0],
)
)
audio_raw = audio_raw[:max_duration]

View File

@ -11,6 +11,20 @@ import bat_detect.detector.compute_features as feats
import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au
from bat_detect.detector import models
from bat_detect.detector.parameters import (
DETECTION_THRESHOLD,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MIN_FREQ_HZ,
NMS_KERNEL_SIZE,
NMS_TOP_K_PER_SEC,
RESIZE_FACTOR,
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
TARGET_SAMPLERATE_HZ,
)
try:
from typing import TypedDict
@ -24,23 +38,17 @@ DEFAULT_MODEL_PATH = os.path.join(
"model.pth",
)
__all__ = ["load_model", "get_audio_files", "DEFAULT_MODEL_PATH"]
def get_default_bd_args():
args = {}
args["detection_threshold"] = 0.001
args["time_expansion_factor"] = 1
args["audio_dir"] = ""
args["ann_dir"] = ""
args["spec_slices"] = False
args["chunk_size"] = 3
args["spec_features"] = False
args["cnn_features"] = False
args["quiet"] = True
args["save_preds_if_empty"] = True
args["ann_dir"] = os.path.join(args["ann_dir"], "")
return args
__all__ = [
"load_model",
"get_audio_files",
"format_results",
"save_results_to_file",
"iterate_over_chunks",
"process_spectrogram",
"process_audio_array",
"process_file",
"DEFAULT_MODEL_PATH",
]
def get_audio_files(ip_dir: str) -> List[str]:
@ -80,7 +88,7 @@ class ModelParameters(TypedDict):
ip_height: int
"""Input height in pixels."""
resize_factor: int
resize_factor: float
"""Resize factor."""
class_names: List[str]
@ -118,6 +126,8 @@ def load_model(
params = net_params["params"]
params["device"] = device
model: torch.nn.Module
if params["model_name"] == "Net2DFast":
model = models.Net2DFast(
params["num_filters"],
@ -159,9 +169,9 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
if num_preds > 0:
for kk in predictions[0].keys():
predictions_m[kk] = np.hstack(
[pp[kk] for pp in predictions if pp["det_probs"].shape[0] > 0]
for key in predictions[0].keys():
predictions_m[key] = np.hstack(
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
)
else:
# hack in case where no detected calls as we need some of the key names in dict
@ -176,7 +186,10 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
return predictions_m, spec_feats, cnn_feats, spec_slices
class Annotation(TypedDict("WithClass", {"class": str})):
DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the annotation
@ -214,7 +227,7 @@ class FileAnnotations(TypedDict):
This is the format of the results expected by the annotation tool.
"""
file_id: str
id: str
"""File ID."""
annotated: bool
@ -232,26 +245,32 @@ class FileAnnotations(TypedDict):
class_name: str
"""Class predicted at file level"""
notes: str
"""Notes of file."""
annotation: List[Annotation]
"""List of annotations."""
class Results(TypedDict):
class RunResults(TypedDict):
"""Run results."""
pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool."""
spec_feats: Optional[np.ndarray]
spec_feats: Optional[List[np.ndarray]]
"""Spectrogram features."""
spec_feat_names: Optional[List[str]]
"""Spectrogram feature names."""
cnn_feats: Optional[np.ndarray]
cnn_feats: Optional[List[np.ndarray]]
"""CNN features."""
cnn_feat_names: Optional[List[str]]
"""CNN feature names."""
spec_slices: Optional[np.ndarray]
spec_slices: Optional[List[np.ndarray]]
"""Spectrogram slices."""
@ -343,7 +362,7 @@ def convert_results(
spec_feats,
cnn_feats,
spec_slices,
) -> Results:
) -> RunResults:
"""Convert results to dictionary as expected by the annotation tool.
Args:
@ -369,8 +388,14 @@ def convert_results(
)
# combine into final results dictionary
results = {}
results["pred_dict"] = pred_dict
results: RunResults = {
"pred_dict": pred_dict,
"spec_feats": None,
"spec_feat_names": None,
"cnn_feats": None,
"cnn_feat_names": None,
"spec_slices": None,
}
# add spectrogram features if they exist
if len(spec_feats) > 0:
@ -463,19 +488,16 @@ def save_results_to_file(results, op_path: str) -> None:
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: int
"""Length of the FFT window in samples."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: int
"""Number of samples to overlap between FFT windows."""
fft_overlap: float
"""Percentage of overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
spec_width: int
"""Width of the spectrogram in pixels."""
resize_factor: int
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
@ -605,13 +627,14 @@ class ProcessingConfiguration(TypedDict):
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Length of the FFT window in samples."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: float
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
spec_height: int
@ -644,27 +667,36 @@ class ProcessingConfiguration(TypedDict):
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: float
max_freq: int
"""Maximum frequency to consider in Hz."""
min_freq: float
min_freq: int
"""Minimum frequency to consider in Hz."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""
quiet: bool
"""Whether to suppress output."""
chunk_size: float
"""Size of chunks to process in seconds."""
cnn_features: bool
"""Whether to return CNN features."""
spec_features: bool
"""Whether to return spectrogram features."""
spec_slices: bool
"""Whether to return spectrogram slices."""
def process_spectrogram(
spec: torch.Tensor,
samplerate: int,
model: torch.nn.Module,
config: pp.NonMaximumSuppressionConfig,
config: ProcessingConfiguration,
):
"""Process a spectrogram with detection model.
@ -692,17 +724,29 @@ def process_spectrogram(
outputs = model(spec, return_feats=config["cnn_features"])
# run non-max suppression
pred_nms, features = pp.run_nms(
pred_nms_list, features = pp.run_nms(
outputs,
config,
{
"nms_kernel_size": config["nms_kernel_size"],
"max_freq": config["max_freq"],
"min_freq": config["min_freq"],
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"resize_factor": config["resize_factor"],
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
"detection_threshold": config["detection_threshold"],
},
np.array([float(samplerate)]),
)
pred_nms = pred_nms[0]
pred_nms = pred_nms_list[0]
# if we have a background class
if pred_nms["class_probs"].shape[0] > len(config["class_names"]):
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
class_probs = pred_nms.get("class_probs")
if (class_probs is not None) and (
class_probs.shape[0] > len(config["class_names"])
):
pred_nms["class_probs"] = class_probs[:-1, :]
return pred_nms, features
@ -737,7 +781,14 @@ def process_audio_array(
_, spec, spec_np = compute_spectrogram(
audio,
sampling_rate,
config,
{
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"spec_height": config["spec_height"],
"resize_factor": config["resize_factor"],
"spec_divide_factor": config["spec_divide_factor"],
"device": config["device"],
},
return_np=config["spec_features"] or config["spec_slices"],
)
@ -756,7 +807,7 @@ def process_file(
audio_file: str,
model: torch.nn.Module,
config: ProcessingConfiguration,
) -> Union[Results, Any]:
) -> Union[RunResults, Any]:
"""Process a single audio file with detection model.
Will split the audio file into chunks if it is too long and
@ -788,7 +839,7 @@ def process_file(
# load audio file
sampling_rate, audio_full = au.load_audio_file(
audio_file,
time_exp_fact=config["time_expansion"],
time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"],
max_duration=config["max_duration"],
@ -840,7 +891,7 @@ def process_file(
# convert results to a dictionary in the right format
results = convert_results(
file_id=os.path.basename(audio_file),
time_exp=config["time_expansion"],
time_exp=config.get("time_expansion", 1) or 1,
duration=audio_full.shape[0] / float(sampling_rate),
params=config,
predictions=predictions,
@ -877,3 +928,38 @@ def summarize_results(results, predictions, config):
config["class_names"][class_index].ljust(30)
+ str(round(class_overall[class_index], 3))
)
def get_default_run_config(**kwargs) -> ProcessingConfiguration:
"""Get default configuration for running detection model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False,
"chunk_size": 3,
"spec_features": False,
"cnn_features": False,
"quiet": True,
"target_samp_rate": TARGET_SAMPLERATE_HZ,
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO,
"device": device,
"class_names": [],
"time_expansion": 1,
"top_n": 3,
"return_raw_preds": False,
"max_duration": None,
"nms_kernel_size": NMS_KERNEL_SIZE,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
}
return {
**args,
**kwargs,
}

View File

@ -523,7 +523,7 @@ class LossPlotter(object):
def save_confusion_matrix(self, gt, pred):
plt.figure(0)
cm = confusion_matrix(
gt, pred, np.arange(len(self.class_names))
gt, pred, labels=np.arange(len(self.class_names))
).astype(np.float32)
cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0]

View File

@ -49,3 +49,10 @@ batdetect2 = "bat_detect.command:main"
[tool.black]
line-length = 80
[[tool.mypy.overrides]]
module = [
"librosa",
"pandas",
]
ignore_missing_imports = true

View File

@ -86,7 +86,7 @@ if __name__ == "__main__":
args_cmd = vars(parser.parse_args())
# load the model
bd_args = du.get_default_bd_args()
bd_args = du.get_default_run_config()
model, params_bd = du.load_model(args_cmd["model_path"])
bd_args["detection_threshold"] = args_cmd["detection_threshold"]
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]

View File

@ -89,7 +89,7 @@ if __name__ == "__main__":
os.makedirs(op_dir)
params = parameters.get_params(False)
args = du.get_default_bd_args()
args = du.get_default_run_config()
args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
args["detection_threshold"] = args_cmd["detection_threshold"]