update run_batdetect command

This commit is contained in:
Santiago Martinez 2023-03-20 11:09:41 +00:00
parent a4f99dd26a
commit 74e8283576
4 changed files with 118 additions and 16 deletions

View File

@ -1,4 +1,5 @@
import copy import copy
from typing import Tuple
import librosa import librosa
import numpy as np import numpy as np
@ -7,9 +8,47 @@ import torch.nn.functional as F
import torchaudio import torchaudio
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
from bat_detect.types import AnnotationGroup, HeatmapParameters
def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params): def generate_gt_heatmaps(
spec_op_shape: Tuple[int, int],
sampling_rate: int,
ann: AnnotationGroup,
params: HeatmapParameters,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
"""Generate ground truth heatmaps from annotations.
Parameters
----------
spec_op_shape : Tuple[int, int]
Shape of the input spectrogram.
sampling_rate : int
Sampling rate of the input audio in Hz.
ann : AnnotationGroup
Dictionary containing the annotation information.
params : HeatmapParameters
Parameters controlling the generation of the heatmaps.
Returns
-------
y_2d_det : np.ndarray
2D heatmap of the presence of an event.
y_2d_size : np.ndarray
2D heatmap of the size of the bounding box associated to event.
y_2d_classes : np.ndarray
3D array containing the ground-truth class probabilities for each
pixel.
ann_aug : AnnotationGroup
A dictionary containing the annotation information of the
annotations that are within the input spectrogram, augmented with
the x and y indices of their pixel location in the input spectrogram.
"""
# spec may be resized on input into the network # spec may be resized on input into the network
num_classes = len(params["class_names"]) num_classes = len(params["class_names"])
op_height = spec_op_shape[0] op_height = spec_op_shape[0]
@ -40,6 +79,7 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
bb_widths = x_pos_end - x_pos_start bb_widths = x_pos_end - x_pos_start
bb_heights = y_pos_low - y_pos_high bb_heights = y_pos_low - y_pos_high
# Only include annotations that are within the input spectrogram
valid_inds = np.where( valid_inds = np.where(
(x_pos_start >= 0) (x_pos_start >= 0)
& (x_pos_start < op_width) & (x_pos_start < op_width)
@ -47,19 +87,26 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
& (y_pos_low < (op_height - 1)) & (y_pos_low < (op_height - 1))
)[0] )[0]
ann_aug = {} ann_aug: AnnotationGroup = {
"start_times": ann["start_times"][valid_inds],
"end_times": ann["end_times"][valid_inds],
"high_freqs": ann["high_freqs"][valid_inds],
"low_freqs": ann["low_freqs"][valid_inds],
"class_ids": ann["class_ids"][valid_inds],
"individual_ids": ann["individual_ids"][valid_inds],
}
ann_aug["x_inds"] = x_pos_start[valid_inds] ann_aug["x_inds"] = x_pos_start[valid_inds]
ann_aug["y_inds"] = y_pos_low[valid_inds] ann_aug["y_inds"] = y_pos_low[valid_inds]
keys = [ # keys = [
"start_times", # "start_times",
"end_times", # "end_times",
"high_freqs", # "high_freqs",
"low_freqs", # "low_freqs",
"class_ids", # "class_ids",
"individual_ids", # "individual_ids",
] # ]
for kk in keys: # for kk in keys:
ann_aug[kk] = ann[kk][valid_inds] # ann_aug[kk] = ann[kk][valid_inds]
# if the number of calls is only 1, then it is unique # if the number of calls is only 1, then it is unique
# TODO would be better if we found these unique calls at the merging stage # TODO would be better if we found these unique calls at the merging stage
@ -69,7 +116,7 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32) y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32) y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
# num classes and "background" class # num classes and "background" class
y_2d_classes = np.zeros( y_2d_classes: np.ndarray = np.zeros(
(num_classes + 1, op_height, op_width), dtype=np.float32 (num_classes + 1, op_height, op_width), dtype=np.float32
) )

View File

@ -416,3 +416,60 @@ class NonMaximumSuppressionConfig(TypedDict):
detection_threshold: float detection_threshold: float
"""Threshold for detection probability.""" """Threshold for detection probability."""
class HeatmapParameters(TypedDict):
"""Parameters that control the heatmap generation function."""
class_names: List[str]
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of the FFT windows overlap."""
resize_factor: float
"""Factor by which the input was resized."""
min_freq: int
"""Minimum frequency to consider in Hz."""
max_freq: int
"""Maximum frequency to consider in Hz."""
target_sigma: float
"""Sigma for the Gaussian kernel. Controls the width of the points in
the heatmap."""
class AnnotationGroup(TypedDict):
"""Group of annotations.
Each key is a numpy array of length `num_annotations` containing the
corresponding values for each annotation.
"""
start_times: np.ndarray
"""Start times of the annotations in seconds."""
end_times: np.ndarray
"""End times of the annotations in seconds."""
low_freqs: np.ndarray
"""Low frequencies of the annotations in Hz."""
high_freqs: np.ndarray
"""High frequencies of the annotations in Hz."""
class_ids: np.ndarray
"""Class IDs of the annotations."""
individual_ids: np.ndarray
"""Individual IDs of the annotations."""
x_inds: NotRequired[np.ndarray]
"""X coordinate of the annotations in the spectrogram."""
y_inds: NotRequired[np.ndarray]
"""Y coordinate of the annotations in the spectrogram."""

View File

@ -1,5 +1,5 @@
"""Run bat_detect.command.main() from the command line.""" """Run bat_detect.command.main() from the command line."""
from bat_detect.command import main from bat_detect.cli import main
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -28,8 +28,6 @@ def test_load_model_with_default_params():
assert "num_filters" in params assert "num_filters" in params
assert "emb_dim" in params assert "emb_dim" in params
assert "ip_height" in params assert "ip_height" in params
assert "resize_factor" in params
assert "class_names" in params
assert params["model_name"] == "Net2DFast" assert params["model_name"] == "Net2DFast"
assert params["num_filters"] == 128 assert params["num_filters"] == 128