mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Finished refactoring detector_utils
This commit is contained in:
parent
8da98b5258
commit
e6a6ad4696
5
.pylintrc
Normal file
5
.pylintrc
Normal file
@ -0,0 +1,5 @@
|
||||
[TYPECHECK]
|
||||
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=torch.*
|
2
app.py
2
app.py
@ -9,7 +9,7 @@ import bat_detect.utils.plot_utils as viz
|
||||
|
||||
# setup the arguments
|
||||
args = {}
|
||||
args = du.get_default_bd_args()
|
||||
args = du.get_default_run_config()
|
||||
args["detection_threshold"] = 0.3
|
||||
args["time_expansion_factor"] = 1
|
||||
args["model_path"] = "models/Net2DFast_UK_same.pth.tar"
|
||||
|
@ -1,7 +1,18 @@
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
RESIZE_FACTOR = 0.5
|
||||
SPEC_DIVIDE_FACTOR = 32
|
||||
SPEC_HEIGHT = 256
|
||||
SCALE_RAW_AUDIO = False
|
||||
DETECTION_THRESHOLD = 0.01
|
||||
NMS_KERNEL_SIZE = 9
|
||||
NMS_TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
def mk_dir(path):
|
||||
@ -30,35 +41,39 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
||||
# spec parameters
|
||||
params[
|
||||
"target_samp_rate"
|
||||
] = 256000 # resamples all audio so that it is at this rate
|
||||
params["fft_win_length"] = (
|
||||
512 / 256000.0
|
||||
) # in milliseconds, amount of time per stft time step
|
||||
params["fft_overlap"] = 0.75 # stft window overlap
|
||||
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
|
||||
params[
|
||||
"fft_win_length"
|
||||
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
|
||||
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
|
||||
|
||||
params[
|
||||
"max_freq"
|
||||
] = 120000 # in Hz, everything above this will be discarded
|
||||
params["min_freq"] = 10000 # in Hz, everything below this will be discarded
|
||||
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
|
||||
params[
|
||||
"min_freq"
|
||||
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
|
||||
|
||||
params[
|
||||
"resize_factor"
|
||||
] = 0.5 # resize so the spectrogram at the input of the network
|
||||
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
|
||||
params[
|
||||
"spec_height"
|
||||
] = 256 # units are number of frequency bins (before resizing is performed)
|
||||
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
|
||||
params[
|
||||
"spec_train_width"
|
||||
] = 512 # units are number of time steps (before resizing is performed)
|
||||
params[
|
||||
"spec_divide_factor"
|
||||
] = 32 # spectrogram should be divisible by this amount in width and height
|
||||
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
|
||||
|
||||
# spec processing params
|
||||
params[
|
||||
"denoise_spec_avg"
|
||||
] = True # removes the mean for each frequency band
|
||||
params["scale_raw_audio"] = False # scales the raw audio to [-1, 1]
|
||||
params[
|
||||
"scale_raw_audio"
|
||||
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
|
||||
params[
|
||||
"max_scale_spec"
|
||||
] = False # scales the spectrogram so that it is max 1
|
||||
@ -73,11 +88,13 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
||||
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
|
||||
params[
|
||||
"detection_threshold"
|
||||
] = 0.01 # the smaller this is the better the recall will be
|
||||
params["nms_kernel_size"] = 9
|
||||
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
|
||||
params[
|
||||
"nms_kernel_size"
|
||||
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
|
||||
params[
|
||||
"nms_top_k_per_sec"
|
||||
] = 200 # keep top K highest predictions per second of audio
|
||||
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
|
||||
params["target_sigma"] = 2.0
|
||||
|
||||
# augmentation params
|
||||
|
@ -1,3 +1,6 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -10,11 +13,26 @@ except ImportError:
|
||||
np.seterr(divide="ignore", invalid="ignore")
|
||||
|
||||
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
def x_coords_to_time(
|
||||
x_pos: float,
|
||||
sampling_rate: int,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> float:
|
||||
"""Convert x coordinates of spectrogram to time in seconds.
|
||||
|
||||
Args:
|
||||
x_pos: X position of the detection in pixels.
|
||||
sampling_rate: Sampling rate of the audio in Hz.
|
||||
fft_win_length: Length of the FFT window in seconds.
|
||||
fft_overlap: Overlap of the FFT windows in seconds.
|
||||
|
||||
Returns:
|
||||
Time in seconds.
|
||||
"""
|
||||
nfft = int(fft_win_length * sampling_rate)
|
||||
noverlap = int(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
|
||||
|
||||
def overall_class_pred(det_prob, class_prob):
|
||||
@ -28,10 +46,10 @@ class NonMaximumSuppressionConfig(TypedDict):
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: float
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: float
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
fft_win_length: float
|
||||
@ -40,6 +58,9 @@ class NonMaximumSuppressionConfig(TypedDict):
|
||||
fft_overlap: float
|
||||
"""Overlap of the FFT windows in seconds."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input was resized."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
@ -47,8 +68,73 @@ class NonMaximumSuppressionConfig(TypedDict):
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
|
||||
def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
|
||||
"""Run non-maximum suppression on the output of the model."""
|
||||
class PredictionResults(TypedDict):
|
||||
"""Results of the prediction.
|
||||
|
||||
Each key is a list of length `num_detections` containing the
|
||||
corresponding values for each detection.
|
||||
"""
|
||||
|
||||
det_probs: np.ndarray
|
||||
"""Detection probabilities."""
|
||||
|
||||
x_pos: np.ndarray
|
||||
"""X position of the detection in pixels."""
|
||||
|
||||
y_pos: np.ndarray
|
||||
"""Y position of the detection in pixels."""
|
||||
|
||||
bb_width: np.ndarray
|
||||
"""Width of the detection in pixels."""
|
||||
|
||||
bb_height: np.ndarray
|
||||
"""Height of the detection in pixels."""
|
||||
|
||||
start_times: np.ndarray
|
||||
"""Start times of the detections in seconds."""
|
||||
|
||||
end_times: np.ndarray
|
||||
"""End times of the detections in seconds."""
|
||||
|
||||
low_freqs: np.ndarray
|
||||
"""Low frequencies of the detections in Hz."""
|
||||
|
||||
high_freqs: np.ndarray
|
||||
"""High frequencies of the detections in Hz."""
|
||||
|
||||
class_probs: Optional[np.ndarray]
|
||||
"""Class probabilities."""
|
||||
|
||||
|
||||
class ModelOutputs(TypedDict):
|
||||
"""Outputs of the model."""
|
||||
|
||||
pred_det: torch.Tensor
|
||||
"""Detection probabilities."""
|
||||
|
||||
pred_size: torch.Tensor
|
||||
"""Box sizes."""
|
||||
|
||||
pred_class: Optional[torch.Tensor]
|
||||
"""Class probabilities."""
|
||||
|
||||
features: Optional[torch.Tensor]
|
||||
"""Features extracted by the model."""
|
||||
|
||||
|
||||
def run_nms(
|
||||
outputs: ModelOutputs,
|
||||
params: NonMaximumSuppressionConfig,
|
||||
sampling_rate: np.ndarray,
|
||||
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
||||
"""Run non-maximum suppression on the output of the model.
|
||||
|
||||
Model outputs processed are expected to have a batch dimension.
|
||||
Each element of the batch is processed independently. The
|
||||
result is a pair of lists, one for the predictions and one for
|
||||
the features. Each element of the lists corresponds to one
|
||||
element of the batch.
|
||||
"""
|
||||
|
||||
pred_det = outputs["pred_det"] # probability of box
|
||||
pred_size = outputs["pred_size"] # box size
|
||||
@ -62,7 +148,7 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
|
||||
# as we are choosing the same sampling rate for the entire batch
|
||||
duration = x_coords_to_time(
|
||||
pred_det.shape[-1],
|
||||
sampling_rate[0].item(),
|
||||
int(sampling_rate[0].item()),
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
@ -70,58 +156,72 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
|
||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||
|
||||
# loop over batch to save outputs
|
||||
preds = []
|
||||
feats = []
|
||||
for ii in range(pred_det_nms.shape[0]):
|
||||
preds: List[PredictionResults] = []
|
||||
feats: List[np.ndarray] = []
|
||||
for num_detection in range(pred_det_nms.shape[0]):
|
||||
# get valid indices
|
||||
inds_ord = torch.argsort(x_pos[ii, :])
|
||||
valid_inds = scores[ii, inds_ord] > params["detection_threshold"]
|
||||
inds_ord = torch.argsort(x_pos[num_detection, :])
|
||||
valid_inds = (
|
||||
scores[num_detection, inds_ord] > params["detection_threshold"]
|
||||
)
|
||||
valid_inds = inds_ord[valid_inds]
|
||||
|
||||
# create result dictionary
|
||||
pred = {}
|
||||
pred["det_probs"] = scores[ii, valid_inds]
|
||||
pred["x_pos"] = x_pos[ii, valid_inds]
|
||||
pred["y_pos"] = y_pos[ii, valid_inds]
|
||||
pred["bb_width"] = pred_size[ii, 0, pred["y_pos"], pred["x_pos"]]
|
||||
pred["bb_height"] = pred_size[ii, 1, pred["y_pos"], pred["x_pos"]]
|
||||
pred["det_probs"] = scores[num_detection, valid_inds]
|
||||
pred["x_pos"] = x_pos[num_detection, valid_inds]
|
||||
pred["y_pos"] = y_pos[num_detection, valid_inds]
|
||||
pred["bb_width"] = pred_size[
|
||||
num_detection, 0, pred["y_pos"], pred["x_pos"]
|
||||
]
|
||||
pred["bb_height"] = pred_size[
|
||||
num_detection, 1, pred["y_pos"], pred["x_pos"]
|
||||
]
|
||||
pred["start_times"] = x_coords_to_time(
|
||||
pred["x_pos"].float() / params["resize_factor"],
|
||||
sampling_rate[ii].item(),
|
||||
int(sampling_rate[num_detection].item()),
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
pred["end_times"] = x_coords_to_time(
|
||||
(pred["x_pos"].float() + pred["bb_width"])
|
||||
/ params["resize_factor"],
|
||||
sampling_rate[ii].item(),
|
||||
int(sampling_rate[num_detection].item()),
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
pred["low_freqs"] = (
|
||||
pred_size[ii].shape[1] - pred["y_pos"].float()
|
||||
pred_size[num_detection].shape[1] - pred["y_pos"].float()
|
||||
) * freq_rescale + params["min_freq"]
|
||||
pred["high_freqs"] = (
|
||||
pred["low_freqs"] + pred["bb_height"] * freq_rescale
|
||||
)
|
||||
|
||||
# extract the per class votes
|
||||
if "pred_class" in outputs:
|
||||
pred["class_probs"] = outputs["pred_class"][
|
||||
ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]
|
||||
pred_class = outputs.get("pred_class")
|
||||
if pred_class is not None:
|
||||
pred["class_probs"] = pred_class[
|
||||
num_detection,
|
||||
:,
|
||||
y_pos[num_detection, valid_inds],
|
||||
x_pos[num_detection, valid_inds],
|
||||
]
|
||||
|
||||
# extract the model features
|
||||
if "features" in outputs:
|
||||
feat = outputs["features"][
|
||||
ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]
|
||||
features = outputs.get("features")
|
||||
if features is not None:
|
||||
feat = features[
|
||||
num_detection,
|
||||
:,
|
||||
y_pos[num_detection, valid_inds],
|
||||
x_pos[num_detection, valid_inds],
|
||||
].transpose(0, 1)
|
||||
feat = feat.cpu().numpy().astype(np.float32)
|
||||
feats.append(feat)
|
||||
|
||||
# convert to numpy
|
||||
for kk in pred.keys():
|
||||
pred[kk] = pred[kk].cpu().numpy().astype(np.float32)
|
||||
for key, value in pred.items():
|
||||
pred[key] = value.cpu().numpy().astype(np.float32)
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
@ -130,7 +230,7 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
|
||||
|
||||
def non_max_suppression(heat, kernel_size):
|
||||
# kernel can be an int or list/tuple
|
||||
if type(kernel_size) is int:
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
|
||||
|
@ -739,7 +739,7 @@ if __name__ == "__main__":
|
||||
#
|
||||
if args["bd_model_path"] != "":
|
||||
# load model
|
||||
bd_args = du.get_default_bd_args()
|
||||
bd_args = du.get_default_run_config()
|
||||
model, params_bd = du.load_model(args["bd_model_path"])
|
||||
|
||||
# check if the class names are the same
|
||||
|
@ -1,7 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
@ -9,7 +6,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
sys.path.append(os.path.join("..", ".."))
|
||||
import bat_detect.utils.audio_utils as au
|
||||
|
||||
|
||||
@ -218,7 +214,10 @@ def resample_aug(audio, sampling_rate, params):
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = np.random.choice(params["aug_sampling_rates"])
|
||||
audio = librosa.resample(
|
||||
audio, sampling_rate_old, sampling_rate, res_type="polyphase"
|
||||
audio,
|
||||
orig_sr=sampling_rate_old,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
|
||||
audio = au.pad_audio(
|
||||
@ -237,7 +236,10 @@ def resample_aug(audio, sampling_rate, params):
|
||||
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
if sampling_rate != sampling_rate2:
|
||||
audio2 = librosa.resample(
|
||||
audio2, sampling_rate2, sampling_rate, res_type="polyphase"
|
||||
audio2,
|
||||
orig_sr=sampling_rate2,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
sampling_rate2 = sampling_rate
|
||||
if audio2.shape[0] < num_samples:
|
||||
|
@ -553,5 +553,6 @@ if __name__ == "__main__":
|
||||
torch.save(op_state, params["model_file_name"])
|
||||
|
||||
# save an image with associated prediction for each batch in the test set
|
||||
if not args["do_not_save_images"]:
|
||||
save_images_batch(model, test_loader, params)
|
||||
# TODO: args variable does not exist
|
||||
# if not args["do_not_save_images"]:
|
||||
# save_images_batch(model, test_loader, params)
|
||||
|
@ -7,7 +7,6 @@ import torch
|
||||
|
||||
from . import wavfile
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_audio_file",
|
||||
]
|
||||
@ -163,9 +162,11 @@ def load_audio_file(
|
||||
|
||||
# clipping maximum duration
|
||||
if max_duration is not None:
|
||||
max_duration = np.minimum(
|
||||
int(sampling_rate * max_duration),
|
||||
audio_raw.shape[0],
|
||||
max_duration = int(
|
||||
np.minimum(
|
||||
int(sampling_rate * max_duration),
|
||||
audio_raw.shape[0],
|
||||
)
|
||||
)
|
||||
audio_raw = audio_raw[:max_duration]
|
||||
|
||||
|
@ -11,6 +11,20 @@ import bat_detect.detector.compute_features as feats
|
||||
import bat_detect.detector.post_process as pp
|
||||
import bat_detect.utils.audio_utils as au
|
||||
from bat_detect.detector import models
|
||||
from bat_detect.detector.parameters import (
|
||||
DETECTION_THRESHOLD,
|
||||
FFT_OVERLAP,
|
||||
FFT_WIN_LENGTH_S,
|
||||
MAX_FREQ_HZ,
|
||||
MIN_FREQ_HZ,
|
||||
NMS_KERNEL_SIZE,
|
||||
NMS_TOP_K_PER_SEC,
|
||||
RESIZE_FACTOR,
|
||||
SCALE_RAW_AUDIO,
|
||||
SPEC_DIVIDE_FACTOR,
|
||||
SPEC_HEIGHT,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
@ -24,23 +38,17 @@ DEFAULT_MODEL_PATH = os.path.join(
|
||||
"model.pth",
|
||||
)
|
||||
|
||||
__all__ = ["load_model", "get_audio_files", "DEFAULT_MODEL_PATH"]
|
||||
|
||||
|
||||
def get_default_bd_args():
|
||||
args = {}
|
||||
args["detection_threshold"] = 0.001
|
||||
args["time_expansion_factor"] = 1
|
||||
args["audio_dir"] = ""
|
||||
args["ann_dir"] = ""
|
||||
args["spec_slices"] = False
|
||||
args["chunk_size"] = 3
|
||||
args["spec_features"] = False
|
||||
args["cnn_features"] = False
|
||||
args["quiet"] = True
|
||||
args["save_preds_if_empty"] = True
|
||||
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
||||
return args
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"get_audio_files",
|
||||
"format_results",
|
||||
"save_results_to_file",
|
||||
"iterate_over_chunks",
|
||||
"process_spectrogram",
|
||||
"process_audio_array",
|
||||
"process_file",
|
||||
"DEFAULT_MODEL_PATH",
|
||||
]
|
||||
|
||||
|
||||
def get_audio_files(ip_dir: str) -> List[str]:
|
||||
@ -80,7 +88,7 @@ class ModelParameters(TypedDict):
|
||||
ip_height: int
|
||||
"""Input height in pixels."""
|
||||
|
||||
resize_factor: int
|
||||
resize_factor: float
|
||||
"""Resize factor."""
|
||||
|
||||
class_names: List[str]
|
||||
@ -118,6 +126,8 @@ def load_model(
|
||||
params = net_params["params"]
|
||||
params["device"] = device
|
||||
|
||||
model: torch.nn.Module
|
||||
|
||||
if params["model_name"] == "Net2DFast":
|
||||
model = models.Net2DFast(
|
||||
params["num_filters"],
|
||||
@ -159,9 +169,9 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
||||
|
||||
if num_preds > 0:
|
||||
for kk in predictions[0].keys():
|
||||
predictions_m[kk] = np.hstack(
|
||||
[pp[kk] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
||||
for key in predictions[0].keys():
|
||||
predictions_m[key] = np.hstack(
|
||||
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
||||
)
|
||||
else:
|
||||
# hack in case where no detected calls as we need some of the key names in dict
|
||||
@ -176,7 +186,10 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
return predictions_m, spec_feats, cnn_feats, spec_slices
|
||||
|
||||
|
||||
class Annotation(TypedDict("WithClass", {"class": str})):
|
||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
|
||||
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
@ -214,7 +227,7 @@ class FileAnnotations(TypedDict):
|
||||
This is the format of the results expected by the annotation tool.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
id: str
|
||||
"""File ID."""
|
||||
|
||||
annotated: bool
|
||||
@ -232,26 +245,32 @@ class FileAnnotations(TypedDict):
|
||||
class_name: str
|
||||
"""Class predicted at file level"""
|
||||
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
|
||||
annotation: List[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
|
||||
class Results(TypedDict):
|
||||
class RunResults(TypedDict):
|
||||
"""Run results."""
|
||||
|
||||
pred_dict: FileAnnotations
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: Optional[np.ndarray]
|
||||
spec_feats: Optional[List[np.ndarray]]
|
||||
"""Spectrogram features."""
|
||||
|
||||
spec_feat_names: Optional[List[str]]
|
||||
"""Spectrogram feature names."""
|
||||
|
||||
cnn_feats: Optional[np.ndarray]
|
||||
cnn_feats: Optional[List[np.ndarray]]
|
||||
"""CNN features."""
|
||||
|
||||
cnn_feat_names: Optional[List[str]]
|
||||
"""CNN feature names."""
|
||||
|
||||
spec_slices: Optional[np.ndarray]
|
||||
spec_slices: Optional[List[np.ndarray]]
|
||||
"""Spectrogram slices."""
|
||||
|
||||
|
||||
@ -343,7 +362,7 @@ def convert_results(
|
||||
spec_feats,
|
||||
cnn_feats,
|
||||
spec_slices,
|
||||
) -> Results:
|
||||
) -> RunResults:
|
||||
"""Convert results to dictionary as expected by the annotation tool.
|
||||
|
||||
Args:
|
||||
@ -369,8 +388,14 @@ def convert_results(
|
||||
)
|
||||
|
||||
# combine into final results dictionary
|
||||
results = {}
|
||||
results["pred_dict"] = pred_dict
|
||||
results: RunResults = {
|
||||
"pred_dict": pred_dict,
|
||||
"spec_feats": None,
|
||||
"spec_feat_names": None,
|
||||
"cnn_feats": None,
|
||||
"cnn_feat_names": None,
|
||||
"spec_slices": None,
|
||||
}
|
||||
|
||||
# add spectrogram features if they exist
|
||||
if len(spec_feats) > 0:
|
||||
@ -463,19 +488,16 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
||||
fft_win_length: int
|
||||
"""Length of the FFT window in samples."""
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: int
|
||||
"""Number of samples to overlap between FFT windows."""
|
||||
fft_overlap: float
|
||||
"""Percentage of overlap between FFT windows."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
spec_width: int
|
||||
"""Width of the spectrogram in pixels."""
|
||||
|
||||
resize_factor: int
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
@ -605,13 +627,14 @@ class ProcessingConfiguration(TypedDict):
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Length of the FFT window in samples."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: float
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
spec_height: int
|
||||
@ -644,27 +667,36 @@ class ProcessingConfiguration(TypedDict):
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: float
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: float
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
quiet: bool
|
||||
"""Whether to suppress output."""
|
||||
|
||||
chunk_size: float
|
||||
"""Size of chunks to process in seconds."""
|
||||
|
||||
cnn_features: bool
|
||||
"""Whether to return CNN features."""
|
||||
|
||||
spec_features: bool
|
||||
"""Whether to return spectrogram features."""
|
||||
|
||||
spec_slices: bool
|
||||
"""Whether to return spectrogram slices."""
|
||||
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: int,
|
||||
model: torch.nn.Module,
|
||||
config: pp.NonMaximumSuppressionConfig,
|
||||
config: ProcessingConfiguration,
|
||||
):
|
||||
"""Process a spectrogram with detection model.
|
||||
|
||||
@ -692,17 +724,29 @@ def process_spectrogram(
|
||||
outputs = model(spec, return_feats=config["cnn_features"])
|
||||
|
||||
# run non-max suppression
|
||||
pred_nms, features = pp.run_nms(
|
||||
pred_nms_list, features = pp.run_nms(
|
||||
outputs,
|
||||
config,
|
||||
{
|
||||
"nms_kernel_size": config["nms_kernel_size"],
|
||||
"max_freq": config["max_freq"],
|
||||
"min_freq": config["min_freq"],
|
||||
"fft_win_length": config["fft_win_length"],
|
||||
"fft_overlap": config["fft_overlap"],
|
||||
"resize_factor": config["resize_factor"],
|
||||
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
||||
"detection_threshold": config["detection_threshold"],
|
||||
},
|
||||
np.array([float(samplerate)]),
|
||||
)
|
||||
|
||||
pred_nms = pred_nms[0]
|
||||
pred_nms = pred_nms_list[0]
|
||||
|
||||
# if we have a background class
|
||||
if pred_nms["class_probs"].shape[0] > len(config["class_names"]):
|
||||
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
|
||||
class_probs = pred_nms.get("class_probs")
|
||||
if (class_probs is not None) and (
|
||||
class_probs.shape[0] > len(config["class_names"])
|
||||
):
|
||||
pred_nms["class_probs"] = class_probs[:-1, :]
|
||||
|
||||
return pred_nms, features
|
||||
|
||||
@ -737,7 +781,14 @@ def process_audio_array(
|
||||
_, spec, spec_np = compute_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
config,
|
||||
{
|
||||
"fft_win_length": config["fft_win_length"],
|
||||
"fft_overlap": config["fft_overlap"],
|
||||
"spec_height": config["spec_height"],
|
||||
"resize_factor": config["resize_factor"],
|
||||
"spec_divide_factor": config["spec_divide_factor"],
|
||||
"device": config["device"],
|
||||
},
|
||||
return_np=config["spec_features"] or config["spec_slices"],
|
||||
)
|
||||
|
||||
@ -756,7 +807,7 @@ def process_file(
|
||||
audio_file: str,
|
||||
model: torch.nn.Module,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Union[Results, Any]:
|
||||
) -> Union[RunResults, Any]:
|
||||
"""Process a single audio file with detection model.
|
||||
|
||||
Will split the audio file into chunks if it is too long and
|
||||
@ -788,7 +839,7 @@ def process_file(
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio_file(
|
||||
audio_file,
|
||||
time_exp_fact=config["time_expansion"],
|
||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
scale=config["scale_raw_audio"],
|
||||
max_duration=config["max_duration"],
|
||||
@ -840,7 +891,7 @@ def process_file(
|
||||
# convert results to a dictionary in the right format
|
||||
results = convert_results(
|
||||
file_id=os.path.basename(audio_file),
|
||||
time_exp=config["time_expansion"],
|
||||
time_exp=config.get("time_expansion", 1) or 1,
|
||||
duration=audio_full.shape[0] / float(sampling_rate),
|
||||
params=config,
|
||||
predictions=predictions,
|
||||
@ -877,3 +928,38 @@ def summarize_results(results, predictions, config):
|
||||
config["class_names"][class_index].ljust(30)
|
||||
+ str(round(class_overall[class_index], 3))
|
||||
)
|
||||
|
||||
|
||||
def get_default_run_config(**kwargs) -> ProcessingConfiguration:
|
||||
"""Get default configuration for running detection model."""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
args: ProcessingConfiguration = {
|
||||
"detection_threshold": DETECTION_THRESHOLD,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 3,
|
||||
"spec_features": False,
|
||||
"cnn_features": False,
|
||||
"quiet": True,
|
||||
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||
"device": device,
|
||||
"class_names": [],
|
||||
"time_expansion": 1,
|
||||
"top_n": 3,
|
||||
"return_raw_preds": False,
|
||||
"max_duration": None,
|
||||
"nms_kernel_size": NMS_KERNEL_SIZE,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||
}
|
||||
return {
|
||||
**args,
|
||||
**kwargs,
|
||||
}
|
||||
|
@ -523,7 +523,7 @@ class LossPlotter(object):
|
||||
def save_confusion_matrix(self, gt, pred):
|
||||
plt.figure(0)
|
||||
cm = confusion_matrix(
|
||||
gt, pred, np.arange(len(self.class_names))
|
||||
gt, pred, labels=np.arange(len(self.class_names))
|
||||
).astype(np.float32)
|
||||
cm_norm = cm.sum(1)
|
||||
valid_inds = np.where(cm_norm > 0)[0]
|
||||
|
@ -49,3 +49,10 @@ batdetect2 = "bat_detect.command:main"
|
||||
|
||||
[tool.black]
|
||||
line-length = 80
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"librosa",
|
||||
"pandas",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
@ -86,7 +86,7 @@ if __name__ == "__main__":
|
||||
args_cmd = vars(parser.parse_args())
|
||||
|
||||
# load the model
|
||||
bd_args = du.get_default_bd_args()
|
||||
bd_args = du.get_default_run_config()
|
||||
model, params_bd = du.load_model(args_cmd["model_path"])
|
||||
bd_args["detection_threshold"] = args_cmd["detection_threshold"]
|
||||
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
||||
|
@ -89,7 +89,7 @@ if __name__ == "__main__":
|
||||
os.makedirs(op_dir)
|
||||
|
||||
params = parameters.get_params(False)
|
||||
args = du.get_default_bd_args()
|
||||
args = du.get_default_run_config()
|
||||
args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
|
||||
args["detection_threshold"] = args_cmd["detection_threshold"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user