mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
fix: implemented a cleaning step to remove detections above the nyquist limit
This commit is contained in:
parent
986cfc463c
commit
860e63dddf
@ -2,6 +2,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: Input directory not found.
|
FileNotFoundError: Input directory not found.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
matches = []
|
matches = []
|
||||||
for root, _, filenames in os.walk(ip_dir):
|
for root, _, filenames in os.walk(ip_dir):
|
||||||
@ -269,6 +269,7 @@ def convert_results(
|
|||||||
spec_feats,
|
spec_feats,
|
||||||
cnn_feats,
|
cnn_feats,
|
||||||
spec_slices,
|
spec_slices,
|
||||||
|
nyquist_freq: Optional[float] = None,
|
||||||
) -> RunResults:
|
) -> RunResults:
|
||||||
"""Convert results to dictionary as expected by the annotation tool.
|
"""Convert results to dictionary as expected by the annotation tool.
|
||||||
|
|
||||||
@ -284,8 +285,8 @@ def convert_results(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary with results.
|
dict: Dictionary with results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pred_dict = format_single_result(
|
pred_dict = format_single_result(
|
||||||
file_id,
|
file_id,
|
||||||
time_exp,
|
time_exp,
|
||||||
@ -294,6 +295,14 @@ def convert_results(
|
|||||||
params["class_names"],
|
params["class_names"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove high frequency detections
|
||||||
|
if nyquist_freq is not None:
|
||||||
|
pred_dict["annotation"] = [
|
||||||
|
pred
|
||||||
|
for pred in pred_dict["annotation"]
|
||||||
|
if pred["high_freq"] <= nyquist_freq
|
||||||
|
]
|
||||||
|
|
||||||
# combine into final results dictionary
|
# combine into final results dictionary
|
||||||
results: RunResults = {
|
results: RunResults = {
|
||||||
"pred_dict": pred_dict,
|
"pred_dict": pred_dict,
|
||||||
@ -326,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
Args:
|
Args:
|
||||||
results (dict): Results.
|
results (dict): Results.
|
||||||
op_path (str): Output path.
|
op_path (str): Output path.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# make directory if it does not exist
|
# make directory if it does not exist
|
||||||
if not os.path.isdir(os.path.dirname(op_path)):
|
if not os.path.isdir(os.path.dirname(op_path)):
|
||||||
@ -488,7 +496,6 @@ def iterate_over_chunks(
|
|||||||
chunk_start : float
|
chunk_start : float
|
||||||
Start time of chunk in seconds.
|
Start time of chunk in seconds.
|
||||||
chunk : np.ndarray
|
chunk : np.ndarray
|
||||||
|
|
||||||
"""
|
"""
|
||||||
nsamples = audio.shape[0]
|
nsamples = audio.shape[0]
|
||||||
duration_full = nsamples / samplerate
|
duration_full = nsamples / samplerate
|
||||||
@ -694,7 +701,6 @@ def process_audio_array(
|
|||||||
The array is of shape (num_detections, num_features).
|
The array is of shape (num_detections, num_features).
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Spectrogram of the audio used as input.
|
Spectrogram of the audio used as input.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pred_nms, features, spec = _process_audio_array(
|
pred_nms, features, spec = _process_audio_array(
|
||||||
audio,
|
audio,
|
||||||
@ -746,6 +752,10 @@ def process_file(
|
|||||||
cnn_feats = []
|
cnn_feats = []
|
||||||
spec_slices = []
|
spec_slices = []
|
||||||
|
|
||||||
|
# Get original sampling rate
|
||||||
|
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||||
|
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
audio_file,
|
audio_file,
|
||||||
@ -793,9 +803,7 @@ def process_file(
|
|||||||
|
|
||||||
if config["spec_slices"]:
|
if config["spec_slices"]:
|
||||||
# FIX: This is not currently working. Returns empty slices
|
# FIX: This is not currently working. Returns empty slices
|
||||||
spec_slices.extend(
|
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
|
||||||
feats.extract_spec_slices(spec_np, pred_nms)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge results from chunks
|
# Merge results from chunks
|
||||||
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
||||||
@ -815,6 +823,7 @@ def process_file(
|
|||||||
spec_feats=spec_feats,
|
spec_feats=spec_feats,
|
||||||
cnn_feats=cnn_feats,
|
cnn_feats=cnn_feats,
|
||||||
spec_slices=spec_slices,
|
spec_slices=spec_slices,
|
||||||
|
nyquist_freq=orig_samp_rate / 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# summarize results
|
# summarize results
|
||||||
|
Loading…
Reference in New Issue
Block a user