diff --git a/bat_detect/train/audio_dataloader.py b/bat_detect/train/audio_dataloader.py index cce8255..6d4d9d8 100644 --- a/bat_detect/train/audio_dataloader.py +++ b/bat_detect/train/audio_dataloader.py @@ -1,4 +1,5 @@ import copy +from typing import Tuple import librosa import numpy as np @@ -7,9 +8,47 @@ import torch.nn.functional as F import torchaudio import bat_detect.utils.audio_utils as au +from bat_detect.types import AnnotationGroup, HeatmapParameters -def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params): +def generate_gt_heatmaps( + spec_op_shape: Tuple[int, int], + sampling_rate: int, + ann: AnnotationGroup, + params: HeatmapParameters, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]: + """Generate ground truth heatmaps from annotations. + + Parameters + ---------- + spec_op_shape : Tuple[int, int] + Shape of the input spectrogram. + sampling_rate : int + Sampling rate of the input audio in Hz. + ann : AnnotationGroup + Dictionary containing the annotation information. + params : HeatmapParameters + Parameters controlling the generation of the heatmaps. + + Returns + ------- + + y_2d_det : np.ndarray + 2D heatmap of the presence of an event. + + y_2d_size : np.ndarray + 2D heatmap of the size of the bounding box associated to event. + + y_2d_classes : np.ndarray + 3D array containing the ground-truth class probabilities for each + pixel. + + ann_aug : AnnotationGroup + A dictionary containing the annotation information of the + annotations that are within the input spectrogram, augmented with + the x and y indices of their pixel location in the input spectrogram. + + """ # spec may be resized on input into the network num_classes = len(params["class_names"]) op_height = spec_op_shape[0] @@ -40,6 +79,7 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params): bb_widths = x_pos_end - x_pos_start bb_heights = y_pos_low - y_pos_high + # Only include annotations that are within the input spectrogram valid_inds = np.where( (x_pos_start >= 0) & (x_pos_start < op_width) @@ -47,19 +87,26 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params): & (y_pos_low < (op_height - 1)) )[0] - ann_aug = {} + ann_aug: AnnotationGroup = { + "start_times": ann["start_times"][valid_inds], + "end_times": ann["end_times"][valid_inds], + "high_freqs": ann["high_freqs"][valid_inds], + "low_freqs": ann["low_freqs"][valid_inds], + "class_ids": ann["class_ids"][valid_inds], + "individual_ids": ann["individual_ids"][valid_inds], + } ann_aug["x_inds"] = x_pos_start[valid_inds] ann_aug["y_inds"] = y_pos_low[valid_inds] - keys = [ - "start_times", - "end_times", - "high_freqs", - "low_freqs", - "class_ids", - "individual_ids", - ] - for kk in keys: - ann_aug[kk] = ann[kk][valid_inds] + # keys = [ + # "start_times", + # "end_times", + # "high_freqs", + # "low_freqs", + # "class_ids", + # "individual_ids", + # ] + # for kk in keys: + # ann_aug[kk] = ann[kk][valid_inds] # if the number of calls is only 1, then it is unique # TODO would be better if we found these unique calls at the merging stage @@ -69,7 +116,7 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params): y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32) y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32) # num classes and "background" class - y_2d_classes = np.zeros( + y_2d_classes: np.ndarray = np.zeros( (num_classes + 1, op_height, op_width), dtype=np.float32 ) diff --git a/bat_detect/types.py b/bat_detect/types.py index 5e20c48..3bc810b 100644 --- a/bat_detect/types.py +++ b/bat_detect/types.py @@ -416,3 +416,60 @@ class NonMaximumSuppressionConfig(TypedDict): detection_threshold: float """Threshold for detection probability.""" + + +class HeatmapParameters(TypedDict): + """Parameters that control the heatmap generation function.""" + + class_names: List[str] + + fft_win_length: float + """Length of the FFT window in seconds.""" + + fft_overlap: float + """Percentage of the FFT windows overlap.""" + + resize_factor: float + """Factor by which the input was resized.""" + + min_freq: int + """Minimum frequency to consider in Hz.""" + + max_freq: int + """Maximum frequency to consider in Hz.""" + + target_sigma: float + """Sigma for the Gaussian kernel. Controls the width of the points in + the heatmap.""" + + +class AnnotationGroup(TypedDict): + """Group of annotations. + + Each key is a numpy array of length `num_annotations` containing the + corresponding values for each annotation. + """ + + start_times: np.ndarray + """Start times of the annotations in seconds.""" + + end_times: np.ndarray + """End times of the annotations in seconds.""" + + low_freqs: np.ndarray + """Low frequencies of the annotations in Hz.""" + + high_freqs: np.ndarray + """High frequencies of the annotations in Hz.""" + + class_ids: np.ndarray + """Class IDs of the annotations.""" + + individual_ids: np.ndarray + """Individual IDs of the annotations.""" + + x_inds: NotRequired[np.ndarray] + """X coordinate of the annotations in the spectrogram.""" + + y_inds: NotRequired[np.ndarray] + """Y coordinate of the annotations in the spectrogram.""" diff --git a/run_batdetect.py b/run_batdetect.py index b2c5230..54bac02 100644 --- a/run_batdetect.py +++ b/run_batdetect.py @@ -1,5 +1,5 @@ """Run bat_detect.command.main() from the command line.""" -from bat_detect.command import main +from bat_detect.cli import main if __name__ == "__main__": main() diff --git a/tests/test_api.py b/tests/test_api.py index 927d7be..8158a1f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -28,8 +28,6 @@ def test_load_model_with_default_params(): assert "num_filters" in params assert "emb_dim" in params assert "ip_height" in params - assert "resize_factor" in params - assert "class_names" in params assert params["model_name"] == "Net2DFast" assert params["num_filters"] == 128