Refactored detector_utils module

This commit is contained in:
Santiago Martinez 2023-02-22 20:24:43 +00:00
parent 7550c6faf1
commit 8da98b5258
9 changed files with 716 additions and 187 deletions

13
app.py
View File

@ -1,5 +1,3 @@
import os
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
@ -37,7 +35,6 @@ examples = [
def make_prediction(file_name=None, detection_threshold=0.3):
if file_name is not None:
audio_file = file_name
else:
@ -46,9 +43,17 @@ def make_prediction(file_name=None, detection_threshold=0.3):
if detection_threshold is not None and detection_threshold != "":
args["detection_threshold"] = float(detection_threshold)
run_config = {
**params,
**args,
"max_duration": max_duration,
}
# process the file to generate predictions
results = du.process_file(
audio_file, model, params, args, max_duration=max_duration
audio_file,
model,
run_config,
)
anns = [ann for ann in results["pred_dict"]["annotation"]]

View File

@ -1,3 +1,9 @@
"""Main script for running BatDetect2 on audio files.
Example usage:
python command.py /path/to/audio/ /path/to/ann/ 0.1
"""
import argparse
import os
@ -7,7 +13,6 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
def parse_args():
info_str = (
"\nBatDetect2 - Detection and Classification\n"
+ " Assumes audio files are mono, not stereo.\n"
@ -88,14 +93,22 @@ def main():
print("\nInput directory: " + args["audio_dir"])
files = du.get_audio_files(args["audio_dir"])
print("Number of audio files: {}".format(len(files)))
print(f"Number of audio files: {len(files)}")
print("\nSaving results to: " + args["ann_dir"])
# set up run config
run_config = {
**args,
**params,
}
# process files
error_files = []
for ii, audio_file in enumerate(files):
for audio_file in files:
try:
results = du.process_file(audio_file, model, params, args)
results = du.process_file(audio_file, model, run_config)
if args["save_preds_if_empty"] or (
len(results["pred_dict"]["annotation"]) > 0
):
@ -103,9 +116,9 @@ def main():
args["audio_dir"], args["ann_dir"]
)
du.save_results_to_file(results, results_path)
except:
except (RuntimeError, ValueError, LookupError) as err:
error_files.append(audio_file)
print("Error processing file!")
print(f"Error processing file!: {err}")
print("\nResults saved to: " + args["ann_dir"])

View File

@ -1,7 +1,11 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
np.seterr(divide="ignore", invalid="ignore")
@ -18,7 +22,33 @@ def overall_class_pred(det_prob, class_prob):
return weighted_pred / weighted_pred.sum()
def run_nms(outputs, params, sampling_rate):
class NonMaximumSuppressionConfig(TypedDict):
"""Configuration for non-maximum suppression."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: float
"""Maximum frequency to consider in Hz."""
min_freq: float
"""Minimum frequency to consider in Hz."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Overlap of the FFT windows in seconds."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""
def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
"""Run non-maximum suppression on the output of the model."""
pred_det = outputs["pred_det"] # probability of box
pred_size = outputs["pred_size"] # box size
@ -92,6 +122,7 @@ def run_nms(outputs, params, sampling_rate):
# convert to numpy
for kk in pred.keys():
pred[kk] = pred[kk].cpu().numpy().astype(np.float32)
preds.append(pred)
return preds, feats

View File

@ -6,14 +6,12 @@ import argparse
import copy
import json
import os
import sys
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
sys.path.append("../../")
import bat_detect.detector.parameters as parameters
from bat_detect.detector import parameters
import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu
import bat_detect.utils.detector_utils as du
@ -749,14 +747,18 @@ if __name__ == "__main__":
print("Warning: Class names are not the same as the trained model")
assert False
run_config = {
**bd_args,
**params_bd,
"return_raw_preds": True,
}
preds_bd = []
for ii, gg in enumerate(gt_test):
pred = du.process_file(
gg["file_path"],
model,
params_bd,
bd_args,
return_raw_preds=True,
run_config,
)
preds_bd.append(pred)

View File

@ -1,4 +1,5 @@
import warnings
from typing import Optional, Tuple
import librosa
import numpy as np
@ -7,6 +8,11 @@ import torch
from . import wavfile
__all__ = [
"load_audio_file",
]
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap * nfft)
@ -105,24 +111,49 @@ def generate_spectrogram(
def load_audio_file(
audio_file,
time_exp_fact,
target_samp_rate,
scale=False,
max_duration=False,
audio_file: str,
time_exp_fact: float,
target_samp_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
):
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported.
Args:
audio_file (str): Path to the audio file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
Returns:
sampling_rate: The sampling rate of the audio.
audio_raw: The audio signal in a numpy array.
Raises:
ValueError: If the audio file is stereo.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
# sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(audio_file, sr=None)
audio_raw, sampling_rate = librosa.load(
audio_file,
sr=None,
dtype=np.float32,
)
if len(audio_raw.shape) > 1:
raise Exception("Currently does not handle stereo files")
raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate
sampling_rate = target_samp_rate
if sampling_rate_old != sampling_rate:
audio_raw = librosa.resample(
audio_raw,
orig_sr=sampling_rate_old,
@ -131,14 +162,14 @@ def load_audio_file(
)
# clipping maximum duration
if max_duration is not False:
if max_duration is not None:
max_duration = np.minimum(
int(sampling_rate * max_duration), audio_raw.shape[0]
int(sampling_rate * max_duration),
audio_raw.shape[0],
)
audio_raw = audio_raw[:max_duration]
# convert to float32 and scale
audio_raw = audio_raw.astype(np.float32)
# scale to [-1, 1]
if scale:
audio_raw = audio_raw - audio_raw.mean()
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)

View File

@ -1,6 +1,6 @@
import json
import os
from typing import List, Tuple
from typing import Any, Iterator, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
@ -24,7 +24,7 @@ DEFAULT_MODEL_PATH = os.path.join(
"model.pth",
)
__all__ = ["load_model", "DEFAULT_MODEL_PATH"]
__all__ = ["load_model", "get_audio_files", "DEFAULT_MODEL_PATH"]
def get_default_bd_args():
@ -69,16 +69,29 @@ class ModelParameters(TypedDict):
"""Model parameters."""
model_name: str
"""Model name."""
num_filters: int
"""Number of filters."""
emb_dim: int
"""Embedding dimension."""
ip_height: int
"""Input height in pixels."""
resize_factor: int
"""Resize factor."""
class_names: List[str]
"""Class names. The model is trained to detect these classes."""
device: torch.device
def load_model(
model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True
model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True,
) -> Tuple[torch.nn.Module, ModelParameters]:
"""Load model from file.
@ -141,8 +154,7 @@ def load_model(
return model, params
def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {}
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
@ -157,82 +169,255 @@ def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
if len(spec_feats) > 0:
spec_feats = np.vstack(spec_feats)
if len(cnn_feats) > 0:
cnn_feats = np.vstack(cnn_feats)
return predictions_m, spec_feats, cnn_feats, spec_slices
class Annotation(TypedDict("WithClass", {"class": str})):
"""Format of annotations.
This is the format of a single annotation as expected by the annotation
tool.
"""
start_time: float
"""Start time in seconds."""
end_time: float
"""End time in seconds."""
low_freq: int
"""Low frequency in Hz."""
high_freq: int
"""High frequency in Hz."""
class_prob: float
"""Probability of class assignment."""
det_prob: float
"""Probability of detection."""
individual: str
"""Individual ID."""
event: str
"""Type of detected event."""
class FileAnnotations(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
"""
file_id: str
"""File ID."""
annotated: bool
"""Whether file has been annotated."""
duration: float
"""Duration of audio file."""
issues: bool
"""Whether file has issues."""
time_exp: float
"""Time expansion factor."""
class_name: str
"""Class predicted at file level"""
annotation: List[Annotation]
class Results(TypedDict):
pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool."""
spec_feats: Optional[np.ndarray]
"""Spectrogram features."""
spec_feat_names: Optional[List[str]]
"""Spectrogram feature names."""
cnn_feats: Optional[np.ndarray]
"""CNN features."""
cnn_feat_names: Optional[List[str]]
"""CNN feature names."""
spec_slices: Optional[np.ndarray]
"""Spectrogram slices."""
class ResultParams(TypedDict):
"""Result parameters."""
class_names: List[str]
"""Class names."""
def format_results(
file_id: str,
time_exp: float,
duration: float,
predictions,
class_names: List[str],
) -> FileAnnotations:
"""Format results into the format expected by the annotation tool.
Args:
file_id (str): File ID.
time_exp (float): Time expansion factor.
duration (float): Duration of audio file.
predictions (dict): Predictions.
Returns:
dict: Results in the format expected by the annotation tool.
"""
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
# Get the best class prediction probability and index for each detection
class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0)
# Pack the results into a list of dictionaries
annotations: List[Annotation] = [
{
"start_time": round(float(start_time), 4),
"end_time": round(end_time, 4),
"low_freq": int(low_freq),
"high_freq": int(high_freq),
"class": str(class_names[class_index]),
"class_prob": round(float(class_prob), 3),
"det_prob": round(float(det_prob), 3),
"individual": "-1",
"event": "Echolocation",
}
for (
start_time,
end_time,
low_freq,
high_freq,
class_index,
class_prob,
det_prob,
) in zip(
predictions["start_time"],
predictions["end_times"],
predictions["low_freqs"],
predictions["high_freqs"],
class_ind_best,
class_prob_best,
predictions["det_probs"],
)
]
return {
"id": file_id,
"annotated": False,
"issues": False,
"notes": "Automatically generated.",
"time_exp": time_exp,
"duration": round(duration, 4),
"annotation": annotations,
"class_name": class_names[np.argmax(class_overall)],
}
def convert_results(
file_id,
time_exp,
duration,
params,
file_id: str,
time_exp: float,
duration: float,
params: ResultParams,
predictions,
spec_feats,
cnn_feats,
spec_slices,
):
) -> Results:
"""Convert results to dictionary as expected by the annotation tool.
# create a single dictionary - this is the format used by the annotation tool
pred_dict = {}
pred_dict["id"] = file_id
pred_dict["annotated"] = False
pred_dict["issues"] = False
pred_dict["notes"] = "Automatically generated."
pred_dict["time_exp"] = time_exp
pred_dict["duration"] = round(duration, 4)
pred_dict["annotation"] = []
Args:
file_id (str): File ID.
time_exp (float): Time expansion factor.
duration (float): Duration of audio file.
params (dict): Model parameters.
predictions (dict): Predictions.
spec_feats (np.ndarray): Spectral features.
cnn_feats (np.ndarray): CNN features.
spec_slices (list): Spectrogram slices.
class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0)
class_overall = pp.overall_class_pred(
predictions["det_probs"], predictions["class_probs"]
Returns:
dict: Dictionary with results.
"""
pred_dict = format_results(
file_id,
time_exp,
duration,
predictions,
params["class_names"],
)
pred_dict["class_name"] = params["class_names"][np.argmax(class_overall)]
for ii in range(predictions["det_probs"].shape[0]):
res = {}
res["start_time"] = round(float(predictions["start_times"][ii]), 4)
res["end_time"] = round(float(predictions["end_times"][ii]), 4)
res["low_freq"] = int(predictions["low_freqs"][ii])
res["high_freq"] = int(predictions["high_freqs"][ii])
res["class"] = str(params["class_names"][int(class_ind_best[ii])])
res["class_prob"] = round(float(class_prob_best[ii]), 3)
res["det_prob"] = round(float(predictions["det_probs"][ii]), 3)
res["individual"] = "-1"
res["event"] = "Echolocation"
pred_dict["annotation"].append(res)
# combine into final results dictionary
results = {}
results["pred_dict"] = pred_dict
# add spectrogram features if they exist
if len(spec_feats) > 0:
results["spec_feats"] = spec_feats
results["spec_feat_names"] = feats.get_feature_names()
# add CNN features if they exist
if len(cnn_feats) > 0:
results["cnn_feats"] = cnn_feats
results["cnn_feat_names"] = [
str(ii) for ii in range(cnn_feats.shape[1])
]
# add spectrogram slices if they exist
if len(spec_slices) > 0:
results["spec_slices"] = spec_slices
return results
def save_results_to_file(results, op_path):
def save_results_to_file(results, op_path: str) -> None:
"""Save results to file.
Args:
results (dict): Results.
op_path (str): Output path.
"""
# make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)):
os.makedirs(os.path.dirname(op_path))
# save csv file - if there are predictions
result_list = [res for res in results["pred_dict"]["annotation"]]
df = pd.DataFrame(result_list)
df["file_name"] = [results["pred_dict"]["id"]] * len(result_list)
df.index.name = "id"
if "class_prob" in df.columns:
df = df[
result_list = results["pred_dict"]["annotation"]
results_df = pd.DataFrame(result_list)
# add file name as a column
results_df["file_name"] = results["pred_dict"]["id"]
# rename index column
results_df.index.name = "id"
# create a csv file with predicted events
if "class_prob" in results_df.columns:
preds_df = results_df[
[
"det_prob",
"start_time",
@ -243,14 +428,14 @@ def save_results_to_file(results, op_path):
"class_prob",
]
]
df.to_csv(op_path + ".csv", sep=",")
preds_df.to_csv(op_path + ".csv", sep=",")
# save features
if "spec_feats" in results.keys():
df = pd.DataFrame(
# create csv file with spectrogram features
spec_feats_df = pd.DataFrame(
results["spec_feats"], columns=results["spec_feat_names"]
)
df.to_csv(
spec_feats_df.to_csv(
op_path + "_spec_features.csv",
sep=",",
index=False,
@ -258,10 +443,12 @@ def save_results_to_file(results, op_path):
)
if "cnn_feats" in results.keys():
df = pd.DataFrame(
results["cnn_feats"], columns=results["cnn_feat_names"]
# create csv file with cnn extracted features
cnn_feats_df = pd.DataFrame(
results["cnn_feats"],
columns=results["cnn_feat_names"],
)
df.to_csv(
cnn_feats_df.to_csv(
op_path + "_cnn_features.csv",
sep=",",
index=False,
@ -269,11 +456,71 @@ def save_results_to_file(results, op_path):
)
# save json file
with open(op_path + ".json", "w") as da:
json.dump(results["pred_dict"], da, indent=2, sort_keys=True)
with open(op_path + ".json", "w", encoding="utf-8") as jsonfile:
json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True)
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: int
"""Length of the FFT window in samples."""
fft_overlap: int
"""Number of samples to overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
spec_width: int
"""Width of the spectrogram in pixels."""
resize_factor: int
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
device: torch.device
"""Device to store the spectrogram on."""
def compute_spectrogram(
audio: np.ndarray,
sampling_rate: int,
params: SpectrogramParameters,
return_np: bool = False,
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
"""Compute a spectrogram from an audio array.
Will pad the audio array so that it is evenly divisible by the
downsampling factors.
Parameters
----------
audio : np.ndarray
sampling_rate : int
params : SpectrogramParameters
The parameters to use for generating the spectrogram.
return_np : bool, optional
Whether to return the spectrogram as a numpy array as well as a
torch tensor. The default is False.
Returns
-------
duration : float
The duration of the spectrgram in seconds.
spec : torch.Tensor
The spectrogram as a torch tensor.
spec_np : np.ndarray, optional
The spectrogram as a numpy array. Only returned if `return_np` is
True, otherwise None.
"""
# pad audio so it is evenly divisible by downsampling factors
duration = audio.shape[0] / float(sampling_rate)
audio = au.pad_audio(
@ -290,13 +537,21 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False):
# convert to pytorch
spec = torch.from_numpy(spec).to(params["device"])
# add batch and channel dimensions
spec = spec.unsqueeze(0).unsqueeze(0)
# resize the spec
rs = params["resize_factor"]
spec_op_shape = (int(params["spec_height"] * rs), int(spec.shape[-1] * rs))
resize_factor = params["resize_factor"]
spec_op_shape = (
int(params["spec_height"] * resize_factor),
int(spec.shape[-1] * resize_factor),
)
spec = F.interpolate(
spec, size=spec_op_shape, mode="bilinear", align_corners=False
spec,
size=spec_op_shape,
mode="bilinear",
align_corners=False,
)
if return_np:
@ -307,135 +562,318 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False):
return duration, spec, spec_np
def process_file(
audio_file,
model,
params,
args,
time_exp=None,
top_n=5,
return_raw_preds=False,
max_duration=False,
):
def iterate_over_chunks(
audio: np.ndarray,
samplerate: int,
chunk_size: float,
) -> Iterator[Tuple[float, np.ndarray]]:
"""Iterate over audio in chunks of size chunk_size.
Parameters
----------
audio : np.ndarray
samplerate : int
chunk_size : float
Size of chunks in seconds.
Yields
------
chunk_start : float
Start time of chunk in seconds.
chunk : np.ndarray
"""
nsamples = audio.shape[0]
duration_full = nsamples / samplerate
num_chunks = int(np.ceil(duration_full / chunk_size))
for chunk_id in range(num_chunks):
chunk_start = chunk_size * chunk_id
chunk_length = int(samplerate * chunk_size)
start_sample = chunk_id * chunk_length
end_sample = np.minimum((chunk_id + 1) * chunk_length, nsamples)
yield chunk_start, audio[start_sample:end_sample]
class ProcessingConfiguration(TypedDict):
"""Parameters for processing audio files."""
# audio parameters
target_samp_rate: int
"""Target sampling rate of the audio."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Length of the FFT window in samples."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: float
"""Factor to divide the spectrogram by."""
spec_height: int
"""Height of the spectrogram in pixels."""
scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1."""
device: torch.device
"""Device to run the model on."""
class_names: List[str]
"""Names of the classes the model can detect."""
detection_threshold: float
"""Threshold for detection probability."""
time_expansion: Optional[float]
"""Time expansion factor of the processed recordings."""
top_n: int
"""Number of top detections to keep."""
return_raw_preds: bool
"""Whether to return raw predictions."""
max_duration: Optional[float]
"""Maximum duration of audio file to process in seconds."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: float
"""Maximum frequency to consider in Hz."""
min_freq: float
"""Minimum frequency to consider in Hz."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""
quiet: bool
"""Whether to suppress output."""
def process_spectrogram(
spec: torch.Tensor,
samplerate: int,
model: torch.nn.Module,
config: pp.NonMaximumSuppressionConfig,
):
"""Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model.
Parameters
----------
spec : torch.Tensor
samplerate : int
model : torch.nn.Module
Detection model.
config : pp.NonMaximumSuppressionConfig
Parameters for non-maximum suppression.
Returns
-------
pred_nms : Dict[str, np.ndarray]
features : Dict[str, np.ndarray]
"""
# evaluate model
with torch.no_grad():
outputs = model(spec, return_feats=config["cnn_features"])
# run non-max suppression
pred_nms, features = pp.run_nms(
outputs,
config,
np.array([float(samplerate)]),
)
pred_nms = pred_nms[0]
# if we have a background class
if pred_nms["class_probs"].shape[0] > len(config["class_names"]):
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
return pred_nms, features
def process_audio_array(
audio: np.ndarray,
sampling_rate: int,
model: torch.nn.Module,
config: ProcessingConfiguration,
):
"""Process a single audio array with detection model.
Parameters
----------
audio : np.ndarray
sampling_rate : int
model : torch.nn.Module
Detection model.
config : ProcessingConfiguration
Configuration for processing.
Returns
-------
pred_nms : Dict[str, np.ndarray]
features : Dict[str, np.ndarray]
spec_np : np.ndarray
"""
# load audio file and compute spectrogram
_, spec, spec_np = compute_spectrogram(
audio,
sampling_rate,
config,
return_np=config["spec_features"] or config["spec_slices"],
)
# process spectrogram with model
pred_nms, features = process_spectrogram(
spec,
sampling_rate,
model,
config,
)
return pred_nms, features, spec_np
def process_file(
audio_file: str,
model: torch.nn.Module,
config: ProcessingConfiguration,
) -> Union[Results, Any]:
"""Process a single audio file with detection model.
Will split the audio file into chunks if it is too long and
process each chunk separately.
Parameters
----------
audio_file : str
Path to audio file.
model : torch.nn.Module
Detection model.
config : ProcessingConfiguration
Configuration for processing.
Returns
-------
results : Results or Any
Results of processing audio file with the given detection model.
Will be a dictionary if `config["return_raw_preds"]` is `True`,
"""
# store temporary results here
predictions = []
spec_feats = []
cnn_feats = []
spec_slices = []
# get time expansion factor
if time_exp is None:
time_exp = args["time_expansion_factor"]
params["detection_threshold"] = args["detection_threshold"]
# load audio file
sampling_rate, audio_full = au.load_audio_file(
audio_file,
time_exp,
params["target_samp_rate"],
params["scale_raw_audio"],
time_exp_fact=config["time_expansion"],
target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"],
max_duration=config["max_duration"],
)
# clipping maximum duration
if max_duration is not False:
max_duration = np.minimum(
int(sampling_rate * max_duration), audio_full.shape[0]
)
audio_full = audio_full[:max_duration]
duration_full = audio_full.shape[0] / float(sampling_rate)
return_np_spec = args["spec_features"] or args["spec_slices"]
# loop through larger file and split into chunks
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
num_chunks = int(np.ceil(duration_full / args["chunk_size"]))
for chunk_id in range(num_chunks):
# chunk
chunk_time = args["chunk_size"] * chunk_id
chunk_length = int(sampling_rate * args["chunk_size"])
start_sample = chunk_id * chunk_length
end_sample = np.minimum(
(chunk_id + 1) * chunk_length, audio_full.shape[0]
)
audio = audio_full[start_sample:end_sample]
# load audio file and compute spectrogram
duration, spec, spec_np = compute_spectrogram(
audio, sampling_rate, params, return_np_spec
# TODO fix so that it overlaps correctly and takes care of
# duplicate detections at borders
for chunk_time, audio in iterate_over_chunks(
audio_full,
sampling_rate,
config["chunk_size"],
):
# Run detection model on chunk
pred_nms, features, spec_np = process_audio_array(
audio,
sampling_rate,
model,
config,
)
# evaluate model
with torch.no_grad():
outputs = model(spec, return_feats=args["cnn_features"])
# run non-max suppression
pred_nms, features = pp.run_nms(
outputs, params, np.array([float(sampling_rate)])
)
pred_nms = pred_nms[0]
# add chunk time to start and end times
pred_nms["start_times"] += chunk_time
pred_nms["end_times"] += chunk_time
# if we have a background class
if pred_nms["class_probs"].shape[0] > len(params["class_names"]):
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
predictions.append(pred_nms)
# extract features - if there are any calls detected
if pred_nms["det_probs"].shape[0] > 0:
if args["spec_features"]:
spec_feats.append(feats.get_feats(spec_np, pred_nms, params))
if config["spec_features"]:
spec_feats.append(feats.get_feats(spec_np, pred_nms, config))
if args["cnn_features"]:
if config["cnn_features"]:
cnn_feats.append(features[0])
if args["spec_slices"]:
if config["spec_slices"]:
spec_slices.extend(
feats.extract_spec_slices(spec_np, pred_nms, params)
feats.extract_spec_slices(spec_np, pred_nms, config)
)
# convert the predictions into output dictionary
file_id = os.path.basename(audio_file)
predictions, spec_feats, cnn_feats, spec_slices = merge_results(
predictions, spec_feats, cnn_feats, spec_slices
)
results = convert_results(
file_id,
time_exp,
duration_full,
params,
# Merge results from chunks
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
predictions,
spec_feats,
cnn_feats,
spec_slices,
)
# summarize results
if not args["quiet"]:
num_detections = len(results["pred_dict"]["annotation"])
print(
"{}".format(num_detections)
+ " call(s) detected above the threshold."
# convert results to a dictionary in the right format
results = convert_results(
file_id=os.path.basename(audio_file),
time_exp=config["time_expansion"],
duration=audio_full.shape[0] / float(sampling_rate),
params=config,
predictions=predictions,
spec_feats=spec_feats,
cnn_feats=cnn_feats,
spec_slices=spec_slices,
)
# summarize results
if not config["quiet"]:
summarize_results(results, predictions, config)
if config["return_raw_preds"]:
return predictions
return results
def summarize_results(results, predictions, config):
"""Print summary of results."""
num_detections = len(results["pred_dict"]["annotation"])
print(f"{num_detections} call(s) detected above the threshold.")
# print results for top n classes
if not args["quiet"] and (num_detections > 0):
if num_detections > 0:
class_overall = pp.overall_class_pred(
predictions["det_probs"], predictions["class_probs"]
predictions["det_probs"],
predictions["class_probs"],
)
print("species name".ljust(30) + "probablity present")
for cc in np.argsort(class_overall)[::-1][:top_n]:
print(
params["class_names"][cc].ljust(30)
+ str(round(class_overall[cc], 3))
)
if return_raw_preds:
return predictions
else:
return results
for class_index in np.argsort(class_overall)[::-1][: config["top_n"]]:
print(
config["class_names"][class_index].ljust(30)
+ str(round(class_overall[class_index], 3))
)

View File

@ -5,7 +5,6 @@ import bat_detect.utils.detector_utils as du
def main(args):
print("Loading model: " + args["model_path"])
model, params = du.load_model(args["model_path"])

View File

@ -136,8 +136,13 @@ if __name__ == "__main__":
audio, sampling_rate, params_bd, True, False
)
run_config = {
**params_bd,
**bd_args,
}
# run model and filter detections so only keep ones in relevant time range
results = du.process_file(args_cmd["audio_file"], model, params_bd, bd_args)
results = du.process_file(args_cmd["audio_file"], model, run_config)
pred_anns = filter_anns(
results["pred_dict"]["annotation"],
args_cmd["start_time"],

View File

@ -122,7 +122,12 @@ if __name__ == "__main__":
print(" Loading model and running detector on entire file ...")
model, det_params = du.load_model(args_cmd["model_path"])
det_params["detection_threshold"] = args["detection_threshold"]
results = du.process_file(audio_file, model, det_params, args)
run_config = {
**det_params,
**args,
}
results = du.process_file(audio_file, model, run_config)
print(" Processing detections and plotting ...")
detections = []