mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
update run_batdetect command
This commit is contained in:
parent
a4f99dd26a
commit
74e8283576
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from typing import Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
@ -7,9 +8,47 @@ import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
import bat_detect.utils.audio_utils as au
|
||||
from bat_detect.types import AnnotationGroup, HeatmapParameters
|
||||
|
||||
|
||||
def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
||||
def generate_gt_heatmaps(
|
||||
spec_op_shape: Tuple[int, int],
|
||||
sampling_rate: int,
|
||||
ann: AnnotationGroup,
|
||||
params: HeatmapParameters,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
|
||||
"""Generate ground truth heatmaps from annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec_op_shape : Tuple[int, int]
|
||||
Shape of the input spectrogram.
|
||||
sampling_rate : int
|
||||
Sampling rate of the input audio in Hz.
|
||||
ann : AnnotationGroup
|
||||
Dictionary containing the annotation information.
|
||||
params : HeatmapParameters
|
||||
Parameters controlling the generation of the heatmaps.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
y_2d_det : np.ndarray
|
||||
2D heatmap of the presence of an event.
|
||||
|
||||
y_2d_size : np.ndarray
|
||||
2D heatmap of the size of the bounding box associated to event.
|
||||
|
||||
y_2d_classes : np.ndarray
|
||||
3D array containing the ground-truth class probabilities for each
|
||||
pixel.
|
||||
|
||||
ann_aug : AnnotationGroup
|
||||
A dictionary containing the annotation information of the
|
||||
annotations that are within the input spectrogram, augmented with
|
||||
the x and y indices of their pixel location in the input spectrogram.
|
||||
|
||||
"""
|
||||
# spec may be resized on input into the network
|
||||
num_classes = len(params["class_names"])
|
||||
op_height = spec_op_shape[0]
|
||||
@ -40,6 +79,7 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = y_pos_low - y_pos_high
|
||||
|
||||
# Only include annotations that are within the input spectrogram
|
||||
valid_inds = np.where(
|
||||
(x_pos_start >= 0)
|
||||
& (x_pos_start < op_width)
|
||||
@ -47,19 +87,26 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
||||
& (y_pos_low < (op_height - 1))
|
||||
)[0]
|
||||
|
||||
ann_aug = {}
|
||||
ann_aug: AnnotationGroup = {
|
||||
"start_times": ann["start_times"][valid_inds],
|
||||
"end_times": ann["end_times"][valid_inds],
|
||||
"high_freqs": ann["high_freqs"][valid_inds],
|
||||
"low_freqs": ann["low_freqs"][valid_inds],
|
||||
"class_ids": ann["class_ids"][valid_inds],
|
||||
"individual_ids": ann["individual_ids"][valid_inds],
|
||||
}
|
||||
ann_aug["x_inds"] = x_pos_start[valid_inds]
|
||||
ann_aug["y_inds"] = y_pos_low[valid_inds]
|
||||
keys = [
|
||||
"start_times",
|
||||
"end_times",
|
||||
"high_freqs",
|
||||
"low_freqs",
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
]
|
||||
for kk in keys:
|
||||
ann_aug[kk] = ann[kk][valid_inds]
|
||||
# keys = [
|
||||
# "start_times",
|
||||
# "end_times",
|
||||
# "high_freqs",
|
||||
# "low_freqs",
|
||||
# "class_ids",
|
||||
# "individual_ids",
|
||||
# ]
|
||||
# for kk in keys:
|
||||
# ann_aug[kk] = ann[kk][valid_inds]
|
||||
|
||||
# if the number of calls is only 1, then it is unique
|
||||
# TODO would be better if we found these unique calls at the merging stage
|
||||
@ -69,7 +116,7 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
||||
# num classes and "background" class
|
||||
y_2d_classes = np.zeros(
|
||||
y_2d_classes: np.ndarray = np.zeros(
|
||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
||||
)
|
||||
|
||||
|
@ -416,3 +416,60 @@ class NonMaximumSuppressionConfig(TypedDict):
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
|
||||
class HeatmapParameters(TypedDict):
|
||||
"""Parameters that control the heatmap generation function."""
|
||||
|
||||
class_names: List[str]
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Percentage of the FFT windows overlap."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input was resized."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
target_sigma: float
|
||||
"""Sigma for the Gaussian kernel. Controls the width of the points in
|
||||
the heatmap."""
|
||||
|
||||
|
||||
class AnnotationGroup(TypedDict):
|
||||
"""Group of annotations.
|
||||
|
||||
Each key is a numpy array of length `num_annotations` containing the
|
||||
corresponding values for each annotation.
|
||||
"""
|
||||
|
||||
start_times: np.ndarray
|
||||
"""Start times of the annotations in seconds."""
|
||||
|
||||
end_times: np.ndarray
|
||||
"""End times of the annotations in seconds."""
|
||||
|
||||
low_freqs: np.ndarray
|
||||
"""Low frequencies of the annotations in Hz."""
|
||||
|
||||
high_freqs: np.ndarray
|
||||
"""High frequencies of the annotations in Hz."""
|
||||
|
||||
class_ids: np.ndarray
|
||||
"""Class IDs of the annotations."""
|
||||
|
||||
individual_ids: np.ndarray
|
||||
"""Individual IDs of the annotations."""
|
||||
|
||||
x_inds: NotRequired[np.ndarray]
|
||||
"""X coordinate of the annotations in the spectrogram."""
|
||||
|
||||
y_inds: NotRequired[np.ndarray]
|
||||
"""Y coordinate of the annotations in the spectrogram."""
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Run bat_detect.command.main() from the command line."""
|
||||
from bat_detect.command import main
|
||||
from bat_detect.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -28,8 +28,6 @@ def test_load_model_with_default_params():
|
||||
assert "num_filters" in params
|
||||
assert "emb_dim" in params
|
||||
assert "ip_height" in params
|
||||
assert "resize_factor" in params
|
||||
assert "class_names" in params
|
||||
|
||||
assert params["model_name"] == "Net2DFast"
|
||||
assert params["num_filters"] == 128
|
||||
|
Loading…
Reference in New Issue
Block a user