mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Merge branch 'main' into train
This commit is contained in:
commit
a8f64d172b
39
.github/workflows/python-publish.yml
vendored
Normal file
39
.github/workflows/python-publish.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# This workflow will upload a Python Package using Twine when a release is created
|
||||||
|
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||||
|
|
||||||
|
# This workflow uses actions that are not certified by GitHub.
|
||||||
|
# They are provided by a third-party and are governed by
|
||||||
|
# separate terms of service, privacy policy, and support
|
||||||
|
# documentation.
|
||||||
|
|
||||||
|
name: Upload Python Package
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [published]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install build
|
||||||
|
- name: Build package
|
||||||
|
run: python -m build
|
||||||
|
- name: Publish package
|
||||||
|
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||||
|
with:
|
||||||
|
user: __token__
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
10
.gitignore
vendored
10
.gitignore
vendored
@ -65,7 +65,7 @@ ipython_config.py
|
|||||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
# in version control.
|
# in version control.
|
||||||
# https://pdm.fming.dev/#use-with-ide
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
.pdm.toml
|
.pdm-python
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
__pypackages__/
|
__pypackages__/
|
||||||
@ -102,7 +102,11 @@ experiments/*
|
|||||||
.virtual_documents
|
.virtual_documents
|
||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
*.ipynb
|
*.ipynb
|
||||||
!batdetect2_notebook.ipynb
|
|
||||||
|
|
||||||
# Batdetect Models [Include]
|
# Bump2version
|
||||||
|
.bumpversion.cfg
|
||||||
|
|
||||||
|
# DO Include
|
||||||
|
!batdetect2_notebook.ipynb
|
||||||
!batdetect2/models/*.pth.tar
|
!batdetect2/models/*.pth.tar
|
||||||
|
!tests/data/*.wav
|
||||||
|
@ -29,7 +29,7 @@ pip install batdetect2
|
|||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
|
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
|
||||||
Once unziped, run this from extracted folder.
|
Once unzipped, run this from extracted folder.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install .
|
pip install .
|
||||||
@ -98,7 +98,7 @@ You can integrate the detections or the extracted features to your custom analys
|
|||||||
|
|
||||||
|
|
||||||
## Training the model on your own data
|
## Training the model on your own data
|
||||||
Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model.
|
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
|
||||||
|
|
||||||
|
|
||||||
## Data and annotations
|
## Data and annotations
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '1.0.0'
|
__version__ = '1.0.7'
|
||||||
|
@ -5,6 +5,7 @@ import click
|
|||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||||
|
from batdetect2.types import ProcessingConfiguration
|
||||||
from batdetect2.utils.detector_utils import save_results_to_file
|
from batdetect2.utils.detector_utils import save_results_to_file
|
||||||
|
|
||||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
@ -77,6 +78,7 @@ def detect(
|
|||||||
audio_dir: str,
|
audio_dir: str,
|
||||||
ann_dir: str,
|
ann_dir: str,
|
||||||
detection_threshold: float,
|
detection_threshold: float,
|
||||||
|
time_expansion_factor: int,
|
||||||
**args,
|
**args,
|
||||||
):
|
):
|
||||||
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
||||||
@ -103,16 +105,23 @@ def detect(
|
|||||||
**{
|
**{
|
||||||
**params,
|
**params,
|
||||||
**args,
|
**args,
|
||||||
|
"time_expansion": time_expansion_factor,
|
||||||
"spec_slices": False,
|
"spec_slices": False,
|
||||||
"chunk_size": 2,
|
"chunk_size": 2,
|
||||||
"detection_threshold": detection_threshold,
|
"detection_threshold": detection_threshold,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not args["quiet"]:
|
||||||
|
print_config(config)
|
||||||
|
|
||||||
# process files
|
# process files
|
||||||
error_files = []
|
error_files = []
|
||||||
for audio_file in files:
|
for index, audio_file in enumerate(files):
|
||||||
try:
|
try:
|
||||||
|
if not args["quiet"]:
|
||||||
|
click.echo(f"\n{index} {audio_file}")
|
||||||
|
|
||||||
results = api.process_file(audio_file, model, config=config)
|
results = api.process_file(audio_file, model, config=config)
|
||||||
|
|
||||||
if args["save_preds_if_empty"] or (
|
if args["save_preds_if_empty"] or (
|
||||||
@ -133,5 +142,12 @@ def detect(
|
|||||||
click.echo(f" {err}")
|
click.echo(f" {err}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_config(config: ProcessingConfiguration):
|
||||||
|
"""Print the processing configuration."""
|
||||||
|
click.echo("\nProcessing Configuration:")
|
||||||
|
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||||
|
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
@ -1,22 +1,27 @@
|
|||||||
|
"""Functions to compute features from predictions."""
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from batdetect2 import types
|
||||||
|
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
|
||||||
|
|
||||||
|
|
||||||
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
||||||
|
"""Convert spectrogram index to frequency in Hz.""" ""
|
||||||
spec_ind = spec_height - spec_ind
|
spec_ind = spec_height - spec_ind
|
||||||
return round(
|
return round(
|
||||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_spec_slices(spec, pred_nms, params):
|
def extract_spec_slices(spec, pred_nms):
|
||||||
"""
|
"""Extract spectrogram slices from spectrogram.
|
||||||
Extracts spectrogram slices from spectrogram based on detected call locations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
The slices are extracted based on detected call locations.
|
||||||
|
"""
|
||||||
x_pos = pred_nms["x_pos"]
|
x_pos = pred_nms["x_pos"]
|
||||||
y_pos = pred_nms["y_pos"]
|
|
||||||
bb_width = pred_nms["bb_width"]
|
bb_width = pred_nms["bb_width"]
|
||||||
bb_height = pred_nms["bb_height"]
|
|
||||||
slices = []
|
slices = []
|
||||||
|
|
||||||
# add 20% padding either side of call
|
# add 20% padding either side of call
|
||||||
@ -35,100 +40,273 @@ def extract_spec_slices(spec, pred_nms, params):
|
|||||||
return slices
|
return slices
|
||||||
|
|
||||||
|
|
||||||
def get_feature_names():
|
def compute_duration(
|
||||||
feature_names = [
|
prediction: types.Prediction,
|
||||||
"duration",
|
**_,
|
||||||
"low_freq_bb",
|
) -> float:
|
||||||
"high_freq_bb",
|
"""Compute duration of call in seconds."""
|
||||||
"bandwidth",
|
return round(prediction["end_time"] - prediction["start_time"], 5)
|
||||||
"max_power_bb",
|
|
||||||
"max_power",
|
|
||||||
"max_power_first",
|
|
||||||
"max_power_second",
|
|
||||||
"call_interval",
|
|
||||||
]
|
|
||||||
return feature_names
|
|
||||||
|
|
||||||
|
|
||||||
def get_feats(spec, pred_nms, params):
|
def compute_low_freq(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute lowest frequency in call in Hz."""
|
||||||
|
return int(prediction["low_freq"])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_high_freq(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute highest frequency in call in Hz."""
|
||||||
|
return int(prediction["high_freq"])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bandwidth(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute bandwidth of call in Hz."""
|
||||||
|
return int(prediction["high_freq"] - prediction["low_freq"])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_power_bb(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
spec: Optional[np.ndarray] = None,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute frequency with maximum power in call in Hz.
|
||||||
|
|
||||||
|
This is the frequency with the maximum power in the bounding box of the
|
||||||
|
call.
|
||||||
"""
|
"""
|
||||||
Extracts features from spectrogram based on detected call locations.
|
if spec is None:
|
||||||
Condsider re-extracting spectrogram for this to get better temporal resolution.
|
return np.nan
|
||||||
|
|
||||||
|
x_start = max(0, prediction["x_pos"])
|
||||||
|
x_end = min(
|
||||||
|
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# y low is the lowest freq but it will have a higher value due to array
|
||||||
|
# starting at 0 at top
|
||||||
|
y_low = min(spec.shape[0] - 1, prediction["y_pos"])
|
||||||
|
y_high = max(0, prediction["y_pos"] - prediction["bb_height"])
|
||||||
|
|
||||||
|
spec_bb = spec[y_high:y_low, x_start:x_end]
|
||||||
|
power_per_freq_band = np.sum(spec_bb, axis=1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
|
except ValueError:
|
||||||
|
# If the call is too short, the bounding box might be empty.
|
||||||
|
# In this case, return NaN.
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
return int(
|
||||||
|
convert_int_to_freq(
|
||||||
|
y_high + max_power_ind,
|
||||||
|
spec.shape[0],
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_power(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
spec: Optional[np.ndarray] = None,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute frequency with maximum power in during the call in Hz."""
|
||||||
|
if spec is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
x_start = max(0, prediction["x_pos"])
|
||||||
|
x_end = min(
|
||||||
|
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||||
|
)
|
||||||
|
spec_call = spec[:, x_start:x_end]
|
||||||
|
power_per_freq_band = np.sum(spec_call, axis=1)
|
||||||
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
|
return int(
|
||||||
|
convert_int_to_freq(
|
||||||
|
max_power_ind,
|
||||||
|
spec.shape[0],
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_power_first(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
spec: Optional[np.ndarray] = None,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute frequency with maximum power in first half of call in Hz."""
|
||||||
|
if spec is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
x_start = max(0, prediction["x_pos"])
|
||||||
|
x_end = min(
|
||||||
|
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||||
|
)
|
||||||
|
spec_call = spec[:, x_start:x_end]
|
||||||
|
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
|
||||||
|
power_per_freq_band = np.sum(first_half, axis=1)
|
||||||
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
|
return int(
|
||||||
|
convert_int_to_freq(
|
||||||
|
max_power_ind,
|
||||||
|
spec.shape[0],
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_power_second(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
spec: Optional[np.ndarray] = None,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute frequency with maximum power in second half of call in Hz."""
|
||||||
|
if spec is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
x_start = max(0, prediction["x_pos"])
|
||||||
|
x_end = min(
|
||||||
|
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
|
||||||
|
)
|
||||||
|
spec_call = spec[:, x_start:x_end]
|
||||||
|
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
|
||||||
|
power_per_freq_band = np.sum(second_half, axis=1)
|
||||||
|
max_power_ind = np.argmax(power_per_freq_band)
|
||||||
|
return int(
|
||||||
|
convert_int_to_freq(
|
||||||
|
max_power_ind,
|
||||||
|
spec.shape[0],
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_call_interval(
|
||||||
|
prediction: types.Prediction,
|
||||||
|
previous: Optional[types.Prediction] = None,
|
||||||
|
**_,
|
||||||
|
) -> float:
|
||||||
|
"""Compute time between this call and the previous call in seconds."""
|
||||||
|
if previous is None:
|
||||||
|
return np.nan
|
||||||
|
return round(prediction["start_time"] - previous["end_time"], 5)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: The order of the features in this dictionary is important. The
|
||||||
|
# features are extracted in this order and the order of the columns in the
|
||||||
|
# output csv file is determined by this order. In order to avoid breaking
|
||||||
|
# changes in the output csv file, new features should be added to the end of
|
||||||
|
# this dictionary.
|
||||||
|
FEATURES: Dict[str, types.FeatureExtractor] = {
|
||||||
|
"duration": compute_duration,
|
||||||
|
"low_freq_bb": compute_low_freq,
|
||||||
|
"high_freq_bb": compute_high_freq,
|
||||||
|
"bandwidth": compute_bandwidth,
|
||||||
|
"max_power_bb": compute_max_power_bb,
|
||||||
|
"max_power": compute_max_power,
|
||||||
|
"max_power_first": compute_max_power_first,
|
||||||
|
"max_power_second": compute_max_power_second,
|
||||||
|
"call_interval": compute_call_interval,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_feats(
|
||||||
|
spec: np.ndarray,
|
||||||
|
pred_nms: types.PredictionResults,
|
||||||
|
params: types.FeatureExtractionParameters,
|
||||||
|
):
|
||||||
|
"""Extract features from spectrogram based on detected call locations.
|
||||||
|
|
||||||
|
The features extracted are:
|
||||||
|
|
||||||
|
- duration: duration of call in seconds
|
||||||
|
- low_freq: lowest frequency in call in kHz
|
||||||
|
- high_freq: highest frequency in call in kHz
|
||||||
|
- bandwidth: high_freq - low_freq
|
||||||
|
- max_power_bb: frequency with maximum power in call in kHz
|
||||||
|
- max_power: frequency with maximum power in spectrogram in kHz
|
||||||
|
- max_power_first: frequency with maximum power in first half of call in
|
||||||
|
kHz.
|
||||||
|
- max_power_second: frequency with maximum power in second half of call in
|
||||||
|
kHz.
|
||||||
|
- call_interval: time between this call and the previous call in seconds
|
||||||
|
|
||||||
|
Consider re-extracting spectrogram for this to get better temporal
|
||||||
|
resolution.
|
||||||
|
|
||||||
For more possible features check out:
|
For more possible features check out:
|
||||||
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
|
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : np.ndarray
|
||||||
|
Spectrogram from which to extract features.
|
||||||
|
|
||||||
|
pred_nms : types.PredictionResults
|
||||||
|
Information about detected calls from which to extract features.
|
||||||
|
|
||||||
|
params : types.FeatureExtractionParameters
|
||||||
|
Parameters for feature extraction.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features : np.ndarray
|
||||||
|
Extracted features for each detected call. Shape is
|
||||||
|
(num_detections, num_features).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x_pos = pred_nms["x_pos"]
|
|
||||||
y_pos = pred_nms["y_pos"]
|
|
||||||
bb_width = pred_nms["bb_width"]
|
|
||||||
bb_height = pred_nms["bb_height"]
|
|
||||||
|
|
||||||
feature_names = get_feature_names()
|
|
||||||
num_detections = len(pred_nms["det_probs"])
|
num_detections = len(pred_nms["det_probs"])
|
||||||
features = (
|
features = np.empty((num_detections, len(FEATURES)), dtype=np.float32)
|
||||||
np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1
|
previous = None
|
||||||
)
|
|
||||||
|
|
||||||
for ff in range(num_detections):
|
for row in range(num_detections):
|
||||||
x_start = int(np.maximum(0, x_pos[ff]))
|
prediction: types.Prediction = {
|
||||||
x_end = int(
|
"det_prob": float(pred_nms["det_probs"][row]),
|
||||||
np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff]))
|
"class_prob": pred_nms["class_probs"][:, row],
|
||||||
)
|
"start_time": float(pred_nms["start_times"][row]),
|
||||||
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top
|
"end_time": float(pred_nms["end_times"][row]),
|
||||||
y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff]))
|
"low_freq": float(pred_nms["low_freqs"][row]),
|
||||||
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
|
"high_freq": float(pred_nms["high_freqs"][row]),
|
||||||
spec_slice = spec[:, x_start:x_end]
|
"x_pos": int(pred_nms["x_pos"][row]),
|
||||||
|
"y_pos": int(pred_nms["y_pos"][row]),
|
||||||
|
"bb_width": int(pred_nms["bb_width"][row]),
|
||||||
|
"bb_height": int(pred_nms["bb_height"][row]),
|
||||||
|
}
|
||||||
|
|
||||||
if spec_slice.shape[1] > 1:
|
for col, feature in enumerate(FEATURES.values()):
|
||||||
features[ff, 0] = round(
|
features[row, col] = feature(
|
||||||
pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5
|
prediction,
|
||||||
)
|
previous=previous,
|
||||||
features[ff, 1] = int(pred_nms["low_freqs"][ff])
|
spec=spec,
|
||||||
features[ff, 2] = int(pred_nms["high_freqs"][ff])
|
**params,
|
||||||
features[ff, 3] = int(
|
|
||||||
pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff]
|
|
||||||
)
|
|
||||||
features[ff, 4] = int(
|
|
||||||
convert_int_to_freq(
|
|
||||||
y_high + spec_slice[y_high:y_low, :].sum(1).argmax(),
|
|
||||||
spec.shape[0],
|
|
||||||
params["min_freq"],
|
|
||||||
params["max_freq"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
features[ff, 5] = int(
|
|
||||||
convert_int_to_freq(
|
|
||||||
spec_slice.sum(1).argmax(),
|
|
||||||
spec.shape[0],
|
|
||||||
params["min_freq"],
|
|
||||||
params["max_freq"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
hlf_val = spec_slice.shape[1] // 2
|
|
||||||
|
|
||||||
features[ff, 6] = int(
|
|
||||||
convert_int_to_freq(
|
|
||||||
spec_slice[:, :hlf_val].sum(1).argmax(),
|
|
||||||
spec.shape[0],
|
|
||||||
params["min_freq"],
|
|
||||||
params["max_freq"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
features[ff, 7] = int(
|
|
||||||
convert_int_to_freq(
|
|
||||||
spec_slice[:, hlf_val:].sum(1).argmax(),
|
|
||||||
spec.shape[0],
|
|
||||||
params["min_freq"],
|
|
||||||
params["max_freq"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if ff > 0:
|
previous = prediction
|
||||||
features[ff, 8] = round(
|
|
||||||
pred_nms["start_times"][ff]
|
|
||||||
- pred_nms["start_times"][ff - 1],
|
|
||||||
5,
|
|
||||||
)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_names():
|
||||||
|
"""Get names of features in the order they are extracted."""
|
||||||
|
return list(FEATURES.keys())
|
||||||
|
@ -19,13 +19,15 @@ def x_coords_to_time(
|
|||||||
) -> float:
|
) -> float:
|
||||||
"""Convert x coordinates of spectrogram to time in seconds.
|
"""Convert x coordinates of spectrogram to time in seconds.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
|
----------
|
||||||
x_pos: X position of the detection in pixels.
|
x_pos: X position of the detection in pixels.
|
||||||
sampling_rate: Sampling rate of the audio in Hz.
|
sampling_rate: Sampling rate of the audio in Hz.
|
||||||
fft_win_length: Length of the FFT window in seconds.
|
fft_win_length: Length of the FFT window in seconds.
|
||||||
fft_overlap: Overlap of the FFT windows in seconds.
|
fft_overlap: Overlap of the FFT windows in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
|
-------
|
||||||
Time in seconds.
|
Time in seconds.
|
||||||
"""
|
"""
|
||||||
nfft = int(fft_win_length * sampling_rate)
|
nfft = int(fft_win_length * sampling_rate)
|
||||||
@ -134,12 +136,12 @@ def run_nms(
|
|||||||
y_pos[num_detection, valid_inds],
|
y_pos[num_detection, valid_inds],
|
||||||
x_pos[num_detection, valid_inds],
|
x_pos[num_detection, valid_inds],
|
||||||
].transpose(0, 1)
|
].transpose(0, 1)
|
||||||
feat = feat.detach().numpy().astype(np.float32)
|
feat = feat.detach().cpu().numpy().astype(np.float32)
|
||||||
feats.append(feat)
|
feats.append(feat)
|
||||||
|
|
||||||
# convert to numpy
|
# convert to numpy
|
||||||
for key, value in pred.items():
|
for key, value in pred.items():
|
||||||
pred[key] = value.detach().numpy().astype(np.float32)
|
pred[key] = value.detach().cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
preds.append(pred) # type: ignore
|
preds.append(pred) # type: ignore
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ def spectrogram(
|
|||||||
"""
|
"""
|
||||||
# Convert to numpy array if needed
|
# Convert to numpy array if needed
|
||||||
if isinstance(spec, torch.Tensor):
|
if isinstance(spec, torch.Tensor):
|
||||||
spec = spec.numpy()
|
spec = spec.detach().cpu().numpy()
|
||||||
|
|
||||||
# Remove batch and channel dimensions if present
|
# Remove batch and channel dimensions if present
|
||||||
spec = spec.squeeze()
|
spec = spec.squeeze()
|
||||||
@ -265,7 +265,7 @@ def detection(
|
|||||||
# Add class label
|
# Add class label
|
||||||
txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
|
txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
|
||||||
font_info = {
|
font_info = {
|
||||||
"color": "white",
|
"color": edgecolor,
|
||||||
"size": 10,
|
"size": 10,
|
||||||
"weight": "bold",
|
"weight": "bold",
|
||||||
"alpha": rect.get_alpha(),
|
"alpha": rect.get_alpha(),
|
||||||
|
@ -7,15 +7,14 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.detector import models
|
|
||||||
from batdetect2.detector import parameters
|
|
||||||
from batdetect2.train import losses
|
|
||||||
import batdetect2.detector.post_process as pp
|
import batdetect2.detector.post_process as pp
|
||||||
import batdetect2.train.audio_dataloader as adl
|
import batdetect2.train.audio_dataloader as adl
|
||||||
import batdetect2.train.evaluate as evl
|
import batdetect2.train.evaluate as evl
|
||||||
import batdetect2.train.train_split as ts
|
import batdetect2.train.train_split as ts
|
||||||
import batdetect2.train.train_utils as tu
|
import batdetect2.train.train_utils as tu
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
|
from batdetect2.detector import models, parameters
|
||||||
|
from batdetect2.train import losses
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
@ -84,7 +83,6 @@ def save_image(
|
|||||||
def loss_fun(
|
def loss_fun(
|
||||||
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
||||||
):
|
):
|
||||||
|
|
||||||
# detection loss
|
# detection loss
|
||||||
loss = params["det_loss_weight"] * det_criterion(
|
loss = params["det_loss_weight"] * det_criterion(
|
||||||
outputs["pred_det"], gt_det
|
outputs["pred_det"], gt_det
|
||||||
@ -108,7 +106,6 @@ def loss_fun(
|
|||||||
def train(
|
def train(
|
||||||
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
||||||
):
|
):
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
train_loss = tu.AverageMeter()
|
train_loss = tu.AverageMeter()
|
||||||
@ -119,7 +116,6 @@ def train(
|
|||||||
|
|
||||||
print("\nEpoch", epoch)
|
print("\nEpoch", epoch)
|
||||||
for batch_idx, inputs in enumerate(data_loader):
|
for batch_idx, inputs in enumerate(data_loader):
|
||||||
|
|
||||||
data = inputs["spec"].to(params["device"])
|
data = inputs["spec"].to(params["device"])
|
||||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||||
@ -172,7 +168,6 @@ def test(model, epoch, data_loader, det_criterion, params):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, inputs in enumerate(data_loader):
|
for batch_idx, inputs in enumerate(data_loader):
|
||||||
|
|
||||||
data = inputs["spec"].to(params["device"])
|
data = inputs["spec"].to(params["device"])
|
||||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||||
@ -279,7 +274,7 @@ def parse_gt_data(inputs):
|
|||||||
is_valid = inputs["is_valid"][ind] == 1
|
is_valid = inputs["is_valid"][ind] == 1
|
||||||
gt = {}
|
gt = {}
|
||||||
for kk in keys:
|
for kk in keys:
|
||||||
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32)
|
gt[kk] = inputs[kk][ind][is_valid].cpu().numpy().astype(np.float32)
|
||||||
gt["duration"] = inputs["duration"][ind].item()
|
gt["duration"] = inputs["duration"][ind].item()
|
||||||
gt["file_id"] = inputs["file_id"][ind].item()
|
gt["file_id"] = inputs["file_id"][ind].item()
|
||||||
gt["class_id_file"] = inputs["class_id_file"][ind].item()
|
gt["class_id_file"] = inputs["class_id_file"][ind].item()
|
||||||
@ -318,8 +313,7 @@ def select_model(params):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
|
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
|
|
||||||
params = parameters.get_params(True)
|
params = parameters.get_params(True)
|
||||||
@ -501,7 +495,6 @@ if __name__ == "__main__":
|
|||||||
#
|
#
|
||||||
# main train loop
|
# main train loop
|
||||||
for epoch in range(0, params["num_epochs"] + 1):
|
for epoch in range(0, params["num_epochs"] + 1):
|
||||||
|
|
||||||
train_loss = train(
|
train_loss = train(
|
||||||
model,
|
model,
|
||||||
epoch,
|
epoch,
|
||||||
@ -550,3 +543,7 @@ if __name__ == "__main__":
|
|||||||
# TODO: args variable does not exist
|
# TODO: args variable does not exist
|
||||||
# if not args["do_not_save_images"]:
|
# if not args["do_not_save_images"]:
|
||||||
# save_images_batch(model, test_loader, params)
|
# save_images_batch(model, test_loader, params)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
from typing import List, NamedTuple, Optional
|
from typing import List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -25,10 +25,13 @@ except ImportError:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Annotation",
|
"Annotation",
|
||||||
"DetectionModel",
|
"DetectionModel",
|
||||||
|
"FeatureExtractionParameters",
|
||||||
|
"FeatureExtractor",
|
||||||
"FileAnnotations",
|
"FileAnnotations",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"ModelParameters",
|
"ModelParameters",
|
||||||
"NonMaximumSuppressionConfig",
|
"NonMaximumSuppressionConfig",
|
||||||
|
"Prediction",
|
||||||
"PredictionResults",
|
"PredictionResults",
|
||||||
"ProcessingConfiguration",
|
"ProcessingConfiguration",
|
||||||
"ResultParams",
|
"ResultParams",
|
||||||
@ -316,6 +319,40 @@ class ModelOutput(NamedTuple):
|
|||||||
"""Tensor with intermediate features."""
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
|
class Prediction(TypedDict):
|
||||||
|
"""Singe prediction."""
|
||||||
|
|
||||||
|
det_prob: float
|
||||||
|
"""Detection probability."""
|
||||||
|
|
||||||
|
x_pos: int
|
||||||
|
"""X position of the detection in pixels."""
|
||||||
|
|
||||||
|
y_pos: int
|
||||||
|
"""Y position of the detection in pixels."""
|
||||||
|
|
||||||
|
bb_width: int
|
||||||
|
"""Width of the detection in pixels."""
|
||||||
|
|
||||||
|
bb_height: int
|
||||||
|
"""Height of the detection in pixels."""
|
||||||
|
|
||||||
|
start_time: float
|
||||||
|
"""Start time of the detection in seconds."""
|
||||||
|
|
||||||
|
end_time: float
|
||||||
|
"""End time of the detection in seconds."""
|
||||||
|
|
||||||
|
low_freq: float
|
||||||
|
"""Low frequency of the detection in Hz."""
|
||||||
|
|
||||||
|
high_freq: float
|
||||||
|
"""High frequency of the detection in Hz."""
|
||||||
|
|
||||||
|
class_prob: np.ndarray
|
||||||
|
"""Vector holding the probability of each class."""
|
||||||
|
|
||||||
|
|
||||||
class PredictionResults(TypedDict):
|
class PredictionResults(TypedDict):
|
||||||
"""Results of the prediction.
|
"""Results of the prediction.
|
||||||
|
|
||||||
@ -422,6 +459,16 @@ class NonMaximumSuppressionConfig(TypedDict):
|
|||||||
"""Threshold for detection probability."""
|
"""Threshold for detection probability."""
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractionParameters(TypedDict):
|
||||||
|
"""Parameters that control the feature extraction function."""
|
||||||
|
|
||||||
|
min_freq: int
|
||||||
|
"""Minimum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
max_freq: int
|
||||||
|
"""Maximum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
|
||||||
class HeatmapParameters(TypedDict):
|
class HeatmapParameters(TypedDict):
|
||||||
"""Parameters that control the heatmap generation function."""
|
"""Parameters that control the heatmap generation function."""
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: Input directory not found.
|
FileNotFoundError: Input directory not found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
matches = []
|
matches = []
|
||||||
for root, _, filenames in os.walk(ip_dir):
|
for root, _, filenames in os.walk(ip_dir):
|
||||||
@ -143,7 +143,19 @@ def load_model(
|
|||||||
|
|
||||||
|
|
||||||
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||||
predictions_m = {}
|
predictions_m = {
|
||||||
|
"det_probs": np.array([]),
|
||||||
|
"x_pos": np.array([]),
|
||||||
|
"y_pos": np.array([]),
|
||||||
|
"bb_widths": np.array([]),
|
||||||
|
"bb_heights": np.array([]),
|
||||||
|
"start_times": np.array([]),
|
||||||
|
"end_times": np.array([]),
|
||||||
|
"low_freqs": np.array([]),
|
||||||
|
"high_freqs": np.array([]),
|
||||||
|
"class_probs": np.array([]),
|
||||||
|
}
|
||||||
|
|
||||||
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
||||||
|
|
||||||
if num_preds > 0:
|
if num_preds > 0:
|
||||||
@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
|||||||
predictions_m[key] = np.hstack(
|
predictions_m[key] = np.hstack(
|
||||||
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# hack in case where no detected calls as we need some of the key
|
|
||||||
# names in dict
|
|
||||||
predictions_m = predictions[0]
|
|
||||||
|
|
||||||
if len(spec_feats) > 0:
|
if len(spec_feats) > 0:
|
||||||
spec_feats = np.vstack(spec_feats)
|
spec_feats = np.vstack(spec_feats)
|
||||||
@ -226,11 +234,19 @@ def format_single_result(
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Results in the format expected by the annotation tool.
|
dict: Results in the format expected by the annotation tool.
|
||||||
"""
|
"""
|
||||||
# Get a single class prediction for the file
|
try:
|
||||||
class_overall = pp.overall_class_pred(
|
# Get a single class prediction for the file
|
||||||
predictions["det_probs"],
|
class_overall = pp.overall_class_pred(
|
||||||
predictions["class_probs"],
|
predictions["det_probs"],
|
||||||
)
|
predictions["class_probs"],
|
||||||
|
)
|
||||||
|
class_name = class_names[np.argmax(class_overall)]
|
||||||
|
annotations = get_annotations_from_preds(predictions, class_names)
|
||||||
|
except (np.AxisError, ValueError):
|
||||||
|
# No detections
|
||||||
|
class_overall = np.zeros(len(class_names))
|
||||||
|
class_name = "None"
|
||||||
|
annotations = []
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": file_id,
|
"id": file_id,
|
||||||
@ -239,8 +255,8 @@ def format_single_result(
|
|||||||
"notes": "Automatically generated.",
|
"notes": "Automatically generated.",
|
||||||
"time_exp": time_exp,
|
"time_exp": time_exp,
|
||||||
"duration": round(float(duration), 4),
|
"duration": round(float(duration), 4),
|
||||||
"annotation": get_annotations_from_preds(predictions, class_names),
|
"annotation": annotations,
|
||||||
"class_name": class_names[np.argmax(class_overall)],
|
"class_name": class_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -253,6 +269,7 @@ def convert_results(
|
|||||||
spec_feats,
|
spec_feats,
|
||||||
cnn_feats,
|
cnn_feats,
|
||||||
spec_slices,
|
spec_slices,
|
||||||
|
nyquist_freq: Optional[float] = None,
|
||||||
) -> RunResults:
|
) -> RunResults:
|
||||||
"""Convert results to dictionary as expected by the annotation tool.
|
"""Convert results to dictionary as expected by the annotation tool.
|
||||||
|
|
||||||
@ -268,8 +285,8 @@ def convert_results(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary with results.
|
dict: Dictionary with results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pred_dict = format_single_result(
|
pred_dict = format_single_result(
|
||||||
file_id,
|
file_id,
|
||||||
time_exp,
|
time_exp,
|
||||||
@ -278,6 +295,14 @@ def convert_results(
|
|||||||
params["class_names"],
|
params["class_names"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove high frequency detections
|
||||||
|
if nyquist_freq is not None:
|
||||||
|
pred_dict["annotation"] = [
|
||||||
|
pred
|
||||||
|
for pred in pred_dict["annotation"]
|
||||||
|
if pred["high_freq"] <= nyquist_freq
|
||||||
|
]
|
||||||
|
|
||||||
# combine into final results dictionary
|
# combine into final results dictionary
|
||||||
results: RunResults = {
|
results: RunResults = {
|
||||||
"pred_dict": pred_dict,
|
"pred_dict": pred_dict,
|
||||||
@ -310,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
Args:
|
Args:
|
||||||
results (dict): Results.
|
results (dict): Results.
|
||||||
op_path (str): Output path.
|
op_path (str): Output path.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# make directory if it does not exist
|
# make directory if it does not exist
|
||||||
if not os.path.isdir(os.path.dirname(op_path)):
|
if not os.path.isdir(os.path.dirname(op_path)):
|
||||||
@ -472,7 +496,6 @@ def iterate_over_chunks(
|
|||||||
chunk_start : float
|
chunk_start : float
|
||||||
Start time of chunk in seconds.
|
Start time of chunk in seconds.
|
||||||
chunk : np.ndarray
|
chunk : np.ndarray
|
||||||
|
|
||||||
"""
|
"""
|
||||||
nsamples = audio.shape[0]
|
nsamples = audio.shape[0]
|
||||||
duration_full = nsamples / samplerate
|
duration_full = nsamples / samplerate
|
||||||
@ -678,7 +701,6 @@ def process_audio_array(
|
|||||||
The array is of shape (num_detections, num_features).
|
The array is of shape (num_detections, num_features).
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Spectrogram of the audio used as input.
|
Spectrogram of the audio used as input.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pred_nms, features, spec = _process_audio_array(
|
pred_nms, features, spec = _process_audio_array(
|
||||||
audio,
|
audio,
|
||||||
@ -730,6 +752,10 @@ def process_file(
|
|||||||
cnn_feats = []
|
cnn_feats = []
|
||||||
spec_slices = []
|
spec_slices = []
|
||||||
|
|
||||||
|
# Get original sampling rate
|
||||||
|
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||||
|
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
audio_file,
|
audio_file,
|
||||||
@ -757,7 +783,7 @@ def process_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# convert to numpy
|
# convert to numpy
|
||||||
spec_np = spec.detach().cpu().numpy()
|
spec_np = spec.detach().cpu().numpy().squeeze()
|
||||||
|
|
||||||
# add chunk time to start and end times
|
# add chunk time to start and end times
|
||||||
pred_nms["start_times"] += chunk_time
|
pred_nms["start_times"] += chunk_time
|
||||||
@ -777,9 +803,7 @@ def process_file(
|
|||||||
|
|
||||||
if config["spec_slices"]:
|
if config["spec_slices"]:
|
||||||
# FIX: This is not currently working. Returns empty slices
|
# FIX: This is not currently working. Returns empty slices
|
||||||
spec_slices.extend(
|
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
|
||||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge results from chunks
|
# Merge results from chunks
|
||||||
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
||||||
@ -799,6 +823,7 @@ def process_file(
|
|||||||
spec_feats=spec_feats,
|
spec_feats=spec_feats,
|
||||||
cnn_feats=cnn_feats,
|
cnn_feats=cnn_feats,
|
||||||
spec_slices=spec_slices,
|
spec_slices=spec_slices,
|
||||||
|
nyquist_freq=orig_samp_rate / 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# summarize results
|
# summarize results
|
||||||
|
2
pdm.lock
generated
2
pdm.lock
generated
@ -453,7 +453,7 @@ summary = "Backport of pathlib-compatible object wrapper for zip files"
|
|||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock_version = "4.1"
|
lock_version = "4.1"
|
||||||
content_hash = "sha256:2401b930c14b3b7e107372f0103cccebff74691b6bcd54148d832ce847df5673"
|
content_hash = "sha256:667d4d2891fb85565cb04d84d0970eaac799bf272e3c4d7e4e6fea0b33c241fb"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
"appdirs 1.4.4" = [
|
"appdirs 1.4.4" = [
|
||||||
|
@ -6,7 +6,7 @@ dev = [
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "batdetect2"
|
name = "batdetect2"
|
||||||
version = "1.0.0"
|
version = "1.0.7"
|
||||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||||
@ -19,7 +19,7 @@ dependencies = [
|
|||||||
"pandas",
|
"pandas",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
"torch<2",
|
"torch>=1.13.1,<2",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"click",
|
"click",
|
||||||
@ -53,10 +53,10 @@ requires = ["pdm-pep517>=1.0.0"]
|
|||||||
build-backend = "pdm.pep517.api"
|
build-backend = "pdm.pep517.api"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
batdetect2 = "bat_detect.cli:cli"
|
batdetect2 = "batdetect2.cli:cli"
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 80
|
line-length = 79
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = [
|
module = [
|
||||||
|
BIN
tests/data/20230322_172000_selec2.wav
Normal file
BIN
tests/data/20230322_172000_selec2.wav
Normal file
Binary file not shown.
@ -1,11 +1,14 @@
|
|||||||
"""Test bat detect module API."""
|
"""Test bat detect module API."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
|
||||||
@ -262,3 +265,20 @@ def test_process_file_with_spec_slices():
|
|||||||
assert "spec_slices" in results
|
assert "spec_slices" in results
|
||||||
assert isinstance(results["spec_slices"], list)
|
assert isinstance(results["spec_slices"], list)
|
||||||
assert len(results["spec_slices"]) == len(detections)
|
assert len(results["spec_slices"]) == len(detections)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_file_with_empty_predictions_does_not_fail(
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
"""Test process file with empty predictions does not fail."""
|
||||||
|
# Create empty file
|
||||||
|
empty_file = tmp_path / "empty.wav"
|
||||||
|
empty_wav = np.zeros((0, 1), dtype=np.float32)
|
||||||
|
sf.write(empty_file, empty_wav, 256000)
|
||||||
|
|
||||||
|
# Process file
|
||||||
|
results = api.process_file(str(empty_file))
|
||||||
|
|
||||||
|
assert results is not None
|
||||||
|
assert len(results["pred_dict"]["annotation"]) == 0
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test the command line interface."""
|
"""Test the command line interface."""
|
||||||
|
from pathlib import Path
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
@ -42,3 +44,67 @@ def test_cli_detect_command_on_test_audio(tmp_path):
|
|||||||
assert results_dir.exists()
|
assert results_dir.exists()
|
||||||
assert len(list(results_dir.glob("*.csv"))) == 3
|
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||||
assert len(list(results_dir.glob("*.json"))) == 3
|
assert len(list(results_dir.glob("*.json"))) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
||||||
|
"""Test the detect command with a non-trivial time expansion factor."""
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
|
||||||
|
# Remove results dir if it exists
|
||||||
|
if results_dir.exists():
|
||||||
|
results_dir.rmdir()
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
"example_data/audio",
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
"--time_expansion_factor",
|
||||||
|
"10",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert 'Time Expansion Factor: 10' in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||||
|
"""Test the detect command with the spec feature flag."""
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
|
||||||
|
# Remove results dir if it exists
|
||||||
|
if results_dir.exists():
|
||||||
|
results_dir.rmdir()
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
"example_data/audio",
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
"--spec_features",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert results_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
|
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||||
|
|
||||||
|
expected_files = [
|
||||||
|
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||||
|
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||||
|
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
|
||||||
|
]
|
||||||
|
|
||||||
|
for expected_file in expected_files:
|
||||||
|
assert expected_file in csv_files
|
||||||
|
|
||||||
|
df = pd.read_csv(results_dir / expected_file)
|
||||||
|
assert not (df.duration == -1).any()
|
||||||
|
23
tests/test_detections.py
Normal file
23
tests/test_detections.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""Test suite to ensure that model detections are not incorrect."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from batdetect2 import api
|
||||||
|
|
||||||
|
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_detections_above_nyquist():
|
||||||
|
"""Test that no detections are made above the nyquist frequency."""
|
||||||
|
# Recording donated by @@kdarras
|
||||||
|
path = os.path.join(DATA_DIR, "20230322_172000_selec2.wav")
|
||||||
|
|
||||||
|
# This recording has a sampling rate of 192 kHz
|
||||||
|
nyquist = 192_000 / 2
|
||||||
|
|
||||||
|
output = api.process_file(path)
|
||||||
|
predictions = output["pred_dict"]
|
||||||
|
assert len(predictions["annotation"]) != 0
|
||||||
|
assert all(
|
||||||
|
pred["high_freq"] < nyquist for pred in predictions["annotation"]
|
||||||
|
)
|
291
tests/test_features.py
Normal file
291
tests/test_features.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
"""Test suite for feature extraction functions."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import batdetect2.detector.compute_features as feats
|
||||||
|
from batdetect2 import api, types
|
||||||
|
from batdetect2.utils import audio_utils as au
|
||||||
|
|
||||||
|
numba_logger = logging.getLogger("numba")
|
||||||
|
numba_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
def index_to_freq(
|
||||||
|
index: int,
|
||||||
|
spec_height: int,
|
||||||
|
min_freq: int,
|
||||||
|
max_freq: int,
|
||||||
|
) -> float:
|
||||||
|
"""Convert spectrogram index to frequency in Hz."""
|
||||||
|
index = spec_height - index
|
||||||
|
return round(
|
||||||
|
(index / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def index_to_time(
|
||||||
|
index: int,
|
||||||
|
spec_width: int,
|
||||||
|
spec_duration: float,
|
||||||
|
) -> float:
|
||||||
|
"""Convert spectrogram index to time in seconds."""
|
||||||
|
return round((index / float(spec_width)) * spec_duration, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_feats_function_with_empty_spectrogram():
|
||||||
|
"""Test get_feats function with empty spectrogram.
|
||||||
|
|
||||||
|
This tests that the overall flow of the function works, even if the
|
||||||
|
spectrogram is empty.
|
||||||
|
"""
|
||||||
|
spec_duration = 3
|
||||||
|
spec_width = 100
|
||||||
|
spec_height = 100
|
||||||
|
min_freq = 10_000
|
||||||
|
max_freq = 120_000
|
||||||
|
spectrogram = np.zeros((spec_height, spec_width))
|
||||||
|
|
||||||
|
x_pos = 20
|
||||||
|
y_pos = 80
|
||||||
|
bb_width = 20
|
||||||
|
bb_height = 20
|
||||||
|
|
||||||
|
start_time = index_to_time(x_pos, spec_width, spec_duration)
|
||||||
|
end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration)
|
||||||
|
low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq)
|
||||||
|
high_freq = index_to_freq(
|
||||||
|
y_pos - bb_height, spec_height, min_freq, max_freq
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_nms: types.PredictionResults = {
|
||||||
|
"det_probs": np.array([1]),
|
||||||
|
"class_probs": np.array([[1]]),
|
||||||
|
"x_pos": np.array([x_pos]),
|
||||||
|
"y_pos": np.array([y_pos]),
|
||||||
|
"bb_width": np.array([bb_width]),
|
||||||
|
"bb_height": np.array([bb_height]),
|
||||||
|
"start_times": np.array([start_time]),
|
||||||
|
"end_times": np.array([end_time]),
|
||||||
|
"low_freqs": np.array([low_freq]),
|
||||||
|
"high_freqs": np.array([high_freq]),
|
||||||
|
}
|
||||||
|
|
||||||
|
params: types.FeatureExtractionParameters = {
|
||||||
|
"min_freq": min_freq,
|
||||||
|
"max_freq": max_freq,
|
||||||
|
}
|
||||||
|
|
||||||
|
features = feats.get_feats(spectrogram, pred_nms, params)
|
||||||
|
assert low_freq < high_freq
|
||||||
|
assert isinstance(features, np.ndarray)
|
||||||
|
assert features.shape == (len(pred_nms["det_probs"]), 9)
|
||||||
|
assert np.isclose(
|
||||||
|
features[0],
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
end_time - start_time,
|
||||||
|
low_freq,
|
||||||
|
high_freq,
|
||||||
|
high_freq - low_freq,
|
||||||
|
high_freq,
|
||||||
|
max_freq,
|
||||||
|
max_freq,
|
||||||
|
max_freq,
|
||||||
|
np.nan,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
equal_nan=True,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"max_power",
|
||||||
|
[
|
||||||
|
30_000,
|
||||||
|
31_000,
|
||||||
|
32_000,
|
||||||
|
33_000,
|
||||||
|
34_000,
|
||||||
|
35_000,
|
||||||
|
36_000,
|
||||||
|
37_000,
|
||||||
|
38_000,
|
||||||
|
39_000,
|
||||||
|
40_000,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_compute_max_power_bb(max_power: int):
|
||||||
|
"""Test compute_max_power_bb function."""
|
||||||
|
duration = 1
|
||||||
|
samplerate = 256_000
|
||||||
|
min_freq = 0
|
||||||
|
max_freq = 128_000
|
||||||
|
|
||||||
|
start_time = 0.3
|
||||||
|
end_time = 0.6
|
||||||
|
low_freq = 30_000
|
||||||
|
high_freq = 40_000
|
||||||
|
|
||||||
|
audio = np.zeros((int(duration * samplerate),))
|
||||||
|
|
||||||
|
# Add a signal during the time and frequency range of interest
|
||||||
|
audio[
|
||||||
|
int(start_time * samplerate) : int(end_time * samplerate)
|
||||||
|
] = 0.5 * librosa.tone(
|
||||||
|
max_power, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a more powerful signal outside frequency range of interest
|
||||||
|
audio[
|
||||||
|
int(start_time * samplerate) : int(end_time * samplerate)
|
||||||
|
] += 2 * librosa.tone(
|
||||||
|
80_000, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
params = api.get_config(
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
target_samp_rate=samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec, _ = au.generate_spectrogram(
|
||||||
|
audio,
|
||||||
|
samplerate,
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_start = int(
|
||||||
|
au.time_to_x_coords(
|
||||||
|
start_time,
|
||||||
|
samplerate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_end = int(
|
||||||
|
au.time_to_x_coords(
|
||||||
|
end_time,
|
||||||
|
samplerate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_freq_bins = spec.shape[0]
|
||||||
|
y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq)
|
||||||
|
y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq)
|
||||||
|
|
||||||
|
prediction: types.Prediction = {
|
||||||
|
"det_prob": 1,
|
||||||
|
"class_prob": np.ones((1,)),
|
||||||
|
"x_pos": x_start,
|
||||||
|
"y_pos": int(y_low),
|
||||||
|
"bb_width": int(x_end - x_start),
|
||||||
|
"bb_height": int(y_low - y_high),
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"low_freq": low_freq,
|
||||||
|
"high_freq": high_freq,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(prediction)
|
||||||
|
|
||||||
|
max_power_bb = feats.compute_max_power_bb(
|
||||||
|
prediction,
|
||||||
|
spec,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert abs(max_power_bb - max_power) <= 500
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_max_power():
|
||||||
|
"""Test compute_max_power_bb function."""
|
||||||
|
duration = 3
|
||||||
|
samplerate = 16_000
|
||||||
|
min_freq = 0
|
||||||
|
max_freq = 8_000
|
||||||
|
|
||||||
|
start_time = 1
|
||||||
|
end_time = 2
|
||||||
|
low_freq = 3_000
|
||||||
|
high_freq = 4_000
|
||||||
|
max_power = 5_000
|
||||||
|
|
||||||
|
audio = np.zeros((int(duration * samplerate),))
|
||||||
|
|
||||||
|
# Add a signal during the time and frequency range of interest
|
||||||
|
audio[
|
||||||
|
int(start_time * samplerate) : int(end_time * samplerate)
|
||||||
|
] = 0.5 * librosa.tone(
|
||||||
|
3_500, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a more powerful signal outside frequency range of interest
|
||||||
|
audio[
|
||||||
|
int(start_time * samplerate) : int(end_time * samplerate)
|
||||||
|
] += 2 * librosa.tone(
|
||||||
|
max_power, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
params = api.get_config(
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
target_samp_rate=samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec, _ = au.generate_spectrogram(
|
||||||
|
audio,
|
||||||
|
samplerate,
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_start = int(
|
||||||
|
au.time_to_x_coords(
|
||||||
|
start_time,
|
||||||
|
samplerate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_end = int(
|
||||||
|
au.time_to_x_coords(
|
||||||
|
end_time,
|
||||||
|
samplerate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_freq_bins = spec.shape[0]
|
||||||
|
y_low = int(num_freq_bins * low_freq / max_freq)
|
||||||
|
y_high = int(num_freq_bins * high_freq / max_freq)
|
||||||
|
|
||||||
|
prediction: types.Prediction = {
|
||||||
|
"det_prob": 1,
|
||||||
|
"class_prob": np.ones((1,)),
|
||||||
|
"x_pos": x_start,
|
||||||
|
"y_pos": int(y_high),
|
||||||
|
"bb_width": int(x_end - x_start),
|
||||||
|
"bb_height": int(y_high - y_low),
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"low_freq": low_freq,
|
||||||
|
"high_freq": high_freq,
|
||||||
|
}
|
||||||
|
|
||||||
|
computed_max_power = feats.compute_max_power(
|
||||||
|
prediction,
|
||||||
|
spec,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert abs(computed_max_power - max_power) < 100
|
Loading…
Reference in New Issue
Block a user