Merge pull request #16 from macaodha/fix/GH-15-spectrogram-features-computation

Fix/gh 15 spectrogram features computation
This commit is contained in:
Oisin Mac Aodha 2023-08-03 12:36:19 +01:00 committed by GitHub
commit 70877495d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 659 additions and 94 deletions

2
.gitignore vendored
View File

@ -65,7 +65,7 @@ ipython_config.py
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
.pdm-python
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

View File

@ -1,22 +1,27 @@
"""Functions to compute features from predictions."""
from typing import Dict, Optional
import numpy as np
from batdetect2 import types
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
"""Convert spectrogram index to frequency in Hz.""" ""
spec_ind = spec_height - spec_ind
return round(
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
)
def extract_spec_slices(spec, pred_nms, params):
"""
Extracts spectrogram slices from spectrogram based on detected call locations.
"""
def extract_spec_slices(spec, pred_nms):
"""Extract spectrogram slices from spectrogram.
The slices are extracted based on detected call locations.
"""
x_pos = pred_nms["x_pos"]
y_pos = pred_nms["y_pos"]
bb_width = pred_nms["bb_width"]
bb_height = pred_nms["bb_height"]
slices = []
# add 20% padding either side of call
@ -35,100 +40,273 @@ def extract_spec_slices(spec, pred_nms, params):
return slices
def get_feature_names():
feature_names = [
"duration",
"low_freq_bb",
"high_freq_bb",
"bandwidth",
"max_power_bb",
"max_power",
"max_power_first",
"max_power_second",
"call_interval",
]
return feature_names
def compute_duration(
prediction: types.Prediction,
**_,
) -> float:
"""Compute duration of call in seconds."""
return round(prediction["end_time"] - prediction["start_time"], 5)
def get_feats(spec, pred_nms, params):
def compute_low_freq(
prediction: types.Prediction,
**_,
) -> float:
"""Compute lowest frequency in call in Hz."""
return int(prediction["low_freq"])
def compute_high_freq(
prediction: types.Prediction,
**_,
) -> float:
"""Compute highest frequency in call in Hz."""
return int(prediction["high_freq"])
def compute_bandwidth(
prediction: types.Prediction,
**_,
) -> float:
"""Compute bandwidth of call in Hz."""
return int(prediction["high_freq"] - prediction["low_freq"])
def compute_max_power_bb(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
) -> float:
"""Compute frequency with maximum power in call in Hz.
This is the frequency with the maximum power in the bounding box of the
call.
"""
Extracts features from spectrogram based on detected call locations.
Condsider re-extracting spectrogram for this to get better temporal resolution.
if spec is None:
return np.nan
x_start = max(0, prediction["x_pos"])
x_end = min(
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
)
# y low is the lowest freq but it will have a higher value due to array
# starting at 0 at top
y_low = min(spec.shape[0] - 1, prediction["y_pos"])
y_high = max(0, prediction["y_pos"] - prediction["bb_height"])
spec_bb = spec[y_high:y_low, x_start:x_end]
power_per_freq_band = np.sum(spec_bb, axis=1)
try:
max_power_ind = np.argmax(power_per_freq_band)
except ValueError:
# If the call is too short, the bounding box might be empty.
# In this case, return NaN.
return np.nan
return int(
convert_int_to_freq(
y_high + max_power_ind,
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_max_power(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
) -> float:
"""Compute frequency with maximum power in during the call in Hz."""
if spec is None:
return np.nan
x_start = max(0, prediction["x_pos"])
x_end = min(
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
)
spec_call = spec[:, x_start:x_end]
power_per_freq_band = np.sum(spec_call, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_max_power_first(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
) -> float:
"""Compute frequency with maximum power in first half of call in Hz."""
if spec is None:
return np.nan
x_start = max(0, prediction["x_pos"])
x_end = min(
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
)
spec_call = spec[:, x_start:x_end]
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
power_per_freq_band = np.sum(first_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_max_power_second(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
) -> float:
"""Compute frequency with maximum power in second half of call in Hz."""
if spec is None:
return np.nan
x_start = max(0, prediction["x_pos"])
x_end = min(
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
)
spec_call = spec[:, x_start:x_end]
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
power_per_freq_band = np.sum(second_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_call_interval(
prediction: types.Prediction,
previous: Optional[types.Prediction] = None,
**_,
) -> float:
"""Compute time between this call and the previous call in seconds."""
if previous is None:
return np.nan
return round(prediction["start_time"] - previous["end_time"], 5)
# NOTE: The order of the features in this dictionary is important. The
# features are extracted in this order and the order of the columns in the
# output csv file is determined by this order. In order to avoid breaking
# changes in the output csv file, new features should be added to the end of
# this dictionary.
FEATURES: Dict[str, types.FeatureExtractor] = {
"duration": compute_duration,
"low_freq_bb": compute_low_freq,
"high_freq_bb": compute_high_freq,
"bandwidth": compute_bandwidth,
"max_power_bb": compute_max_power_bb,
"max_power": compute_max_power,
"max_power_first": compute_max_power_first,
"max_power_second": compute_max_power_second,
"call_interval": compute_call_interval,
}
def get_feats(
spec: np.ndarray,
pred_nms: types.PredictionResults,
params: types.FeatureExtractionParameters,
):
"""Extract features from spectrogram based on detected call locations.
The features extracted are:
- duration: duration of call in seconds
- low_freq: lowest frequency in call in kHz
- high_freq: highest frequency in call in kHz
- bandwidth: high_freq - low_freq
- max_power_bb: frequency with maximum power in call in kHz
- max_power: frequency with maximum power in spectrogram in kHz
- max_power_first: frequency with maximum power in first half of call in
kHz.
- max_power_second: frequency with maximum power in second half of call in
kHz.
- call_interval: time between this call and the previous call in seconds
Consider re-extracting spectrogram for this to get better temporal
resolution.
For more possible features check out:
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
Parameters
----------
spec : np.ndarray
Spectrogram from which to extract features.
pred_nms : types.PredictionResults
Information about detected calls from which to extract features.
params : types.FeatureExtractionParameters
Parameters for feature extraction.
Returns
-------
features : np.ndarray
Extracted features for each detected call. Shape is
(num_detections, num_features).
"""
x_pos = pred_nms["x_pos"]
y_pos = pred_nms["y_pos"]
bb_width = pred_nms["bb_width"]
bb_height = pred_nms["bb_height"]
feature_names = get_feature_names()
num_detections = len(pred_nms["det_probs"])
features = (
np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1
)
features = np.empty((num_detections, len(FEATURES)), dtype=np.float32)
previous = None
for ff in range(num_detections):
x_start = int(np.maximum(0, x_pos[ff]))
x_end = int(
np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff]))
)
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top
y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff]))
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
spec_slice = spec[:, x_start:x_end]
for row in range(num_detections):
prediction: types.Prediction = {
"det_prob": float(pred_nms["det_probs"][row]),
"class_prob": pred_nms["class_probs"][:, row],
"start_time": float(pred_nms["start_times"][row]),
"end_time": float(pred_nms["end_times"][row]),
"low_freq": float(pred_nms["low_freqs"][row]),
"high_freq": float(pred_nms["high_freqs"][row]),
"x_pos": int(pred_nms["x_pos"][row]),
"y_pos": int(pred_nms["y_pos"][row]),
"bb_width": int(pred_nms["bb_width"][row]),
"bb_height": int(pred_nms["bb_height"][row]),
}
if spec_slice.shape[1] > 1:
features[ff, 0] = round(
pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5
)
features[ff, 1] = int(pred_nms["low_freqs"][ff])
features[ff, 2] = int(pred_nms["high_freqs"][ff])
features[ff, 3] = int(
pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff]
)
features[ff, 4] = int(
convert_int_to_freq(
y_high + spec_slice[y_high:y_low, :].sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
features[ff, 5] = int(
convert_int_to_freq(
spec_slice.sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
hlf_val = spec_slice.shape[1] // 2
features[ff, 6] = int(
convert_int_to_freq(
spec_slice[:, :hlf_val].sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
features[ff, 7] = int(
convert_int_to_freq(
spec_slice[:, hlf_val:].sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
for col, feature in enumerate(FEATURES.values()):
features[row, col] = feature(
prediction,
previous=previous,
spec=spec,
**params,
)
if ff > 0:
features[ff, 8] = round(
pred_nms["start_times"][ff]
- pred_nms["start_times"][ff - 1],
5,
)
previous = prediction
return features
def get_feature_names():
"""Get names of features in the order they are extracted."""
return list(FEATURES.keys())

View File

@ -1,5 +1,5 @@
"""Types used in the code base."""
from typing import List, NamedTuple, Optional
from typing import List, NamedTuple, Optional, Union
import numpy as np
import torch
@ -25,10 +25,13 @@ except ImportError:
__all__ = [
"Annotation",
"DetectionModel",
"FeatureExtractionParameters",
"FeatureExtractor",
"FileAnnotations",
"ModelOutput",
"ModelParameters",
"NonMaximumSuppressionConfig",
"Prediction",
"PredictionResults",
"ProcessingConfiguration",
"ResultParams",
@ -312,6 +315,40 @@ class ModelOutput(NamedTuple):
"""Tensor with intermediate features."""
class Prediction(TypedDict):
"""Singe prediction."""
det_prob: float
"""Detection probability."""
x_pos: int
"""X position of the detection in pixels."""
y_pos: int
"""Y position of the detection in pixels."""
bb_width: int
"""Width of the detection in pixels."""
bb_height: int
"""Height of the detection in pixels."""
start_time: float
"""Start time of the detection in seconds."""
end_time: float
"""End time of the detection in seconds."""
low_freq: float
"""Low frequency of the detection in Hz."""
high_freq: float
"""High frequency of the detection in Hz."""
class_prob: np.ndarray
"""Vector holding the probability of each class."""
class PredictionResults(TypedDict):
"""Results of the prediction.
@ -418,6 +455,16 @@ class NonMaximumSuppressionConfig(TypedDict):
"""Threshold for detection probability."""
class FeatureExtractionParameters(TypedDict):
"""Parameters that control the feature extraction function."""
min_freq: int
"""Minimum frequency to consider in Hz."""
max_freq: int
"""Maximum frequency to consider in Hz."""
class HeatmapParameters(TypedDict):
"""Parameters that control the heatmap generation function."""
@ -473,3 +520,11 @@ class AnnotationGroup(TypedDict):
y_inds: NotRequired[np.ndarray]
"""Y coordinate of the annotations in the spectrogram."""
class FeatureExtractor(Protocol):
"""Protocol for feature extractors."""
def __call__(self, prediction: Prediction, **kwargs) -> Union[float, int]:
"""Extract features from a prediction."""
...

View File

@ -773,7 +773,7 @@ def process_file(
)
# convert to numpy
spec_np = spec.detach().cpu().numpy()
spec_np = spec.detach().cpu().numpy().squeeze()
# add chunk time to start and end times
pred_nms["start_times"] += chunk_time
@ -794,7 +794,7 @@ def process_file(
if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices
spec_slices.extend(
feats.extract_spec_slices(spec_np, pred_nms, config)
feats.extract_spec_slices(spec_np, pred_nms)
)
# Merge results from chunks

View File

@ -56,7 +56,7 @@ build-backend = "pdm.pep517.api"
batdetect2 = "batdetect2.cli:cli"
[tool.black]
line-length = 80
line-length = 79
[[tool.mypy.overrides]]
module = [

View File

@ -1,5 +1,7 @@
"""Test the command line interface."""
from pathlib import Path
from click.testing import CliRunner
import pandas as pd
from batdetect2.cli import cli
@ -67,3 +69,42 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
assert result.exit_code == 0
assert 'Time Expansion Factor: 10' in result.stdout
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
"""Test the detect command with the spec feature flag."""
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert results_dir.exists()
csv_files = [path.name for path in results_dir.glob("*.csv")]
expected_files = [
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
]
for expected_file in expected_files:
assert expected_file in csv_files
df = pd.read_csv(results_dir / expected_file)
assert not (df.duration == -1).any()

291
tests/test_features.py Normal file
View File

@ -0,0 +1,291 @@
"""Test suite for feature extraction functions."""
import logging
import librosa
import numpy as np
import pytest
import batdetect2.detector.compute_features as feats
from batdetect2 import api, types
from batdetect2.utils import audio_utils as au
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
def index_to_freq(
index: int,
spec_height: int,
min_freq: int,
max_freq: int,
) -> float:
"""Convert spectrogram index to frequency in Hz."""
index = spec_height - index
return round(
(index / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
)
def index_to_time(
index: int,
spec_width: int,
spec_duration: float,
) -> float:
"""Convert spectrogram index to time in seconds."""
return round((index / float(spec_width)) * spec_duration, 2)
def test_get_feats_function_with_empty_spectrogram():
"""Test get_feats function with empty spectrogram.
This tests that the overall flow of the function works, even if the
spectrogram is empty.
"""
spec_duration = 3
spec_width = 100
spec_height = 100
min_freq = 10_000
max_freq = 120_000
spectrogram = np.zeros((spec_height, spec_width))
x_pos = 20
y_pos = 80
bb_width = 20
bb_height = 20
start_time = index_to_time(x_pos, spec_width, spec_duration)
end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration)
low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq)
high_freq = index_to_freq(
y_pos - bb_height, spec_height, min_freq, max_freq
)
pred_nms: types.PredictionResults = {
"det_probs": np.array([1]),
"class_probs": np.array([[1]]),
"x_pos": np.array([x_pos]),
"y_pos": np.array([y_pos]),
"bb_width": np.array([bb_width]),
"bb_height": np.array([bb_height]),
"start_times": np.array([start_time]),
"end_times": np.array([end_time]),
"low_freqs": np.array([low_freq]),
"high_freqs": np.array([high_freq]),
}
params: types.FeatureExtractionParameters = {
"min_freq": min_freq,
"max_freq": max_freq,
}
features = feats.get_feats(spectrogram, pred_nms, params)
assert low_freq < high_freq
assert isinstance(features, np.ndarray)
assert features.shape == (len(pred_nms["det_probs"]), 9)
assert np.isclose(
features[0],
np.array(
[
end_time - start_time,
low_freq,
high_freq,
high_freq - low_freq,
high_freq,
max_freq,
max_freq,
max_freq,
np.nan,
]
),
equal_nan=True,
).all()
@pytest.mark.parametrize(
"max_power",
[
30_000,
31_000,
32_000,
33_000,
34_000,
35_000,
36_000,
37_000,
38_000,
39_000,
40_000,
],
)
def test_compute_max_power_bb(max_power: int):
"""Test compute_max_power_bb function."""
duration = 1
samplerate = 256_000
min_freq = 0
max_freq = 128_000
start_time = 0.3
end_time = 0.6
low_freq = 30_000
high_freq = 40_000
audio = np.zeros((int(duration * samplerate),))
# Add a signal during the time and frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] = 0.5 * librosa.tone(
max_power, sr=samplerate, duration=end_time - start_time
)
# Add a more powerful signal outside frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] += 2 * librosa.tone(
80_000, sr=samplerate, duration=end_time - start_time
)
params = api.get_config(
min_freq=min_freq,
max_freq=max_freq,
target_samp_rate=samplerate,
)
spec, _ = au.generate_spectrogram(
audio,
samplerate,
params,
)
x_start = int(
au.time_to_x_coords(
start_time,
samplerate,
params["fft_win_length"],
params["fft_overlap"],
)
)
x_end = int(
au.time_to_x_coords(
end_time,
samplerate,
params["fft_win_length"],
params["fft_overlap"],
)
)
num_freq_bins = spec.shape[0]
y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq)
y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq)
prediction: types.Prediction = {
"det_prob": 1,
"class_prob": np.ones((1,)),
"x_pos": x_start,
"y_pos": int(y_low),
"bb_width": int(x_end - x_start),
"bb_height": int(y_low - y_high),
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
}
print(prediction)
max_power_bb = feats.compute_max_power_bb(
prediction,
spec,
min_freq=min_freq,
max_freq=max_freq,
)
assert abs(max_power_bb - max_power) <= 500
def test_compute_max_power():
"""Test compute_max_power_bb function."""
duration = 3
samplerate = 16_000
min_freq = 0
max_freq = 8_000
start_time = 1
end_time = 2
low_freq = 3_000
high_freq = 4_000
max_power = 5_000
audio = np.zeros((int(duration * samplerate),))
# Add a signal during the time and frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] = 0.5 * librosa.tone(
3_500, sr=samplerate, duration=end_time - start_time
)
# Add a more powerful signal outside frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] += 2 * librosa.tone(
max_power, sr=samplerate, duration=end_time - start_time
)
params = api.get_config(
min_freq=min_freq,
max_freq=max_freq,
target_samp_rate=samplerate,
)
spec, _ = au.generate_spectrogram(
audio,
samplerate,
params,
)
x_start = int(
au.time_to_x_coords(
start_time,
samplerate,
params["fft_win_length"],
params["fft_overlap"],
)
)
x_end = int(
au.time_to_x_coords(
end_time,
samplerate,
params["fft_win_length"],
params["fft_overlap"],
)
)
num_freq_bins = spec.shape[0]
y_low = int(num_freq_bins * low_freq / max_freq)
y_high = int(num_freq_bins * high_freq / max_freq)
prediction: types.Prediction = {
"det_prob": 1,
"class_prob": np.ones((1,)),
"x_pos": x_start,
"y_pos": int(y_high),
"bb_width": int(x_end - x_start),
"bb_height": int(y_high - y_low),
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
}
computed_max_power = feats.compute_max_power(
prediction,
spec,
min_freq=min_freq,
max_freq=max_freq,
)
assert abs(computed_max_power - max_power) < 100