Added an API file with tests to check basic functionality

This commit is contained in:
Santiago Martinez 2023-02-25 19:40:54 +00:00
parent 40222d8233
commit 0eecf54a94
15 changed files with 822 additions and 244 deletions

2
app.py
View File

@ -77,7 +77,7 @@ def make_prediction(file_name=None, detection_threshold=0.3):
def generate_results_image(audio_file, anns):
# load audio
sampling_rate, audio = au.load_audio_file(
sampling_rate, audio = au.load_audio(
audio_file,
args["time_expansion_factor"],
params["target_samp_rate"],

215
bat_detect/api.py Normal file
View File

@ -0,0 +1,215 @@
from typing import List, Optional, Tuple
import numpy as np
import torch
import bat_detect.detector.models as md
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ
from bat_detect.utils.detector_utils import list_audio_files, load_model
# Use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
__all__ = [
"load_model",
"load_audio",
"list_audio_files",
"generate_spectrogram",
"get_config",
"process_file",
"process_spectrogram",
"process_audio",
]
def get_config(**kwargs) -> du.ProcessingConfiguration:
"""Get default processing configuration.
Can be used to override default parameters by passing keyword arguments.
"""
return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
def load_audio(
path: str,
time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray]:
"""Load audio from file.
Parameters
----------
path : str
Path to audio file.
time_exp_fact : float, optional
Time expansion factor, by default 1
target_samp_rate : int, optional
Target sample rate, by default 256000
scale : bool, optional
Scale audio to [-1, 1], by default False
max_duration : Optional[float], optional
Maximum duration of audio in seconds, by default None
Returns
-------
np.ndarray
Audio data.
int
Sample rate.
"""
return au.load_audio(
path,
time_exp_fact,
target_samp_rate,
scale,
max_duration,
)
def generate_spectrogram(
audio: np.ndarray,
samp_rate: int,
config: Optional[au.SpectrogramParameters] = None,
device: torch.device = DEVICE,
) -> torch.Tensor:
"""Generate spectrogram from audio array.
Parameters
----------
audio : np.ndarray
Audio data.
samp_rate : int
Sample rate.
config : Optional[SpectrogramParameters], optional
Spectrogram parameters, by default None (uses default parameters).
Returns
-------
torch.Tensor
Spectrogram.
"""
if config is None:
config = au.DEFAULT_SPECTROGRAM_PARAMETERS
_, spec, _ = du.compute_spectrogram(
audio,
samp_rate,
config,
return_np=False,
device=device,
)
return spec
def process_file(
audio_file: str,
model: md.DetectionModel,
config: Optional[du.ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
audio_file : str
Path to audio file.
model : DetectionModel
Detection model.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
"""
if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_file(
audio_file,
model,
config,
device,
)
def process_spectrogram(
spec: torch.Tensor,
samp_rate: int,
model: md.DetectionModel,
config: Optional[du.ProcessingConfiguration] = None,
) -> Tuple[List[du.Annotation], List[np.ndarray]]:
"""Process spectrogram with model.
Parameters
----------
spec : torch.Tensor
Spectrogram.
samp_rate : int
Sample rate of the audio from which the spectrogram was generated.
model : DetectionModel
Detection model.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
Returns
-------
DetectionResult
"""
if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_spectrogram(
spec,
samp_rate,
model,
config,
)
def process_audio(
audio: np.ndarray,
samp_rate: int,
model: md.DetectionModel,
config: Optional[du.ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]:
"""Process audio array with model.
Parameters
----------
audio : np.ndarray
Audio data.
samp_rate : int
Sample rate.
model : DetectionModel
Detection model.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
Returns
-------
annotations : List[Annotation]
List of predicted annotations.
features: List[np.ndarray]
List of extracted features for each annotation.
spec : torch.Tensor
Spectrogram of the audio used for prediction.
"""
if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_audio_array(
audio,
samp_rate,
model,
config,
device,
)

View File

@ -92,7 +92,7 @@ def main():
model, params = du.load_model(args["model_path"])
print("\nInput directory: " + args["audio_dir"])
files = du.get_audio_files(args["audio_dir"])
files = du.list_audio_files(args["audio_dir"])
print(f"Number of audio files: {len(files)}")
print("\nSaving results to: " + args["ann_dir"])

View File

@ -1,9 +1,11 @@
from typing import NamedTuple, Optional
import torch
import torch.fft
import torch.nn.functional as F
from torch import nn
from .model_helpers import (
from bat_detect.detector.model_helpers import (
ConvBlockDownCoordF,
ConvBlockDownStandard,
ConvBlockUpF,
@ -11,13 +13,88 @@ from .model_helpers import (
SelfAttention,
)
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
__all__ = [
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
"ModelOutput",
"DetectionModel",
]
class ModelOutput(NamedTuple):
"""Output of the detection model."""
pred_det: torch.Tensor
"""Tensor with predict detection probabilities."""
pred_size: torch.Tensor
"""Tensor with predicted bounding box sizes."""
pred_class: torch.Tensor
"""Tensor with predicted class probabilities."""
pred_class_un_norm: torch.Tensor
"""Tensor with predicted class probabilities before softmax."""
pred_emb: Optional[torch.Tensor]
"""Tensor with embeddings."""
features: Optional[torch.Tensor]
"""Tensor with intermediate features."""
class DetectionModel(Protocol):
"""Protocol for detection models.
This protocol is used to define the interface for the detection models.
This allows us to use the same code for training and inference, even
though the models are different.
"""
num_classes: int
"""Number of classes the model can classify."""
emb_dim: int
"""Dimension of the embedding vector."""
num_filts: int
"""Number of filters in the model."""
resize_factor: float
"""Factor by which the input is resized."""
ip_height: int
"""Height of the input image."""
def forward(
self,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model.
When `return_feats` is `True`, the model should return the
intermediate features of the model.
"""
def __call__(
self,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model.
When `return_feats` is `True`, the model should return the
int
"""
class Net2DFast(nn.Module):
def __init__(
self,
@ -27,7 +104,7 @@ class Net2DFast(nn.Module):
ip_height=128,
resize_factor=0.5,
):
super(Net2DFast, self).__init__()
super().__init__()
self.num_classes = num_classes
self.emb_dim = emb_dim
self.num_filts = num_filts
@ -102,7 +179,7 @@ class Net2DFast(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0
)
def forward(self, ip, return_feats=False):
def forward(self, ip, return_feats=False) -> ModelOutput:
# encoder
x1 = self.conv_dn_0(ip)
@ -125,17 +202,14 @@ class Net2DFast(nn.Module):
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
op["pred_class"] = comb
op["pred_class_un_norm"] = cls
if self.emb_dim > 0:
op["pred_emb"] = self.conv_emb(x)
if return_feats:
op["features"] = x
return op
return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb,
pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x if return_feats else None,
)
class Net2DFastNoAttn(nn.Module):
@ -147,7 +221,7 @@ class Net2DFastNoAttn(nn.Module):
ip_height=128,
resize_factor=0.5,
):
super(Net2DFastNoAttn, self).__init__()
super().__init__()
self.num_classes = num_classes
self.emb_dim = emb_dim
@ -219,8 +293,7 @@ class Net2DFastNoAttn(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0
)
def forward(self, ip, return_feats=False):
def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
@ -237,17 +310,14 @@ class Net2DFastNoAttn(nn.Module):
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
op["pred_class"] = comb
op["pred_class_un_norm"] = cls
if self.emb_dim > 0:
op["pred_emb"] = self.conv_emb(x)
if return_feats:
op["features"] = x
return op
return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb,
pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x if return_feats else None,
)
class Net2DFastNoCoordConv(nn.Module):
@ -259,7 +329,7 @@ class Net2DFastNoCoordConv(nn.Module):
ip_height=128,
resize_factor=0.5,
):
super(Net2DFastNoCoordConv, self).__init__()
super().__init__()
self.num_classes = num_classes
self.emb_dim = emb_dim
@ -333,7 +403,7 @@ class Net2DFastNoCoordConv(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0
)
def forward(self, ip, return_feats=False):
def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1)
@ -352,14 +422,11 @@ class Net2DFastNoCoordConv(nn.Module):
cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
op["pred_class"] = comb
op["pred_class_un_norm"] = cls
if self.emb_dim > 0:
op["pred_emb"] = self.conv_emb(x)
if return_feats:
op["features"] = x
return op
return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb,
pred_class_un_norm=cls,
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
features=x if return_feats else None,
)

View File

@ -5,6 +5,8 @@ import numpy as np
import torch
from torch import nn
from bat_detect.detector.models import ModelOutput
try:
from typing import TypedDict
except ImportError:
@ -106,24 +108,8 @@ class PredictionResults(TypedDict):
"""Class probabilities."""
class ModelOutputs(TypedDict):
"""Outputs of the model."""
pred_det: torch.Tensor
"""Detection probabilities."""
pred_size: torch.Tensor
"""Box sizes."""
pred_class: Optional[torch.Tensor]
"""Class probabilities."""
features: Optional[torch.Tensor]
"""Features extracted by the model."""
def run_nms(
outputs: ModelOutputs,
outputs: ModelOutput,
params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
@ -135,16 +121,14 @@ def run_nms(
the features. Each element of the lists corresponds to one
element of the batch.
"""
pred_det = outputs["pred_det"] # probability of box
pred_size = outputs["pred_size"] # box size
pred_det, pred_size, pred_class, _, _, features = outputs
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
-2
]
# NOTE there will be small differences depending on which sampling rate is chosen
# NOTE: there will be small differences depending on which sampling rate is chosen
# as we are choosing the same sampling rate for the entire batch
duration = x_coords_to_time(
pred_det.shape[-1],
@ -172,10 +156,16 @@ def run_nms(
pred["x_pos"] = x_pos[num_detection, valid_inds]
pred["y_pos"] = y_pos[num_detection, valid_inds]
pred["bb_width"] = pred_size[
num_detection, 0, pred["y_pos"], pred["x_pos"]
num_detection,
0,
pred["y_pos"],
pred["x_pos"],
]
pred["bb_height"] = pred_size[
num_detection, 1, pred["y_pos"], pred["x_pos"]
num_detection,
1,
pred["y_pos"],
pred["x_pos"],
]
pred["start_times"] = x_coords_to_time(
pred["x_pos"].float() / params["resize_factor"],
@ -198,7 +188,6 @@ def run_nms(
)
# extract the per class votes
pred_class = outputs.get("pred_class")
if pred_class is not None:
pred["class_probs"] = pred_class[
num_detection,
@ -208,7 +197,6 @@ def run_nms(
]
# extract the model features
features = outputs.get("features")
if features is not None:
feat = features[
num_detection,

View File

@ -373,7 +373,7 @@ class AudioLoader(torch.utils.data.Dataset):
index = np.random.randint(0, len(self.data_anns))
audio_file = self.data_anns[index]["file_path"]
sampling_rate, audio_raw = au.load_audio_file(
sampling_rate, audio_raw = au.load_audio(
audio_file,
self.data_anns[index]["time_exp"],
self.params["target_samp_rate"],

View File

@ -5,13 +5,87 @@ import librosa
import numpy as np
import torch
from bat_detect.detector.parameters import (
DENOISE_SPEC_AVG,
DETECTION_THRESHOLD,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MAX_SCALE_SPEC,
MIN_FREQ_HZ,
NMS_KERNEL_SIZE,
NMS_TOP_K_PER_SEC,
RESIZE_FACTOR,
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
SPEC_SCALE,
)
from . import wavfile
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
__all__ = [
"load_audio_file",
"load_audio",
"generate_spectrogram",
"pad_audio",
"SpectrogramParameters",
"DEFAULT_SPECTROGRAM_PARAMETERS",
]
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
max_freq: int
"""Maximum frequency to display in the spectrogram."""
min_freq: int
"""Minimum frequency to display in the spectrogram."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"spec_height": SPEC_HEIGHT,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap * nfft)
@ -36,7 +110,10 @@ def generate_spectrogram(
# generate spectrogram
spec = gen_mag_spectrogram(
audio, sampling_rate, params["fft_win_length"], params["fft_overlap"]
audio,
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
# crop to min/max freq
@ -70,6 +147,7 @@ def generate_spectrogram(
spec = np.log1p(log_scaling * spec_cropped)
elif params["spec_scale"] == "pcen":
spec = pcen(spec_cropped, sampling_rate)
elif params["spec_scale"] == "none":
pass
@ -109,13 +187,13 @@ def generate_spectrogram(
return spec, spec_for_viz
def load_audio_file(
def load_audio(
audio_file: str,
time_exp_fact: float,
target_samp_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
):
) -> Tuple[int, np.ndarray]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.

View File

@ -43,19 +43,19 @@ DEFAULT_MODEL_PATH = os.path.join(
__all__ = [
"load_model",
"get_audio_files",
"get_default_config",
"format_results",
"list_audio_files",
"format_single_result",
"save_results_to_file",
"iterate_over_chunks",
"process_spectrogram",
"process_audio_array",
"process_file",
"DEFAULT_MODEL_PATH",
"DEFAULT_PROCESSING_CONFIGURATIONS",
]
def get_audio_files(ip_dir: str) -> List[str]:
def list_audio_files(ip_dir: str) -> List[str]:
"""Get all audio files in directory.
Args:
@ -98,13 +98,12 @@ class ModelParameters(TypedDict):
class_names: List[str]
"""Class names. The model is trained to detect these classes."""
device: torch.device
def load_model(
model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True,
) -> Tuple[torch.nn.Module, ModelParameters]:
device: Optional[torch.device] = None,
) -> Tuple[models.DetectionModel, ModelParameters]:
"""Load model from file.
Args:
@ -120,7 +119,8 @@ def load_model(
"""
# load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.")
@ -128,9 +128,8 @@ def load_model(
net_params = torch.load(model_path, map_location=device)
params = net_params["params"]
params["device"] = device
model: torch.nn.Module
model: models.DetectionModel
if params["model_name"] == "Net2DFast":
model = models.Net2DFast(
@ -162,7 +161,7 @@ def load_model(
if load_weights:
model.load_state_dict(net_params["state_dict"])
model = model.to(params["device"])
model = model.to(device)
model.eval()
return model, params
@ -285,30 +284,11 @@ class ResultParams(TypedDict):
"""Class names."""
def format_results(
file_id: str,
time_exp: float,
duration: float,
def get_annotations_from_preds(
predictions,
class_names: List[str],
) -> FileAnnotations:
"""Format results into the format expected by the annotation tool.
Args:
file_id (str): File ID.
time_exp (float): Time expansion factor.
duration (float): Duration of audio file.
predictions (dict): Predictions.
Returns:
dict: Results in the format expected by the annotation tool.
"""
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
) -> List[Annotation]:
"""Get list of annotations from predictions."""
# Get the best class prediction probability and index for each detection
class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0)
@ -344,6 +324,32 @@ def format_results(
predictions["det_probs"],
)
]
return annotations
def format_single_result(
file_id: str,
time_exp: float,
duration: float,
predictions,
class_names: List[str],
) -> FileAnnotations:
"""Format results into the format expected by the annotation tool.
Args:
file_id (str): File ID.
time_exp (float): Time expansion factor.
duration (float): Duration of audio file.
predictions (dict): Predictions.
Returns:
dict: Results in the format expected by the annotation tool.
"""
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
return {
"id": file_id,
@ -352,7 +358,7 @@ def format_results(
"notes": "Automatically generated.",
"time_exp": time_exp,
"duration": round(float(duration), 4),
"annotation": annotations,
"annotation": get_annotations_from_preds(predictions, class_names),
"class_name": class_names[np.argmax(class_overall)],
}
@ -383,7 +389,7 @@ def convert_results(
dict: Dictionary with results.
"""
pred_dict = format_results(
pred_dict = format_single_result(
file_id,
time_exp,
duration,
@ -490,47 +496,11 @@ def save_results_to_file(results, op_path: str) -> None:
json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True)
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
device: torch.device
"""Device to store the spectrogram on."""
max_freq: int
"""Maximum frequency to display in the spectrogram."""
min_freq: int
"""Minimum frequency to display in the spectrogram."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
def compute_spectrogram(
audio: np.ndarray,
sampling_rate: int,
params: SpectrogramParameters,
params: au.SpectrogramParameters,
device: torch.device,
return_np: bool = False,
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
"""Compute a spectrogram from an audio array.
@ -578,7 +548,7 @@ def compute_spectrogram(
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
# convert to pytorch
spec = torch.from_numpy(spec).to(params["device"])
spec = torch.from_numpy(spec).to(device)
# add batch and channel dimensions
spec = spec.unsqueeze(0).unsqueeze(0)
@ -672,9 +642,6 @@ class ProcessingConfiguration(TypedDict):
scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1."""
device: torch.device
"""Device to run the model on."""
class_names: List[str]
"""Names of the classes the model can detect."""
@ -721,33 +688,12 @@ class ProcessingConfiguration(TypedDict):
"""Whether to return spectrogram slices."""
def process_spectrogram(
def _process_spectrogram(
spec: torch.Tensor,
samplerate: int,
model: torch.nn.Module,
model: models.DetectionModel,
config: ProcessingConfiguration,
):
"""Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model.
Parameters
----------
spec : torch.Tensor
samplerate : int
model : torch.nn.Module
Detection model.
config : pp.NonMaximumSuppressionConfig
Parameters for non-maximum suppression.
Returns
-------
pred_nms : Dict[str, np.ndarray]
features : Dict[str, np.ndarray]
"""
) -> Tuple[List[Annotation], List[np.ndarray]]:
# evaluate model
with torch.no_grad():
outputs = model(spec, return_feats=config["cnn_features"])
@ -781,12 +727,96 @@ def process_spectrogram(
return pred_nms, features
def process_spectrogram(
spec: torch.Tensor,
samplerate: int,
model: models.DetectionModel,
config: ProcessingConfiguration,
) -> Tuple[List[Annotation], List[np.ndarray]]:
"""Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model.
Parameters
----------
spec : torch.Tensor
samplerate : int
model : torch.nn.Module
Detection model.
config : pp.NonMaximumSuppressionConfig
Parameters for non-maximum suppression.
Returns
-------
annotations : List[Annotation]
List of annotations predicted by the model.
features : List[np.ndarray]
List of CNN features associated with each annotation.
Is empty if `config["cnn_features"]` is False.
"""
pred_nms, features = _process_spectrogram(
spec,
samplerate,
model,
config,
)
annotations = get_annotations_from_preds(
pred_nms,
config["class_names"],
)
return annotations, features
def _process_audio_array(
audio: np.ndarray,
sampling_rate: int,
model: torch.nn.Module,
config: ProcessingConfiguration,
device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
# load audio file and compute spectrogram
_, spec, _ = compute_spectrogram(
audio,
sampling_rate,
{
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"spec_height": config["spec_height"],
"resize_factor": config["resize_factor"],
"spec_divide_factor": config["spec_divide_factor"],
"max_freq": config["max_freq"],
"min_freq": config["min_freq"],
"spec_scale": config["spec_scale"],
"denoise_spec_avg": config["denoise_spec_avg"],
"max_scale_spec": config["max_scale_spec"],
},
device,
return_np=False,
)
# process spectrogram with model
pred_nms, features = _process_spectrogram(
spec,
sampling_rate,
model,
config,
)
return pred_nms, features, spec
def process_audio_array(
audio: np.ndarray,
sampling_rate: int,
model: torch.nn.Module,
config: ProcessingConfiguration,
):
device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
"""Process a single audio array with detection model.
Parameters
@ -801,47 +831,42 @@ def process_audio_array(
config : ProcessingConfiguration
Configuration for processing.
device : torch.device
Device to use for processing.
Returns
-------
pred_nms : Dict[str, np.ndarray]
features : Dict[str, np.ndarray]
spec_np : np.ndarray
"""
# load audio file and compute spectrogram
_, spec, spec_np = compute_spectrogram(
audio,
sampling_rate,
{
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"spec_height": config["spec_height"],
"resize_factor": config["resize_factor"],
"spec_divide_factor": config["spec_divide_factor"],
"device": config["device"],
"max_freq": config["max_freq"],
"min_freq": config["min_freq"],
"spec_scale": config["spec_scale"],
"denoise_spec_avg": config["denoise_spec_avg"],
"max_scale_spec": config["max_scale_spec"],
},
return_np=config["spec_features"] or config["spec_slices"],
)
annotations : List[Annotation]
List of annotations predicted by the model.
# process spectrogram with model
pred_nms, features = process_spectrogram(
spec,
features : List[np.ndarray]
List of CNN features associated with each annotation.
spec : torch.Tensor
Spectrogram of the audio used as input.
"""
pred_nms, features, spec = _process_audio_array(
audio,
sampling_rate,
model,
config,
device,
)
return pred_nms, features, spec_np
annotations = get_annotations_from_preds(
pred_nms,
config["class_names"],
)
return annotations, features, spec
def process_file(
audio_file: str,
model: torch.nn.Module,
config: ProcessingConfiguration,
device: torch.device,
) -> Union[RunResults, Any]:
"""Process a single audio file with detection model.
@ -872,7 +897,7 @@ def process_file(
spec_slices = []
# load audio file
sampling_rate, audio_full = au.load_audio_file(
sampling_rate, audio_full = au.load_audio(
audio_file,
time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"],
@ -881,7 +906,7 @@ def process_file(
)
# loop through larger file and split into chunks
# TODO fix so that it overlaps correctly and takes care of
# TODO: fix so that it overlaps correctly and takes care of
# duplicate detections at borders
for chunk_time, audio in iterate_over_chunks(
audio_full,
@ -889,11 +914,12 @@ def process_file(
config["chunk_size"],
):
# Run detection model on chunk
pred_nms, features, spec_np = process_audio_array(
pred_nms, features, spec_np = _process_audio_array(
audio,
sampling_rate,
model,
config,
device,
)
# add chunk time to start and end times
@ -965,39 +991,30 @@ def summarize_results(results, predictions, config):
)
def get_default_config(**kwargs) -> ProcessingConfiguration:
"""Get default configuration for running detection model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False,
"chunk_size": 3,
"spec_features": False,
"cnn_features": False,
"quiet": True,
"target_samp_rate": TARGET_SAMPLERATE_HZ,
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO,
"device": device,
"class_names": [],
"time_expansion": 1,
"top_n": 3,
"return_raw_preds": False,
"max_duration": None,
"nms_kernel_size": NMS_KERNEL_SIZE,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
return {
**args,
**kwargs,
}
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False,
"chunk_size": 3,
"spec_features": False,
"cnn_features": False,
"quiet": True,
"target_samp_rate": TARGET_SAMPLERATE_HZ,
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO,
"class_names": [],
"time_expansion": 1,
"top_n": 3,
"return_raw_preds": False,
"max_duration": None,
"nms_kernel_size": NMS_KERNEL_SIZE,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}

View File

@ -114,7 +114,7 @@ if __name__ == "__main__":
# load audio and crop
print("\nProcessing: " + os.path.basename(args_cmd["audio_file"]))
print("\nOutput directory: " + args_cmd["op_dir"])
sampling_rate, audio = au.load_audio_file(
sampling_rate, audio = au.load_audio(
args_cmd["audio_file"],
args_cmd["time_exp"],
params_bd["target_samp_rate"],

View File

@ -96,7 +96,7 @@ if __name__ == "__main__":
# load audio file
print("\nProcessing: " + os.path.basename(audio_file))
print("\nOutput directory: " + op_dir)
sampling_rate, audio = au.load_audio_file(
sampling_rate, audio = au.load_audio(
audio_file, args["time_expansion_factor"], params["target_samp_rate"]
)
audio = audio[

View File

@ -72,7 +72,7 @@ def load_data(
sampling_rates = []
file_names = []
for cur_file in anns:
sampling_rate, audio_orig = au.load_audio_file(
sampling_rate, audio_orig = au.load_audio(
cur_file["file_path"],
cur_file["time_exp"],
params["target_samp_rate"],

0
tests/__init__.py Normal file
View File

0
tests/test_api.py Normal file
View File

213
tests/test_bat_detect.py Normal file
View File

@ -0,0 +1,213 @@
"""Test bat detect module API."""
import os
from glob import glob
import numpy as np
import torch
from torch import nn
from bat_detect.api import (
generate_spectrogram,
get_config,
list_audio_files,
load_audio,
load_model,
process_audio,
process_file,
process_spectrogram,
)
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
def test_load_model_with_default_params():
"""Test loading model with default parameters."""
model, params = load_model()
assert model is not None
assert isinstance(model, nn.Module)
assert params is not None
assert isinstance(params, dict)
assert "model_name" in params
assert "num_filters" in params
assert "emb_dim" in params
assert "ip_height" in params
assert "resize_factor" in params
assert "class_names" in params
assert params["model_name"] == "Net2DFast"
assert params["num_filters"] == 128
assert params["emb_dim"] == 0
assert params["ip_height"] == 128
assert params["resize_factor"] == 0.5
assert len(params["class_names"]) == 17
def test_list_audio_files():
"""Test listing audio files."""
audio_files = list_audio_files(TEST_DATA_DIR)
assert len(audio_files) == 3
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
def test_load_audio():
"""Test loading audio."""
samplerate, audio = load_audio(TEST_DATA[0])
assert audio is not None
assert samplerate == 256000
assert isinstance(audio, np.ndarray)
assert audio.shape == (128000,)
def test_generate_spectrogram():
"""Test generating spectrogram."""
samplerate, audio = load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate)
assert spectrogram is not None
assert isinstance(spectrogram, torch.Tensor)
assert spectrogram.shape == (1, 1, 128, 512)
def test_get_default_config():
"""Test getting default configuration."""
config = get_config()
assert config is not None
assert isinstance(config, dict)
assert config["target_samp_rate"] == 256000
assert config["fft_win_length"] == 0.002
assert config["fft_overlap"] == 0.75
assert config["resize_factor"] == 0.5
assert config["spec_divide_factor"] == 32
assert config["spec_height"] == 256
assert config["spec_scale"] == "pcen"
assert config["denoise_spec_avg"] is True
assert config["max_scale_spec"] is False
assert config["scale_raw_audio"] is False
assert len(config["class_names"]) == 0
assert config["detection_threshold"] == 0.01
assert config["time_expansion"] == 1
assert config["top_n"] == 3
assert config["return_raw_preds"] is False
assert config["max_duration"] is None
assert config["nms_kernel_size"] == 9
assert config["max_freq"] == 120000
assert config["min_freq"] == 10000
assert config["nms_top_k_per_sec"] == 200
assert config["quiet"] is True
assert config["chunk_size"] == 3
assert config["cnn_features"] is False
assert config["spec_features"] is False
assert config["spec_slices"] is False
def test_process_file_with_model():
"""Test processing file with model."""
model, params = load_model()
config = get_config(**params)
predictions = process_file(TEST_DATA[0], model, config=config)
assert predictions is not None
assert isinstance(predictions, dict)
assert "pred_dict" in predictions
assert "spec_feats" in predictions
assert "spec_feat_names" in predictions
assert "cnn_feats" in predictions
assert "cnn_feat_names" in predictions
assert "spec_slices" in predictions
# By default will not return spectrogram features
assert predictions["spec_feats"] is None
assert predictions["spec_feat_names"] is None
assert predictions["cnn_feats"] is None
assert predictions["cnn_feat_names"] is None
assert predictions["spec_slices"] is None
# Check that predictions are returned
assert isinstance(predictions["pred_dict"], dict)
pred_dict = predictions["pred_dict"]
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
assert pred_dict["annotated"] is False
assert pred_dict["issues"] is False
assert pred_dict["notes"] == "Automatically generated."
assert pred_dict["time_exp"] == 1
assert pred_dict["duration"] == 0.5
assert pred_dict["class_name"] is not None
assert len(pred_dict["annotation"]) > 0
def test_process_spectrogram_with_model():
"""Test processing spectrogram with model."""
model, params = load_model()
config = get_config(**params)
samplerate, audio = load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate)
predictions, features = process_spectrogram(
spectrogram,
samplerate,
model,
config=config,
)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
# By default will not return cnn features
assert len(features) == 0
def test_process_audio_with_model():
"""Test processing audio with model."""
model, params = load_model()
config = get_config(**params)
samplerate, audio = load_audio(TEST_DATA[0])
predictions, features, spec = process_audio(
audio,
samplerate,
model,
config=config,
)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
# By default will not return cnn features
assert len(features) == 0
assert spec is not None
assert isinstance(spec, torch.Tensor)
assert spec.shape == (1, 1, 128, 512)

0
tests/test_cli.py Normal file
View File