Refactored detector_utils module

This commit is contained in:
Santiago Martinez 2023-02-22 20:24:43 +00:00
parent 7550c6faf1
commit 8da98b5258
9 changed files with 716 additions and 187 deletions

13
app.py
View File

@ -1,5 +1,3 @@
import os
import gradio as gr import gradio as gr
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -37,7 +35,6 @@ examples = [
def make_prediction(file_name=None, detection_threshold=0.3): def make_prediction(file_name=None, detection_threshold=0.3):
if file_name is not None: if file_name is not None:
audio_file = file_name audio_file = file_name
else: else:
@ -46,9 +43,17 @@ def make_prediction(file_name=None, detection_threshold=0.3):
if detection_threshold is not None and detection_threshold != "": if detection_threshold is not None and detection_threshold != "":
args["detection_threshold"] = float(detection_threshold) args["detection_threshold"] = float(detection_threshold)
run_config = {
**params,
**args,
"max_duration": max_duration,
}
# process the file to generate predictions # process the file to generate predictions
results = du.process_file( results = du.process_file(
audio_file, model, params, args, max_duration=max_duration audio_file,
model,
run_config,
) )
anns = [ann for ann in results["pred_dict"]["annotation"]] anns = [ann for ann in results["pred_dict"]["annotation"]]

View File

@ -1,3 +1,9 @@
"""Main script for running BatDetect2 on audio files.
Example usage:
python command.py /path/to/audio/ /path/to/ann/ 0.1
"""
import argparse import argparse
import os import os
@ -7,7 +13,6 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
def parse_args(): def parse_args():
info_str = ( info_str = (
"\nBatDetect2 - Detection and Classification\n" "\nBatDetect2 - Detection and Classification\n"
+ " Assumes audio files are mono, not stereo.\n" + " Assumes audio files are mono, not stereo.\n"
@ -88,14 +93,22 @@ def main():
print("\nInput directory: " + args["audio_dir"]) print("\nInput directory: " + args["audio_dir"])
files = du.get_audio_files(args["audio_dir"]) files = du.get_audio_files(args["audio_dir"])
print("Number of audio files: {}".format(len(files)))
print(f"Number of audio files: {len(files)}")
print("\nSaving results to: " + args["ann_dir"]) print("\nSaving results to: " + args["ann_dir"])
# set up run config
run_config = {
**args,
**params,
}
# process files # process files
error_files = [] error_files = []
for ii, audio_file in enumerate(files): for audio_file in files:
try: try:
results = du.process_file(audio_file, model, params, args) results = du.process_file(audio_file, model, run_config)
if args["save_preds_if_empty"] or ( if args["save_preds_if_empty"] or (
len(results["pred_dict"]["annotation"]) > 0 len(results["pred_dict"]["annotation"]) > 0
): ):
@ -103,9 +116,9 @@ def main():
args["audio_dir"], args["ann_dir"] args["audio_dir"], args["ann_dir"]
) )
du.save_results_to_file(results, results_path) du.save_results_to_file(results, results_path)
except: except (RuntimeError, ValueError, LookupError) as err:
error_files.append(audio_file) error_files.append(audio_file)
print("Error processing file!") print(f"Error processing file!: {err}")
print("\nResults saved to: " + args["ann_dir"]) print("\nResults saved to: " + args["ann_dir"])

View File

@ -1,7 +1,11 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn from torch import nn
import torch.nn.functional as F
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
np.seterr(divide="ignore", invalid="ignore") np.seterr(divide="ignore", invalid="ignore")
@ -18,7 +22,33 @@ def overall_class_pred(det_prob, class_prob):
return weighted_pred / weighted_pred.sum() return weighted_pred / weighted_pred.sum()
def run_nms(outputs, params, sampling_rate): class NonMaximumSuppressionConfig(TypedDict):
"""Configuration for non-maximum suppression."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: float
"""Maximum frequency to consider in Hz."""
min_freq: float
"""Minimum frequency to consider in Hz."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Overlap of the FFT windows in seconds."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""
def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
"""Run non-maximum suppression on the output of the model."""
pred_det = outputs["pred_det"] # probability of box pred_det = outputs["pred_det"] # probability of box
pred_size = outputs["pred_size"] # box size pred_size = outputs["pred_size"] # box size
@ -92,6 +122,7 @@ def run_nms(outputs, params, sampling_rate):
# convert to numpy # convert to numpy
for kk in pred.keys(): for kk in pred.keys():
pred[kk] = pred[kk].cpu().numpy().astype(np.float32) pred[kk] = pred[kk].cpu().numpy().astype(np.float32)
preds.append(pred) preds.append(pred)
return preds, feats return preds, feats

View File

@ -6,14 +6,12 @@ import argparse
import copy import copy
import json import json
import os import os
import sys
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
sys.path.append("../../") from bat_detect.detector import parameters
import bat_detect.detector.parameters as parameters
import bat_detect.train.evaluate as evl import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu import bat_detect.train.train_utils as tu
import bat_detect.utils.detector_utils as du import bat_detect.utils.detector_utils as du
@ -749,14 +747,18 @@ if __name__ == "__main__":
print("Warning: Class names are not the same as the trained model") print("Warning: Class names are not the same as the trained model")
assert False assert False
run_config = {
**bd_args,
**params_bd,
"return_raw_preds": True,
}
preds_bd = [] preds_bd = []
for ii, gg in enumerate(gt_test): for ii, gg in enumerate(gt_test):
pred = du.process_file( pred = du.process_file(
gg["file_path"], gg["file_path"],
model, model,
params_bd, run_config,
bd_args,
return_raw_preds=True,
) )
preds_bd.append(pred) preds_bd.append(pred)

View File

@ -1,4 +1,5 @@
import warnings import warnings
from typing import Optional, Tuple
import librosa import librosa
import numpy as np import numpy as np
@ -7,6 +8,11 @@ import torch
from . import wavfile from . import wavfile
__all__ = [
"load_audio_file",
]
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap * nfft) noverlap = np.floor(fft_overlap * nfft)
@ -105,40 +111,65 @@ def generate_spectrogram(
def load_audio_file( def load_audio_file(
audio_file, audio_file: str,
time_exp_fact, time_exp_fact: float,
target_samp_rate, target_samp_rate: int,
scale=False, scale: bool = False,
max_duration=False, max_duration: Optional[float] = None,
): ):
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported.
Args:
audio_file (str): Path to the audio file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
Returns:
sampling_rate: The sampling rate of the audio.
audio_raw: The audio signal in a numpy array.
Raises:
ValueError: If the audio file is stereo.
"""
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning) warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
# sampling_rate, audio_raw = wavfile.read(audio_file) # sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(audio_file, sr=None) audio_raw, sampling_rate = librosa.load(
audio_file,
sr=None,
dtype=np.float32,
)
if len(audio_raw.shape) > 1: if len(audio_raw.shape) > 1:
raise Exception("Currently does not handle stereo files") raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact sampling_rate = sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion # resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate sampling_rate_old = sampling_rate
sampling_rate = target_samp_rate sampling_rate = target_samp_rate
audio_raw = librosa.resample( if sampling_rate_old != sampling_rate:
audio_raw, audio_raw = librosa.resample(
orig_sr=sampling_rate_old, audio_raw,
target_sr=sampling_rate, orig_sr=sampling_rate_old,
res_type="polyphase", target_sr=sampling_rate,
) res_type="polyphase",
)
# clipping maximum duration # clipping maximum duration
if max_duration is not False: if max_duration is not None:
max_duration = np.minimum( max_duration = np.minimum(
int(sampling_rate * max_duration), audio_raw.shape[0] int(sampling_rate * max_duration),
audio_raw.shape[0],
) )
audio_raw = audio_raw[:max_duration] audio_raw = audio_raw[:max_duration]
# convert to float32 and scale # scale to [-1, 1]
audio_raw = audio_raw.astype(np.float32)
if scale: if scale:
audio_raw = audio_raw - audio_raw.mean() audio_raw = audio_raw - audio_raw.mean()
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6) audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)

View File

@ -1,6 +1,6 @@
import json import json
import os import os
from typing import List, Tuple from typing import Any, Iterator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -24,7 +24,7 @@ DEFAULT_MODEL_PATH = os.path.join(
"model.pth", "model.pth",
) )
__all__ = ["load_model", "DEFAULT_MODEL_PATH"] __all__ = ["load_model", "get_audio_files", "DEFAULT_MODEL_PATH"]
def get_default_bd_args(): def get_default_bd_args():
@ -69,16 +69,29 @@ class ModelParameters(TypedDict):
"""Model parameters.""" """Model parameters."""
model_name: str model_name: str
"""Model name."""
num_filters: int num_filters: int
"""Number of filters."""
emb_dim: int emb_dim: int
"""Embedding dimension."""
ip_height: int ip_height: int
"""Input height in pixels."""
resize_factor: int resize_factor: int
"""Resize factor."""
class_names: List[str] class_names: List[str]
"""Class names. The model is trained to detect these classes."""
device: torch.device device: torch.device
def load_model( def load_model(
model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True,
) -> Tuple[torch.nn.Module, ModelParameters]: ) -> Tuple[torch.nn.Module, ModelParameters]:
"""Load model from file. """Load model from file.
@ -141,8 +154,7 @@ def load_model(
return model, params return model, params
def merge_results(predictions, spec_feats, cnn_feats, spec_slices): def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {} predictions_m = {}
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
@ -157,82 +169,255 @@ def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
if len(spec_feats) > 0: if len(spec_feats) > 0:
spec_feats = np.vstack(spec_feats) spec_feats = np.vstack(spec_feats)
if len(cnn_feats) > 0: if len(cnn_feats) > 0:
cnn_feats = np.vstack(cnn_feats) cnn_feats = np.vstack(cnn_feats)
return predictions_m, spec_feats, cnn_feats, spec_slices return predictions_m, spec_feats, cnn_feats, spec_slices
class Annotation(TypedDict("WithClass", {"class": str})):
"""Format of annotations.
This is the format of a single annotation as expected by the annotation
tool.
"""
start_time: float
"""Start time in seconds."""
end_time: float
"""End time in seconds."""
low_freq: int
"""Low frequency in Hz."""
high_freq: int
"""High frequency in Hz."""
class_prob: float
"""Probability of class assignment."""
det_prob: float
"""Probability of detection."""
individual: str
"""Individual ID."""
event: str
"""Type of detected event."""
class FileAnnotations(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
"""
file_id: str
"""File ID."""
annotated: bool
"""Whether file has been annotated."""
duration: float
"""Duration of audio file."""
issues: bool
"""Whether file has issues."""
time_exp: float
"""Time expansion factor."""
class_name: str
"""Class predicted at file level"""
annotation: List[Annotation]
class Results(TypedDict):
pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool."""
spec_feats: Optional[np.ndarray]
"""Spectrogram features."""
spec_feat_names: Optional[List[str]]
"""Spectrogram feature names."""
cnn_feats: Optional[np.ndarray]
"""CNN features."""
cnn_feat_names: Optional[List[str]]
"""CNN feature names."""
spec_slices: Optional[np.ndarray]
"""Spectrogram slices."""
class ResultParams(TypedDict):
"""Result parameters."""
class_names: List[str]
"""Class names."""
def format_results(
file_id: str,
time_exp: float,
duration: float,
predictions,
class_names: List[str],
) -> FileAnnotations:
"""Format results into the format expected by the annotation tool.
Args:
file_id (str): File ID.
time_exp (float): Time expansion factor.
duration (float): Duration of audio file.
predictions (dict): Predictions.
Returns:
dict: Results in the format expected by the annotation tool.
"""
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
# Get the best class prediction probability and index for each detection
class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0)
# Pack the results into a list of dictionaries
annotations: List[Annotation] = [
{
"start_time": round(float(start_time), 4),
"end_time": round(end_time, 4),
"low_freq": int(low_freq),
"high_freq": int(high_freq),
"class": str(class_names[class_index]),
"class_prob": round(float(class_prob), 3),
"det_prob": round(float(det_prob), 3),
"individual": "-1",
"event": "Echolocation",
}
for (
start_time,
end_time,
low_freq,
high_freq,
class_index,
class_prob,
det_prob,
) in zip(
predictions["start_time"],
predictions["end_times"],
predictions["low_freqs"],
predictions["high_freqs"],
class_ind_best,
class_prob_best,
predictions["det_probs"],
)
]
return {
"id": file_id,
"annotated": False,
"issues": False,
"notes": "Automatically generated.",
"time_exp": time_exp,
"duration": round(duration, 4),
"annotation": annotations,
"class_name": class_names[np.argmax(class_overall)],
}
def convert_results( def convert_results(
file_id, file_id: str,
time_exp, time_exp: float,
duration, duration: float,
params, params: ResultParams,
predictions, predictions,
spec_feats, spec_feats,
cnn_feats, cnn_feats,
spec_slices, spec_slices,
): ) -> Results:
"""Convert results to dictionary as expected by the annotation tool.
# create a single dictionary - this is the format used by the annotation tool Args:
pred_dict = {} file_id (str): File ID.
pred_dict["id"] = file_id time_exp (float): Time expansion factor.
pred_dict["annotated"] = False duration (float): Duration of audio file.
pred_dict["issues"] = False params (dict): Model parameters.
pred_dict["notes"] = "Automatically generated." predictions (dict): Predictions.
pred_dict["time_exp"] = time_exp spec_feats (np.ndarray): Spectral features.
pred_dict["duration"] = round(duration, 4) cnn_feats (np.ndarray): CNN features.
pred_dict["annotation"] = [] spec_slices (list): Spectrogram slices.
class_prob_best = predictions["class_probs"].max(0) Returns:
class_ind_best = predictions["class_probs"].argmax(0) dict: Dictionary with results.
class_overall = pp.overall_class_pred(
predictions["det_probs"], predictions["class_probs"] """
pred_dict = format_results(
file_id,
time_exp,
duration,
predictions,
params["class_names"],
) )
pred_dict["class_name"] = params["class_names"][np.argmax(class_overall)]
for ii in range(predictions["det_probs"].shape[0]):
res = {}
res["start_time"] = round(float(predictions["start_times"][ii]), 4)
res["end_time"] = round(float(predictions["end_times"][ii]), 4)
res["low_freq"] = int(predictions["low_freqs"][ii])
res["high_freq"] = int(predictions["high_freqs"][ii])
res["class"] = str(params["class_names"][int(class_ind_best[ii])])
res["class_prob"] = round(float(class_prob_best[ii]), 3)
res["det_prob"] = round(float(predictions["det_probs"][ii]), 3)
res["individual"] = "-1"
res["event"] = "Echolocation"
pred_dict["annotation"].append(res)
# combine into final results dictionary # combine into final results dictionary
results = {} results = {}
results["pred_dict"] = pred_dict results["pred_dict"] = pred_dict
# add spectrogram features if they exist
if len(spec_feats) > 0: if len(spec_feats) > 0:
results["spec_feats"] = spec_feats results["spec_feats"] = spec_feats
results["spec_feat_names"] = feats.get_feature_names() results["spec_feat_names"] = feats.get_feature_names()
# add CNN features if they exist
if len(cnn_feats) > 0: if len(cnn_feats) > 0:
results["cnn_feats"] = cnn_feats results["cnn_feats"] = cnn_feats
results["cnn_feat_names"] = [ results["cnn_feat_names"] = [
str(ii) for ii in range(cnn_feats.shape[1]) str(ii) for ii in range(cnn_feats.shape[1])
] ]
# add spectrogram slices if they exist
if len(spec_slices) > 0: if len(spec_slices) > 0:
results["spec_slices"] = spec_slices results["spec_slices"] = spec_slices
return results return results
def save_results_to_file(results, op_path): def save_results_to_file(results, op_path: str) -> None:
"""Save results to file.
Args:
results (dict): Results.
op_path (str): Output path.
"""
# make directory if it does not exist # make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)): if not os.path.isdir(os.path.dirname(op_path)):
os.makedirs(os.path.dirname(op_path)) os.makedirs(os.path.dirname(op_path))
# save csv file - if there are predictions # save csv file - if there are predictions
result_list = [res for res in results["pred_dict"]["annotation"]] result_list = results["pred_dict"]["annotation"]
df = pd.DataFrame(result_list)
df["file_name"] = [results["pred_dict"]["id"]] * len(result_list) results_df = pd.DataFrame(result_list)
df.index.name = "id"
if "class_prob" in df.columns: # add file name as a column
df = df[ results_df["file_name"] = results["pred_dict"]["id"]
# rename index column
results_df.index.name = "id"
# create a csv file with predicted events
if "class_prob" in results_df.columns:
preds_df = results_df[
[ [
"det_prob", "det_prob",
"start_time", "start_time",
@ -243,14 +428,14 @@ def save_results_to_file(results, op_path):
"class_prob", "class_prob",
] ]
] ]
df.to_csv(op_path + ".csv", sep=",") preds_df.to_csv(op_path + ".csv", sep=",")
# save features
if "spec_feats" in results.keys(): if "spec_feats" in results.keys():
df = pd.DataFrame( # create csv file with spectrogram features
spec_feats_df = pd.DataFrame(
results["spec_feats"], columns=results["spec_feat_names"] results["spec_feats"], columns=results["spec_feat_names"]
) )
df.to_csv( spec_feats_df.to_csv(
op_path + "_spec_features.csv", op_path + "_spec_features.csv",
sep=",", sep=",",
index=False, index=False,
@ -258,10 +443,12 @@ def save_results_to_file(results, op_path):
) )
if "cnn_feats" in results.keys(): if "cnn_feats" in results.keys():
df = pd.DataFrame( # create csv file with cnn extracted features
results["cnn_feats"], columns=results["cnn_feat_names"] cnn_feats_df = pd.DataFrame(
results["cnn_feats"],
columns=results["cnn_feat_names"],
) )
df.to_csv( cnn_feats_df.to_csv(
op_path + "_cnn_features.csv", op_path + "_cnn_features.csv",
sep=",", sep=",",
index=False, index=False,
@ -269,11 +456,71 @@ def save_results_to_file(results, op_path):
) )
# save json file # save json file
with open(op_path + ".json", "w") as da: with open(op_path + ".json", "w", encoding="utf-8") as jsonfile:
json.dump(results["pred_dict"], da, indent=2, sort_keys=True) json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True)
def compute_spectrogram(audio, sampling_rate, params, return_np=False): class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: int
"""Length of the FFT window in samples."""
fft_overlap: int
"""Number of samples to overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
spec_width: int
"""Width of the spectrogram in pixels."""
resize_factor: int
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
device: torch.device
"""Device to store the spectrogram on."""
def compute_spectrogram(
audio: np.ndarray,
sampling_rate: int,
params: SpectrogramParameters,
return_np: bool = False,
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
"""Compute a spectrogram from an audio array.
Will pad the audio array so that it is evenly divisible by the
downsampling factors.
Parameters
----------
audio : np.ndarray
sampling_rate : int
params : SpectrogramParameters
The parameters to use for generating the spectrogram.
return_np : bool, optional
Whether to return the spectrogram as a numpy array as well as a
torch tensor. The default is False.
Returns
-------
duration : float
The duration of the spectrgram in seconds.
spec : torch.Tensor
The spectrogram as a torch tensor.
spec_np : np.ndarray, optional
The spectrogram as a numpy array. Only returned if `return_np` is
True, otherwise None.
"""
# pad audio so it is evenly divisible by downsampling factors # pad audio so it is evenly divisible by downsampling factors
duration = audio.shape[0] / float(sampling_rate) duration = audio.shape[0] / float(sampling_rate)
audio = au.pad_audio( audio = au.pad_audio(
@ -290,13 +537,21 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False):
# convert to pytorch # convert to pytorch
spec = torch.from_numpy(spec).to(params["device"]) spec = torch.from_numpy(spec).to(params["device"])
# add batch and channel dimensions
spec = spec.unsqueeze(0).unsqueeze(0) spec = spec.unsqueeze(0).unsqueeze(0)
# resize the spec # resize the spec
rs = params["resize_factor"] resize_factor = params["resize_factor"]
spec_op_shape = (int(params["spec_height"] * rs), int(spec.shape[-1] * rs)) spec_op_shape = (
int(params["spec_height"] * resize_factor),
int(spec.shape[-1] * resize_factor),
)
spec = F.interpolate( spec = F.interpolate(
spec, size=spec_op_shape, mode="bilinear", align_corners=False spec,
size=spec_op_shape,
mode="bilinear",
align_corners=False,
) )
if return_np: if return_np:
@ -307,135 +562,318 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False):
return duration, spec, spec_np return duration, spec, spec_np
def process_file( def iterate_over_chunks(
audio_file, audio: np.ndarray,
model, samplerate: int,
params, chunk_size: float,
args, ) -> Iterator[Tuple[float, np.ndarray]]:
time_exp=None, """Iterate over audio in chunks of size chunk_size.
top_n=5,
return_raw_preds=False,
max_duration=False,
):
Parameters
----------
audio : np.ndarray
samplerate : int
chunk_size : float
Size of chunks in seconds.
Yields
------
chunk_start : float
Start time of chunk in seconds.
chunk : np.ndarray
"""
nsamples = audio.shape[0]
duration_full = nsamples / samplerate
num_chunks = int(np.ceil(duration_full / chunk_size))
for chunk_id in range(num_chunks):
chunk_start = chunk_size * chunk_id
chunk_length = int(samplerate * chunk_size)
start_sample = chunk_id * chunk_length
end_sample = np.minimum((chunk_id + 1) * chunk_length, nsamples)
yield chunk_start, audio[start_sample:end_sample]
class ProcessingConfiguration(TypedDict):
"""Parameters for processing audio files."""
# audio parameters
target_samp_rate: int
"""Target sampling rate of the audio."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Length of the FFT window in samples."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: float
"""Factor to divide the spectrogram by."""
spec_height: int
"""Height of the spectrogram in pixels."""
scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1."""
device: torch.device
"""Device to run the model on."""
class_names: List[str]
"""Names of the classes the model can detect."""
detection_threshold: float
"""Threshold for detection probability."""
time_expansion: Optional[float]
"""Time expansion factor of the processed recordings."""
top_n: int
"""Number of top detections to keep."""
return_raw_preds: bool
"""Whether to return raw predictions."""
max_duration: Optional[float]
"""Maximum duration of audio file to process in seconds."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: float
"""Maximum frequency to consider in Hz."""
min_freq: float
"""Minimum frequency to consider in Hz."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""
quiet: bool
"""Whether to suppress output."""
def process_spectrogram(
spec: torch.Tensor,
samplerate: int,
model: torch.nn.Module,
config: pp.NonMaximumSuppressionConfig,
):
"""Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model.
Parameters
----------
spec : torch.Tensor
samplerate : int
model : torch.nn.Module
Detection model.
config : pp.NonMaximumSuppressionConfig
Parameters for non-maximum suppression.
Returns
-------
pred_nms : Dict[str, np.ndarray]
features : Dict[str, np.ndarray]
"""
# evaluate model
with torch.no_grad():
outputs = model(spec, return_feats=config["cnn_features"])
# run non-max suppression
pred_nms, features = pp.run_nms(
outputs,
config,
np.array([float(samplerate)]),
)
pred_nms = pred_nms[0]
# if we have a background class
if pred_nms["class_probs"].shape[0] > len(config["class_names"]):
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
return pred_nms, features
def process_audio_array(
audio: np.ndarray,
sampling_rate: int,
model: torch.nn.Module,
config: ProcessingConfiguration,
):
"""Process a single audio array with detection model.
Parameters
----------
audio : np.ndarray
sampling_rate : int
model : torch.nn.Module
Detection model.
config : ProcessingConfiguration
Configuration for processing.
Returns
-------
pred_nms : Dict[str, np.ndarray]
features : Dict[str, np.ndarray]
spec_np : np.ndarray
"""
# load audio file and compute spectrogram
_, spec, spec_np = compute_spectrogram(
audio,
sampling_rate,
config,
return_np=config["spec_features"] or config["spec_slices"],
)
# process spectrogram with model
pred_nms, features = process_spectrogram(
spec,
sampling_rate,
model,
config,
)
return pred_nms, features, spec_np
def process_file(
audio_file: str,
model: torch.nn.Module,
config: ProcessingConfiguration,
) -> Union[Results, Any]:
"""Process a single audio file with detection model.
Will split the audio file into chunks if it is too long and
process each chunk separately.
Parameters
----------
audio_file : str
Path to audio file.
model : torch.nn.Module
Detection model.
config : ProcessingConfiguration
Configuration for processing.
Returns
-------
results : Results or Any
Results of processing audio file with the given detection model.
Will be a dictionary if `config["return_raw_preds"]` is `True`,
"""
# store temporary results here # store temporary results here
predictions = [] predictions = []
spec_feats = [] spec_feats = []
cnn_feats = [] cnn_feats = []
spec_slices = [] spec_slices = []
# get time expansion factor
if time_exp is None:
time_exp = args["time_expansion_factor"]
params["detection_threshold"] = args["detection_threshold"]
# load audio file # load audio file
sampling_rate, audio_full = au.load_audio_file( sampling_rate, audio_full = au.load_audio_file(
audio_file, audio_file,
time_exp, time_exp_fact=config["time_expansion"],
params["target_samp_rate"], target_samp_rate=config["target_samp_rate"],
params["scale_raw_audio"], scale=config["scale_raw_audio"],
max_duration=config["max_duration"],
) )
# clipping maximum duration
if max_duration is not False:
max_duration = np.minimum(
int(sampling_rate * max_duration), audio_full.shape[0]
)
audio_full = audio_full[:max_duration]
duration_full = audio_full.shape[0] / float(sampling_rate)
return_np_spec = args["spec_features"] or args["spec_slices"]
# loop through larger file and split into chunks # loop through larger file and split into chunks
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders # TODO fix so that it overlaps correctly and takes care of
num_chunks = int(np.ceil(duration_full / args["chunk_size"])) # duplicate detections at borders
for chunk_id in range(num_chunks): for chunk_time, audio in iterate_over_chunks(
audio_full,
# chunk sampling_rate,
chunk_time = args["chunk_size"] * chunk_id config["chunk_size"],
chunk_length = int(sampling_rate * args["chunk_size"]) ):
start_sample = chunk_id * chunk_length # Run detection model on chunk
end_sample = np.minimum( pred_nms, features, spec_np = process_audio_array(
(chunk_id + 1) * chunk_length, audio_full.shape[0] audio,
) sampling_rate,
audio = audio_full[start_sample:end_sample] model,
config,
# load audio file and compute spectrogram
duration, spec, spec_np = compute_spectrogram(
audio, sampling_rate, params, return_np_spec
) )
# evaluate model # add chunk time to start and end times
with torch.no_grad():
outputs = model(spec, return_feats=args["cnn_features"])
# run non-max suppression
pred_nms, features = pp.run_nms(
outputs, params, np.array([float(sampling_rate)])
)
pred_nms = pred_nms[0]
pred_nms["start_times"] += chunk_time pred_nms["start_times"] += chunk_time
pred_nms["end_times"] += chunk_time pred_nms["end_times"] += chunk_time
# if we have a background class
if pred_nms["class_probs"].shape[0] > len(params["class_names"]):
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
predictions.append(pred_nms) predictions.append(pred_nms)
# extract features - if there are any calls detected # extract features - if there are any calls detected
if pred_nms["det_probs"].shape[0] > 0: if pred_nms["det_probs"].shape[0] > 0:
if args["spec_features"]: if config["spec_features"]:
spec_feats.append(feats.get_feats(spec_np, pred_nms, params)) spec_feats.append(feats.get_feats(spec_np, pred_nms, config))
if args["cnn_features"]: if config["cnn_features"]:
cnn_feats.append(features[0]) cnn_feats.append(features[0])
if args["spec_slices"]: if config["spec_slices"]:
spec_slices.extend( spec_slices.extend(
feats.extract_spec_slices(spec_np, pred_nms, params) feats.extract_spec_slices(spec_np, pred_nms, config)
) )
# convert the predictions into output dictionary # Merge results from chunks
file_id = os.path.basename(audio_file) predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
predictions, spec_feats, cnn_feats, spec_slices = merge_results(
predictions, spec_feats, cnn_feats, spec_slices
)
results = convert_results(
file_id,
time_exp,
duration_full,
params,
predictions, predictions,
spec_feats, spec_feats,
cnn_feats, cnn_feats,
spec_slices, spec_slices,
) )
# convert results to a dictionary in the right format
results = convert_results(
file_id=os.path.basename(audio_file),
time_exp=config["time_expansion"],
duration=audio_full.shape[0] / float(sampling_rate),
params=config,
predictions=predictions,
spec_feats=spec_feats,
cnn_feats=cnn_feats,
spec_slices=spec_slices,
)
# summarize results # summarize results
if not args["quiet"]: if not config["quiet"]:
num_detections = len(results["pred_dict"]["annotation"]) summarize_results(results, predictions, config)
print(
"{}".format(num_detections) if config["return_raw_preds"]:
+ " call(s) detected above the threshold." return predictions
)
return results
def summarize_results(results, predictions, config):
"""Print summary of results."""
num_detections = len(results["pred_dict"]["annotation"])
print(f"{num_detections} call(s) detected above the threshold.")
# print results for top n classes # print results for top n classes
if not args["quiet"] and (num_detections > 0): if num_detections > 0:
class_overall = pp.overall_class_pred( class_overall = pp.overall_class_pred(
predictions["det_probs"], predictions["class_probs"] predictions["det_probs"],
predictions["class_probs"],
) )
print("species name".ljust(30) + "probablity present") print("species name".ljust(30) + "probablity present")
for cc in np.argsort(class_overall)[::-1][:top_n]:
print(
params["class_names"][cc].ljust(30)
+ str(round(class_overall[cc], 3))
)
if return_raw_preds: for class_index in np.argsort(class_overall)[::-1][: config["top_n"]]:
return predictions print(
else: config["class_names"][class_index].ljust(30)
return results + str(round(class_overall[class_index], 3))
)

View File

@ -5,7 +5,6 @@ import bat_detect.utils.detector_utils as du
def main(args): def main(args):
print("Loading model: " + args["model_path"]) print("Loading model: " + args["model_path"])
model, params = du.load_model(args["model_path"]) model, params = du.load_model(args["model_path"])

View File

@ -136,8 +136,13 @@ if __name__ == "__main__":
audio, sampling_rate, params_bd, True, False audio, sampling_rate, params_bd, True, False
) )
run_config = {
**params_bd,
**bd_args,
}
# run model and filter detections so only keep ones in relevant time range # run model and filter detections so only keep ones in relevant time range
results = du.process_file(args_cmd["audio_file"], model, params_bd, bd_args) results = du.process_file(args_cmd["audio_file"], model, run_config)
pred_anns = filter_anns( pred_anns = filter_anns(
results["pred_dict"]["annotation"], results["pred_dict"]["annotation"],
args_cmd["start_time"], args_cmd["start_time"],

View File

@ -122,7 +122,12 @@ if __name__ == "__main__":
print(" Loading model and running detector on entire file ...") print(" Loading model and running detector on entire file ...")
model, det_params = du.load_model(args_cmd["model_path"]) model, det_params = du.load_model(args_cmd["model_path"])
det_params["detection_threshold"] = args["detection_threshold"] det_params["detection_threshold"] = args["detection_threshold"]
results = du.process_file(audio_file, model, det_params, args)
run_config = {
**det_params,
**args,
}
results = du.process_file(audio_file, model, run_config)
print(" Processing detections and plotting ...") print(" Processing detections and plotting ...")
detections = [] detections = []