diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 74193b3..8d6ca7f 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -2,6 +2,7 @@ import json import os from typing import Any, Iterator, List, Optional, Tuple, Union +import librosa import numpy as np import pandas as pd import torch @@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]: Raises: FileNotFoundError: Input directory not found. - """ matches = [] for root, _, filenames in os.walk(ip_dir): @@ -269,6 +269,7 @@ def convert_results( spec_feats, cnn_feats, spec_slices, + nyquist_freq: Optional[float] = None, ) -> RunResults: """Convert results to dictionary as expected by the annotation tool. @@ -284,8 +285,8 @@ def convert_results( Returns: dict: Dictionary with results. - """ + pred_dict = format_single_result( file_id, time_exp, @@ -294,6 +295,14 @@ def convert_results( params["class_names"], ) + # Remove high frequency detections + if nyquist_freq is not None: + pred_dict["annotation"] = [ + pred + for pred in pred_dict["annotation"] + if pred["high_freq"] <= nyquist_freq + ] + # combine into final results dictionary results: RunResults = { "pred_dict": pred_dict, @@ -326,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None: Args: results (dict): Results. op_path (str): Output path. - """ # make directory if it does not exist if not os.path.isdir(os.path.dirname(op_path)): @@ -488,7 +496,6 @@ def iterate_over_chunks( chunk_start : float Start time of chunk in seconds. chunk : np.ndarray - """ nsamples = audio.shape[0] duration_full = nsamples / samplerate @@ -694,7 +701,6 @@ def process_audio_array( The array is of shape (num_detections, num_features). spec : torch.Tensor Spectrogram of the audio used as input. - """ pred_nms, features, spec = _process_audio_array( audio, @@ -746,6 +752,10 @@ def process_file( cnn_feats = [] spec_slices = [] + # Get original sampling rate + file_samp_rate = librosa.get_samplerate(audio_file) + orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1 + # load audio file sampling_rate, audio_full = au.load_audio( audio_file, @@ -793,9 +803,7 @@ def process_file( if config["spec_slices"]: # FIX: This is not currently working. Returns empty slices - spec_slices.extend( - feats.extract_spec_slices(spec_np, pred_nms) - ) + spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms)) # Merge results from chunks predictions, spec_feats, cnn_feats, spec_slices = _merge_results( @@ -815,6 +823,7 @@ def process_file( spec_feats=spec_feats, cnn_feats=cnn_feats, spec_slices=spec_slices, + nyquist_freq=orig_samp_rate / 2, ) # summarize results