fix: implemented a cleaning step to remove detections above the nyquist limit

This commit is contained in:
Santiago Martinez 2023-11-24 15:40:58 +00:00
parent 986cfc463c
commit 860e63dddf

View File

@ -2,6 +2,7 @@ import json
import os import os
from typing import Any, Iterator, List, Optional, Tuple, Union from typing import Any, Iterator, List, Optional, Tuple, Union
import librosa
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]:
Raises: Raises:
FileNotFoundError: Input directory not found. FileNotFoundError: Input directory not found.
""" """
matches = [] matches = []
for root, _, filenames in os.walk(ip_dir): for root, _, filenames in os.walk(ip_dir):
@ -269,6 +269,7 @@ def convert_results(
spec_feats, spec_feats,
cnn_feats, cnn_feats,
spec_slices, spec_slices,
nyquist_freq: Optional[float] = None,
) -> RunResults: ) -> RunResults:
"""Convert results to dictionary as expected by the annotation tool. """Convert results to dictionary as expected by the annotation tool.
@ -284,8 +285,8 @@ def convert_results(
Returns: Returns:
dict: Dictionary with results. dict: Dictionary with results.
""" """
pred_dict = format_single_result( pred_dict = format_single_result(
file_id, file_id,
time_exp, time_exp,
@ -294,6 +295,14 @@ def convert_results(
params["class_names"], params["class_names"],
) )
# Remove high frequency detections
if nyquist_freq is not None:
pred_dict["annotation"] = [
pred
for pred in pred_dict["annotation"]
if pred["high_freq"] <= nyquist_freq
]
# combine into final results dictionary # combine into final results dictionary
results: RunResults = { results: RunResults = {
"pred_dict": pred_dict, "pred_dict": pred_dict,
@ -326,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None:
Args: Args:
results (dict): Results. results (dict): Results.
op_path (str): Output path. op_path (str): Output path.
""" """
# make directory if it does not exist # make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)): if not os.path.isdir(os.path.dirname(op_path)):
@ -488,7 +496,6 @@ def iterate_over_chunks(
chunk_start : float chunk_start : float
Start time of chunk in seconds. Start time of chunk in seconds.
chunk : np.ndarray chunk : np.ndarray
""" """
nsamples = audio.shape[0] nsamples = audio.shape[0]
duration_full = nsamples / samplerate duration_full = nsamples / samplerate
@ -694,7 +701,6 @@ def process_audio_array(
The array is of shape (num_detections, num_features). The array is of shape (num_detections, num_features).
spec : torch.Tensor spec : torch.Tensor
Spectrogram of the audio used as input. Spectrogram of the audio used as input.
""" """
pred_nms, features, spec = _process_audio_array( pred_nms, features, spec = _process_audio_array(
audio, audio,
@ -746,6 +752,10 @@ def process_file(
cnn_feats = [] cnn_feats = []
spec_slices = [] spec_slices = []
# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
# load audio file # load audio file
sampling_rate, audio_full = au.load_audio( sampling_rate, audio_full = au.load_audio(
audio_file, audio_file,
@ -793,9 +803,7 @@ def process_file(
if config["spec_slices"]: if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices # FIX: This is not currently working. Returns empty slices
spec_slices.extend( spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
feats.extract_spec_slices(spec_np, pred_nms)
)
# Merge results from chunks # Merge results from chunks
predictions, spec_feats, cnn_feats, spec_slices = _merge_results( predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
@ -815,6 +823,7 @@ def process_file(
spec_feats=spec_feats, spec_feats=spec_feats,
cnn_feats=cnn_feats, cnn_feats=cnn_feats,
spec_slices=spec_slices, spec_slices=spec_slices,
nyquist_freq=orig_samp_rate / 2,
) )
# summarize results # summarize results