Added an API file with tests to check basic functionality

This commit is contained in:
Santiago Martinez 2023-02-25 19:40:54 +00:00
parent 40222d8233
commit 0eecf54a94
15 changed files with 822 additions and 244 deletions

2
app.py
View File

@ -77,7 +77,7 @@ def make_prediction(file_name=None, detection_threshold=0.3):
def generate_results_image(audio_file, anns): def generate_results_image(audio_file, anns):
# load audio # load audio
sampling_rate, audio = au.load_audio_file( sampling_rate, audio = au.load_audio(
audio_file, audio_file,
args["time_expansion_factor"], args["time_expansion_factor"],
params["target_samp_rate"], params["target_samp_rate"],

215
bat_detect/api.py Normal file
View File

@ -0,0 +1,215 @@
from typing import List, Optional, Tuple
import numpy as np
import torch
import bat_detect.detector.models as md
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ
from bat_detect.utils.detector_utils import list_audio_files, load_model
# Use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
__all__ = [
"load_model",
"load_audio",
"list_audio_files",
"generate_spectrogram",
"get_config",
"process_file",
"process_spectrogram",
"process_audio",
]
def get_config(**kwargs) -> du.ProcessingConfiguration:
"""Get default processing configuration.
Can be used to override default parameters by passing keyword arguments.
"""
return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
def load_audio(
path: str,
time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray]:
"""Load audio from file.
Parameters
----------
path : str
Path to audio file.
time_exp_fact : float, optional
Time expansion factor, by default 1
target_samp_rate : int, optional
Target sample rate, by default 256000
scale : bool, optional
Scale audio to [-1, 1], by default False
max_duration : Optional[float], optional
Maximum duration of audio in seconds, by default None
Returns
-------
np.ndarray
Audio data.
int
Sample rate.
"""
return au.load_audio(
path,
time_exp_fact,
target_samp_rate,
scale,
max_duration,
)
def generate_spectrogram(
audio: np.ndarray,
samp_rate: int,
config: Optional[au.SpectrogramParameters] = None,
device: torch.device = DEVICE,
) -> torch.Tensor:
"""Generate spectrogram from audio array.
Parameters
----------
audio : np.ndarray
Audio data.
samp_rate : int
Sample rate.
config : Optional[SpectrogramParameters], optional
Spectrogram parameters, by default None (uses default parameters).
Returns
-------
torch.Tensor
Spectrogram.
"""
if config is None:
config = au.DEFAULT_SPECTROGRAM_PARAMETERS
_, spec, _ = du.compute_spectrogram(
audio,
samp_rate,
config,
return_np=False,
device=device,
)
return spec
def process_file(
audio_file: str,
model: md.DetectionModel,
config: Optional[du.ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
audio_file : str
Path to audio file.
model : DetectionModel
Detection model.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
"""
if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_file(
audio_file,
model,
config,
device,
)
def process_spectrogram(
spec: torch.Tensor,
samp_rate: int,
model: md.DetectionModel,
config: Optional[du.ProcessingConfiguration] = None,
) -> Tuple[List[du.Annotation], List[np.ndarray]]:
"""Process spectrogram with model.
Parameters
----------
spec : torch.Tensor
Spectrogram.
samp_rate : int
Sample rate of the audio from which the spectrogram was generated.
model : DetectionModel
Detection model.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
Returns
-------
DetectionResult
"""
if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_spectrogram(
spec,
samp_rate,
model,
config,
)
def process_audio(
audio: np.ndarray,
samp_rate: int,
model: md.DetectionModel,
config: Optional[du.ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]:
"""Process audio array with model.
Parameters
----------
audio : np.ndarray
Audio data.
samp_rate : int
Sample rate.
model : DetectionModel
Detection model.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
Returns
-------
annotations : List[Annotation]
List of predicted annotations.
features: List[np.ndarray]
List of extracted features for each annotation.
spec : torch.Tensor
Spectrogram of the audio used for prediction.
"""
if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_audio_array(
audio,
samp_rate,
model,
config,
device,
)

View File

@ -92,7 +92,7 @@ def main():
model, params = du.load_model(args["model_path"]) model, params = du.load_model(args["model_path"])
print("\nInput directory: " + args["audio_dir"]) print("\nInput directory: " + args["audio_dir"])
files = du.get_audio_files(args["audio_dir"]) files = du.list_audio_files(args["audio_dir"])
print(f"Number of audio files: {len(files)}") print(f"Number of audio files: {len(files)}")
print("\nSaving results to: " + args["ann_dir"]) print("\nSaving results to: " + args["ann_dir"])

View File

@ -1,9 +1,11 @@
from typing import NamedTuple, Optional
import torch import torch
import torch.fft import torch.fft
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from .model_helpers import ( from bat_detect.detector.model_helpers import (
ConvBlockDownCoordF, ConvBlockDownCoordF,
ConvBlockDownStandard, ConvBlockDownStandard,
ConvBlockUpF, ConvBlockUpF,
@ -11,13 +13,88 @@ from .model_helpers import (
SelfAttention, SelfAttention,
) )
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
__all__ = [ __all__ = [
"Net2DFast", "Net2DFast",
"Net2DFastNoAttn", "Net2DFastNoAttn",
"Net2DFastNoCoordConv", "Net2DFastNoCoordConv",
"ModelOutput",
"DetectionModel",
] ]
class ModelOutput(NamedTuple):
"""Output of the detection model."""
pred_det: torch.Tensor
"""Tensor with predict detection probabilities."""
pred_size: torch.Tensor
"""Tensor with predicted bounding box sizes."""
pred_class: torch.Tensor
"""Tensor with predicted class probabilities."""
pred_class_un_norm: torch.Tensor
"""Tensor with predicted class probabilities before softmax."""
pred_emb: Optional[torch.Tensor]
"""Tensor with embeddings."""
features: Optional[torch.Tensor]
"""Tensor with intermediate features."""
class DetectionModel(Protocol):
"""Protocol for detection models.
This protocol is used to define the interface for the detection models.
This allows us to use the same code for training and inference, even
though the models are different.
"""
num_classes: int
"""Number of classes the model can classify."""
emb_dim: int
"""Dimension of the embedding vector."""
num_filts: int
"""Number of filters in the model."""
resize_factor: float
"""Factor by which the input is resized."""
ip_height: int
"""Height of the input image."""
def forward(
self,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model.
When `return_feats` is `True`, the model should return the
intermediate features of the model.
"""
def __call__(
self,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model.
When `return_feats` is `True`, the model should return the
int
"""
class Net2DFast(nn.Module): class Net2DFast(nn.Module):
def __init__( def __init__(
self, self,
@ -27,7 +104,7 @@ class Net2DFast(nn.Module):
ip_height=128, ip_height=128,
resize_factor=0.5, resize_factor=0.5,
): ):
super(Net2DFast, self).__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.num_filts = num_filts self.num_filts = num_filts
@ -102,7 +179,7 @@ class Net2DFast(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, ip, return_feats=False): def forward(self, ip, return_feats=False) -> ModelOutput:
# encoder # encoder
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(ip)
@ -125,17 +202,14 @@ class Net2DFast(nn.Module):
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
op = {} return ModelOutput(
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) pred_size=F.relu(self.conv_size_op(x), inplace=True),
op["pred_class"] = comb pred_class=comb,
op["pred_class_un_norm"] = cls pred_class_un_norm=cls,
if self.emb_dim > 0: pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
op["pred_emb"] = self.conv_emb(x) features=x if return_feats else None,
if return_feats: )
op["features"] = x
return op
class Net2DFastNoAttn(nn.Module): class Net2DFastNoAttn(nn.Module):
@ -147,7 +221,7 @@ class Net2DFastNoAttn(nn.Module):
ip_height=128, ip_height=128,
resize_factor=0.5, resize_factor=0.5,
): ):
super(Net2DFastNoAttn, self).__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.emb_dim = emb_dim self.emb_dim = emb_dim
@ -219,8 +293,7 @@ class Net2DFastNoAttn(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, ip, return_feats=False): def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
@ -237,17 +310,14 @@ class Net2DFastNoAttn(nn.Module):
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
op = {} return ModelOutput(
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) pred_size=F.relu(self.conv_size_op(x), inplace=True),
op["pred_class"] = comb pred_class=comb,
op["pred_class_un_norm"] = cls pred_class_un_norm=cls,
if self.emb_dim > 0: pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
op["pred_emb"] = self.conv_emb(x) features=x if return_feats else None,
if return_feats: )
op["features"] = x
return op
class Net2DFastNoCoordConv(nn.Module): class Net2DFastNoCoordConv(nn.Module):
@ -259,7 +329,7 @@ class Net2DFastNoCoordConv(nn.Module):
ip_height=128, ip_height=128,
resize_factor=0.5, resize_factor=0.5,
): ):
super(Net2DFastNoCoordConv, self).__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.emb_dim = emb_dim self.emb_dim = emb_dim
@ -333,7 +403,7 @@ class Net2DFastNoCoordConv(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, ip, return_feats=False): def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(ip) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
@ -352,14 +422,11 @@ class Net2DFastNoCoordConv(nn.Module):
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
op = {} return ModelOutput(
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) pred_size=F.relu(self.conv_size_op(x), inplace=True),
op["pred_class"] = comb pred_class=comb,
op["pred_class_un_norm"] = cls pred_class_un_norm=cls,
if self.emb_dim > 0: pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
op["pred_emb"] = self.conv_emb(x) features=x if return_feats else None,
if return_feats: )
op["features"] = x
return op

View File

@ -5,6 +5,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from bat_detect.detector.models import ModelOutput
try: try:
from typing import TypedDict from typing import TypedDict
except ImportError: except ImportError:
@ -106,24 +108,8 @@ class PredictionResults(TypedDict):
"""Class probabilities.""" """Class probabilities."""
class ModelOutputs(TypedDict):
"""Outputs of the model."""
pred_det: torch.Tensor
"""Detection probabilities."""
pred_size: torch.Tensor
"""Box sizes."""
pred_class: Optional[torch.Tensor]
"""Class probabilities."""
features: Optional[torch.Tensor]
"""Features extracted by the model."""
def run_nms( def run_nms(
outputs: ModelOutputs, outputs: ModelOutput,
params: NonMaximumSuppressionConfig, params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray, sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]: ) -> Tuple[List[PredictionResults], List[np.ndarray]]:
@ -135,16 +121,14 @@ def run_nms(
the features. Each element of the lists corresponds to one the features. Each element of the lists corresponds to one
element of the batch. element of the batch.
""" """
pred_det, pred_size, pred_class, _, _, features = outputs
pred_det = outputs["pred_det"] # probability of box
pred_size = outputs["pred_size"] # box size
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"]) pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[ freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
-2 -2
] ]
# NOTE there will be small differences depending on which sampling rate is chosen # NOTE: there will be small differences depending on which sampling rate is chosen
# as we are choosing the same sampling rate for the entire batch # as we are choosing the same sampling rate for the entire batch
duration = x_coords_to_time( duration = x_coords_to_time(
pred_det.shape[-1], pred_det.shape[-1],
@ -172,10 +156,16 @@ def run_nms(
pred["x_pos"] = x_pos[num_detection, valid_inds] pred["x_pos"] = x_pos[num_detection, valid_inds]
pred["y_pos"] = y_pos[num_detection, valid_inds] pred["y_pos"] = y_pos[num_detection, valid_inds]
pred["bb_width"] = pred_size[ pred["bb_width"] = pred_size[
num_detection, 0, pred["y_pos"], pred["x_pos"] num_detection,
0,
pred["y_pos"],
pred["x_pos"],
] ]
pred["bb_height"] = pred_size[ pred["bb_height"] = pred_size[
num_detection, 1, pred["y_pos"], pred["x_pos"] num_detection,
1,
pred["y_pos"],
pred["x_pos"],
] ]
pred["start_times"] = x_coords_to_time( pred["start_times"] = x_coords_to_time(
pred["x_pos"].float() / params["resize_factor"], pred["x_pos"].float() / params["resize_factor"],
@ -198,7 +188,6 @@ def run_nms(
) )
# extract the per class votes # extract the per class votes
pred_class = outputs.get("pred_class")
if pred_class is not None: if pred_class is not None:
pred["class_probs"] = pred_class[ pred["class_probs"] = pred_class[
num_detection, num_detection,
@ -208,7 +197,6 @@ def run_nms(
] ]
# extract the model features # extract the model features
features = outputs.get("features")
if features is not None: if features is not None:
feat = features[ feat = features[
num_detection, num_detection,

View File

@ -373,7 +373,7 @@ class AudioLoader(torch.utils.data.Dataset):
index = np.random.randint(0, len(self.data_anns)) index = np.random.randint(0, len(self.data_anns))
audio_file = self.data_anns[index]["file_path"] audio_file = self.data_anns[index]["file_path"]
sampling_rate, audio_raw = au.load_audio_file( sampling_rate, audio_raw = au.load_audio(
audio_file, audio_file,
self.data_anns[index]["time_exp"], self.data_anns[index]["time_exp"],
self.params["target_samp_rate"], self.params["target_samp_rate"],

View File

@ -5,13 +5,87 @@ import librosa
import numpy as np import numpy as np
import torch import torch
from bat_detect.detector.parameters import (
DENOISE_SPEC_AVG,
DETECTION_THRESHOLD,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MAX_SCALE_SPEC,
MIN_FREQ_HZ,
NMS_KERNEL_SIZE,
NMS_TOP_K_PER_SEC,
RESIZE_FACTOR,
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
SPEC_SCALE,
)
from . import wavfile from . import wavfile
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
__all__ = [ __all__ = [
"load_audio_file", "load_audio",
"generate_spectrogram",
"pad_audio",
"SpectrogramParameters",
"DEFAULT_SPECTROGRAM_PARAMETERS",
] ]
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
max_freq: int
"""Maximum frequency to display in the spectrogram."""
min_freq: int
"""Minimum frequency to display in the spectrogram."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"spec_height": SPEC_HEIGHT,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap * nfft) noverlap = np.floor(fft_overlap * nfft)
@ -36,7 +110,10 @@ def generate_spectrogram(
# generate spectrogram # generate spectrogram
spec = gen_mag_spectrogram( spec = gen_mag_spectrogram(
audio, sampling_rate, params["fft_win_length"], params["fft_overlap"] audio,
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
) )
# crop to min/max freq # crop to min/max freq
@ -70,6 +147,7 @@ def generate_spectrogram(
spec = np.log1p(log_scaling * spec_cropped) spec = np.log1p(log_scaling * spec_cropped)
elif params["spec_scale"] == "pcen": elif params["spec_scale"] == "pcen":
spec = pcen(spec_cropped, sampling_rate) spec = pcen(spec_cropped, sampling_rate)
elif params["spec_scale"] == "none": elif params["spec_scale"] == "none":
pass pass
@ -109,13 +187,13 @@ def generate_spectrogram(
return spec, spec_for_viz return spec, spec_for_viz
def load_audio_file( def load_audio(
audio_file: str, audio_file: str,
time_exp_fact: float, time_exp_fact: float,
target_samp_rate: int, target_samp_rate: int,
scale: bool = False, scale: bool = False,
max_duration: Optional[float] = None, max_duration: Optional[float] = None,
): ) -> Tuple[int, np.ndarray]:
"""Load an audio file and resample it to the target sampling rate. """Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration. The audio is also scaled to [-1, 1] and clipped to the maximum duration.

View File

@ -43,19 +43,19 @@ DEFAULT_MODEL_PATH = os.path.join(
__all__ = [ __all__ = [
"load_model", "load_model",
"get_audio_files", "list_audio_files",
"get_default_config", "format_single_result",
"format_results",
"save_results_to_file", "save_results_to_file",
"iterate_over_chunks", "iterate_over_chunks",
"process_spectrogram", "process_spectrogram",
"process_audio_array", "process_audio_array",
"process_file", "process_file",
"DEFAULT_MODEL_PATH", "DEFAULT_MODEL_PATH",
"DEFAULT_PROCESSING_CONFIGURATIONS",
] ]
def get_audio_files(ip_dir: str) -> List[str]: def list_audio_files(ip_dir: str) -> List[str]:
"""Get all audio files in directory. """Get all audio files in directory.
Args: Args:
@ -98,13 +98,12 @@ class ModelParameters(TypedDict):
class_names: List[str] class_names: List[str]
"""Class names. The model is trained to detect these classes.""" """Class names. The model is trained to detect these classes."""
device: torch.device
def load_model( def load_model(
model_path: str = DEFAULT_MODEL_PATH, model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True, load_weights: bool = True,
) -> Tuple[torch.nn.Module, ModelParameters]: device: Optional[torch.device] = None,
) -> Tuple[models.DetectionModel, ModelParameters]:
"""Load model from file. """Load model from file.
Args: Args:
@ -120,6 +119,7 @@ def load_model(
""" """
# load model # load model
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
@ -128,9 +128,8 @@ def load_model(
net_params = torch.load(model_path, map_location=device) net_params = torch.load(model_path, map_location=device)
params = net_params["params"] params = net_params["params"]
params["device"] = device
model: torch.nn.Module model: models.DetectionModel
if params["model_name"] == "Net2DFast": if params["model_name"] == "Net2DFast":
model = models.Net2DFast( model = models.Net2DFast(
@ -162,7 +161,7 @@ def load_model(
if load_weights: if load_weights:
model.load_state_dict(net_params["state_dict"]) model.load_state_dict(net_params["state_dict"])
model = model.to(params["device"]) model = model.to(device)
model.eval() model.eval()
return model, params return model, params
@ -285,30 +284,11 @@ class ResultParams(TypedDict):
"""Class names.""" """Class names."""
def format_results( def get_annotations_from_preds(
file_id: str,
time_exp: float,
duration: float,
predictions, predictions,
class_names: List[str], class_names: List[str],
) -> FileAnnotations: ) -> List[Annotation]:
"""Format results into the format expected by the annotation tool. """Get list of annotations from predictions."""
Args:
file_id (str): File ID.
time_exp (float): Time expansion factor.
duration (float): Duration of audio file.
predictions (dict): Predictions.
Returns:
dict: Results in the format expected by the annotation tool.
"""
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
# Get the best class prediction probability and index for each detection # Get the best class prediction probability and index for each detection
class_prob_best = predictions["class_probs"].max(0) class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0) class_ind_best = predictions["class_probs"].argmax(0)
@ -344,6 +324,32 @@ def format_results(
predictions["det_probs"], predictions["det_probs"],
) )
] ]
return annotations
def format_single_result(
file_id: str,
time_exp: float,
duration: float,
predictions,
class_names: List[str],
) -> FileAnnotations:
"""Format results into the format expected by the annotation tool.
Args:
file_id (str): File ID.
time_exp (float): Time expansion factor.
duration (float): Duration of audio file.
predictions (dict): Predictions.
Returns:
dict: Results in the format expected by the annotation tool.
"""
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
return { return {
"id": file_id, "id": file_id,
@ -352,7 +358,7 @@ def format_results(
"notes": "Automatically generated.", "notes": "Automatically generated.",
"time_exp": time_exp, "time_exp": time_exp,
"duration": round(float(duration), 4), "duration": round(float(duration), 4),
"annotation": annotations, "annotation": get_annotations_from_preds(predictions, class_names),
"class_name": class_names[np.argmax(class_overall)], "class_name": class_names[np.argmax(class_overall)],
} }
@ -383,7 +389,7 @@ def convert_results(
dict: Dictionary with results. dict: Dictionary with results.
""" """
pred_dict = format_results( pred_dict = format_single_result(
file_id, file_id,
time_exp, time_exp,
duration, duration,
@ -490,47 +496,11 @@ def save_results_to_file(results, op_path: str) -> None:
json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True) json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True)
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
device: torch.device
"""Device to store the spectrogram on."""
max_freq: int
"""Maximum frequency to display in the spectrogram."""
min_freq: int
"""Minimum frequency to display in the spectrogram."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
def compute_spectrogram( def compute_spectrogram(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: int,
params: SpectrogramParameters, params: au.SpectrogramParameters,
device: torch.device,
return_np: bool = False, return_np: bool = False,
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]: ) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
"""Compute a spectrogram from an audio array. """Compute a spectrogram from an audio array.
@ -578,7 +548,7 @@ def compute_spectrogram(
spec, _ = au.generate_spectrogram(audio, sampling_rate, params) spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
# convert to pytorch # convert to pytorch
spec = torch.from_numpy(spec).to(params["device"]) spec = torch.from_numpy(spec).to(device)
# add batch and channel dimensions # add batch and channel dimensions
spec = spec.unsqueeze(0).unsqueeze(0) spec = spec.unsqueeze(0).unsqueeze(0)
@ -672,9 +642,6 @@ class ProcessingConfiguration(TypedDict):
scale_raw_audio: bool scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1.""" """Whether to scale the raw audio to be between -1 and 1."""
device: torch.device
"""Device to run the model on."""
class_names: List[str] class_names: List[str]
"""Names of the classes the model can detect.""" """Names of the classes the model can detect."""
@ -721,33 +688,12 @@ class ProcessingConfiguration(TypedDict):
"""Whether to return spectrogram slices.""" """Whether to return spectrogram slices."""
def process_spectrogram( def _process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samplerate: int, samplerate: int,
model: torch.nn.Module, model: models.DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
): ) -> Tuple[List[Annotation], List[np.ndarray]]:
"""Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model.
Parameters
----------
spec : torch.Tensor
samplerate : int
model : torch.nn.Module
Detection model.
config : pp.NonMaximumSuppressionConfig
Parameters for non-maximum suppression.
Returns
-------
pred_nms : Dict[str, np.ndarray]
features : Dict[str, np.ndarray]
"""
# evaluate model # evaluate model
with torch.no_grad(): with torch.no_grad():
outputs = model(spec, return_feats=config["cnn_features"]) outputs = model(spec, return_feats=config["cnn_features"])
@ -781,12 +727,96 @@ def process_spectrogram(
return pred_nms, features return pred_nms, features
def process_spectrogram(
spec: torch.Tensor,
samplerate: int,
model: models.DetectionModel,
config: ProcessingConfiguration,
) -> Tuple[List[Annotation], List[np.ndarray]]:
"""Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model.
Parameters
----------
spec : torch.Tensor
samplerate : int
model : torch.nn.Module
Detection model.
config : pp.NonMaximumSuppressionConfig
Parameters for non-maximum suppression.
Returns
-------
annotations : List[Annotation]
List of annotations predicted by the model.
features : List[np.ndarray]
List of CNN features associated with each annotation.
Is empty if `config["cnn_features"]` is False.
"""
pred_nms, features = _process_spectrogram(
spec,
samplerate,
model,
config,
)
annotations = get_annotations_from_preds(
pred_nms,
config["class_names"],
)
return annotations, features
def _process_audio_array(
audio: np.ndarray,
sampling_rate: int,
model: torch.nn.Module,
config: ProcessingConfiguration,
device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
# load audio file and compute spectrogram
_, spec, _ = compute_spectrogram(
audio,
sampling_rate,
{
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"spec_height": config["spec_height"],
"resize_factor": config["resize_factor"],
"spec_divide_factor": config["spec_divide_factor"],
"max_freq": config["max_freq"],
"min_freq": config["min_freq"],
"spec_scale": config["spec_scale"],
"denoise_spec_avg": config["denoise_spec_avg"],
"max_scale_spec": config["max_scale_spec"],
},
device,
return_np=False,
)
# process spectrogram with model
pred_nms, features = _process_spectrogram(
spec,
sampling_rate,
model,
config,
)
return pred_nms, features, spec
def process_audio_array( def process_audio_array(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: int,
model: torch.nn.Module, model: torch.nn.Module,
config: ProcessingConfiguration, config: ProcessingConfiguration,
): device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
"""Process a single audio array with detection model. """Process a single audio array with detection model.
Parameters Parameters
@ -801,47 +831,42 @@ def process_audio_array(
config : ProcessingConfiguration config : ProcessingConfiguration
Configuration for processing. Configuration for processing.
device : torch.device
Device to use for processing.
Returns Returns
------- -------
pred_nms : Dict[str, np.ndarray] annotations : List[Annotation]
features : Dict[str, np.ndarray] List of annotations predicted by the model.
spec_np : np.ndarray
"""
# load audio file and compute spectrogram
_, spec, spec_np = compute_spectrogram(
audio,
sampling_rate,
{
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"spec_height": config["spec_height"],
"resize_factor": config["resize_factor"],
"spec_divide_factor": config["spec_divide_factor"],
"device": config["device"],
"max_freq": config["max_freq"],
"min_freq": config["min_freq"],
"spec_scale": config["spec_scale"],
"denoise_spec_avg": config["denoise_spec_avg"],
"max_scale_spec": config["max_scale_spec"],
},
return_np=config["spec_features"] or config["spec_slices"],
)
# process spectrogram with model features : List[np.ndarray]
pred_nms, features = process_spectrogram( List of CNN features associated with each annotation.
spec,
spec : torch.Tensor
Spectrogram of the audio used as input.
"""
pred_nms, features, spec = _process_audio_array(
audio,
sampling_rate, sampling_rate,
model, model,
config, config,
device,
) )
return pred_nms, features, spec_np annotations = get_annotations_from_preds(
pred_nms,
config["class_names"],
)
return annotations, features, spec
def process_file( def process_file(
audio_file: str, audio_file: str,
model: torch.nn.Module, model: torch.nn.Module,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device,
) -> Union[RunResults, Any]: ) -> Union[RunResults, Any]:
"""Process a single audio file with detection model. """Process a single audio file with detection model.
@ -872,7 +897,7 @@ def process_file(
spec_slices = [] spec_slices = []
# load audio file # load audio file
sampling_rate, audio_full = au.load_audio_file( sampling_rate, audio_full = au.load_audio(
audio_file, audio_file,
time_exp_fact=config.get("time_expansion", 1) or 1, time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"], target_samp_rate=config["target_samp_rate"],
@ -881,7 +906,7 @@ def process_file(
) )
# loop through larger file and split into chunks # loop through larger file and split into chunks
# TODO fix so that it overlaps correctly and takes care of # TODO: fix so that it overlaps correctly and takes care of
# duplicate detections at borders # duplicate detections at borders
for chunk_time, audio in iterate_over_chunks( for chunk_time, audio in iterate_over_chunks(
audio_full, audio_full,
@ -889,11 +914,12 @@ def process_file(
config["chunk_size"], config["chunk_size"],
): ):
# Run detection model on chunk # Run detection model on chunk
pred_nms, features, spec_np = process_audio_array( pred_nms, features, spec_np = _process_audio_array(
audio, audio,
sampling_rate, sampling_rate,
model, model,
config, config,
device,
) )
# add chunk time to start and end times # add chunk time to start and end times
@ -965,11 +991,7 @@ def summarize_results(results, predictions, config):
) )
def get_default_config(**kwargs) -> ProcessingConfiguration: DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
"""Get default configuration for running detection model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD, "detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False, "spec_slices": False,
"chunk_size": 3, "chunk_size": 3,
@ -983,7 +1005,6 @@ def get_default_config(**kwargs) -> ProcessingConfiguration:
"spec_divide_factor": SPEC_DIVIDE_FACTOR, "spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT, "spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO, "scale_raw_audio": SCALE_RAW_AUDIO,
"device": device,
"class_names": [], "class_names": [],
"time_expansion": 1, "time_expansion": 1,
"top_n": 3, "top_n": 3,
@ -996,8 +1017,4 @@ def get_default_config(**kwargs) -> ProcessingConfiguration:
"spec_scale": SPEC_SCALE, "spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG, "denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC, "max_scale_spec": MAX_SCALE_SPEC,
} }
return {
**args,
**kwargs,
}

View File

@ -114,7 +114,7 @@ if __name__ == "__main__":
# load audio and crop # load audio and crop
print("\nProcessing: " + os.path.basename(args_cmd["audio_file"])) print("\nProcessing: " + os.path.basename(args_cmd["audio_file"]))
print("\nOutput directory: " + args_cmd["op_dir"]) print("\nOutput directory: " + args_cmd["op_dir"])
sampling_rate, audio = au.load_audio_file( sampling_rate, audio = au.load_audio(
args_cmd["audio_file"], args_cmd["audio_file"],
args_cmd["time_exp"], args_cmd["time_exp"],
params_bd["target_samp_rate"], params_bd["target_samp_rate"],

View File

@ -96,7 +96,7 @@ if __name__ == "__main__":
# load audio file # load audio file
print("\nProcessing: " + os.path.basename(audio_file)) print("\nProcessing: " + os.path.basename(audio_file))
print("\nOutput directory: " + op_dir) print("\nOutput directory: " + op_dir)
sampling_rate, audio = au.load_audio_file( sampling_rate, audio = au.load_audio(
audio_file, args["time_expansion_factor"], params["target_samp_rate"] audio_file, args["time_expansion_factor"], params["target_samp_rate"]
) )
audio = audio[ audio = audio[

View File

@ -72,7 +72,7 @@ def load_data(
sampling_rates = [] sampling_rates = []
file_names = [] file_names = []
for cur_file in anns: for cur_file in anns:
sampling_rate, audio_orig = au.load_audio_file( sampling_rate, audio_orig = au.load_audio(
cur_file["file_path"], cur_file["file_path"],
cur_file["time_exp"], cur_file["time_exp"],
params["target_samp_rate"], params["target_samp_rate"],

0
tests/__init__.py Normal file
View File

0
tests/test_api.py Normal file
View File

213
tests/test_bat_detect.py Normal file
View File

@ -0,0 +1,213 @@
"""Test bat detect module API."""
import os
from glob import glob
import numpy as np
import torch
from torch import nn
from bat_detect.api import (
generate_spectrogram,
get_config,
list_audio_files,
load_audio,
load_model,
process_audio,
process_file,
process_spectrogram,
)
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
def test_load_model_with_default_params():
"""Test loading model with default parameters."""
model, params = load_model()
assert model is not None
assert isinstance(model, nn.Module)
assert params is not None
assert isinstance(params, dict)
assert "model_name" in params
assert "num_filters" in params
assert "emb_dim" in params
assert "ip_height" in params
assert "resize_factor" in params
assert "class_names" in params
assert params["model_name"] == "Net2DFast"
assert params["num_filters"] == 128
assert params["emb_dim"] == 0
assert params["ip_height"] == 128
assert params["resize_factor"] == 0.5
assert len(params["class_names"]) == 17
def test_list_audio_files():
"""Test listing audio files."""
audio_files = list_audio_files(TEST_DATA_DIR)
assert len(audio_files) == 3
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
def test_load_audio():
"""Test loading audio."""
samplerate, audio = load_audio(TEST_DATA[0])
assert audio is not None
assert samplerate == 256000
assert isinstance(audio, np.ndarray)
assert audio.shape == (128000,)
def test_generate_spectrogram():
"""Test generating spectrogram."""
samplerate, audio = load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate)
assert spectrogram is not None
assert isinstance(spectrogram, torch.Tensor)
assert spectrogram.shape == (1, 1, 128, 512)
def test_get_default_config():
"""Test getting default configuration."""
config = get_config()
assert config is not None
assert isinstance(config, dict)
assert config["target_samp_rate"] == 256000
assert config["fft_win_length"] == 0.002
assert config["fft_overlap"] == 0.75
assert config["resize_factor"] == 0.5
assert config["spec_divide_factor"] == 32
assert config["spec_height"] == 256
assert config["spec_scale"] == "pcen"
assert config["denoise_spec_avg"] is True
assert config["max_scale_spec"] is False
assert config["scale_raw_audio"] is False
assert len(config["class_names"]) == 0
assert config["detection_threshold"] == 0.01
assert config["time_expansion"] == 1
assert config["top_n"] == 3
assert config["return_raw_preds"] is False
assert config["max_duration"] is None
assert config["nms_kernel_size"] == 9
assert config["max_freq"] == 120000
assert config["min_freq"] == 10000
assert config["nms_top_k_per_sec"] == 200
assert config["quiet"] is True
assert config["chunk_size"] == 3
assert config["cnn_features"] is False
assert config["spec_features"] is False
assert config["spec_slices"] is False
def test_process_file_with_model():
"""Test processing file with model."""
model, params = load_model()
config = get_config(**params)
predictions = process_file(TEST_DATA[0], model, config=config)
assert predictions is not None
assert isinstance(predictions, dict)
assert "pred_dict" in predictions
assert "spec_feats" in predictions
assert "spec_feat_names" in predictions
assert "cnn_feats" in predictions
assert "cnn_feat_names" in predictions
assert "spec_slices" in predictions
# By default will not return spectrogram features
assert predictions["spec_feats"] is None
assert predictions["spec_feat_names"] is None
assert predictions["cnn_feats"] is None
assert predictions["cnn_feat_names"] is None
assert predictions["spec_slices"] is None
# Check that predictions are returned
assert isinstance(predictions["pred_dict"], dict)
pred_dict = predictions["pred_dict"]
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
assert pred_dict["annotated"] is False
assert pred_dict["issues"] is False
assert pred_dict["notes"] == "Automatically generated."
assert pred_dict["time_exp"] == 1
assert pred_dict["duration"] == 0.5
assert pred_dict["class_name"] is not None
assert len(pred_dict["annotation"]) > 0
def test_process_spectrogram_with_model():
"""Test processing spectrogram with model."""
model, params = load_model()
config = get_config(**params)
samplerate, audio = load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate)
predictions, features = process_spectrogram(
spectrogram,
samplerate,
model,
config=config,
)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
# By default will not return cnn features
assert len(features) == 0
def test_process_audio_with_model():
"""Test processing audio with model."""
model, params = load_model()
config = get_config(**params)
samplerate, audio = load_audio(TEST_DATA[0])
predictions, features, spec = process_audio(
audio,
samplerate,
model,
config=config,
)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
# By default will not return cnn features
assert len(features) == 0
assert spec is not None
assert isinstance(spec, torch.Tensor)
assert spec.shape == (1, 1, 128, 512)

0
tests/test_cli.py Normal file
View File