mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Refactored detector_utils module
This commit is contained in:
parent
7550c6faf1
commit
8da98b5258
13
app.py
13
app.py
@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@ -37,7 +35,6 @@ examples = [
|
||||
|
||||
|
||||
def make_prediction(file_name=None, detection_threshold=0.3):
|
||||
|
||||
if file_name is not None:
|
||||
audio_file = file_name
|
||||
else:
|
||||
@ -46,9 +43,17 @@ def make_prediction(file_name=None, detection_threshold=0.3):
|
||||
if detection_threshold is not None and detection_threshold != "":
|
||||
args["detection_threshold"] = float(detection_threshold)
|
||||
|
||||
run_config = {
|
||||
**params,
|
||||
**args,
|
||||
"max_duration": max_duration,
|
||||
}
|
||||
|
||||
# process the file to generate predictions
|
||||
results = du.process_file(
|
||||
audio_file, model, params, args, max_duration=max_duration
|
||||
audio_file,
|
||||
model,
|
||||
run_config,
|
||||
)
|
||||
|
||||
anns = [ann for ann in results["pred_dict"]["annotation"]]
|
||||
|
@ -1,3 +1,9 @@
|
||||
"""Main script for running BatDetect2 on audio files.
|
||||
|
||||
Example usage:
|
||||
python command.py /path/to/audio/ /path/to/ann/ 0.1
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
@ -7,7 +13,6 @@ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
info_str = (
|
||||
"\nBatDetect2 - Detection and Classification\n"
|
||||
+ " Assumes audio files are mono, not stereo.\n"
|
||||
@ -88,14 +93,22 @@ def main():
|
||||
|
||||
print("\nInput directory: " + args["audio_dir"])
|
||||
files = du.get_audio_files(args["audio_dir"])
|
||||
print("Number of audio files: {}".format(len(files)))
|
||||
|
||||
print(f"Number of audio files: {len(files)}")
|
||||
print("\nSaving results to: " + args["ann_dir"])
|
||||
|
||||
# set up run config
|
||||
run_config = {
|
||||
**args,
|
||||
**params,
|
||||
}
|
||||
|
||||
# process files
|
||||
error_files = []
|
||||
for ii, audio_file in enumerate(files):
|
||||
for audio_file in files:
|
||||
try:
|
||||
results = du.process_file(audio_file, model, params, args)
|
||||
results = du.process_file(audio_file, model, run_config)
|
||||
|
||||
if args["save_preds_if_empty"] or (
|
||||
len(results["pred_dict"]["annotation"]) > 0
|
||||
):
|
||||
@ -103,9 +116,9 @@ def main():
|
||||
args["audio_dir"], args["ann_dir"]
|
||||
)
|
||||
du.save_results_to_file(results, results_path)
|
||||
except:
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
error_files.append(audio_file)
|
||||
print("Error processing file!")
|
||||
print(f"Error processing file!: {err}")
|
||||
|
||||
print("\nResults saved to: " + args["ann_dir"])
|
||||
|
||||
|
@ -1,7 +1,11 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
np.seterr(divide="ignore", invalid="ignore")
|
||||
|
||||
@ -18,7 +22,33 @@ def overall_class_pred(det_prob, class_prob):
|
||||
return weighted_pred / weighted_pred.sum()
|
||||
|
||||
|
||||
def run_nms(outputs, params, sampling_rate):
|
||||
class NonMaximumSuppressionConfig(TypedDict):
|
||||
"""Configuration for non-maximum suppression."""
|
||||
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: float
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: float
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Overlap of the FFT windows in seconds."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
|
||||
def run_nms(outputs, params: NonMaximumSuppressionConfig, sampling_rate: int):
|
||||
"""Run non-maximum suppression on the output of the model."""
|
||||
|
||||
pred_det = outputs["pred_det"] # probability of box
|
||||
pred_size = outputs["pred_size"] # box size
|
||||
@ -92,6 +122,7 @@ def run_nms(outputs, params, sampling_rate):
|
||||
# convert to numpy
|
||||
for kk in pred.keys():
|
||||
pred[kk] = pred[kk].cpu().numpy().astype(np.float32)
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return preds, feats
|
||||
|
@ -6,14 +6,12 @@ import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
sys.path.append("../../")
|
||||
import bat_detect.detector.parameters as parameters
|
||||
from bat_detect.detector import parameters
|
||||
import bat_detect.train.evaluate as evl
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.utils.detector_utils as du
|
||||
@ -749,14 +747,18 @@ if __name__ == "__main__":
|
||||
print("Warning: Class names are not the same as the trained model")
|
||||
assert False
|
||||
|
||||
run_config = {
|
||||
**bd_args,
|
||||
**params_bd,
|
||||
"return_raw_preds": True,
|
||||
}
|
||||
|
||||
preds_bd = []
|
||||
for ii, gg in enumerate(gt_test):
|
||||
pred = du.process_file(
|
||||
gg["file_path"],
|
||||
model,
|
||||
params_bd,
|
||||
bd_args,
|
||||
return_raw_preds=True,
|
||||
run_config,
|
||||
)
|
||||
preds_bd.append(pred)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
@ -7,6 +8,11 @@ import torch
|
||||
from . import wavfile
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_audio_file",
|
||||
]
|
||||
|
||||
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
@ -105,24 +111,49 @@ def generate_spectrogram(
|
||||
|
||||
|
||||
def load_audio_file(
|
||||
audio_file,
|
||||
time_exp_fact,
|
||||
target_samp_rate,
|
||||
scale=False,
|
||||
max_duration=False,
|
||||
audio_file: str,
|
||||
time_exp_fact: float,
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
):
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
Only mono files are supported.
|
||||
|
||||
Args:
|
||||
audio_file (str): Path to the audio file.
|
||||
target_samp_rate (int): Target sampling rate.
|
||||
scale (bool): Whether to scale the audio to [-1, 1].
|
||||
max_duration (float): Maximum duration of the audio in seconds.
|
||||
|
||||
Returns:
|
||||
sampling_rate: The sampling rate of the audio.
|
||||
audio_raw: The audio signal in a numpy array.
|
||||
|
||||
Raises:
|
||||
ValueError: If the audio file is stereo.
|
||||
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||
audio_raw, sampling_rate = librosa.load(audio_file, sr=None)
|
||||
audio_raw, sampling_rate = librosa.load(
|
||||
audio_file,
|
||||
sr=None,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
if len(audio_raw.shape) > 1:
|
||||
raise Exception("Currently does not handle stereo files")
|
||||
raise ValueError("Currently does not handle stereo files")
|
||||
|
||||
sampling_rate = sampling_rate * time_exp_fact
|
||||
|
||||
# resample - need to do this after correcting for time expansion
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = target_samp_rate
|
||||
if sampling_rate_old != sampling_rate:
|
||||
audio_raw = librosa.resample(
|
||||
audio_raw,
|
||||
orig_sr=sampling_rate_old,
|
||||
@ -131,14 +162,14 @@ def load_audio_file(
|
||||
)
|
||||
|
||||
# clipping maximum duration
|
||||
if max_duration is not False:
|
||||
if max_duration is not None:
|
||||
max_duration = np.minimum(
|
||||
int(sampling_rate * max_duration), audio_raw.shape[0]
|
||||
int(sampling_rate * max_duration),
|
||||
audio_raw.shape[0],
|
||||
)
|
||||
audio_raw = audio_raw[:max_duration]
|
||||
|
||||
# convert to float32 and scale
|
||||
audio_raw = audio_raw.astype(np.float32)
|
||||
# scale to [-1, 1]
|
||||
if scale:
|
||||
audio_raw = audio_raw - audio_raw.mean()
|
||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -24,7 +24,7 @@ DEFAULT_MODEL_PATH = os.path.join(
|
||||
"model.pth",
|
||||
)
|
||||
|
||||
__all__ = ["load_model", "DEFAULT_MODEL_PATH"]
|
||||
__all__ = ["load_model", "get_audio_files", "DEFAULT_MODEL_PATH"]
|
||||
|
||||
|
||||
def get_default_bd_args():
|
||||
@ -69,16 +69,29 @@ class ModelParameters(TypedDict):
|
||||
"""Model parameters."""
|
||||
|
||||
model_name: str
|
||||
"""Model name."""
|
||||
|
||||
num_filters: int
|
||||
"""Number of filters."""
|
||||
|
||||
emb_dim: int
|
||||
"""Embedding dimension."""
|
||||
|
||||
ip_height: int
|
||||
"""Input height in pixels."""
|
||||
|
||||
resize_factor: int
|
||||
"""Resize factor."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names. The model is trained to detect these classes."""
|
||||
|
||||
device: torch.device
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
) -> Tuple[torch.nn.Module, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
@ -141,8 +154,7 @@ def load_model(
|
||||
return model, params
|
||||
|
||||
|
||||
def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
|
||||
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
predictions_m = {}
|
||||
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
||||
|
||||
@ -157,82 +169,255 @@ def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
|
||||
if len(spec_feats) > 0:
|
||||
spec_feats = np.vstack(spec_feats)
|
||||
|
||||
if len(cnn_feats) > 0:
|
||||
cnn_feats = np.vstack(cnn_feats)
|
||||
|
||||
return predictions_m, spec_feats, cnn_feats, spec_slices
|
||||
|
||||
|
||||
class Annotation(TypedDict("WithClass", {"class": str})):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
tool.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
"""Start time in seconds."""
|
||||
|
||||
end_time: float
|
||||
"""End time in seconds."""
|
||||
|
||||
low_freq: int
|
||||
"""Low frequency in Hz."""
|
||||
|
||||
high_freq: int
|
||||
"""High frequency in Hz."""
|
||||
|
||||
class_prob: float
|
||||
"""Probability of class assignment."""
|
||||
|
||||
det_prob: float
|
||||
"""Probability of detection."""
|
||||
|
||||
individual: str
|
||||
"""Individual ID."""
|
||||
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
|
||||
class FileAnnotations(TypedDict):
|
||||
"""Format of results.
|
||||
|
||||
This is the format of the results expected by the annotation tool.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
"""File ID."""
|
||||
|
||||
annotated: bool
|
||||
"""Whether file has been annotated."""
|
||||
|
||||
duration: float
|
||||
"""Duration of audio file."""
|
||||
|
||||
issues: bool
|
||||
"""Whether file has issues."""
|
||||
|
||||
time_exp: float
|
||||
"""Time expansion factor."""
|
||||
|
||||
class_name: str
|
||||
"""Class predicted at file level"""
|
||||
|
||||
annotation: List[Annotation]
|
||||
|
||||
|
||||
class Results(TypedDict):
|
||||
pred_dict: FileAnnotations
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: Optional[np.ndarray]
|
||||
"""Spectrogram features."""
|
||||
|
||||
spec_feat_names: Optional[List[str]]
|
||||
"""Spectrogram feature names."""
|
||||
|
||||
cnn_feats: Optional[np.ndarray]
|
||||
"""CNN features."""
|
||||
|
||||
cnn_feat_names: Optional[List[str]]
|
||||
"""CNN feature names."""
|
||||
|
||||
spec_slices: Optional[np.ndarray]
|
||||
"""Spectrogram slices."""
|
||||
|
||||
|
||||
class ResultParams(TypedDict):
|
||||
"""Result parameters."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names."""
|
||||
|
||||
|
||||
def format_results(
|
||||
file_id: str,
|
||||
time_exp: float,
|
||||
duration: float,
|
||||
predictions,
|
||||
class_names: List[str],
|
||||
) -> FileAnnotations:
|
||||
"""Format results into the format expected by the annotation tool.
|
||||
|
||||
Args:
|
||||
file_id (str): File ID.
|
||||
time_exp (float): Time expansion factor.
|
||||
duration (float): Duration of audio file.
|
||||
predictions (dict): Predictions.
|
||||
|
||||
Returns:
|
||||
dict: Results in the format expected by the annotation tool.
|
||||
"""
|
||||
# Get a single class prediction for the file
|
||||
class_overall = pp.overall_class_pred(
|
||||
predictions["det_probs"],
|
||||
predictions["class_probs"],
|
||||
)
|
||||
|
||||
# Get the best class prediction probability and index for each detection
|
||||
class_prob_best = predictions["class_probs"].max(0)
|
||||
class_ind_best = predictions["class_probs"].argmax(0)
|
||||
|
||||
# Pack the results into a list of dictionaries
|
||||
annotations: List[Annotation] = [
|
||||
{
|
||||
"start_time": round(float(start_time), 4),
|
||||
"end_time": round(end_time, 4),
|
||||
"low_freq": int(low_freq),
|
||||
"high_freq": int(high_freq),
|
||||
"class": str(class_names[class_index]),
|
||||
"class_prob": round(float(class_prob), 3),
|
||||
"det_prob": round(float(det_prob), 3),
|
||||
"individual": "-1",
|
||||
"event": "Echolocation",
|
||||
}
|
||||
for (
|
||||
start_time,
|
||||
end_time,
|
||||
low_freq,
|
||||
high_freq,
|
||||
class_index,
|
||||
class_prob,
|
||||
det_prob,
|
||||
) in zip(
|
||||
predictions["start_time"],
|
||||
predictions["end_times"],
|
||||
predictions["low_freqs"],
|
||||
predictions["high_freqs"],
|
||||
class_ind_best,
|
||||
class_prob_best,
|
||||
predictions["det_probs"],
|
||||
)
|
||||
]
|
||||
|
||||
return {
|
||||
"id": file_id,
|
||||
"annotated": False,
|
||||
"issues": False,
|
||||
"notes": "Automatically generated.",
|
||||
"time_exp": time_exp,
|
||||
"duration": round(duration, 4),
|
||||
"annotation": annotations,
|
||||
"class_name": class_names[np.argmax(class_overall)],
|
||||
}
|
||||
|
||||
|
||||
def convert_results(
|
||||
file_id,
|
||||
time_exp,
|
||||
duration,
|
||||
params,
|
||||
file_id: str,
|
||||
time_exp: float,
|
||||
duration: float,
|
||||
params: ResultParams,
|
||||
predictions,
|
||||
spec_feats,
|
||||
cnn_feats,
|
||||
spec_slices,
|
||||
):
|
||||
) -> Results:
|
||||
"""Convert results to dictionary as expected by the annotation tool.
|
||||
|
||||
# create a single dictionary - this is the format used by the annotation tool
|
||||
pred_dict = {}
|
||||
pred_dict["id"] = file_id
|
||||
pred_dict["annotated"] = False
|
||||
pred_dict["issues"] = False
|
||||
pred_dict["notes"] = "Automatically generated."
|
||||
pred_dict["time_exp"] = time_exp
|
||||
pred_dict["duration"] = round(duration, 4)
|
||||
pred_dict["annotation"] = []
|
||||
Args:
|
||||
file_id (str): File ID.
|
||||
time_exp (float): Time expansion factor.
|
||||
duration (float): Duration of audio file.
|
||||
params (dict): Model parameters.
|
||||
predictions (dict): Predictions.
|
||||
spec_feats (np.ndarray): Spectral features.
|
||||
cnn_feats (np.ndarray): CNN features.
|
||||
spec_slices (list): Spectrogram slices.
|
||||
|
||||
class_prob_best = predictions["class_probs"].max(0)
|
||||
class_ind_best = predictions["class_probs"].argmax(0)
|
||||
class_overall = pp.overall_class_pred(
|
||||
predictions["det_probs"], predictions["class_probs"]
|
||||
Returns:
|
||||
dict: Dictionary with results.
|
||||
|
||||
"""
|
||||
pred_dict = format_results(
|
||||
file_id,
|
||||
time_exp,
|
||||
duration,
|
||||
predictions,
|
||||
params["class_names"],
|
||||
)
|
||||
pred_dict["class_name"] = params["class_names"][np.argmax(class_overall)]
|
||||
|
||||
for ii in range(predictions["det_probs"].shape[0]):
|
||||
res = {}
|
||||
res["start_time"] = round(float(predictions["start_times"][ii]), 4)
|
||||
res["end_time"] = round(float(predictions["end_times"][ii]), 4)
|
||||
res["low_freq"] = int(predictions["low_freqs"][ii])
|
||||
res["high_freq"] = int(predictions["high_freqs"][ii])
|
||||
res["class"] = str(params["class_names"][int(class_ind_best[ii])])
|
||||
res["class_prob"] = round(float(class_prob_best[ii]), 3)
|
||||
res["det_prob"] = round(float(predictions["det_probs"][ii]), 3)
|
||||
res["individual"] = "-1"
|
||||
res["event"] = "Echolocation"
|
||||
pred_dict["annotation"].append(res)
|
||||
|
||||
# combine into final results dictionary
|
||||
results = {}
|
||||
results["pred_dict"] = pred_dict
|
||||
|
||||
# add spectrogram features if they exist
|
||||
if len(spec_feats) > 0:
|
||||
results["spec_feats"] = spec_feats
|
||||
results["spec_feat_names"] = feats.get_feature_names()
|
||||
|
||||
# add CNN features if they exist
|
||||
if len(cnn_feats) > 0:
|
||||
results["cnn_feats"] = cnn_feats
|
||||
results["cnn_feat_names"] = [
|
||||
str(ii) for ii in range(cnn_feats.shape[1])
|
||||
]
|
||||
|
||||
# add spectrogram slices if they exist
|
||||
if len(spec_slices) > 0:
|
||||
results["spec_slices"] = spec_slices
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def save_results_to_file(results, op_path):
|
||||
def save_results_to_file(results, op_path: str) -> None:
|
||||
"""Save results to file.
|
||||
|
||||
Args:
|
||||
results (dict): Results.
|
||||
op_path (str): Output path.
|
||||
|
||||
"""
|
||||
|
||||
# make directory if it does not exist
|
||||
if not os.path.isdir(os.path.dirname(op_path)):
|
||||
os.makedirs(os.path.dirname(op_path))
|
||||
|
||||
# save csv file - if there are predictions
|
||||
result_list = [res for res in results["pred_dict"]["annotation"]]
|
||||
df = pd.DataFrame(result_list)
|
||||
df["file_name"] = [results["pred_dict"]["id"]] * len(result_list)
|
||||
df.index.name = "id"
|
||||
if "class_prob" in df.columns:
|
||||
df = df[
|
||||
result_list = results["pred_dict"]["annotation"]
|
||||
|
||||
results_df = pd.DataFrame(result_list)
|
||||
|
||||
# add file name as a column
|
||||
results_df["file_name"] = results["pred_dict"]["id"]
|
||||
|
||||
# rename index column
|
||||
results_df.index.name = "id"
|
||||
|
||||
# create a csv file with predicted events
|
||||
if "class_prob" in results_df.columns:
|
||||
preds_df = results_df[
|
||||
[
|
||||
"det_prob",
|
||||
"start_time",
|
||||
@ -243,14 +428,14 @@ def save_results_to_file(results, op_path):
|
||||
"class_prob",
|
||||
]
|
||||
]
|
||||
df.to_csv(op_path + ".csv", sep=",")
|
||||
preds_df.to_csv(op_path + ".csv", sep=",")
|
||||
|
||||
# save features
|
||||
if "spec_feats" in results.keys():
|
||||
df = pd.DataFrame(
|
||||
# create csv file with spectrogram features
|
||||
spec_feats_df = pd.DataFrame(
|
||||
results["spec_feats"], columns=results["spec_feat_names"]
|
||||
)
|
||||
df.to_csv(
|
||||
spec_feats_df.to_csv(
|
||||
op_path + "_spec_features.csv",
|
||||
sep=",",
|
||||
index=False,
|
||||
@ -258,10 +443,12 @@ def save_results_to_file(results, op_path):
|
||||
)
|
||||
|
||||
if "cnn_feats" in results.keys():
|
||||
df = pd.DataFrame(
|
||||
results["cnn_feats"], columns=results["cnn_feat_names"]
|
||||
# create csv file with cnn extracted features
|
||||
cnn_feats_df = pd.DataFrame(
|
||||
results["cnn_feats"],
|
||||
columns=results["cnn_feat_names"],
|
||||
)
|
||||
df.to_csv(
|
||||
cnn_feats_df.to_csv(
|
||||
op_path + "_cnn_features.csv",
|
||||
sep=",",
|
||||
index=False,
|
||||
@ -269,11 +456,71 @@ def save_results_to_file(results, op_path):
|
||||
)
|
||||
|
||||
# save json file
|
||||
with open(op_path + ".json", "w") as da:
|
||||
json.dump(results["pred_dict"], da, indent=2, sort_keys=True)
|
||||
with open(op_path + ".json", "w", encoding="utf-8") as jsonfile:
|
||||
json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True)
|
||||
|
||||
|
||||
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
||||
fft_win_length: int
|
||||
"""Length of the FFT window in samples."""
|
||||
|
||||
fft_overlap: int
|
||||
"""Number of samples to overlap between FFT windows."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
spec_width: int
|
||||
"""Width of the spectrogram in pixels."""
|
||||
|
||||
resize_factor: int
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
device: torch.device
|
||||
"""Device to store the spectrogram on."""
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
params: SpectrogramParameters,
|
||||
return_np: bool = False,
|
||||
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
||||
"""Compute a spectrogram from an audio array.
|
||||
|
||||
Will pad the audio array so that it is evenly divisible by the
|
||||
downsampling factors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
|
||||
sampling_rate : int
|
||||
|
||||
params : SpectrogramParameters
|
||||
The parameters to use for generating the spectrogram.
|
||||
|
||||
return_np : bool, optional
|
||||
Whether to return the spectrogram as a numpy array as well as a
|
||||
torch tensor. The default is False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
duration : float
|
||||
The duration of the spectrgram in seconds.
|
||||
|
||||
spec : torch.Tensor
|
||||
The spectrogram as a torch tensor.
|
||||
|
||||
spec_np : np.ndarray, optional
|
||||
The spectrogram as a numpy array. Only returned if `return_np` is
|
||||
True, otherwise None.
|
||||
"""
|
||||
# pad audio so it is evenly divisible by downsampling factors
|
||||
duration = audio.shape[0] / float(sampling_rate)
|
||||
audio = au.pad_audio(
|
||||
@ -290,13 +537,21 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
||||
|
||||
# convert to pytorch
|
||||
spec = torch.from_numpy(spec).to(params["device"])
|
||||
|
||||
# add batch and channel dimensions
|
||||
spec = spec.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# resize the spec
|
||||
rs = params["resize_factor"]
|
||||
spec_op_shape = (int(params["spec_height"] * rs), int(spec.shape[-1] * rs))
|
||||
resize_factor = params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
int(params["spec_height"] * resize_factor),
|
||||
int(spec.shape[-1] * resize_factor),
|
||||
)
|
||||
spec = F.interpolate(
|
||||
spec, size=spec_op_shape, mode="bilinear", align_corners=False
|
||||
spec,
|
||||
size=spec_op_shape,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
if return_np:
|
||||
@ -307,135 +562,318 @@ def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
||||
return duration, spec, spec_np
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file,
|
||||
model,
|
||||
params,
|
||||
args,
|
||||
time_exp=None,
|
||||
top_n=5,
|
||||
return_raw_preds=False,
|
||||
max_duration=False,
|
||||
):
|
||||
def iterate_over_chunks(
|
||||
audio: np.ndarray,
|
||||
samplerate: int,
|
||||
chunk_size: float,
|
||||
) -> Iterator[Tuple[float, np.ndarray]]:
|
||||
"""Iterate over audio in chunks of size chunk_size.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
|
||||
samplerate : int
|
||||
|
||||
chunk_size : float
|
||||
Size of chunks in seconds.
|
||||
|
||||
Yields
|
||||
------
|
||||
chunk_start : float
|
||||
Start time of chunk in seconds.
|
||||
chunk : np.ndarray
|
||||
|
||||
"""
|
||||
nsamples = audio.shape[0]
|
||||
duration_full = nsamples / samplerate
|
||||
num_chunks = int(np.ceil(duration_full / chunk_size))
|
||||
for chunk_id in range(num_chunks):
|
||||
chunk_start = chunk_size * chunk_id
|
||||
chunk_length = int(samplerate * chunk_size)
|
||||
start_sample = chunk_id * chunk_length
|
||||
end_sample = np.minimum((chunk_id + 1) * chunk_length, nsamples)
|
||||
yield chunk_start, audio[start_sample:end_sample]
|
||||
|
||||
|
||||
class ProcessingConfiguration(TypedDict):
|
||||
"""Parameters for processing audio files."""
|
||||
|
||||
# audio parameters
|
||||
target_samp_rate: int
|
||||
"""Target sampling rate of the audio."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
fft_overlap: float
|
||||
"""Length of the FFT window in samples."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: float
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
scale_raw_audio: bool
|
||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||
|
||||
device: torch.device
|
||||
"""Device to run the model on."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Names of the classes the model can detect."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
time_expansion: Optional[float]
|
||||
"""Time expansion factor of the processed recordings."""
|
||||
|
||||
top_n: int
|
||||
"""Number of top detections to keep."""
|
||||
|
||||
return_raw_preds: bool
|
||||
"""Whether to return raw predictions."""
|
||||
|
||||
max_duration: Optional[float]
|
||||
"""Maximum duration of audio file to process in seconds."""
|
||||
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: float
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: float
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
quiet: bool
|
||||
"""Whether to suppress output."""
|
||||
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: int,
|
||||
model: torch.nn.Module,
|
||||
config: pp.NonMaximumSuppressionConfig,
|
||||
):
|
||||
"""Process a spectrogram with detection model.
|
||||
|
||||
Will run non-maximum suppression on the output of the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
|
||||
samplerate : int
|
||||
|
||||
model : torch.nn.Module
|
||||
Detection model.
|
||||
|
||||
config : pp.NonMaximumSuppressionConfig
|
||||
Parameters for non-maximum suppression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_nms : Dict[str, np.ndarray]
|
||||
features : Dict[str, np.ndarray]
|
||||
"""
|
||||
# evaluate model
|
||||
with torch.no_grad():
|
||||
outputs = model(spec, return_feats=config["cnn_features"])
|
||||
|
||||
# run non-max suppression
|
||||
pred_nms, features = pp.run_nms(
|
||||
outputs,
|
||||
config,
|
||||
np.array([float(samplerate)]),
|
||||
)
|
||||
|
||||
pred_nms = pred_nms[0]
|
||||
|
||||
# if we have a background class
|
||||
if pred_nms["class_probs"].shape[0] > len(config["class_names"]):
|
||||
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
|
||||
|
||||
return pred_nms, features
|
||||
|
||||
|
||||
def process_audio_array(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
model: torch.nn.Module,
|
||||
config: ProcessingConfiguration,
|
||||
):
|
||||
"""Process a single audio array with detection model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
|
||||
sampling_rate : int
|
||||
|
||||
model : torch.nn.Module
|
||||
Detection model.
|
||||
|
||||
config : ProcessingConfiguration
|
||||
Configuration for processing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_nms : Dict[str, np.ndarray]
|
||||
features : Dict[str, np.ndarray]
|
||||
spec_np : np.ndarray
|
||||
"""
|
||||
# load audio file and compute spectrogram
|
||||
_, spec, spec_np = compute_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
config,
|
||||
return_np=config["spec_features"] or config["spec_slices"],
|
||||
)
|
||||
|
||||
# process spectrogram with model
|
||||
pred_nms, features = process_spectrogram(
|
||||
spec,
|
||||
sampling_rate,
|
||||
model,
|
||||
config,
|
||||
)
|
||||
|
||||
return pred_nms, features, spec_np
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: torch.nn.Module,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Union[Results, Any]:
|
||||
"""Process a single audio file with detection model.
|
||||
|
||||
Will split the audio file into chunks if it is too long and
|
||||
process each chunk separately.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file : str
|
||||
Path to audio file.
|
||||
|
||||
model : torch.nn.Module
|
||||
Detection model.
|
||||
|
||||
config : ProcessingConfiguration
|
||||
Configuration for processing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
results : Results or Any
|
||||
Results of processing audio file with the given detection model.
|
||||
Will be a dictionary if `config["return_raw_preds"]` is `True`,
|
||||
"""
|
||||
# store temporary results here
|
||||
predictions = []
|
||||
spec_feats = []
|
||||
cnn_feats = []
|
||||
spec_slices = []
|
||||
|
||||
# get time expansion factor
|
||||
if time_exp is None:
|
||||
time_exp = args["time_expansion_factor"]
|
||||
|
||||
params["detection_threshold"] = args["detection_threshold"]
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio_file(
|
||||
audio_file,
|
||||
time_exp,
|
||||
params["target_samp_rate"],
|
||||
params["scale_raw_audio"],
|
||||
time_exp_fact=config["time_expansion"],
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
scale=config["scale_raw_audio"],
|
||||
max_duration=config["max_duration"],
|
||||
)
|
||||
|
||||
# clipping maximum duration
|
||||
if max_duration is not False:
|
||||
max_duration = np.minimum(
|
||||
int(sampling_rate * max_duration), audio_full.shape[0]
|
||||
)
|
||||
audio_full = audio_full[:max_duration]
|
||||
|
||||
duration_full = audio_full.shape[0] / float(sampling_rate)
|
||||
|
||||
return_np_spec = args["spec_features"] or args["spec_slices"]
|
||||
|
||||
# loop through larger file and split into chunks
|
||||
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
|
||||
num_chunks = int(np.ceil(duration_full / args["chunk_size"]))
|
||||
for chunk_id in range(num_chunks):
|
||||
|
||||
# chunk
|
||||
chunk_time = args["chunk_size"] * chunk_id
|
||||
chunk_length = int(sampling_rate * args["chunk_size"])
|
||||
start_sample = chunk_id * chunk_length
|
||||
end_sample = np.minimum(
|
||||
(chunk_id + 1) * chunk_length, audio_full.shape[0]
|
||||
)
|
||||
audio = audio_full[start_sample:end_sample]
|
||||
|
||||
# load audio file and compute spectrogram
|
||||
duration, spec, spec_np = compute_spectrogram(
|
||||
audio, sampling_rate, params, return_np_spec
|
||||
# TODO fix so that it overlaps correctly and takes care of
|
||||
# duplicate detections at borders
|
||||
for chunk_time, audio in iterate_over_chunks(
|
||||
audio_full,
|
||||
sampling_rate,
|
||||
config["chunk_size"],
|
||||
):
|
||||
# Run detection model on chunk
|
||||
pred_nms, features, spec_np = process_audio_array(
|
||||
audio,
|
||||
sampling_rate,
|
||||
model,
|
||||
config,
|
||||
)
|
||||
|
||||
# evaluate model
|
||||
with torch.no_grad():
|
||||
outputs = model(spec, return_feats=args["cnn_features"])
|
||||
|
||||
# run non-max suppression
|
||||
pred_nms, features = pp.run_nms(
|
||||
outputs, params, np.array([float(sampling_rate)])
|
||||
)
|
||||
pred_nms = pred_nms[0]
|
||||
# add chunk time to start and end times
|
||||
pred_nms["start_times"] += chunk_time
|
||||
pred_nms["end_times"] += chunk_time
|
||||
|
||||
# if we have a background class
|
||||
if pred_nms["class_probs"].shape[0] > len(params["class_names"]):
|
||||
pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
|
||||
|
||||
predictions.append(pred_nms)
|
||||
|
||||
# extract features - if there are any calls detected
|
||||
if pred_nms["det_probs"].shape[0] > 0:
|
||||
if args["spec_features"]:
|
||||
spec_feats.append(feats.get_feats(spec_np, pred_nms, params))
|
||||
if config["spec_features"]:
|
||||
spec_feats.append(feats.get_feats(spec_np, pred_nms, config))
|
||||
|
||||
if args["cnn_features"]:
|
||||
if config["cnn_features"]:
|
||||
cnn_feats.append(features[0])
|
||||
|
||||
if args["spec_slices"]:
|
||||
if config["spec_slices"]:
|
||||
spec_slices.extend(
|
||||
feats.extract_spec_slices(spec_np, pred_nms, params)
|
||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
||||
)
|
||||
|
||||
# convert the predictions into output dictionary
|
||||
file_id = os.path.basename(audio_file)
|
||||
predictions, spec_feats, cnn_feats, spec_slices = merge_results(
|
||||
predictions, spec_feats, cnn_feats, spec_slices
|
||||
)
|
||||
results = convert_results(
|
||||
file_id,
|
||||
time_exp,
|
||||
duration_full,
|
||||
params,
|
||||
# Merge results from chunks
|
||||
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
||||
predictions,
|
||||
spec_feats,
|
||||
cnn_feats,
|
||||
spec_slices,
|
||||
)
|
||||
|
||||
# summarize results
|
||||
if not args["quiet"]:
|
||||
num_detections = len(results["pred_dict"]["annotation"])
|
||||
print(
|
||||
"{}".format(num_detections)
|
||||
+ " call(s) detected above the threshold."
|
||||
# convert results to a dictionary in the right format
|
||||
results = convert_results(
|
||||
file_id=os.path.basename(audio_file),
|
||||
time_exp=config["time_expansion"],
|
||||
duration=audio_full.shape[0] / float(sampling_rate),
|
||||
params=config,
|
||||
predictions=predictions,
|
||||
spec_feats=spec_feats,
|
||||
cnn_feats=cnn_feats,
|
||||
spec_slices=spec_slices,
|
||||
)
|
||||
|
||||
# summarize results
|
||||
if not config["quiet"]:
|
||||
summarize_results(results, predictions, config)
|
||||
|
||||
if config["return_raw_preds"]:
|
||||
return predictions
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def summarize_results(results, predictions, config):
|
||||
"""Print summary of results."""
|
||||
num_detections = len(results["pred_dict"]["annotation"])
|
||||
print(f"{num_detections} call(s) detected above the threshold.")
|
||||
|
||||
# print results for top n classes
|
||||
if not args["quiet"] and (num_detections > 0):
|
||||
if num_detections > 0:
|
||||
class_overall = pp.overall_class_pred(
|
||||
predictions["det_probs"], predictions["class_probs"]
|
||||
predictions["det_probs"],
|
||||
predictions["class_probs"],
|
||||
)
|
||||
print("species name".ljust(30) + "probablity present")
|
||||
for cc in np.argsort(class_overall)[::-1][:top_n]:
|
||||
print(
|
||||
params["class_names"][cc].ljust(30)
|
||||
+ str(round(class_overall[cc], 3))
|
||||
)
|
||||
|
||||
if return_raw_preds:
|
||||
return predictions
|
||||
else:
|
||||
return results
|
||||
for class_index in np.argsort(class_overall)[::-1][: config["top_n"]]:
|
||||
print(
|
||||
config["class_names"][class_index].ljust(30)
|
||||
+ str(round(class_overall[class_index], 3))
|
||||
)
|
||||
|
@ -5,7 +5,6 @@ import bat_detect.utils.detector_utils as du
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
print("Loading model: " + args["model_path"])
|
||||
model, params = du.load_model(args["model_path"])
|
||||
|
||||
|
@ -136,8 +136,13 @@ if __name__ == "__main__":
|
||||
audio, sampling_rate, params_bd, True, False
|
||||
)
|
||||
|
||||
run_config = {
|
||||
**params_bd,
|
||||
**bd_args,
|
||||
}
|
||||
|
||||
# run model and filter detections so only keep ones in relevant time range
|
||||
results = du.process_file(args_cmd["audio_file"], model, params_bd, bd_args)
|
||||
results = du.process_file(args_cmd["audio_file"], model, run_config)
|
||||
pred_anns = filter_anns(
|
||||
results["pred_dict"]["annotation"],
|
||||
args_cmd["start_time"],
|
||||
|
@ -122,7 +122,12 @@ if __name__ == "__main__":
|
||||
print(" Loading model and running detector on entire file ...")
|
||||
model, det_params = du.load_model(args_cmd["model_path"])
|
||||
det_params["detection_threshold"] = args["detection_threshold"]
|
||||
results = du.process_file(audio_file, model, det_params, args)
|
||||
|
||||
run_config = {
|
||||
**det_params,
|
||||
**args,
|
||||
}
|
||||
results = du.process_file(audio_file, model, run_config)
|
||||
|
||||
print(" Processing detections and plotting ...")
|
||||
detections = []
|
||||
|
Loading…
Reference in New Issue
Block a user