fix: implemented a cleaning step to remove detections above the nyquist limit

This commit is contained in:
Santiago Martinez 2023-11-24 15:40:58 +00:00
parent 986cfc463c
commit 860e63dddf

View File

@ -2,6 +2,7 @@ import json
import os
from typing import Any, Iterator, List, Optional, Tuple, Union
import librosa
import numpy as np
import pandas as pd
import torch
@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]:
Raises:
FileNotFoundError: Input directory not found.
"""
matches = []
for root, _, filenames in os.walk(ip_dir):
@ -269,6 +269,7 @@ def convert_results(
spec_feats,
cnn_feats,
spec_slices,
nyquist_freq: Optional[float] = None,
) -> RunResults:
"""Convert results to dictionary as expected by the annotation tool.
@ -284,8 +285,8 @@ def convert_results(
Returns:
dict: Dictionary with results.
"""
pred_dict = format_single_result(
file_id,
time_exp,
@ -294,6 +295,14 @@ def convert_results(
params["class_names"],
)
# Remove high frequency detections
if nyquist_freq is not None:
pred_dict["annotation"] = [
pred
for pred in pred_dict["annotation"]
if pred["high_freq"] <= nyquist_freq
]
# combine into final results dictionary
results: RunResults = {
"pred_dict": pred_dict,
@ -326,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None:
Args:
results (dict): Results.
op_path (str): Output path.
"""
# make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)):
@ -488,7 +496,6 @@ def iterate_over_chunks(
chunk_start : float
Start time of chunk in seconds.
chunk : np.ndarray
"""
nsamples = audio.shape[0]
duration_full = nsamples / samplerate
@ -694,7 +701,6 @@ def process_audio_array(
The array is of shape (num_detections, num_features).
spec : torch.Tensor
Spectrogram of the audio used as input.
"""
pred_nms, features, spec = _process_audio_array(
audio,
@ -746,6 +752,10 @@ def process_file(
cnn_feats = []
spec_slices = []
# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
# load audio file
sampling_rate, audio_full = au.load_audio(
audio_file,
@ -793,9 +803,7 @@ def process_file(
if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices
spec_slices.extend(
feats.extract_spec_slices(spec_np, pred_nms)
)
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
# Merge results from chunks
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
@ -815,6 +823,7 @@ def process_file(
spec_feats=spec_feats,
cnn_feats=cnn_feats,
spec_slices=spec_slices,
nyquist_freq=orig_samp_rate / 2,
)
# summarize results