Finished refactoring detector_utils

This commit is contained in:
Santiago Martinez 2023-02-22 22:45:26 +00:00
parent 8da98b5258
commit e6a6ad4696
13 changed files with 334 additions and 115 deletions

5
.pylintrc Normal file
View File

@ -0,0 +1,5 @@
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=torch.*

2
app.py
View File

@ -9,7 +9,7 @@ import bat_detect.utils.plot_utils as viz
# setup the arguments # setup the arguments
args = {} args = {}
args = du.get_default_bd_args() args = du.get_default_run_config()
args["detection_threshold"] = 0.3 args["detection_threshold"] = 0.3
args["time_expansion_factor"] = 1 args["time_expansion_factor"] = 1
args["model_path"] = "models/Net2DFast_UK_same.pth.tar" args["model_path"] = "models/Net2DFast_UK_same.pth.tar"

View File

@ -1,7 +1,18 @@
import datetime import datetime
import os import os
import numpy as np TARGET_SAMPLERATE_HZ = 256000
FFT_WIN_LENGTH_S = 512 / 256000.0
FFT_OVERLAP = 0.75
MAX_FREQ_HZ = 120000
MIN_FREQ_HZ = 10000
RESIZE_FACTOR = 0.5
SPEC_DIVIDE_FACTOR = 32
SPEC_HEIGHT = 256
SCALE_RAW_AUDIO = False
DETECTION_THRESHOLD = 0.01
NMS_KERNEL_SIZE = 9
NMS_TOP_K_PER_SEC = 200
def mk_dir(path): def mk_dir(path):
@ -30,35 +41,39 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
# spec parameters # spec parameters
params[ params[
"target_samp_rate" "target_samp_rate"
] = 256000 # resamples all audio so that it is at this rate ] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
params["fft_win_length"] = ( params[
512 / 256000.0 "fft_win_length"
) # in milliseconds, amount of time per stft time step ] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
params["fft_overlap"] = 0.75 # stft window overlap params["fft_overlap"] = FFT_OVERLAP # stft window overlap
params[ params[
"max_freq" "max_freq"
] = 120000 # in Hz, everything above this will be discarded ] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
params["min_freq"] = 10000 # in Hz, everything below this will be discarded params[
"min_freq"
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
params[ params[
"resize_factor" "resize_factor"
] = 0.5 # resize so the spectrogram at the input of the network ] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
params[ params[
"spec_height" "spec_height"
] = 256 # units are number of frequency bins (before resizing is performed) ] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
params[ params[
"spec_train_width" "spec_train_width"
] = 512 # units are number of time steps (before resizing is performed) ] = 512 # units are number of time steps (before resizing is performed)
params[ params[
"spec_divide_factor" "spec_divide_factor"
] = 32 # spectrogram should be divisible by this amount in width and height ] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
# spec processing params # spec processing params
params[ params[
"denoise_spec_avg" "denoise_spec_avg"
] = True # removes the mean for each frequency band ] = True # removes the mean for each frequency band
params["scale_raw_audio"] = False # scales the raw audio to [-1, 1] params[
"scale_raw_audio"
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
params[ params[
"max_scale_spec" "max_scale_spec"
] = False # scales the spectrogram so that it is max 1 ] = False # scales the spectrogram so that it is max 1
@ -73,11 +88,13 @@ def get_params(make_dirs=False, exps_dir="../../experiments/"):
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore ] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
params[ params[
"detection_threshold" "detection_threshold"
] = 0.01 # the smaller this is the better the recall will be ] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
params["nms_kernel_size"] = 9 params[
"nms_kernel_size"
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
params[ params[
"nms_top_k_per_sec" "nms_top_k_per_sec"
] = 200 # keep top K highest predictions per second of audio ] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
params["target_sigma"] = 2.0 params["target_sigma"] = 2.0
# augmentation params # augmentation params

View File

@ -1,3 +1,6 @@
"""Post-processing of the output of the model."""
from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
@ -10,11 +13,26 @@ except ImportError:
np.seterr(divide="ignore", invalid="ignore") np.seterr(divide="ignore", invalid="ignore")
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): def x_coords_to_time(
x_pos: float,
sampling_rate: int,
fft_win_length: float,
fft_overlap: float,
) -> float:
"""Convert x coordinates of spectrogram to time in seconds.
Args:
x_pos: X position of the detection in pixels.
sampling_rate: Sampling rate of the audio in Hz.
fft_win_length: Length of the FFT window in seconds.
fft_overlap: Overlap of the FFT windows in seconds.
Returns:
Time in seconds.
"""
nfft = int(fft_win_length * sampling_rate) nfft = int(fft_win_length * sampling_rate)
noverlap = int(fft_overlap * nfft) noverlap = int(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def overall_class_pred(det_prob, class_prob): def overall_class_pred(det_prob, class_prob):
@ -28,10 +46,10 @@ class NonMaximumSuppressionConfig(TypedDict):
nms_kernel_size: int nms_kernel_size: int
"""Size of the kernel for non-maximum suppression.""" """Size of the kernel for non-maximum suppression."""
max_freq: float max_freq: int
"""Maximum frequency to consider in Hz.""" """Maximum frequency to consider in Hz."""
min_freq: float min_freq: int
"""Minimum frequency to consider in Hz.""" """Minimum frequency to consider in Hz."""
fft_win_length: float fft_win_length: float
@ -40,6 +58,9 @@ class NonMaximumSuppressionConfig(TypedDict):
fft_overlap: float fft_overlap: float
"""Overlap of the FFT windows in seconds.""" """Overlap of the FFT windows in seconds."""
resize_factor: float
"""Factor by which the input was resized."""
nms_top_k_per_sec: float nms_top_k_per_sec: float
"""Number of top detections to keep per second.""" """Number of top detections to keep per second."""
@ -47,8 +68,73 @@ class NonMaximumSuppressionConfig(TypedDict):
"""Threshold for detection probability.""" """Threshold for detection probability."""
def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int): class PredictionResults(TypedDict):
"""Run non-maximum suppression on the output of the model.""" """Results of the prediction.
Each key is a list of length `num_detections` containing the
corresponding values for each detection.
"""
det_probs: np.ndarray
"""Detection probabilities."""
x_pos: np.ndarray
"""X position of the detection in pixels."""
y_pos: np.ndarray
"""Y position of the detection in pixels."""
bb_width: np.ndarray
"""Width of the detection in pixels."""
bb_height: np.ndarray
"""Height of the detection in pixels."""
start_times: np.ndarray
"""Start times of the detections in seconds."""
end_times: np.ndarray
"""End times of the detections in seconds."""
low_freqs: np.ndarray
"""Low frequencies of the detections in Hz."""
high_freqs: np.ndarray
"""High frequencies of the detections in Hz."""
class_probs: Optional[np.ndarray]
"""Class probabilities."""
class ModelOutputs(TypedDict):
"""Outputs of the model."""
pred_det: torch.Tensor
"""Detection probabilities."""
pred_size: torch.Tensor
"""Box sizes."""
pred_class: Optional[torch.Tensor]
"""Class probabilities."""
features: Optional[torch.Tensor]
"""Features extracted by the model."""
def run_nms(
outputs: ModelOutputs,
params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
"""Run non-maximum suppression on the output of the model.
Model outputs processed are expected to have a batch dimension.
Each element of the batch is processed independently. The
result is a pair of lists, one for the predictions and one for
the features. Each element of the lists corresponds to one
element of the batch.
"""
pred_det = outputs["pred_det"] # probability of box pred_det = outputs["pred_det"] # probability of box
pred_size = outputs["pred_size"] # box size pred_size = outputs["pred_size"] # box size
@ -62,7 +148,7 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
# as we are choosing the same sampling rate for the entire batch # as we are choosing the same sampling rate for the entire batch
duration = x_coords_to_time( duration = x_coords_to_time(
pred_det.shape[-1], pred_det.shape[-1],
sampling_rate[0].item(), int(sampling_rate[0].item()),
params["fft_win_length"], params["fft_win_length"],
params["fft_overlap"], params["fft_overlap"],
) )
@ -70,58 +156,72 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs # loop over batch to save outputs
preds = [] preds: List[PredictionResults] = []
feats = [] feats: List[np.ndarray] = []
for ii in range(pred_det_nms.shape[0]): for num_detection in range(pred_det_nms.shape[0]):
# get valid indices # get valid indices
inds_ord = torch.argsort(x_pos[ii, :]) inds_ord = torch.argsort(x_pos[num_detection, :])
valid_inds = scores[ii, inds_ord] > params["detection_threshold"] valid_inds = (
scores[num_detection, inds_ord] > params["detection_threshold"]
)
valid_inds = inds_ord[valid_inds] valid_inds = inds_ord[valid_inds]
# create result dictionary # create result dictionary
pred = {} pred = {}
pred["det_probs"] = scores[ii, valid_inds] pred["det_probs"] = scores[num_detection, valid_inds]
pred["x_pos"] = x_pos[ii, valid_inds] pred["x_pos"] = x_pos[num_detection, valid_inds]
pred["y_pos"] = y_pos[ii, valid_inds] pred["y_pos"] = y_pos[num_detection, valid_inds]
pred["bb_width"] = pred_size[ii, 0, pred["y_pos"], pred["x_pos"]] pred["bb_width"] = pred_size[
pred["bb_height"] = pred_size[ii, 1, pred["y_pos"], pred["x_pos"]] num_detection, 0, pred["y_pos"], pred["x_pos"]
]
pred["bb_height"] = pred_size[
num_detection, 1, pred["y_pos"], pred["x_pos"]
]
pred["start_times"] = x_coords_to_time( pred["start_times"] = x_coords_to_time(
pred["x_pos"].float() / params["resize_factor"], pred["x_pos"].float() / params["resize_factor"],
sampling_rate[ii].item(), int(sampling_rate[num_detection].item()),
params["fft_win_length"], params["fft_win_length"],
params["fft_overlap"], params["fft_overlap"],
) )
pred["end_times"] = x_coords_to_time( pred["end_times"] = x_coords_to_time(
(pred["x_pos"].float() + pred["bb_width"]) (pred["x_pos"].float() + pred["bb_width"])
/ params["resize_factor"], / params["resize_factor"],
sampling_rate[ii].item(), int(sampling_rate[num_detection].item()),
params["fft_win_length"], params["fft_win_length"],
params["fft_overlap"], params["fft_overlap"],
) )
pred["low_freqs"] = ( pred["low_freqs"] = (
pred_size[ii].shape[1] - pred["y_pos"].float() pred_size[num_detection].shape[1] - pred["y_pos"].float()
) * freq_rescale + params["min_freq"] ) * freq_rescale + params["min_freq"]
pred["high_freqs"] = ( pred["high_freqs"] = (
pred["low_freqs"] + pred["bb_height"] * freq_rescale pred["low_freqs"] + pred["bb_height"] * freq_rescale
) )
# extract the per class votes # extract the per class votes
if "pred_class" in outputs: pred_class = outputs.get("pred_class")
pred["class_probs"] = outputs["pred_class"][ if pred_class is not None:
ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds] pred["class_probs"] = pred_class[
num_detection,
:,
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
] ]
# extract the model features # extract the model features
if "features" in outputs: features = outputs.get("features")
feat = outputs["features"][ if features is not None:
ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds] feat = features[
num_detection,
:,
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
].transpose(0, 1) ].transpose(0, 1)
feat = feat.cpu().numpy().astype(np.float32) feat = feat.cpu().numpy().astype(np.float32)
feats.append(feat) feats.append(feat)
# convert to numpy # convert to numpy
for kk in pred.keys(): for key, value in pred.items():
pred[kk] = pred[kk].cpu().numpy().astype(np.float32) pred[key] = value.cpu().numpy().astype(np.float32)
preds.append(pred) preds.append(pred)
@ -130,7 +230,7 @@ def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
def non_max_suppression(heat, kernel_size): def non_max_suppression(heat, kernel_size):
# kernel can be an int or list/tuple # kernel can be an int or list/tuple
if type(kernel_size) is int: if isinstance(kernel_size, int):
kernel_size_h = kernel_size kernel_size_h = kernel_size
kernel_size_w = kernel_size kernel_size_w = kernel_size

View File

@ -739,7 +739,7 @@ if __name__ == "__main__":
# #
if args["bd_model_path"] != "": if args["bd_model_path"] != "":
# load model # load model
bd_args = du.get_default_bd_args() bd_args = du.get_default_run_config()
model, params_bd = du.load_model(args["bd_model_path"]) model, params_bd = du.load_model(args["bd_model_path"])
# check if the class names are the same # check if the class names are the same

View File

@ -1,7 +1,4 @@
import copy import copy
import os
import random
import sys
import librosa import librosa
import numpy as np import numpy as np
@ -9,7 +6,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
sys.path.append(os.path.join("..", ".."))
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
@ -218,7 +214,10 @@ def resample_aug(audio, sampling_rate, params):
sampling_rate_old = sampling_rate sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(params["aug_sampling_rates"]) sampling_rate = np.random.choice(params["aug_sampling_rates"])
audio = librosa.resample( audio = librosa.resample(
audio, sampling_rate_old, sampling_rate, res_type="polyphase" audio,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
) )
audio = au.pad_audio( audio = au.pad_audio(
@ -237,7 +236,10 @@ def resample_aug(audio, sampling_rate, params):
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2): def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
if sampling_rate != sampling_rate2: if sampling_rate != sampling_rate2:
audio2 = librosa.resample( audio2 = librosa.resample(
audio2, sampling_rate2, sampling_rate, res_type="polyphase" audio2,
orig_sr=sampling_rate2,
target_sr=sampling_rate,
res_type="polyphase",
) )
sampling_rate2 = sampling_rate sampling_rate2 = sampling_rate
if audio2.shape[0] < num_samples: if audio2.shape[0] < num_samples:

View File

@ -553,5 +553,6 @@ if __name__ == "__main__":
torch.save(op_state, params["model_file_name"]) torch.save(op_state, params["model_file_name"])
# save an image with associated prediction for each batch in the test set # save an image with associated prediction for each batch in the test set
if not args["do_not_save_images"]: # TODO: args variable does not exist
save_images_batch(model, test_loader, params) # if not args["do_not_save_images"]:
# save_images_batch(model, test_loader, params)

View File

@ -7,7 +7,6 @@ import torch
from . import wavfile from . import wavfile
__all__ = [ __all__ = [
"load_audio_file", "load_audio_file",
] ]
@ -163,9 +162,11 @@ def load_audio_file(
# clipping maximum duration # clipping maximum duration
if max_duration is not None: if max_duration is not None:
max_duration = np.minimum( max_duration = int(
int(sampling_rate * max_duration), np.minimum(
audio_raw.shape[0], int(sampling_rate * max_duration),
audio_raw.shape[0],
)
) )
audio_raw = audio_raw[:max_duration] audio_raw = audio_raw[:max_duration]

View File

@ -11,6 +11,20 @@ import bat_detect.detector.compute_features as feats
import bat_detect.detector.post_process as pp import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
from bat_detect.detector import models from bat_detect.detector import models
from bat_detect.detector.parameters import (
DETECTION_THRESHOLD,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MIN_FREQ_HZ,
NMS_KERNEL_SIZE,
NMS_TOP_K_PER_SEC,
RESIZE_FACTOR,
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
TARGET_SAMPLERATE_HZ,
)
try: try:
from typing import TypedDict from typing import TypedDict
@ -24,23 +38,17 @@ DEFAULT_MODEL_PATH = os.path.join(
"model.pth", "model.pth",
) )
__all__ = ["load_model", "get_audio_files", "DEFAULT_MODEL_PATH"] __all__ = [
"load_model",
"get_audio_files",
def get_default_bd_args(): "format_results",
args = {} "save_results_to_file",
args["detection_threshold"] = 0.001 "iterate_over_chunks",
args["time_expansion_factor"] = 1 "process_spectrogram",
args["audio_dir"] = "" "process_audio_array",
args["ann_dir"] = "" "process_file",
args["spec_slices"] = False "DEFAULT_MODEL_PATH",
args["chunk_size"] = 3 ]
args["spec_features"] = False
args["cnn_features"] = False
args["quiet"] = True
args["save_preds_if_empty"] = True
args["ann_dir"] = os.path.join(args["ann_dir"], "")
return args
def get_audio_files(ip_dir: str) -> List[str]: def get_audio_files(ip_dir: str) -> List[str]:
@ -80,7 +88,7 @@ class ModelParameters(TypedDict):
ip_height: int ip_height: int
"""Input height in pixels.""" """Input height in pixels."""
resize_factor: int resize_factor: float
"""Resize factor.""" """Resize factor."""
class_names: List[str] class_names: List[str]
@ -118,6 +126,8 @@ def load_model(
params = net_params["params"] params = net_params["params"]
params["device"] = device params["device"] = device
model: torch.nn.Module
if params["model_name"] == "Net2DFast": if params["model_name"] == "Net2DFast":
model = models.Net2DFast( model = models.Net2DFast(
params["num_filters"], params["num_filters"],
@ -159,9 +169,9 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
if num_preds > 0: if num_preds > 0:
for kk in predictions[0].keys(): for key in predictions[0].keys():
predictions_m[kk] = np.hstack( predictions_m[key] = np.hstack(
[pp[kk] for pp in predictions if pp["det_probs"].shape[0] > 0] [pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
) )
else: else:
# hack in case where no detected calls as we need some of the key names in dict # hack in case where no detected calls as we need some of the key names in dict
@ -176,7 +186,10 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
return predictions_m, spec_feats, cnn_feats, spec_slices return predictions_m, spec_feats, cnn_feats, spec_slices
class Annotation(TypedDict("WithClass", {"class": str})): DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations. """Format of annotations.
This is the format of a single annotation as expected by the annotation This is the format of a single annotation as expected by the annotation
@ -214,7 +227,7 @@ class FileAnnotations(TypedDict):
This is the format of the results expected by the annotation tool. This is the format of the results expected by the annotation tool.
""" """
file_id: str id: str
"""File ID.""" """File ID."""
annotated: bool annotated: bool
@ -232,26 +245,32 @@ class FileAnnotations(TypedDict):
class_name: str class_name: str
"""Class predicted at file level""" """Class predicted at file level"""
notes: str
"""Notes of file."""
annotation: List[Annotation] annotation: List[Annotation]
"""List of annotations."""
class Results(TypedDict): class RunResults(TypedDict):
"""Run results."""
pred_dict: FileAnnotations pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool.""" """Predictions in the format expected by the annotation tool."""
spec_feats: Optional[np.ndarray] spec_feats: Optional[List[np.ndarray]]
"""Spectrogram features.""" """Spectrogram features."""
spec_feat_names: Optional[List[str]] spec_feat_names: Optional[List[str]]
"""Spectrogram feature names.""" """Spectrogram feature names."""
cnn_feats: Optional[np.ndarray] cnn_feats: Optional[List[np.ndarray]]
"""CNN features.""" """CNN features."""
cnn_feat_names: Optional[List[str]] cnn_feat_names: Optional[List[str]]
"""CNN feature names.""" """CNN feature names."""
spec_slices: Optional[np.ndarray] spec_slices: Optional[List[np.ndarray]]
"""Spectrogram slices.""" """Spectrogram slices."""
@ -343,7 +362,7 @@ def convert_results(
spec_feats, spec_feats,
cnn_feats, cnn_feats,
spec_slices, spec_slices,
) -> Results: ) -> RunResults:
"""Convert results to dictionary as expected by the annotation tool. """Convert results to dictionary as expected by the annotation tool.
Args: Args:
@ -369,8 +388,14 @@ def convert_results(
) )
# combine into final results dictionary # combine into final results dictionary
results = {} results: RunResults = {
results["pred_dict"] = pred_dict "pred_dict": pred_dict,
"spec_feats": None,
"spec_feat_names": None,
"cnn_feats": None,
"cnn_feat_names": None,
"spec_slices": None,
}
# add spectrogram features if they exist # add spectrogram features if they exist
if len(spec_feats) > 0: if len(spec_feats) > 0:
@ -463,19 +488,16 @@ def save_results_to_file(results, op_path: str) -> None:
class SpectrogramParameters(TypedDict): class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms.""" """Parameters for generating spectrograms."""
fft_win_length: int fft_win_length: float
"""Length of the FFT window in samples.""" """Length of the FFT window in seconds."""
fft_overlap: int fft_overlap: float
"""Number of samples to overlap between FFT windows.""" """Percentage of overlap between FFT windows."""
spec_height: int spec_height: int
"""Height of the spectrogram in pixels.""" """Height of the spectrogram in pixels."""
spec_width: int resize_factor: float
"""Width of the spectrogram in pixels."""
resize_factor: int
"""Factor to resize the spectrogram by.""" """Factor to resize the spectrogram by."""
spec_divide_factor: int spec_divide_factor: int
@ -605,13 +627,14 @@ class ProcessingConfiguration(TypedDict):
fft_win_length: float fft_win_length: float
"""Length of the FFT window in seconds.""" """Length of the FFT window in seconds."""
fft_overlap: float fft_overlap: float
"""Length of the FFT window in samples.""" """Length of the FFT window in samples."""
resize_factor: float resize_factor: float
"""Factor to resize the spectrogram by.""" """Factor to resize the spectrogram by."""
spec_divide_factor: float spec_divide_factor: int
"""Factor to divide the spectrogram by.""" """Factor to divide the spectrogram by."""
spec_height: int spec_height: int
@ -644,27 +667,36 @@ class ProcessingConfiguration(TypedDict):
nms_kernel_size: int nms_kernel_size: int
"""Size of the kernel for non-maximum suppression.""" """Size of the kernel for non-maximum suppression."""
max_freq: float max_freq: int
"""Maximum frequency to consider in Hz.""" """Maximum frequency to consider in Hz."""
min_freq: float min_freq: int
"""Minimum frequency to consider in Hz.""" """Minimum frequency to consider in Hz."""
nms_top_k_per_sec: float nms_top_k_per_sec: float
"""Number of top detections to keep per second.""" """Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""
quiet: bool quiet: bool
"""Whether to suppress output.""" """Whether to suppress output."""
chunk_size: float
"""Size of chunks to process in seconds."""
cnn_features: bool
"""Whether to return CNN features."""
spec_features: bool
"""Whether to return spectrogram features."""
spec_slices: bool
"""Whether to return spectrogram slices."""
def process_spectrogram( def process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samplerate: int, samplerate: int,
model: torch.nn.Module, model: torch.nn.Module,
config: pp.NonMaximumSuppressionConfig, config: ProcessingConfiguration,
): ):
"""Process a spectrogram with detection model. """Process a spectrogram with detection model.
@ -692,17 +724,29 @@ def process_spectrogram(
outputs = model(spec, return_feats=config["cnn_features"]) outputs = model(spec, return_feats=config["cnn_features"])
# run non-max suppression # run non-max suppression
pred_nms, features = pp.run_nms( pred_nms_list, features = pp.run_nms(
outputs, outputs,
config, {
"nms_kernel_size": config["nms_kernel_size"],
"max_freq": config["max_freq"],
"min_freq": config["min_freq"],
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"resize_factor": config["resize_factor"],
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
"detection_threshold": config["detection_threshold"],
},
np.array([float(samplerate)]), np.array([float(samplerate)]),
) )
pred_nms = pred_nms[0] pred_nms = pred_nms_list[0]
# if we have a background class # if we have a background class
if pred_nms["class_probs"].shape[0] > len(config["class_names"]): class_probs = pred_nms.get("class_probs")
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :] if (class_probs is not None) and (
class_probs.shape[0] > len(config["class_names"])
):
pred_nms["class_probs"] = class_probs[:-1, :]
return pred_nms, features return pred_nms, features
@ -737,7 +781,14 @@ def process_audio_array(
_, spec, spec_np = compute_spectrogram( _, spec, spec_np = compute_spectrogram(
audio, audio,
sampling_rate, sampling_rate,
config, {
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"spec_height": config["spec_height"],
"resize_factor": config["resize_factor"],
"spec_divide_factor": config["spec_divide_factor"],
"device": config["device"],
},
return_np=config["spec_features"] or config["spec_slices"], return_np=config["spec_features"] or config["spec_slices"],
) )
@ -756,7 +807,7 @@ def process_file(
audio_file: str, audio_file: str,
model: torch.nn.Module, model: torch.nn.Module,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Union[Results, Any]: ) -> Union[RunResults, Any]:
"""Process a single audio file with detection model. """Process a single audio file with detection model.
Will split the audio file into chunks if it is too long and Will split the audio file into chunks if it is too long and
@ -788,7 +839,7 @@ def process_file(
# load audio file # load audio file
sampling_rate, audio_full = au.load_audio_file( sampling_rate, audio_full = au.load_audio_file(
audio_file, audio_file,
time_exp_fact=config["time_expansion"], time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"], target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"], scale=config["scale_raw_audio"],
max_duration=config["max_duration"], max_duration=config["max_duration"],
@ -840,7 +891,7 @@ def process_file(
# convert results to a dictionary in the right format # convert results to a dictionary in the right format
results = convert_results( results = convert_results(
file_id=os.path.basename(audio_file), file_id=os.path.basename(audio_file),
time_exp=config["time_expansion"], time_exp=config.get("time_expansion", 1) or 1,
duration=audio_full.shape[0] / float(sampling_rate), duration=audio_full.shape[0] / float(sampling_rate),
params=config, params=config,
predictions=predictions, predictions=predictions,
@ -877,3 +928,38 @@ def summarize_results(results, predictions, config):
config["class_names"][class_index].ljust(30) config["class_names"][class_index].ljust(30)
+ str(round(class_overall[class_index], 3)) + str(round(class_overall[class_index], 3))
) )
def get_default_run_config(**kwargs) -> ProcessingConfiguration:
"""Get default configuration for running detection model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False,
"chunk_size": 3,
"spec_features": False,
"cnn_features": False,
"quiet": True,
"target_samp_rate": TARGET_SAMPLERATE_HZ,
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO,
"device": device,
"class_names": [],
"time_expansion": 1,
"top_n": 3,
"return_raw_preds": False,
"max_duration": None,
"nms_kernel_size": NMS_KERNEL_SIZE,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
}
return {
**args,
**kwargs,
}

View File

@ -523,7 +523,7 @@ class LossPlotter(object):
def save_confusion_matrix(self, gt, pred): def save_confusion_matrix(self, gt, pred):
plt.figure(0) plt.figure(0)
cm = confusion_matrix( cm = confusion_matrix(
gt, pred, np.arange(len(self.class_names)) gt, pred, labels=np.arange(len(self.class_names))
).astype(np.float32) ).astype(np.float32)
cm_norm = cm.sum(1) cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0] valid_inds = np.where(cm_norm > 0)[0]

View File

@ -49,3 +49,10 @@ batdetect2 = "bat_detect.command:main"
[tool.black] [tool.black]
line-length = 80 line-length = 80
[[tool.mypy.overrides]]
module = [
"librosa",
"pandas",
]
ignore_missing_imports = true

View File

@ -86,7 +86,7 @@ if __name__ == "__main__":
args_cmd = vars(parser.parse_args()) args_cmd = vars(parser.parse_args())
# load the model # load the model
bd_args = du.get_default_bd_args() bd_args = du.get_default_run_config()
model, params_bd = du.load_model(args_cmd["model_path"]) model, params_bd = du.load_model(args_cmd["model_path"])
bd_args["detection_threshold"] = args_cmd["detection_threshold"] bd_args["detection_threshold"] = args_cmd["detection_threshold"]
bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"] bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]

View File

@ -89,7 +89,7 @@ if __name__ == "__main__":
os.makedirs(op_dir) os.makedirs(op_dir)
params = parameters.get_params(False) params = parameters.get_params(False)
args = du.get_default_bd_args() args = du.get_default_run_config()
args["time_expansion_factor"] = args_cmd["time_expansion_factor"] args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
args["detection_threshold"] = args_cmd["detection_threshold"] args["detection_threshold"] = args_cmd["detection_threshold"]