mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Merge branch 'main' into train
This commit is contained in:
commit
a8f64d172b
39
.github/workflows/python-publish.yml
vendored
Normal file
39
.github/workflows/python-publish.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
10
.gitignore
vendored
10
.gitignore
vendored
@ -65,7 +65,7 @@ ipython_config.py
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
@ -102,7 +102,11 @@ experiments/*
|
||||
.virtual_documents
|
||||
.ipynb_checkpoints
|
||||
*.ipynb
|
||||
!batdetect2_notebook.ipynb
|
||||
|
||||
# Batdetect Models [Include]
|
||||
# Bump2version
|
||||
.bumpversion.cfg
|
||||
|
||||
# DO Include
|
||||
!batdetect2_notebook.ipynb
|
||||
!batdetect2/models/*.pth.tar
|
||||
!tests/data/*.wav
|
||||
|
@ -29,7 +29,7 @@ pip install batdetect2
|
||||
```
|
||||
|
||||
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
|
||||
Once unziped, run this from extracted folder.
|
||||
Once unzipped, run this from extracted folder.
|
||||
|
||||
```bash
|
||||
pip install .
|
||||
@ -98,7 +98,7 @@ You can integrate the detections or the extracted features to your custom analys
|
||||
|
||||
|
||||
## Training the model on your own data
|
||||
Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model.
|
||||
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
|
||||
|
||||
|
||||
## Data and annotations
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '1.0.0'
|
||||
__version__ = '1.0.7'
|
||||
|
@ -5,6 +5,7 @@ import click
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -77,6 +78,7 @@ def detect(
|
||||
audio_dir: str,
|
||||
ann_dir: str,
|
||||
detection_threshold: float,
|
||||
time_expansion_factor: int,
|
||||
**args,
|
||||
):
|
||||
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
||||
@ -103,16 +105,23 @@ def detect(
|
||||
**{
|
||||
**params,
|
||||
**args,
|
||||
"time_expansion": time_expansion_factor,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 2,
|
||||
"detection_threshold": detection_threshold,
|
||||
}
|
||||
)
|
||||
|
||||
if not args["quiet"]:
|
||||
print_config(config)
|
||||
|
||||
# process files
|
||||
error_files = []
|
||||
for audio_file in files:
|
||||
for index, audio_file in enumerate(files):
|
||||
try:
|
||||
if not args["quiet"]:
|
||||
click.echo(f"\n{index} {audio_file}")
|
||||
|
||||
results = api.process_file(audio_file, model, config=config)
|
||||
|
||||
if args["save_preds_if_empty"] or (
|
||||
@ -133,5 +142,12 @@ def detect(
|
||||
click.echo(f" {err}")
|
||||
|
||||
|
||||
def print_config(config: ProcessingConfiguration):
|
||||
"""Print the processing configuration."""
|
||||
click.echo("\nProcessing Configuration:")
|
||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
@ -1,22 +1,27 @@
|
||||
"""Functions to compute features from predictions."""
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from batdetect2 import types
|
||||
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
|
||||
|
||||
|
||||
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
||||
"""Convert spectrogram index to frequency in Hz.""" ""
|
||||
spec_ind = spec_height - spec_ind
|
||||
return round(
|
||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||
)
|
||||
|
||||
|
||||
def extract_spec_slices(spec, pred_nms, params):
|
||||
"""
|
||||
Extracts spectrogram slices from spectrogram based on detected call locations.
|
||||
"""
|
||||
def extract_spec_slices(spec, pred_nms):
|
||||
"""Extract spectrogram slices from spectrogram.
|
||||
|
||||
The slices are extracted based on detected call locations.
|
||||
"""
|
||||
x_pos = pred_nms["x_pos"]
|
||||
y_pos = pred_nms["y_pos"]
|
||||
bb_width = pred_nms["bb_width"]
|
||||
bb_height = pred_nms["bb_height"]
|
||||
slices = []
|
||||
|
||||
# add 20% padding either side of call
|
||||
@ -35,100 +40,273 @@ def extract_spec_slices(spec, pred_nms, params):
|
||||
return slices
|
||||
|
||||
|
||||
def get_feature_names():
|
||||
feature_names = [
|
||||
"duration",
|
||||
"low_freq_bb",
|
||||
"high_freq_bb",
|
||||
"bandwidth",
|
||||
"max_power_bb",
|
||||
"max_power",
|
||||
"max_power_first",
|
||||
"max_power_second",
|
||||
"call_interval",
|
||||
]
|
||||
return feature_names
|
||||
def compute_duration(
|
||||
prediction: types.Prediction,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute duration of call in seconds."""
|
||||
return round(prediction["end_time"] - prediction["start_time"], 5)
|
||||
|
||||
|
||||
def get_feats(spec, pred_nms, params):
|
||||
def compute_low_freq(
|
||||
prediction: types.Prediction,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute lowest frequency in call in Hz."""
|
||||
return int(prediction["low_freq"])
|
||||
|
||||
|
||||
def compute_high_freq(
|
||||
prediction: types.Prediction,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute highest frequency in call in Hz."""
|
||||
return int(prediction["high_freq"])
|
||||
|
||||
|
||||
def compute_bandwidth(
|
||||
prediction: types.Prediction,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute bandwidth of call in Hz."""
|
||||
return int(prediction["high_freq"] - prediction["low_freq"])
|
||||
|
||||
|
||||
def compute_max_power_bb(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute frequency with maximum power in call in Hz.
|
||||
|
||||
This is the frequency with the maximum power in the bounding box of the
|
||||
call.
|
||||
"""
|
||||
Extracts features from spectrogram based on detected call locations.
|
||||
Condsider re-extracting spectrogram for this to get better temporal resolution.
|
||||
if spec is None:
|
||||
return np.nan
|
||||
|
||||
x_start = max(0, prediction["x_pos"])
|
||||
x_end = min(
|
||||
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||
)
|
||||
|
||||
# y low is the lowest freq but it will have a higher value due to array
|
||||
# starting at 0 at top
|
||||
y_low = min(spec.shape[0] - 1, prediction["y_pos"])
|
||||
y_high = max(0, prediction["y_pos"] - prediction["bb_height"])
|
||||
|
||||
spec_bb = spec[y_high:y_low, x_start:x_end]
|
||||
power_per_freq_band = np.sum(spec_bb, axis=1)
|
||||
|
||||
try:
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
except ValueError:
|
||||
# If the call is too short, the bounding box might be empty.
|
||||
# In this case, return NaN.
|
||||
return np.nan
|
||||
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
y_high + max_power_ind,
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_max_power(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute frequency with maximum power in during the call in Hz."""
|
||||
if spec is None:
|
||||
return np.nan
|
||||
|
||||
x_start = max(0, prediction["x_pos"])
|
||||
x_end = min(
|
||||
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||
)
|
||||
spec_call = spec[:, x_start:x_end]
|
||||
power_per_freq_band = np.sum(spec_call, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_max_power_first(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute frequency with maximum power in first half of call in Hz."""
|
||||
if spec is None:
|
||||
return np.nan
|
||||
|
||||
x_start = max(0, prediction["x_pos"])
|
||||
x_end = min(
|
||||
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||
)
|
||||
spec_call = spec[:, x_start:x_end]
|
||||
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
|
||||
power_per_freq_band = np.sum(first_half, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_max_power_second(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute frequency with maximum power in second half of call in Hz."""
|
||||
if spec is None:
|
||||
return np.nan
|
||||
|
||||
x_start = max(0, prediction["x_pos"])
|
||||
x_end = min(
|
||||
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||
)
|
||||
spec_call = spec[:, x_start:x_end]
|
||||
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
|
||||
power_per_freq_band = np.sum(second_half, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_call_interval(
|
||||
prediction: types.Prediction,
|
||||
previous: Optional[types.Prediction] = None,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute time between this call and the previous call in seconds."""
|
||||
if previous is None:
|
||||
return np.nan
|
||||
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||
|
||||
|
||||
# NOTE: The order of the features in this dictionary is important. The
|
||||
# features are extracted in this order and the order of the columns in the
|
||||
# output csv file is determined by this order. In order to avoid breaking
|
||||
# changes in the output csv file, new features should be added to the end of
|
||||
# this dictionary.
|
||||
FEATURES: Dict[str, types.FeatureExtractor] = {
|
||||
"duration": compute_duration,
|
||||
"low_freq_bb": compute_low_freq,
|
||||
"high_freq_bb": compute_high_freq,
|
||||
"bandwidth": compute_bandwidth,
|
||||
"max_power_bb": compute_max_power_bb,
|
||||
"max_power": compute_max_power,
|
||||
"max_power_first": compute_max_power_first,
|
||||
"max_power_second": compute_max_power_second,
|
||||
"call_interval": compute_call_interval,
|
||||
}
|
||||
|
||||
|
||||
def get_feats(
|
||||
spec: np.ndarray,
|
||||
pred_nms: types.PredictionResults,
|
||||
params: types.FeatureExtractionParameters,
|
||||
):
|
||||
"""Extract features from spectrogram based on detected call locations.
|
||||
|
||||
The features extracted are:
|
||||
|
||||
- duration: duration of call in seconds
|
||||
- low_freq: lowest frequency in call in kHz
|
||||
- high_freq: highest frequency in call in kHz
|
||||
- bandwidth: high_freq - low_freq
|
||||
- max_power_bb: frequency with maximum power in call in kHz
|
||||
- max_power: frequency with maximum power in spectrogram in kHz
|
||||
- max_power_first: frequency with maximum power in first half of call in
|
||||
kHz.
|
||||
- max_power_second: frequency with maximum power in second half of call in
|
||||
kHz.
|
||||
- call_interval: time between this call and the previous call in seconds
|
||||
|
||||
Consider re-extracting spectrogram for this to get better temporal
|
||||
resolution.
|
||||
|
||||
For more possible features check out:
|
||||
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : np.ndarray
|
||||
Spectrogram from which to extract features.
|
||||
|
||||
pred_nms : types.PredictionResults
|
||||
Information about detected calls from which to extract features.
|
||||
|
||||
params : types.FeatureExtractionParameters
|
||||
Parameters for feature extraction.
|
||||
|
||||
Returns
|
||||
-------
|
||||
features : np.ndarray
|
||||
Extracted features for each detected call. Shape is
|
||||
(num_detections, num_features).
|
||||
"""
|
||||
|
||||
x_pos = pred_nms["x_pos"]
|
||||
y_pos = pred_nms["y_pos"]
|
||||
bb_width = pred_nms["bb_width"]
|
||||
bb_height = pred_nms["bb_height"]
|
||||
|
||||
feature_names = get_feature_names()
|
||||
num_detections = len(pred_nms["det_probs"])
|
||||
features = (
|
||||
np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1
|
||||
)
|
||||
features = np.empty((num_detections, len(FEATURES)), dtype=np.float32)
|
||||
previous = None
|
||||
|
||||
for ff in range(num_detections):
|
||||
x_start = int(np.maximum(0, x_pos[ff]))
|
||||
x_end = int(
|
||||
np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff]))
|
||||
)
|
||||
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top
|
||||
y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff]))
|
||||
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
|
||||
spec_slice = spec[:, x_start:x_end]
|
||||
for row in range(num_detections):
|
||||
prediction: types.Prediction = {
|
||||
"det_prob": float(pred_nms["det_probs"][row]),
|
||||
"class_prob": pred_nms["class_probs"][:, row],
|
||||
"start_time": float(pred_nms["start_times"][row]),
|
||||
"end_time": float(pred_nms["end_times"][row]),
|
||||
"low_freq": float(pred_nms["low_freqs"][row]),
|
||||
"high_freq": float(pred_nms["high_freqs"][row]),
|
||||
"x_pos": int(pred_nms["x_pos"][row]),
|
||||
"y_pos": int(pred_nms["y_pos"][row]),
|
||||
"bb_width": int(pred_nms["bb_width"][row]),
|
||||
"bb_height": int(pred_nms["bb_height"][row]),
|
||||
}
|
||||
|
||||
if spec_slice.shape[1] > 1:
|
||||
features[ff, 0] = round(
|
||||
pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5
|
||||
)
|
||||
features[ff, 1] = int(pred_nms["low_freqs"][ff])
|
||||
features[ff, 2] = int(pred_nms["high_freqs"][ff])
|
||||
features[ff, 3] = int(
|
||||
pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff]
|
||||
)
|
||||
features[ff, 4] = int(
|
||||
convert_int_to_freq(
|
||||
y_high + spec_slice[y_high:y_low, :].sum(1).argmax(),
|
||||
spec.shape[0],
|
||||
params["min_freq"],
|
||||
params["max_freq"],
|
||||
)
|
||||
)
|
||||
features[ff, 5] = int(
|
||||
convert_int_to_freq(
|
||||
spec_slice.sum(1).argmax(),
|
||||
spec.shape[0],
|
||||
params["min_freq"],
|
||||
params["max_freq"],
|
||||
)
|
||||
)
|
||||
hlf_val = spec_slice.shape[1] // 2
|
||||
|
||||
features[ff, 6] = int(
|
||||
convert_int_to_freq(
|
||||
spec_slice[:, :hlf_val].sum(1).argmax(),
|
||||
spec.shape[0],
|
||||
params["min_freq"],
|
||||
params["max_freq"],
|
||||
)
|
||||
)
|
||||
features[ff, 7] = int(
|
||||
convert_int_to_freq(
|
||||
spec_slice[:, hlf_val:].sum(1).argmax(),
|
||||
spec.shape[0],
|
||||
params["min_freq"],
|
||||
params["max_freq"],
|
||||
)
|
||||
for col, feature in enumerate(FEATURES.values()):
|
||||
features[row, col] = feature(
|
||||
prediction,
|
||||
previous=previous,
|
||||
spec=spec,
|
||||
**params,
|
||||
)
|
||||
|
||||
if ff > 0:
|
||||
features[ff, 8] = round(
|
||||
pred_nms["start_times"][ff]
|
||||
- pred_nms["start_times"][ff - 1],
|
||||
5,
|
||||
)
|
||||
previous = prediction
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def get_feature_names():
|
||||
"""Get names of features in the order they are extracted."""
|
||||
return list(FEATURES.keys())
|
||||
|
@ -19,13 +19,15 @@ def x_coords_to_time(
|
||||
) -> float:
|
||||
"""Convert x coordinates of spectrogram to time in seconds.
|
||||
|
||||
Args:
|
||||
Parameters
|
||||
----------
|
||||
x_pos: X position of the detection in pixels.
|
||||
sampling_rate: Sampling rate of the audio in Hz.
|
||||
fft_win_length: Length of the FFT window in seconds.
|
||||
fft_overlap: Overlap of the FFT windows in seconds.
|
||||
|
||||
Returns:
|
||||
Returns
|
||||
-------
|
||||
Time in seconds.
|
||||
"""
|
||||
nfft = int(fft_win_length * sampling_rate)
|
||||
@ -134,12 +136,12 @@ def run_nms(
|
||||
y_pos[num_detection, valid_inds],
|
||||
x_pos[num_detection, valid_inds],
|
||||
].transpose(0, 1)
|
||||
feat = feat.detach().numpy().astype(np.float32)
|
||||
feat = feat.detach().cpu().numpy().astype(np.float32)
|
||||
feats.append(feat)
|
||||
|
||||
# convert to numpy
|
||||
for key, value in pred.items():
|
||||
pred[key] = value.detach().numpy().astype(np.float32)
|
||||
pred[key] = value.detach().cpu().numpy().astype(np.float32)
|
||||
|
||||
preds.append(pred) # type: ignore
|
||||
|
||||
|
@ -61,7 +61,7 @@ def spectrogram(
|
||||
"""
|
||||
# Convert to numpy array if needed
|
||||
if isinstance(spec, torch.Tensor):
|
||||
spec = spec.numpy()
|
||||
spec = spec.detach().cpu().numpy()
|
||||
|
||||
# Remove batch and channel dimensions if present
|
||||
spec = spec.squeeze()
|
||||
@ -265,7 +265,7 @@ def detection(
|
||||
# Add class label
|
||||
txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
|
||||
font_info = {
|
||||
"color": "white",
|
||||
"color": edgecolor,
|
||||
"size": 10,
|
||||
"weight": "bold",
|
||||
"alpha": rect.get_alpha(),
|
||||
|
@ -7,15 +7,14 @@ import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.detector import models
|
||||
from batdetect2.detector import parameters
|
||||
from batdetect2.train import losses
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_split as ts
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import models, parameters
|
||||
from batdetect2.train import losses
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
@ -84,7 +83,6 @@ def save_image(
|
||||
def loss_fun(
|
||||
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
||||
):
|
||||
|
||||
# detection loss
|
||||
loss = params["det_loss_weight"] * det_criterion(
|
||||
outputs["pred_det"], gt_det
|
||||
@ -108,7 +106,6 @@ def loss_fun(
|
||||
def train(
|
||||
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
||||
):
|
||||
|
||||
model.train()
|
||||
|
||||
train_loss = tu.AverageMeter()
|
||||
@ -119,7 +116,6 @@ def train(
|
||||
|
||||
print("\nEpoch", epoch)
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
|
||||
data = inputs["spec"].to(params["device"])
|
||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||
@ -172,7 +168,6 @@ def test(model, epoch, data_loader, det_criterion, params):
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
|
||||
data = inputs["spec"].to(params["device"])
|
||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||
@ -279,7 +274,7 @@ def parse_gt_data(inputs):
|
||||
is_valid = inputs["is_valid"][ind] == 1
|
||||
gt = {}
|
||||
for kk in keys:
|
||||
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32)
|
||||
gt[kk] = inputs[kk][ind][is_valid].cpu().numpy().astype(np.float32)
|
||||
gt["duration"] = inputs["duration"][ind].item()
|
||||
gt["file_id"] = inputs["file_id"][ind].item()
|
||||
gt["class_id_file"] = inputs["class_id_file"][ind].item()
|
||||
@ -318,8 +313,7 @@ def select_model(params):
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
plt.close("all")
|
||||
|
||||
params = parameters.get_params(True)
|
||||
@ -501,7 +495,6 @@ if __name__ == "__main__":
|
||||
#
|
||||
# main train loop
|
||||
for epoch in range(0, params["num_epochs"] + 1):
|
||||
|
||||
train_loss = train(
|
||||
model,
|
||||
epoch,
|
||||
@ -550,3 +543,7 @@ if __name__ == "__main__":
|
||||
# TODO: args variable does not exist
|
||||
# if not args["do_not_save_images"]:
|
||||
# save_images_batch(model, test_loader, params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Types used in the code base."""
|
||||
from typing import List, NamedTuple, Optional
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -25,10 +25,13 @@ except ImportError:
|
||||
__all__ = [
|
||||
"Annotation",
|
||||
"DetectionModel",
|
||||
"FeatureExtractionParameters",
|
||||
"FeatureExtractor",
|
||||
"FileAnnotations",
|
||||
"ModelOutput",
|
||||
"ModelParameters",
|
||||
"NonMaximumSuppressionConfig",
|
||||
"Prediction",
|
||||
"PredictionResults",
|
||||
"ProcessingConfiguration",
|
||||
"ResultParams",
|
||||
@ -316,6 +319,40 @@ class ModelOutput(NamedTuple):
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class Prediction(TypedDict):
|
||||
"""Singe prediction."""
|
||||
|
||||
det_prob: float
|
||||
"""Detection probability."""
|
||||
|
||||
x_pos: int
|
||||
"""X position of the detection in pixels."""
|
||||
|
||||
y_pos: int
|
||||
"""Y position of the detection in pixels."""
|
||||
|
||||
bb_width: int
|
||||
"""Width of the detection in pixels."""
|
||||
|
||||
bb_height: int
|
||||
"""Height of the detection in pixels."""
|
||||
|
||||
start_time: float
|
||||
"""Start time of the detection in seconds."""
|
||||
|
||||
end_time: float
|
||||
"""End time of the detection in seconds."""
|
||||
|
||||
low_freq: float
|
||||
"""Low frequency of the detection in Hz."""
|
||||
|
||||
high_freq: float
|
||||
"""High frequency of the detection in Hz."""
|
||||
|
||||
class_prob: np.ndarray
|
||||
"""Vector holding the probability of each class."""
|
||||
|
||||
|
||||
class PredictionResults(TypedDict):
|
||||
"""Results of the prediction.
|
||||
|
||||
@ -422,6 +459,16 @@ class NonMaximumSuppressionConfig(TypedDict):
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
|
||||
class FeatureExtractionParameters(TypedDict):
|
||||
"""Parameters that control the feature extraction function."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
|
||||
class HeatmapParameters(TypedDict):
|
||||
"""Parameters that control the heatmap generation function."""
|
||||
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]:
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Input directory not found.
|
||||
|
||||
"""
|
||||
matches = []
|
||||
for root, _, filenames in os.walk(ip_dir):
|
||||
@ -143,7 +143,19 @@ def load_model(
|
||||
|
||||
|
||||
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
predictions_m = {}
|
||||
predictions_m = {
|
||||
"det_probs": np.array([]),
|
||||
"x_pos": np.array([]),
|
||||
"y_pos": np.array([]),
|
||||
"bb_widths": np.array([]),
|
||||
"bb_heights": np.array([]),
|
||||
"start_times": np.array([]),
|
||||
"end_times": np.array([]),
|
||||
"low_freqs": np.array([]),
|
||||
"high_freqs": np.array([]),
|
||||
"class_probs": np.array([]),
|
||||
}
|
||||
|
||||
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
||||
|
||||
if num_preds > 0:
|
||||
@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
predictions_m[key] = np.hstack(
|
||||
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
||||
)
|
||||
else:
|
||||
# hack in case where no detected calls as we need some of the key
|
||||
# names in dict
|
||||
predictions_m = predictions[0]
|
||||
|
||||
if len(spec_feats) > 0:
|
||||
spec_feats = np.vstack(spec_feats)
|
||||
@ -226,11 +234,19 @@ def format_single_result(
|
||||
Returns:
|
||||
dict: Results in the format expected by the annotation tool.
|
||||
"""
|
||||
# Get a single class prediction for the file
|
||||
class_overall = pp.overall_class_pred(
|
||||
predictions["det_probs"],
|
||||
predictions["class_probs"],
|
||||
)
|
||||
try:
|
||||
# Get a single class prediction for the file
|
||||
class_overall = pp.overall_class_pred(
|
||||
predictions["det_probs"],
|
||||
predictions["class_probs"],
|
||||
)
|
||||
class_name = class_names[np.argmax(class_overall)]
|
||||
annotations = get_annotations_from_preds(predictions, class_names)
|
||||
except (np.AxisError, ValueError):
|
||||
# No detections
|
||||
class_overall = np.zeros(len(class_names))
|
||||
class_name = "None"
|
||||
annotations = []
|
||||
|
||||
return {
|
||||
"id": file_id,
|
||||
@ -239,8 +255,8 @@ def format_single_result(
|
||||
"notes": "Automatically generated.",
|
||||
"time_exp": time_exp,
|
||||
"duration": round(float(duration), 4),
|
||||
"annotation": get_annotations_from_preds(predictions, class_names),
|
||||
"class_name": class_names[np.argmax(class_overall)],
|
||||
"annotation": annotations,
|
||||
"class_name": class_name,
|
||||
}
|
||||
|
||||
|
||||
@ -253,6 +269,7 @@ def convert_results(
|
||||
spec_feats,
|
||||
cnn_feats,
|
||||
spec_slices,
|
||||
nyquist_freq: Optional[float] = None,
|
||||
) -> RunResults:
|
||||
"""Convert results to dictionary as expected by the annotation tool.
|
||||
|
||||
@ -268,8 +285,8 @@ def convert_results(
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with results.
|
||||
|
||||
"""
|
||||
|
||||
pred_dict = format_single_result(
|
||||
file_id,
|
||||
time_exp,
|
||||
@ -278,6 +295,14 @@ def convert_results(
|
||||
params["class_names"],
|
||||
)
|
||||
|
||||
# Remove high frequency detections
|
||||
if nyquist_freq is not None:
|
||||
pred_dict["annotation"] = [
|
||||
pred
|
||||
for pred in pred_dict["annotation"]
|
||||
if pred["high_freq"] <= nyquist_freq
|
||||
]
|
||||
|
||||
# combine into final results dictionary
|
||||
results: RunResults = {
|
||||
"pred_dict": pred_dict,
|
||||
@ -310,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
Args:
|
||||
results (dict): Results.
|
||||
op_path (str): Output path.
|
||||
|
||||
"""
|
||||
# make directory if it does not exist
|
||||
if not os.path.isdir(os.path.dirname(op_path)):
|
||||
@ -472,7 +496,6 @@ def iterate_over_chunks(
|
||||
chunk_start : float
|
||||
Start time of chunk in seconds.
|
||||
chunk : np.ndarray
|
||||
|
||||
"""
|
||||
nsamples = audio.shape[0]
|
||||
duration_full = nsamples / samplerate
|
||||
@ -678,7 +701,6 @@ def process_audio_array(
|
||||
The array is of shape (num_detections, num_features).
|
||||
spec : torch.Tensor
|
||||
Spectrogram of the audio used as input.
|
||||
|
||||
"""
|
||||
pred_nms, features, spec = _process_audio_array(
|
||||
audio,
|
||||
@ -730,6 +752,10 @@ def process_file(
|
||||
cnn_feats = []
|
||||
spec_slices = []
|
||||
|
||||
# Get original sampling rate
|
||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
audio_file,
|
||||
@ -757,7 +783,7 @@ def process_file(
|
||||
)
|
||||
|
||||
# convert to numpy
|
||||
spec_np = spec.detach().cpu().numpy()
|
||||
spec_np = spec.detach().cpu().numpy().squeeze()
|
||||
|
||||
# add chunk time to start and end times
|
||||
pred_nms["start_times"] += chunk_time
|
||||
@ -777,9 +803,7 @@ def process_file(
|
||||
|
||||
if config["spec_slices"]:
|
||||
# FIX: This is not currently working. Returns empty slices
|
||||
spec_slices.extend(
|
||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
||||
)
|
||||
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
|
||||
|
||||
# Merge results from chunks
|
||||
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
||||
@ -799,6 +823,7 @@ def process_file(
|
||||
spec_feats=spec_feats,
|
||||
cnn_feats=cnn_feats,
|
||||
spec_slices=spec_slices,
|
||||
nyquist_freq=orig_samp_rate / 2,
|
||||
)
|
||||
|
||||
# summarize results
|
||||
|
2
pdm.lock
generated
2
pdm.lock
generated
@ -453,7 +453,7 @@ summary = "Backport of pathlib-compatible object wrapper for zip files"
|
||||
|
||||
[metadata]
|
||||
lock_version = "4.1"
|
||||
content_hash = "sha256:2401b930c14b3b7e107372f0103cccebff74691b6bcd54148d832ce847df5673"
|
||||
content_hash = "sha256:667d4d2891fb85565cb04d84d0970eaac799bf272e3c4d7e4e6fea0b33c241fb"
|
||||
|
||||
[metadata.files]
|
||||
"appdirs 1.4.4" = [
|
||||
|
@ -6,7 +6,7 @@ dev = [
|
||||
|
||||
[project]
|
||||
name = "batdetect2"
|
||||
version = "1.0.0"
|
||||
version = "1.0.7"
|
||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||
authors = [
|
||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||
@ -19,7 +19,7 @@ dependencies = [
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"torch<2",
|
||||
"torch>=1.13.1,<2",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"click",
|
||||
@ -53,10 +53,10 @@ requires = ["pdm-pep517>=1.0.0"]
|
||||
build-backend = "pdm.pep517.api"
|
||||
|
||||
[project.scripts]
|
||||
batdetect2 = "bat_detect.cli:cli"
|
||||
batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
[tool.black]
|
||||
line-length = 80
|
||||
line-length = 79
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
|
BIN
tests/data/20230322_172000_selec2.wav
Normal file
BIN
tests/data/20230322_172000_selec2.wav
Normal file
Binary file not shown.
@ -1,11 +1,14 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import soundfile as sf
|
||||
|
||||
from batdetect2 import api
|
||||
|
||||
@ -262,3 +265,20 @@ def test_process_file_with_spec_slices():
|
||||
assert "spec_slices" in results
|
||||
assert isinstance(results["spec_slices"], list)
|
||||
assert len(results["spec_slices"]) == len(detections)
|
||||
|
||||
|
||||
|
||||
def test_process_file_with_empty_predictions_does_not_fail(
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test process file with empty predictions does not fail."""
|
||||
# Create empty file
|
||||
empty_file = tmp_path / "empty.wav"
|
||||
empty_wav = np.zeros((0, 1), dtype=np.float32)
|
||||
sf.write(empty_file, empty_wav, 256000)
|
||||
|
||||
# Process file
|
||||
results = api.process_file(str(empty_file))
|
||||
|
||||
assert results is not None
|
||||
assert len(results["pred_dict"]["annotation"]) == 0
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Test the command line interface."""
|
||||
from pathlib import Path
|
||||
from click.testing import CliRunner
|
||||
import pandas as pd
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
@ -42,3 +44,67 @@ def test_cli_detect_command_on_test_audio(tmp_path):
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||
assert len(list(results_dir.glob("*.json"))) == 3
|
||||
|
||||
|
||||
def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
||||
"""Test the detect command with a non-trivial time expansion factor."""
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
# Remove results dir if it exists
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--time_expansion_factor",
|
||||
"10",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert 'Time Expansion Factor: 10' in result.stdout
|
||||
|
||||
|
||||
|
||||
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
"""Test the detect command with the spec feature flag."""
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
# Remove results dir if it exists
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--spec_features",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
|
||||
|
||||
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||
|
||||
expected_files = [
|
||||
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
|
||||
]
|
||||
|
||||
for expected_file in expected_files:
|
||||
assert expected_file in csv_files
|
||||
|
||||
df = pd.read_csv(results_dir / expected_file)
|
||||
assert not (df.duration == -1).any()
|
||||
|
23
tests/test_detections.py
Normal file
23
tests/test_detections.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Test suite to ensure that model detections are not incorrect."""
|
||||
|
||||
import os
|
||||
|
||||
from batdetect2 import api
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
|
||||
def test_no_detections_above_nyquist():
|
||||
"""Test that no detections are made above the nyquist frequency."""
|
||||
# Recording donated by @@kdarras
|
||||
path = os.path.join(DATA_DIR, "20230322_172000_selec2.wav")
|
||||
|
||||
# This recording has a sampling rate of 192 kHz
|
||||
nyquist = 192_000 / 2
|
||||
|
||||
output = api.process_file(path)
|
||||
predictions = output["pred_dict"]
|
||||
assert len(predictions["annotation"]) != 0
|
||||
assert all(
|
||||
pred["high_freq"] < nyquist for pred in predictions["annotation"]
|
||||
)
|
291
tests/test_features.py
Normal file
291
tests/test_features.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""Test suite for feature extraction functions."""
|
||||
|
||||
import logging
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import batdetect2.detector.compute_features as feats
|
||||
from batdetect2 import api, types
|
||||
from batdetect2.utils import audio_utils as au
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def index_to_freq(
|
||||
index: int,
|
||||
spec_height: int,
|
||||
min_freq: int,
|
||||
max_freq: int,
|
||||
) -> float:
|
||||
"""Convert spectrogram index to frequency in Hz."""
|
||||
index = spec_height - index
|
||||
return round(
|
||||
(index / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||
)
|
||||
|
||||
|
||||
def index_to_time(
|
||||
index: int,
|
||||
spec_width: int,
|
||||
spec_duration: float,
|
||||
) -> float:
|
||||
"""Convert spectrogram index to time in seconds."""
|
||||
return round((index / float(spec_width)) * spec_duration, 2)
|
||||
|
||||
|
||||
def test_get_feats_function_with_empty_spectrogram():
|
||||
"""Test get_feats function with empty spectrogram.
|
||||
|
||||
This tests that the overall flow of the function works, even if the
|
||||
spectrogram is empty.
|
||||
"""
|
||||
spec_duration = 3
|
||||
spec_width = 100
|
||||
spec_height = 100
|
||||
min_freq = 10_000
|
||||
max_freq = 120_000
|
||||
spectrogram = np.zeros((spec_height, spec_width))
|
||||
|
||||
x_pos = 20
|
||||
y_pos = 80
|
||||
bb_width = 20
|
||||
bb_height = 20
|
||||
|
||||
start_time = index_to_time(x_pos, spec_width, spec_duration)
|
||||
end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration)
|
||||
low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq)
|
||||
high_freq = index_to_freq(
|
||||
y_pos - bb_height, spec_height, min_freq, max_freq
|
||||
)
|
||||
|
||||
pred_nms: types.PredictionResults = {
|
||||
"det_probs": np.array([1]),
|
||||
"class_probs": np.array([[1]]),
|
||||
"x_pos": np.array([x_pos]),
|
||||
"y_pos": np.array([y_pos]),
|
||||
"bb_width": np.array([bb_width]),
|
||||
"bb_height": np.array([bb_height]),
|
||||
"start_times": np.array([start_time]),
|
||||
"end_times": np.array([end_time]),
|
||||
"low_freqs": np.array([low_freq]),
|
||||
"high_freqs": np.array([high_freq]),
|
||||
}
|
||||
|
||||
params: types.FeatureExtractionParameters = {
|
||||
"min_freq": min_freq,
|
||||
"max_freq": max_freq,
|
||||
}
|
||||
|
||||
features = feats.get_feats(spectrogram, pred_nms, params)
|
||||
assert low_freq < high_freq
|
||||
assert isinstance(features, np.ndarray)
|
||||
assert features.shape == (len(pred_nms["det_probs"]), 9)
|
||||
assert np.isclose(
|
||||
features[0],
|
||||
np.array(
|
||||
[
|
||||
end_time - start_time,
|
||||
low_freq,
|
||||
high_freq,
|
||||
high_freq - low_freq,
|
||||
high_freq,
|
||||
max_freq,
|
||||
max_freq,
|
||||
max_freq,
|
||||
np.nan,
|
||||
]
|
||||
),
|
||||
equal_nan=True,
|
||||
).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_power",
|
||||
[
|
||||
30_000,
|
||||
31_000,
|
||||
32_000,
|
||||
33_000,
|
||||
34_000,
|
||||
35_000,
|
||||
36_000,
|
||||
37_000,
|
||||
38_000,
|
||||
39_000,
|
||||
40_000,
|
||||
],
|
||||
)
|
||||
def test_compute_max_power_bb(max_power: int):
|
||||
"""Test compute_max_power_bb function."""
|
||||
duration = 1
|
||||
samplerate = 256_000
|
||||
min_freq = 0
|
||||
max_freq = 128_000
|
||||
|
||||
start_time = 0.3
|
||||
end_time = 0.6
|
||||
low_freq = 30_000
|
||||
high_freq = 40_000
|
||||
|
||||
audio = np.zeros((int(duration * samplerate),))
|
||||
|
||||
# Add a signal during the time and frequency range of interest
|
||||
audio[
|
||||
int(start_time * samplerate) : int(end_time * samplerate)
|
||||
] = 0.5 * librosa.tone(
|
||||
max_power, sr=samplerate, duration=end_time - start_time
|
||||
)
|
||||
|
||||
# Add a more powerful signal outside frequency range of interest
|
||||
audio[
|
||||
int(start_time * samplerate) : int(end_time * samplerate)
|
||||
] += 2 * librosa.tone(
|
||||
80_000, sr=samplerate, duration=end_time - start_time
|
||||
)
|
||||
|
||||
params = api.get_config(
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
target_samp_rate=samplerate,
|
||||
)
|
||||
|
||||
spec, _ = au.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
)
|
||||
|
||||
x_start = int(
|
||||
au.time_to_x_coords(
|
||||
start_time,
|
||||
samplerate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
)
|
||||
|
||||
x_end = int(
|
||||
au.time_to_x_coords(
|
||||
end_time,
|
||||
samplerate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
)
|
||||
|
||||
num_freq_bins = spec.shape[0]
|
||||
y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq)
|
||||
y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq)
|
||||
|
||||
prediction: types.Prediction = {
|
||||
"det_prob": 1,
|
||||
"class_prob": np.ones((1,)),
|
||||
"x_pos": x_start,
|
||||
"y_pos": int(y_low),
|
||||
"bb_width": int(x_end - x_start),
|
||||
"bb_height": int(y_low - y_high),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
}
|
||||
|
||||
print(prediction)
|
||||
|
||||
max_power_bb = feats.compute_max_power_bb(
|
||||
prediction,
|
||||
spec,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
assert abs(max_power_bb - max_power) <= 500
|
||||
|
||||
|
||||
def test_compute_max_power():
|
||||
"""Test compute_max_power_bb function."""
|
||||
duration = 3
|
||||
samplerate = 16_000
|
||||
min_freq = 0
|
||||
max_freq = 8_000
|
||||
|
||||
start_time = 1
|
||||
end_time = 2
|
||||
low_freq = 3_000
|
||||
high_freq = 4_000
|
||||
max_power = 5_000
|
||||
|
||||
audio = np.zeros((int(duration * samplerate),))
|
||||
|
||||
# Add a signal during the time and frequency range of interest
|
||||
audio[
|
||||
int(start_time * samplerate) : int(end_time * samplerate)
|
||||
] = 0.5 * librosa.tone(
|
||||
3_500, sr=samplerate, duration=end_time - start_time
|
||||
)
|
||||
|
||||
# Add a more powerful signal outside frequency range of interest
|
||||
audio[
|
||||
int(start_time * samplerate) : int(end_time * samplerate)
|
||||
] += 2 * librosa.tone(
|
||||
max_power, sr=samplerate, duration=end_time - start_time
|
||||
)
|
||||
|
||||
params = api.get_config(
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
target_samp_rate=samplerate,
|
||||
)
|
||||
|
||||
spec, _ = au.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
)
|
||||
|
||||
x_start = int(
|
||||
au.time_to_x_coords(
|
||||
start_time,
|
||||
samplerate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
)
|
||||
|
||||
x_end = int(
|
||||
au.time_to_x_coords(
|
||||
end_time,
|
||||
samplerate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
)
|
||||
|
||||
num_freq_bins = spec.shape[0]
|
||||
y_low = int(num_freq_bins * low_freq / max_freq)
|
||||
y_high = int(num_freq_bins * high_freq / max_freq)
|
||||
|
||||
prediction: types.Prediction = {
|
||||
"det_prob": 1,
|
||||
"class_prob": np.ones((1,)),
|
||||
"x_pos": x_start,
|
||||
"y_pos": int(y_high),
|
||||
"bb_width": int(x_end - x_start),
|
||||
"bb_height": int(y_high - y_low),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
}
|
||||
|
||||
computed_max_power = feats.compute_max_power(
|
||||
prediction,
|
||||
spec,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
assert abs(computed_max_power - max_power) < 100
|
Loading…
Reference in New Issue
Block a user