mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
fix: implemented a cleaning step to remove detections above the nyquist limit
This commit is contained in:
parent
986cfc463c
commit
860e63dddf
@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]:
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Input directory not found.
|
||||
|
||||
"""
|
||||
matches = []
|
||||
for root, _, filenames in os.walk(ip_dir):
|
||||
@ -269,6 +269,7 @@ def convert_results(
|
||||
spec_feats,
|
||||
cnn_feats,
|
||||
spec_slices,
|
||||
nyquist_freq: Optional[float] = None,
|
||||
) -> RunResults:
|
||||
"""Convert results to dictionary as expected by the annotation tool.
|
||||
|
||||
@ -284,8 +285,8 @@ def convert_results(
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with results.
|
||||
|
||||
"""
|
||||
|
||||
pred_dict = format_single_result(
|
||||
file_id,
|
||||
time_exp,
|
||||
@ -294,6 +295,14 @@ def convert_results(
|
||||
params["class_names"],
|
||||
)
|
||||
|
||||
# Remove high frequency detections
|
||||
if nyquist_freq is not None:
|
||||
pred_dict["annotation"] = [
|
||||
pred
|
||||
for pred in pred_dict["annotation"]
|
||||
if pred["high_freq"] <= nyquist_freq
|
||||
]
|
||||
|
||||
# combine into final results dictionary
|
||||
results: RunResults = {
|
||||
"pred_dict": pred_dict,
|
||||
@ -326,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
Args:
|
||||
results (dict): Results.
|
||||
op_path (str): Output path.
|
||||
|
||||
"""
|
||||
# make directory if it does not exist
|
||||
if not os.path.isdir(os.path.dirname(op_path)):
|
||||
@ -488,7 +496,6 @@ def iterate_over_chunks(
|
||||
chunk_start : float
|
||||
Start time of chunk in seconds.
|
||||
chunk : np.ndarray
|
||||
|
||||
"""
|
||||
nsamples = audio.shape[0]
|
||||
duration_full = nsamples / samplerate
|
||||
@ -694,7 +701,6 @@ def process_audio_array(
|
||||
The array is of shape (num_detections, num_features).
|
||||
spec : torch.Tensor
|
||||
Spectrogram of the audio used as input.
|
||||
|
||||
"""
|
||||
pred_nms, features, spec = _process_audio_array(
|
||||
audio,
|
||||
@ -746,6 +752,10 @@ def process_file(
|
||||
cnn_feats = []
|
||||
spec_slices = []
|
||||
|
||||
# Get original sampling rate
|
||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
audio_file,
|
||||
@ -793,9 +803,7 @@ def process_file(
|
||||
|
||||
if config["spec_slices"]:
|
||||
# FIX: This is not currently working. Returns empty slices
|
||||
spec_slices.extend(
|
||||
feats.extract_spec_slices(spec_np, pred_nms)
|
||||
)
|
||||
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
|
||||
|
||||
# Merge results from chunks
|
||||
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
||||
@ -815,6 +823,7 @@ def process_file(
|
||||
spec_feats=spec_feats,
|
||||
cnn_feats=cnn_feats,
|
||||
spec_slices=spec_slices,
|
||||
nyquist_freq=orig_samp_rate / 2,
|
||||
)
|
||||
|
||||
# summarize results
|
||||
|
Loading…
Reference in New Issue
Block a user