mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Merge pull request #16 from macaodha/fix/GH-15-spectrogram-features-computation
Fix/gh 15 spectrogram features computation
This commit is contained in:
commit
70877495d4
2
.gitignore
vendored
2
.gitignore
vendored
@ -65,7 +65,7 @@ ipython_config.py
|
|||||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
# in version control.
|
# in version control.
|
||||||
# https://pdm.fming.dev/#use-with-ide
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
.pdm.toml
|
.pdm-python
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
__pypackages__/
|
__pypackages__/
|
||||||
|
@ -1,22 +1,27 @@
|
|||||||
|
"""Functions to compute features from predictions."""
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from batdetect2 import types
|
||||||
|
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
|
||||||
|
|
||||||
|
|
||||||
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
||||||
|
"""Convert spectrogram index to frequency in Hz.""" ""
|
||||||
spec_ind = spec_height - spec_ind
|
spec_ind = spec_height - spec_ind
|
||||||
return round(
|
return round(
|
||||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_spec_slices(spec, pred_nms, params):
|
def extract_spec_slices(spec, pred_nms):
|
||||||
"""
|
"""Extract spectrogram slices from spectrogram.
|
||||||
Extracts spectrogram slices from spectrogram based on detected call locations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
The slices are extracted based on detected call locations.
|
||||||
|
"""
|
||||||
x_pos = pred_nms["x_pos"]
|
x_pos = pred_nms["x_pos"]
|
||||||
y_pos = pred_nms["y_pos"]
|
|
||||||
bb_width = pred_nms["bb_width"]
|
bb_width = pred_nms["bb_width"]
|
||||||
bb_height = pred_nms["bb_height"]
|
|
||||||
slices = []
|
slices = []
|
||||||
|
|
||||||
# add 20% padding either side of call
|
# add 20% padding either side of call
|
||||||
@ -35,100 +40,273 @@ def extract_spec_slices(spec, pred_nms, params):
|
|||||||
return slices
|
return slices
|
||||||
|
|
||||||
|
|
||||||
def get_feature_names():
|
def compute_duration(
|
||||||
feature_names = [
|
prediction: types.Prediction,
|
||||||
"duration",
|
**_,
|
||||||
"low_freq_bb",
|
) -> float:
|
||||||
"high_freq_bb",
|
"""Compute duration of call in seconds."""
|
||||||
"bandwidth",
|
return round(prediction["end_time"] - prediction["start_time"], 5)
|
||||||
"max_power_bb",
|
|
||||||
"max_power",
|
|
||||||
"max_power_first",
|
|
||||||
"max_power_second",
|
|
||||||
"call_interval",
|
|
||||||
]
|
|
||||||
return feature_names
|
|
||||||
|
|
||||||
|
|
||||||
def get_feats(spec, pred_nms, params):
|
def compute_low_freq(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute lowest frequency in call in Hz."""
|
||||||
|
return int(prediction["low_freq"])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_high_freq(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute highest frequency in call in Hz."""
|
||||||
|
return int(prediction["high_freq"])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bandwidth(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute bandwidth of call in Hz."""
|
||||||
|
return int(prediction["high_freq"] - prediction["low_freq"])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_power_bb(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
spec: Optional[np.ndarray] = None,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute frequency with maximum power in call in Hz.
|
||||||
|
|
||||||
|
This is the frequency with the maximum power in the bounding box of the
|
||||||
|
call.
|
||||||
"""
|
"""
|
||||||
Extracts features from spectrogram based on detected call locations.
|
if spec is None:
|
||||||
Condsider re-extracting spectrogram for this to get better temporal resolution.
|
return np.nan
|
||||||
|
|
||||||
|
x_start = max(0, prediction["x_pos"])
|
||||||
|
x_end = min(
|
||||||
|
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# y low is the lowest freq but it will have a higher value due to array
|
||||||
|
# starting at 0 at top
|
||||||
|
y_low = min(spec.shape[0] - 1, prediction["y_pos"])
|
||||||
|
y_high = max(0, prediction["y_pos"] - prediction["bb_height"])
|
||||||
|
|
||||||
|
spec_bb = spec[y_high:y_low, x_start:x_end]
|
||||||
|
power_per_freq_band = np.sum(spec_bb, axis=1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
|
except ValueError:
|
||||||
|
# If the call is too short, the bounding box might be empty.
|
||||||
|
# In this case, return NaN.
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
return int(
|
||||||
|
convert_int_to_freq(
|
||||||
|
y_high + max_power_ind,
|
||||||
|
spec.shape[0],
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_power(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
spec: Optional[np.ndarray] = None,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute frequency with maximum power in during the call in Hz."""
|
||||||
|
if spec is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
x_start = max(0, prediction["x_pos"])
|
||||||
|
x_end = min(
|
||||||
|
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||||
|
)
|
||||||
|
spec_call = spec[:, x_start:x_end]
|
||||||
|
power_per_freq_band = np.sum(spec_call, axis=1)
|
||||||
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
|
return int(
|
||||||
|
convert_int_to_freq(
|
||||||
|
max_power_ind,
|
||||||
|
spec.shape[0],
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_power_first(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
spec: Optional[np.ndarray] = None,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute frequency with maximum power in first half of call in Hz."""
|
||||||
|
if spec is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
x_start = max(0, prediction["x_pos"])
|
||||||
|
x_end = min(
|
||||||
|
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||||
|
)
|
||||||
|
spec_call = spec[:, x_start:x_end]
|
||||||
|
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
|
||||||
|
power_per_freq_band = np.sum(first_half, axis=1)
|
||||||
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
|
return int(
|
||||||
|
convert_int_to_freq(
|
||||||
|
max_power_ind,
|
||||||
|
spec.shape[0],
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_power_second(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
spec: Optional[np.ndarray] = None,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute frequency with maximum power in second half of call in Hz."""
|
||||||
|
if spec is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
x_start = max(0, prediction["x_pos"])
|
||||||
|
x_end = min(
|
||||||
|
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||||
|
)
|
||||||
|
spec_call = spec[:, x_start:x_end]
|
||||||
|
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
|
||||||
|
power_per_freq_band = np.sum(second_half, axis=1)
|
||||||
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
|
return int(
|
||||||
|
convert_int_to_freq(
|
||||||
|
max_power_ind,
|
||||||
|
spec.shape[0],
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_call_interval(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
previous: Optional[types.Prediction] = None,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute time between this call and the previous call in seconds."""
|
||||||
|
if previous is None:
|
||||||
|
return np.nan
|
||||||
|
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: The order of the features in this dictionary is important. The
|
||||||
|
# features are extracted in this order and the order of the columns in the
|
||||||
|
# output csv file is determined by this order. In order to avoid breaking
|
||||||
|
# changes in the output csv file, new features should be added to the end of
|
||||||
|
# this dictionary.
|
||||||
|
FEATURES: Dict[str, types.FeatureExtractor] = {
|
||||||
|
"duration": compute_duration,
|
||||||
|
"low_freq_bb": compute_low_freq,
|
||||||
|
"high_freq_bb": compute_high_freq,
|
||||||
|
"bandwidth": compute_bandwidth,
|
||||||
|
"max_power_bb": compute_max_power_bb,
|
||||||
|
"max_power": compute_max_power,
|
||||||
|
"max_power_first": compute_max_power_first,
|
||||||
|
"max_power_second": compute_max_power_second,
|
||||||
|
"call_interval": compute_call_interval,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_feats(
|
||||||
|
spec: np.ndarray,
|
||||||
|
pred_nms: types.PredictionResults,
|
||||||
|
params: types.FeatureExtractionParameters,
|
||||||
|
):
|
||||||
|
"""Extract features from spectrogram based on detected call locations.
|
||||||
|
|
||||||
|
The features extracted are:
|
||||||
|
|
||||||
|
- duration: duration of call in seconds
|
||||||
|
- low_freq: lowest frequency in call in kHz
|
||||||
|
- high_freq: highest frequency in call in kHz
|
||||||
|
- bandwidth: high_freq - low_freq
|
||||||
|
- max_power_bb: frequency with maximum power in call in kHz
|
||||||
|
- max_power: frequency with maximum power in spectrogram in kHz
|
||||||
|
- max_power_first: frequency with maximum power in first half of call in
|
||||||
|
kHz.
|
||||||
|
- max_power_second: frequency with maximum power in second half of call in
|
||||||
|
kHz.
|
||||||
|
- call_interval: time between this call and the previous call in seconds
|
||||||
|
|
||||||
|
Consider re-extracting spectrogram for this to get better temporal
|
||||||
|
resolution.
|
||||||
|
|
||||||
For more possible features check out:
|
For more possible features check out:
|
||||||
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
|
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : np.ndarray
|
||||||
|
Spectrogram from which to extract features.
|
||||||
|
|
||||||
|
pred_nms : types.PredictionResults
|
||||||
|
Information about detected calls from which to extract features.
|
||||||
|
|
||||||
|
params : types.FeatureExtractionParameters
|
||||||
|
Parameters for feature extraction.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features : np.ndarray
|
||||||
|
Extracted features for each detected call. Shape is
|
||||||
|
(num_detections, num_features).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x_pos = pred_nms["x_pos"]
|
|
||||||
y_pos = pred_nms["y_pos"]
|
|
||||||
bb_width = pred_nms["bb_width"]
|
|
||||||
bb_height = pred_nms["bb_height"]
|
|
||||||
|
|
||||||
feature_names = get_feature_names()
|
|
||||||
num_detections = len(pred_nms["det_probs"])
|
num_detections = len(pred_nms["det_probs"])
|
||||||
features = (
|
features = np.empty((num_detections, len(FEATURES)), dtype=np.float32)
|
||||||
np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1
|
previous = None
|
||||||
)
|
|
||||||
|
|
||||||
for ff in range(num_detections):
|
for row in range(num_detections):
|
||||||
x_start = int(np.maximum(0, x_pos[ff]))
|
prediction: types.Prediction = {
|
||||||
x_end = int(
|
"det_prob": float(pred_nms["det_probs"][row]),
|
||||||
np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff]))
|
"class_prob": pred_nms["class_probs"][:, row],
|
||||||
)
|
"start_time": float(pred_nms["start_times"][row]),
|
||||||
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top
|
"end_time": float(pred_nms["end_times"][row]),
|
||||||
y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff]))
|
"low_freq": float(pred_nms["low_freqs"][row]),
|
||||||
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
|
"high_freq": float(pred_nms["high_freqs"][row]),
|
||||||
spec_slice = spec[:, x_start:x_end]
|
"x_pos": int(pred_nms["x_pos"][row]),
|
||||||
|
"y_pos": int(pred_nms["y_pos"][row]),
|
||||||
|
"bb_width": int(pred_nms["bb_width"][row]),
|
||||||
|
"bb_height": int(pred_nms["bb_height"][row]),
|
||||||
|
}
|
||||||
|
|
||||||
if spec_slice.shape[1] > 1:
|
for col, feature in enumerate(FEATURES.values()):
|
||||||
features[ff, 0] = round(
|
features[row, col] = feature(
|
||||||
pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5
|
prediction,
|
||||||
)
|
previous=previous,
|
||||||
features[ff, 1] = int(pred_nms["low_freqs"][ff])
|
spec=spec,
|
||||||
features[ff, 2] = int(pred_nms["high_freqs"][ff])
|
**params,
|
||||||
features[ff, 3] = int(
|
|
||||||
pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff]
|
|
||||||
)
|
|
||||||
features[ff, 4] = int(
|
|
||||||
convert_int_to_freq(
|
|
||||||
y_high + spec_slice[y_high:y_low, :].sum(1).argmax(),
|
|
||||||
spec.shape[0],
|
|
||||||
params["min_freq"],
|
|
||||||
params["max_freq"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
features[ff, 5] = int(
|
|
||||||
convert_int_to_freq(
|
|
||||||
spec_slice.sum(1).argmax(),
|
|
||||||
spec.shape[0],
|
|
||||||
params["min_freq"],
|
|
||||||
params["max_freq"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
hlf_val = spec_slice.shape[1] // 2
|
|
||||||
|
|
||||||
features[ff, 6] = int(
|
|
||||||
convert_int_to_freq(
|
|
||||||
spec_slice[:, :hlf_val].sum(1).argmax(),
|
|
||||||
spec.shape[0],
|
|
||||||
params["min_freq"],
|
|
||||||
params["max_freq"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
features[ff, 7] = int(
|
|
||||||
convert_int_to_freq(
|
|
||||||
spec_slice[:, hlf_val:].sum(1).argmax(),
|
|
||||||
spec.shape[0],
|
|
||||||
params["min_freq"],
|
|
||||||
params["max_freq"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if ff > 0:
|
previous = prediction
|
||||||
features[ff, 8] = round(
|
|
||||||
pred_nms["start_times"][ff]
|
|
||||||
- pred_nms["start_times"][ff - 1],
|
|
||||||
5,
|
|
||||||
)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_names():
|
||||||
|
"""Get names of features in the order they are extracted."""
|
||||||
|
return list(FEATURES.keys())
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -25,10 +25,13 @@ except ImportError:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Annotation",
|
"Annotation",
|
||||||
"DetectionModel",
|
"DetectionModel",
|
||||||
|
"FeatureExtractionParameters",
|
||||||
|
"FeatureExtractor",
|
||||||
"FileAnnotations",
|
"FileAnnotations",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"ModelParameters",
|
"ModelParameters",
|
||||||
"NonMaximumSuppressionConfig",
|
"NonMaximumSuppressionConfig",
|
||||||
|
"Prediction",
|
||||||
"PredictionResults",
|
"PredictionResults",
|
||||||
"ProcessingConfiguration",
|
"ProcessingConfiguration",
|
||||||
"ResultParams",
|
"ResultParams",
|
||||||
@ -312,6 +315,40 @@ class ModelOutput(NamedTuple):
|
|||||||
"""Tensor with intermediate features."""
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
|
class Prediction(TypedDict):
|
||||||
|
"""Singe prediction."""
|
||||||
|
|
||||||
|
det_prob: float
|
||||||
|
"""Detection probability."""
|
||||||
|
|
||||||
|
x_pos: int
|
||||||
|
"""X position of the detection in pixels."""
|
||||||
|
|
||||||
|
y_pos: int
|
||||||
|
"""Y position of the detection in pixels."""
|
||||||
|
|
||||||
|
bb_width: int
|
||||||
|
"""Width of the detection in pixels."""
|
||||||
|
|
||||||
|
bb_height: int
|
||||||
|
"""Height of the detection in pixels."""
|
||||||
|
|
||||||
|
start_time: float
|
||||||
|
"""Start time of the detection in seconds."""
|
||||||
|
|
||||||
|
end_time: float
|
||||||
|
"""End time of the detection in seconds."""
|
||||||
|
|
||||||
|
low_freq: float
|
||||||
|
"""Low frequency of the detection in Hz."""
|
||||||
|
|
||||||
|
high_freq: float
|
||||||
|
"""High frequency of the detection in Hz."""
|
||||||
|
|
||||||
|
class_prob: np.ndarray
|
||||||
|
"""Vector holding the probability of each class."""
|
||||||
|
|
||||||
|
|
||||||
class PredictionResults(TypedDict):
|
class PredictionResults(TypedDict):
|
||||||
"""Results of the prediction.
|
"""Results of the prediction.
|
||||||
|
|
||||||
@ -418,6 +455,16 @@ class NonMaximumSuppressionConfig(TypedDict):
|
|||||||
"""Threshold for detection probability."""
|
"""Threshold for detection probability."""
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractionParameters(TypedDict):
|
||||||
|
"""Parameters that control the feature extraction function."""
|
||||||
|
|
||||||
|
min_freq: int
|
||||||
|
"""Minimum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
max_freq: int
|
||||||
|
"""Maximum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
|
||||||
class HeatmapParameters(TypedDict):
|
class HeatmapParameters(TypedDict):
|
||||||
"""Parameters that control the heatmap generation function."""
|
"""Parameters that control the heatmap generation function."""
|
||||||
|
|
||||||
@ -473,3 +520,11 @@ class AnnotationGroup(TypedDict):
|
|||||||
|
|
||||||
y_inds: NotRequired[np.ndarray]
|
y_inds: NotRequired[np.ndarray]
|
||||||
"""Y coordinate of the annotations in the spectrogram."""
|
"""Y coordinate of the annotations in the spectrogram."""
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractor(Protocol):
|
||||||
|
"""Protocol for feature extractors."""
|
||||||
|
|
||||||
|
def __call__(self, prediction: Prediction, **kwargs) -> Union[float, int]:
|
||||||
|
"""Extract features from a prediction."""
|
||||||
|
...
|
||||||
|
@ -773,7 +773,7 @@ def process_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# convert to numpy
|
# convert to numpy
|
||||||
spec_np = spec.detach().cpu().numpy()
|
spec_np = spec.detach().cpu().numpy().squeeze()
|
||||||
|
|
||||||
# add chunk time to start and end times
|
# add chunk time to start and end times
|
||||||
pred_nms["start_times"] += chunk_time
|
pred_nms["start_times"] += chunk_time
|
||||||
@ -794,7 +794,7 @@ def process_file(
|
|||||||
if config["spec_slices"]:
|
if config["spec_slices"]:
|
||||||
# FIX: This is not currently working. Returns empty slices
|
# FIX: This is not currently working. Returns empty slices
|
||||||
spec_slices.extend(
|
spec_slices.extend(
|
||||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
feats.extract_spec_slices(spec_np, pred_nms)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge results from chunks
|
# Merge results from chunks
|
||||||
|
@ -56,7 +56,7 @@ build-backend = "pdm.pep517.api"
|
|||||||
batdetect2 = "batdetect2.cli:cli"
|
batdetect2 = "batdetect2.cli:cli"
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 80
|
line-length = 79
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = [
|
module = [
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test the command line interface."""
|
"""Test the command line interface."""
|
||||||
|
from pathlib import Path
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
@ -67,3 +69,42 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
|||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert 'Time Expansion Factor: 10' in result.stdout
|
assert 'Time Expansion Factor: 10' in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||||
|
"""Test the detect command with the spec feature flag."""
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
|
||||||
|
# Remove results dir if it exists
|
||||||
|
if results_dir.exists():
|
||||||
|
results_dir.rmdir()
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
"example_data/audio",
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
"--spec_features",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert results_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||||
|
|
||||||
|
expected_files = [
|
||||||
|
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||||
|
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||||
|
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
|
||||||
|
]
|
||||||
|
|
||||||
|
for expected_file in expected_files:
|
||||||
|
assert expected_file in csv_files
|
||||||
|
|
||||||
|
df = pd.read_csv(results_dir / expected_file)
|
||||||
|
assert not (df.duration == -1).any()
|
||||||
|
291
tests/test_features.py
Normal file
291
tests/test_features.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
"""Test suite for feature extraction functions."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import batdetect2.detector.compute_features as feats
|
||||||
|
from batdetect2 import api, types
|
||||||
|
from batdetect2.utils import audio_utils as au
|
||||||
|
|
||||||
|
numba_logger = logging.getLogger("numba")
|
||||||
|
numba_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
def index_to_freq(
|
||||||
|
index: int,
|
||||||
|
spec_height: int,
|
||||||
|
min_freq: int,
|
||||||
|
max_freq: int,
|
||||||
|
) -> float:
|
||||||
|
"""Convert spectrogram index to frequency in Hz."""
|
||||||
|
index = spec_height - index
|
||||||
|
return round(
|
||||||
|
(index / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def index_to_time(
|
||||||
|
index: int,
|
||||||
|
spec_width: int,
|
||||||
|
spec_duration: float,
|
||||||
|
) -> float:
|
||||||
|
"""Convert spectrogram index to time in seconds."""
|
||||||
|
return round((index / float(spec_width)) * spec_duration, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_feats_function_with_empty_spectrogram():
|
||||||
|
"""Test get_feats function with empty spectrogram.
|
||||||
|
|
||||||
|
This tests that the overall flow of the function works, even if the
|
||||||
|
spectrogram is empty.
|
||||||
|
"""
|
||||||
|
spec_duration = 3
|
||||||
|
spec_width = 100
|
||||||
|
spec_height = 100
|
||||||
|
min_freq = 10_000
|
||||||
|
max_freq = 120_000
|
||||||
|
spectrogram = np.zeros((spec_height, spec_width))
|
||||||
|
|
||||||
|
x_pos = 20
|
||||||
|
y_pos = 80
|
||||||
|
bb_width = 20
|
||||||
|
bb_height = 20
|
||||||
|
|
||||||
|
start_time = index_to_time(x_pos, spec_width, spec_duration)
|
||||||
|
end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration)
|
||||||
|
low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq)
|
||||||
|
high_freq = index_to_freq(
|
||||||
|
y_pos - bb_height, spec_height, min_freq, max_freq
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_nms: types.PredictionResults = {
|
||||||
|
"det_probs": np.array([1]),
|
||||||
|
"class_probs": np.array([[1]]),
|
||||||
|
"x_pos": np.array([x_pos]),
|
||||||
|
"y_pos": np.array([y_pos]),
|
||||||
|
"bb_width": np.array([bb_width]),
|
||||||
|
"bb_height": np.array([bb_height]),
|
||||||
|
"start_times": np.array([start_time]),
|
||||||
|
"end_times": np.array([end_time]),
|
||||||
|
"low_freqs": np.array([low_freq]),
|
||||||
|
"high_freqs": np.array([high_freq]),
|
||||||
|
}
|
||||||
|
|
||||||
|
params: types.FeatureExtractionParameters = {
|
||||||
|
"min_freq": min_freq,
|
||||||
|
"max_freq": max_freq,
|
||||||
|
}
|
||||||
|
|
||||||
|
features = feats.get_feats(spectrogram, pred_nms, params)
|
||||||
|
assert low_freq < high_freq
|
||||||
|
assert isinstance(features, np.ndarray)
|
||||||
|
assert features.shape == (len(pred_nms["det_probs"]), 9)
|
||||||
|
assert np.isclose(
|
||||||
|
features[0],
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
end_time - start_time,
|
||||||
|
low_freq,
|
||||||
|
high_freq,
|
||||||
|
high_freq - low_freq,
|
||||||
|
high_freq,
|
||||||
|
max_freq,
|
||||||
|
max_freq,
|
||||||
|
max_freq,
|
||||||
|
np.nan,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
equal_nan=True,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"max_power",
|
||||||
|
[
|
||||||
|
30_000,
|
||||||
|
31_000,
|
||||||
|
32_000,
|
||||||
|
33_000,
|
||||||
|
34_000,
|
||||||
|
35_000,
|
||||||
|
36_000,
|
||||||
|
37_000,
|
||||||
|
38_000,
|
||||||
|
39_000,
|
||||||
|
40_000,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_compute_max_power_bb(max_power: int):
|
||||||
|
"""Test compute_max_power_bb function."""
|
||||||
|
duration = 1
|
||||||
|
samplerate = 256_000
|
||||||
|
min_freq = 0
|
||||||
|
max_freq = 128_000
|
||||||
|
|
||||||
|
start_time = 0.3
|
||||||
|
end_time = 0.6
|
||||||
|
low_freq = 30_000
|
||||||
|
high_freq = 40_000
|
||||||
|
|
||||||
|
audio = np.zeros((int(duration * samplerate),))
|
||||||
|
|
||||||
|
# Add a signal during the time and frequency range of interest
|
||||||
|
audio[
|
||||||
|
int(start_time * samplerate) : int(end_time * samplerate)
|
||||||
|
] = 0.5 * librosa.tone(
|
||||||
|
max_power, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a more powerful signal outside frequency range of interest
|
||||||
|
audio[
|
||||||
|
int(start_time * samplerate) : int(end_time * samplerate)
|
||||||
|
] += 2 * librosa.tone(
|
||||||
|
80_000, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
params = api.get_config(
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
target_samp_rate=samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec, _ = au.generate_spectrogram(
|
||||||
|
audio,
|
||||||
|
samplerate,
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_start = int(
|
||||||
|
au.time_to_x_coords(
|
||||||
|
start_time,
|
||||||
|
samplerate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_end = int(
|
||||||
|
au.time_to_x_coords(
|
||||||
|
end_time,
|
||||||
|
samplerate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_freq_bins = spec.shape[0]
|
||||||
|
y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq)
|
||||||
|
y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq)
|
||||||
|
|
||||||
|
prediction: types.Prediction = {
|
||||||
|
"det_prob": 1,
|
||||||
|
"class_prob": np.ones((1,)),
|
||||||
|
"x_pos": x_start,
|
||||||
|
"y_pos": int(y_low),
|
||||||
|
"bb_width": int(x_end - x_start),
|
||||||
|
"bb_height": int(y_low - y_high),
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"low_freq": low_freq,
|
||||||
|
"high_freq": high_freq,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(prediction)
|
||||||
|
|
||||||
|
max_power_bb = feats.compute_max_power_bb(
|
||||||
|
prediction,
|
||||||
|
spec,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert abs(max_power_bb - max_power) <= 500
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_max_power():
|
||||||
|
"""Test compute_max_power_bb function."""
|
||||||
|
duration = 3
|
||||||
|
samplerate = 16_000
|
||||||
|
min_freq = 0
|
||||||
|
max_freq = 8_000
|
||||||
|
|
||||||
|
start_time = 1
|
||||||
|
end_time = 2
|
||||||
|
low_freq = 3_000
|
||||||
|
high_freq = 4_000
|
||||||
|
max_power = 5_000
|
||||||
|
|
||||||
|
audio = np.zeros((int(duration * samplerate),))
|
||||||
|
|
||||||
|
# Add a signal during the time and frequency range of interest
|
||||||
|
audio[
|
||||||
|
int(start_time * samplerate) : int(end_time * samplerate)
|
||||||
|
] = 0.5 * librosa.tone(
|
||||||
|
3_500, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a more powerful signal outside frequency range of interest
|
||||||
|
audio[
|
||||||
|
int(start_time * samplerate) : int(end_time * samplerate)
|
||||||
|
] += 2 * librosa.tone(
|
||||||
|
max_power, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
params = api.get_config(
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
target_samp_rate=samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec, _ = au.generate_spectrogram(
|
||||||
|
audio,
|
||||||
|
samplerate,
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_start = int(
|
||||||
|
au.time_to_x_coords(
|
||||||
|
start_time,
|
||||||
|
samplerate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_end = int(
|
||||||
|
au.time_to_x_coords(
|
||||||
|
end_time,
|
||||||
|
samplerate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_freq_bins = spec.shape[0]
|
||||||
|
y_low = int(num_freq_bins * low_freq / max_freq)
|
||||||
|
y_high = int(num_freq_bins * high_freq / max_freq)
|
||||||
|
|
||||||
|
prediction: types.Prediction = {
|
||||||
|
"det_prob": 1,
|
||||||
|
"class_prob": np.ones((1,)),
|
||||||
|
"x_pos": x_start,
|
||||||
|
"y_pos": int(y_high),
|
||||||
|
"bb_width": int(x_end - x_start),
|
||||||
|
"bb_height": int(y_high - y_low),
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"low_freq": low_freq,
|
||||||
|
"high_freq": high_freq,
|
||||||
|
}
|
||||||
|
|
||||||
|
computed_max_power = feats.compute_max_power(
|
||||||
|
prediction,
|
||||||
|
spec,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert abs(computed_max_power - max_power) < 100
|
Loading…
Reference in New Issue
Block a user