mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
update run_batdetect command
This commit is contained in:
parent
a4f99dd26a
commit
74e8283576
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,9 +8,47 @@ import torch.nn.functional as F
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
import bat_detect.utils.audio_utils as au
|
import bat_detect.utils.audio_utils as au
|
||||||
|
from bat_detect.types import AnnotationGroup, HeatmapParameters
|
||||||
|
|
||||||
|
|
||||||
def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
def generate_gt_heatmaps(
|
||||||
|
spec_op_shape: Tuple[int, int],
|
||||||
|
sampling_rate: int,
|
||||||
|
ann: AnnotationGroup,
|
||||||
|
params: HeatmapParameters,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
|
||||||
|
"""Generate ground truth heatmaps from annotations.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec_op_shape : Tuple[int, int]
|
||||||
|
Shape of the input spectrogram.
|
||||||
|
sampling_rate : int
|
||||||
|
Sampling rate of the input audio in Hz.
|
||||||
|
ann : AnnotationGroup
|
||||||
|
Dictionary containing the annotation information.
|
||||||
|
params : HeatmapParameters
|
||||||
|
Parameters controlling the generation of the heatmaps.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
|
||||||
|
y_2d_det : np.ndarray
|
||||||
|
2D heatmap of the presence of an event.
|
||||||
|
|
||||||
|
y_2d_size : np.ndarray
|
||||||
|
2D heatmap of the size of the bounding box associated to event.
|
||||||
|
|
||||||
|
y_2d_classes : np.ndarray
|
||||||
|
3D array containing the ground-truth class probabilities for each
|
||||||
|
pixel.
|
||||||
|
|
||||||
|
ann_aug : AnnotationGroup
|
||||||
|
A dictionary containing the annotation information of the
|
||||||
|
annotations that are within the input spectrogram, augmented with
|
||||||
|
the x and y indices of their pixel location in the input spectrogram.
|
||||||
|
|
||||||
|
"""
|
||||||
# spec may be resized on input into the network
|
# spec may be resized on input into the network
|
||||||
num_classes = len(params["class_names"])
|
num_classes = len(params["class_names"])
|
||||||
op_height = spec_op_shape[0]
|
op_height = spec_op_shape[0]
|
||||||
@ -40,6 +79,7 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
|||||||
bb_widths = x_pos_end - x_pos_start
|
bb_widths = x_pos_end - x_pos_start
|
||||||
bb_heights = y_pos_low - y_pos_high
|
bb_heights = y_pos_low - y_pos_high
|
||||||
|
|
||||||
|
# Only include annotations that are within the input spectrogram
|
||||||
valid_inds = np.where(
|
valid_inds = np.where(
|
||||||
(x_pos_start >= 0)
|
(x_pos_start >= 0)
|
||||||
& (x_pos_start < op_width)
|
& (x_pos_start < op_width)
|
||||||
@ -47,19 +87,26 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
|||||||
& (y_pos_low < (op_height - 1))
|
& (y_pos_low < (op_height - 1))
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
ann_aug = {}
|
ann_aug: AnnotationGroup = {
|
||||||
|
"start_times": ann["start_times"][valid_inds],
|
||||||
|
"end_times": ann["end_times"][valid_inds],
|
||||||
|
"high_freqs": ann["high_freqs"][valid_inds],
|
||||||
|
"low_freqs": ann["low_freqs"][valid_inds],
|
||||||
|
"class_ids": ann["class_ids"][valid_inds],
|
||||||
|
"individual_ids": ann["individual_ids"][valid_inds],
|
||||||
|
}
|
||||||
ann_aug["x_inds"] = x_pos_start[valid_inds]
|
ann_aug["x_inds"] = x_pos_start[valid_inds]
|
||||||
ann_aug["y_inds"] = y_pos_low[valid_inds]
|
ann_aug["y_inds"] = y_pos_low[valid_inds]
|
||||||
keys = [
|
# keys = [
|
||||||
"start_times",
|
# "start_times",
|
||||||
"end_times",
|
# "end_times",
|
||||||
"high_freqs",
|
# "high_freqs",
|
||||||
"low_freqs",
|
# "low_freqs",
|
||||||
"class_ids",
|
# "class_ids",
|
||||||
"individual_ids",
|
# "individual_ids",
|
||||||
]
|
# ]
|
||||||
for kk in keys:
|
# for kk in keys:
|
||||||
ann_aug[kk] = ann[kk][valid_inds]
|
# ann_aug[kk] = ann[kk][valid_inds]
|
||||||
|
|
||||||
# if the number of calls is only 1, then it is unique
|
# if the number of calls is only 1, then it is unique
|
||||||
# TODO would be better if we found these unique calls at the merging stage
|
# TODO would be better if we found these unique calls at the merging stage
|
||||||
@ -69,7 +116,7 @@ def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
|||||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
||||||
# num classes and "background" class
|
# num classes and "background" class
|
||||||
y_2d_classes = np.zeros(
|
y_2d_classes: np.ndarray = np.zeros(
|
||||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
(num_classes + 1, op_height, op_width), dtype=np.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -416,3 +416,60 @@ class NonMaximumSuppressionConfig(TypedDict):
|
|||||||
|
|
||||||
detection_threshold: float
|
detection_threshold: float
|
||||||
"""Threshold for detection probability."""
|
"""Threshold for detection probability."""
|
||||||
|
|
||||||
|
|
||||||
|
class HeatmapParameters(TypedDict):
|
||||||
|
"""Parameters that control the heatmap generation function."""
|
||||||
|
|
||||||
|
class_names: List[str]
|
||||||
|
|
||||||
|
fft_win_length: float
|
||||||
|
"""Length of the FFT window in seconds."""
|
||||||
|
|
||||||
|
fft_overlap: float
|
||||||
|
"""Percentage of the FFT windows overlap."""
|
||||||
|
|
||||||
|
resize_factor: float
|
||||||
|
"""Factor by which the input was resized."""
|
||||||
|
|
||||||
|
min_freq: int
|
||||||
|
"""Minimum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
max_freq: int
|
||||||
|
"""Maximum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
target_sigma: float
|
||||||
|
"""Sigma for the Gaussian kernel. Controls the width of the points in
|
||||||
|
the heatmap."""
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationGroup(TypedDict):
|
||||||
|
"""Group of annotations.
|
||||||
|
|
||||||
|
Each key is a numpy array of length `num_annotations` containing the
|
||||||
|
corresponding values for each annotation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_times: np.ndarray
|
||||||
|
"""Start times of the annotations in seconds."""
|
||||||
|
|
||||||
|
end_times: np.ndarray
|
||||||
|
"""End times of the annotations in seconds."""
|
||||||
|
|
||||||
|
low_freqs: np.ndarray
|
||||||
|
"""Low frequencies of the annotations in Hz."""
|
||||||
|
|
||||||
|
high_freqs: np.ndarray
|
||||||
|
"""High frequencies of the annotations in Hz."""
|
||||||
|
|
||||||
|
class_ids: np.ndarray
|
||||||
|
"""Class IDs of the annotations."""
|
||||||
|
|
||||||
|
individual_ids: np.ndarray
|
||||||
|
"""Individual IDs of the annotations."""
|
||||||
|
|
||||||
|
x_inds: NotRequired[np.ndarray]
|
||||||
|
"""X coordinate of the annotations in the spectrogram."""
|
||||||
|
|
||||||
|
y_inds: NotRequired[np.ndarray]
|
||||||
|
"""Y coordinate of the annotations in the spectrogram."""
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Run bat_detect.command.main() from the command line."""
|
"""Run bat_detect.command.main() from the command line."""
|
||||||
from bat_detect.command import main
|
from bat_detect.cli import main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -28,8 +28,6 @@ def test_load_model_with_default_params():
|
|||||||
assert "num_filters" in params
|
assert "num_filters" in params
|
||||||
assert "emb_dim" in params
|
assert "emb_dim" in params
|
||||||
assert "ip_height" in params
|
assert "ip_height" in params
|
||||||
assert "resize_factor" in params
|
|
||||||
assert "class_names" in params
|
|
||||||
|
|
||||||
assert params["model_name"] == "Net2DFast"
|
assert params["model_name"] == "Net2DFast"
|
||||||
assert params["num_filters"] == 128
|
assert params["num_filters"] == 128
|
||||||
|
Loading…
Reference in New Issue
Block a user