mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added an API file with tests to check basic functionality
This commit is contained in:
parent
40222d8233
commit
0eecf54a94
2
app.py
2
app.py
@ -77,7 +77,7 @@ def make_prediction(file_name=None, detection_threshold=0.3):
|
||||
def generate_results_image(audio_file, anns):
|
||||
|
||||
# load audio
|
||||
sampling_rate, audio = au.load_audio_file(
|
||||
sampling_rate, audio = au.load_audio(
|
||||
audio_file,
|
||||
args["time_expansion_factor"],
|
||||
params["target_samp_rate"],
|
||||
|
215
bat_detect/api.py
Normal file
215
bat_detect/api.py
Normal file
@ -0,0 +1,215 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import bat_detect.detector.models as md
|
||||
import bat_detect.utils.audio_utils as au
|
||||
import bat_detect.utils.detector_utils as du
|
||||
from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ
|
||||
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
||||
|
||||
# Use GPU if available
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"load_audio",
|
||||
"list_audio_files",
|
||||
"generate_spectrogram",
|
||||
"get_config",
|
||||
"process_file",
|
||||
"process_spectrogram",
|
||||
"process_audio",
|
||||
]
|
||||
|
||||
|
||||
def get_config(**kwargs) -> du.ProcessingConfiguration:
|
||||
"""Get default processing configuration.
|
||||
|
||||
Can be used to override default parameters by passing keyword arguments.
|
||||
"""
|
||||
return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
time_exp_fact: float = 1,
|
||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""Load audio from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path to audio file.
|
||||
time_exp_fact : float, optional
|
||||
Time expansion factor, by default 1
|
||||
target_samp_rate : int, optional
|
||||
Target sample rate, by default 256000
|
||||
scale : bool, optional
|
||||
Scale audio to [-1, 1], by default False
|
||||
max_duration : Optional[float], optional
|
||||
Maximum duration of audio in seconds, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio data.
|
||||
int
|
||||
Sample rate.
|
||||
"""
|
||||
return au.load_audio(
|
||||
path,
|
||||
time_exp_fact,
|
||||
target_samp_rate,
|
||||
scale,
|
||||
max_duration,
|
||||
)
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int,
|
||||
config: Optional[au.SpectrogramParameters] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> torch.Tensor:
|
||||
"""Generate spectrogram from audio array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
Audio data.
|
||||
samp_rate : int
|
||||
Sample rate.
|
||||
config : Optional[SpectrogramParameters], optional
|
||||
Spectrogram parameters, by default None (uses default parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram.
|
||||
"""
|
||||
if config is None:
|
||||
config = au.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
_, spec, _ = du.compute_spectrogram(
|
||||
audio,
|
||||
samp_rate,
|
||||
config,
|
||||
return_np=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: md.DetectionModel,
|
||||
config: Optional[du.ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file : str
|
||||
Path to audio file.
|
||||
model : DetectionModel
|
||||
Detection model.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
"""
|
||||
if config is None:
|
||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
|
||||
return du.process_file(
|
||||
audio_file,
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samp_rate: int,
|
||||
model: md.DetectionModel,
|
||||
config: Optional[du.ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[du.Annotation], List[np.ndarray]]:
|
||||
"""Process spectrogram with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Spectrogram.
|
||||
samp_rate : int
|
||||
Sample rate of the audio from which the spectrogram was generated.
|
||||
model : DetectionModel
|
||||
Detection model.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionResult
|
||||
"""
|
||||
if config is None:
|
||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
|
||||
return du.process_spectrogram(
|
||||
spec,
|
||||
samp_rate,
|
||||
model,
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
def process_audio(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int,
|
||||
model: md.DetectionModel,
|
||||
config: Optional[du.ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]:
|
||||
"""Process audio array with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
Audio data.
|
||||
samp_rate : int
|
||||
Sample rate.
|
||||
model : DetectionModel
|
||||
Detection model.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
List of predicted annotations.
|
||||
|
||||
features: List[np.ndarray]
|
||||
List of extracted features for each annotation.
|
||||
|
||||
spec : torch.Tensor
|
||||
Spectrogram of the audio used for prediction.
|
||||
"""
|
||||
if config is None:
|
||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
|
||||
return du.process_audio_array(
|
||||
audio,
|
||||
samp_rate,
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
)
|
@ -92,7 +92,7 @@ def main():
|
||||
model, params = du.load_model(args["model_path"])
|
||||
|
||||
print("\nInput directory: " + args["audio_dir"])
|
||||
files = du.get_audio_files(args["audio_dir"])
|
||||
files = du.list_audio_files(args["audio_dir"])
|
||||
|
||||
print(f"Number of audio files: {len(files)}")
|
||||
print("\nSaving results to: " + args["ann_dir"])
|
@ -1,9 +1,11 @@
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .model_helpers import (
|
||||
from bat_detect.detector.model_helpers import (
|
||||
ConvBlockDownCoordF,
|
||||
ConvBlockDownStandard,
|
||||
ConvBlockUpF,
|
||||
@ -11,13 +13,88 @@ from .model_helpers import (
|
||||
SelfAttention,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"ModelOutput",
|
||||
"DetectionModel",
|
||||
]
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Output of the detection model."""
|
||||
|
||||
pred_det: torch.Tensor
|
||||
"""Tensor with predict detection probabilities."""
|
||||
|
||||
pred_size: torch.Tensor
|
||||
"""Tensor with predicted bounding box sizes."""
|
||||
|
||||
pred_class: torch.Tensor
|
||||
"""Tensor with predicted class probabilities."""
|
||||
|
||||
pred_class_un_norm: torch.Tensor
|
||||
"""Tensor with predicted class probabilities before softmax."""
|
||||
|
||||
pred_emb: Optional[torch.Tensor]
|
||||
"""Tensor with embeddings."""
|
||||
|
||||
features: Optional[torch.Tensor]
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class DetectionModel(Protocol):
|
||||
"""Protocol for detection models.
|
||||
|
||||
This protocol is used to define the interface for the detection models.
|
||||
This allows us to use the same code for training and inference, even
|
||||
though the models are different.
|
||||
"""
|
||||
|
||||
num_classes: int
|
||||
"""Number of classes the model can classify."""
|
||||
|
||||
emb_dim: int
|
||||
"""Dimension of the embedding vector."""
|
||||
|
||||
num_filts: int
|
||||
"""Number of filters in the model."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input is resized."""
|
||||
|
||||
ip_height: int
|
||||
"""Height of the input image."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model.
|
||||
|
||||
When `return_feats` is `True`, the model should return the
|
||||
intermediate features of the model.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model.
|
||||
|
||||
When `return_feats` is `True`, the model should return the
|
||||
int
|
||||
"""
|
||||
|
||||
|
||||
class Net2DFast(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -27,7 +104,7 @@ class Net2DFast(nn.Module):
|
||||
ip_height=128,
|
||||
resize_factor=0.5,
|
||||
):
|
||||
super(Net2DFast, self).__init__()
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.emb_dim = emb_dim
|
||||
self.num_filts = num_filts
|
||||
@ -102,7 +179,7 @@ class Net2DFast(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, ip, return_feats=False):
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
|
||||
# encoder
|
||||
x1 = self.conv_dn_0(ip)
|
||||
@ -125,17 +202,14 @@ class Net2DFast(nn.Module):
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
op = {}
|
||||
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
|
||||
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
|
||||
op["pred_class"] = comb
|
||||
op["pred_class_un_norm"] = cls
|
||||
if self.emb_dim > 0:
|
||||
op["pred_emb"] = self.conv_emb(x)
|
||||
if return_feats:
|
||||
op["features"] = x
|
||||
|
||||
return op
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||
features=x if return_feats else None,
|
||||
)
|
||||
|
||||
|
||||
class Net2DFastNoAttn(nn.Module):
|
||||
@ -147,7 +221,7 @@ class Net2DFastNoAttn(nn.Module):
|
||||
ip_height=128,
|
||||
resize_factor=0.5,
|
||||
):
|
||||
super(Net2DFastNoAttn, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.emb_dim = emb_dim
|
||||
@ -219,8 +293,7 @@ class Net2DFastNoAttn(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, ip, return_feats=False):
|
||||
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(ip)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
@ -237,17 +310,14 @@ class Net2DFastNoAttn(nn.Module):
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
op = {}
|
||||
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
|
||||
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
|
||||
op["pred_class"] = comb
|
||||
op["pred_class_un_norm"] = cls
|
||||
if self.emb_dim > 0:
|
||||
op["pred_emb"] = self.conv_emb(x)
|
||||
if return_feats:
|
||||
op["features"] = x
|
||||
|
||||
return op
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||
features=x if return_feats else None,
|
||||
)
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(nn.Module):
|
||||
@ -259,7 +329,7 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
ip_height=128,
|
||||
resize_factor=0.5,
|
||||
):
|
||||
super(Net2DFastNoCoordConv, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.emb_dim = emb_dim
|
||||
@ -333,7 +403,7 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, ip, return_feats=False):
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
|
||||
x1 = self.conv_dn_0(ip)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
@ -352,14 +422,11 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
op = {}
|
||||
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
|
||||
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
|
||||
op["pred_class"] = comb
|
||||
op["pred_class_un_norm"] = cls
|
||||
if self.emb_dim > 0:
|
||||
op["pred_emb"] = self.conv_emb(x)
|
||||
if return_feats:
|
||||
op["features"] = x
|
||||
|
||||
return op
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||
features=x if return_feats else None,
|
||||
)
|
||||
|
@ -5,6 +5,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from bat_detect.detector.models import ModelOutput
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
@ -106,24 +108,8 @@ class PredictionResults(TypedDict):
|
||||
"""Class probabilities."""
|
||||
|
||||
|
||||
class ModelOutputs(TypedDict):
|
||||
"""Outputs of the model."""
|
||||
|
||||
pred_det: torch.Tensor
|
||||
"""Detection probabilities."""
|
||||
|
||||
pred_size: torch.Tensor
|
||||
"""Box sizes."""
|
||||
|
||||
pred_class: Optional[torch.Tensor]
|
||||
"""Class probabilities."""
|
||||
|
||||
features: Optional[torch.Tensor]
|
||||
"""Features extracted by the model."""
|
||||
|
||||
|
||||
def run_nms(
|
||||
outputs: ModelOutputs,
|
||||
outputs: ModelOutput,
|
||||
params: NonMaximumSuppressionConfig,
|
||||
sampling_rate: np.ndarray,
|
||||
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
||||
@ -135,16 +121,14 @@ def run_nms(
|
||||
the features. Each element of the lists corresponds to one
|
||||
element of the batch.
|
||||
"""
|
||||
|
||||
pred_det = outputs["pred_det"] # probability of box
|
||||
pred_size = outputs["pred_size"] # box size
|
||||
pred_det, pred_size, pred_class, _, _, features = outputs
|
||||
|
||||
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
|
||||
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
|
||||
-2
|
||||
]
|
||||
|
||||
# NOTE there will be small differences depending on which sampling rate is chosen
|
||||
# NOTE: there will be small differences depending on which sampling rate is chosen
|
||||
# as we are choosing the same sampling rate for the entire batch
|
||||
duration = x_coords_to_time(
|
||||
pred_det.shape[-1],
|
||||
@ -172,10 +156,16 @@ def run_nms(
|
||||
pred["x_pos"] = x_pos[num_detection, valid_inds]
|
||||
pred["y_pos"] = y_pos[num_detection, valid_inds]
|
||||
pred["bb_width"] = pred_size[
|
||||
num_detection, 0, pred["y_pos"], pred["x_pos"]
|
||||
num_detection,
|
||||
0,
|
||||
pred["y_pos"],
|
||||
pred["x_pos"],
|
||||
]
|
||||
pred["bb_height"] = pred_size[
|
||||
num_detection, 1, pred["y_pos"], pred["x_pos"]
|
||||
num_detection,
|
||||
1,
|
||||
pred["y_pos"],
|
||||
pred["x_pos"],
|
||||
]
|
||||
pred["start_times"] = x_coords_to_time(
|
||||
pred["x_pos"].float() / params["resize_factor"],
|
||||
@ -198,7 +188,6 @@ def run_nms(
|
||||
)
|
||||
|
||||
# extract the per class votes
|
||||
pred_class = outputs.get("pred_class")
|
||||
if pred_class is not None:
|
||||
pred["class_probs"] = pred_class[
|
||||
num_detection,
|
||||
@ -208,7 +197,6 @@ def run_nms(
|
||||
]
|
||||
|
||||
# extract the model features
|
||||
features = outputs.get("features")
|
||||
if features is not None:
|
||||
feat = features[
|
||||
num_detection,
|
||||
|
@ -373,7 +373,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
index = np.random.randint(0, len(self.data_anns))
|
||||
|
||||
audio_file = self.data_anns[index]["file_path"]
|
||||
sampling_rate, audio_raw = au.load_audio_file(
|
||||
sampling_rate, audio_raw = au.load_audio(
|
||||
audio_file,
|
||||
self.data_anns[index]["time_exp"],
|
||||
self.params["target_samp_rate"],
|
||||
|
@ -5,13 +5,87 @@ import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from bat_detect.detector.parameters import (
|
||||
DENOISE_SPEC_AVG,
|
||||
DETECTION_THRESHOLD,
|
||||
FFT_OVERLAP,
|
||||
FFT_WIN_LENGTH_S,
|
||||
MAX_FREQ_HZ,
|
||||
MAX_SCALE_SPEC,
|
||||
MIN_FREQ_HZ,
|
||||
NMS_KERNEL_SIZE,
|
||||
NMS_TOP_K_PER_SEC,
|
||||
RESIZE_FACTOR,
|
||||
SCALE_RAW_AUDIO,
|
||||
SPEC_DIVIDE_FACTOR,
|
||||
SPEC_HEIGHT,
|
||||
SPEC_SCALE,
|
||||
)
|
||||
|
||||
from . import wavfile
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
__all__ = [
|
||||
"load_audio_file",
|
||||
"load_audio",
|
||||
"generate_spectrogram",
|
||||
"pad_audio",
|
||||
"SpectrogramParameters",
|
||||
"DEFAULT_SPECTROGRAM_PARAMETERS",
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Percentage of overlap between FFT windows."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to display in the spectrogram."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to display in the spectrogram."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
|
||||
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
||||
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
@ -36,7 +110,10 @@ def generate_spectrogram(
|
||||
|
||||
# generate spectrogram
|
||||
spec = gen_mag_spectrogram(
|
||||
audio, sampling_rate, params["fft_win_length"], params["fft_overlap"]
|
||||
audio,
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
|
||||
# crop to min/max freq
|
||||
@ -70,6 +147,7 @@ def generate_spectrogram(
|
||||
spec = np.log1p(log_scaling * spec_cropped)
|
||||
elif params["spec_scale"] == "pcen":
|
||||
spec = pcen(spec_cropped, sampling_rate)
|
||||
|
||||
elif params["spec_scale"] == "none":
|
||||
pass
|
||||
|
||||
@ -109,13 +187,13 @@ def generate_spectrogram(
|
||||
return spec, spec_for_viz
|
||||
|
||||
|
||||
def load_audio_file(
|
||||
def load_audio(
|
||||
audio_file: str,
|
||||
time_exp_fact: float,
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
):
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
|
@ -43,19 +43,19 @@ DEFAULT_MODEL_PATH = os.path.join(
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"get_audio_files",
|
||||
"get_default_config",
|
||||
"format_results",
|
||||
"list_audio_files",
|
||||
"format_single_result",
|
||||
"save_results_to_file",
|
||||
"iterate_over_chunks",
|
||||
"process_spectrogram",
|
||||
"process_audio_array",
|
||||
"process_file",
|
||||
"DEFAULT_MODEL_PATH",
|
||||
"DEFAULT_PROCESSING_CONFIGURATIONS",
|
||||
]
|
||||
|
||||
|
||||
def get_audio_files(ip_dir: str) -> List[str]:
|
||||
def list_audio_files(ip_dir: str) -> List[str]:
|
||||
"""Get all audio files in directory.
|
||||
|
||||
Args:
|
||||
@ -98,13 +98,12 @@ class ModelParameters(TypedDict):
|
||||
class_names: List[str]
|
||||
"""Class names. The model is trained to detect these classes."""
|
||||
|
||||
device: torch.device
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
) -> Tuple[torch.nn.Module, ModelParameters]:
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[models.DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
@ -120,7 +119,8 @@ def load_model(
|
||||
"""
|
||||
|
||||
# load model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError("Model file not found.")
|
||||
@ -128,9 +128,8 @@ def load_model(
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
|
||||
params = net_params["params"]
|
||||
params["device"] = device
|
||||
|
||||
model: torch.nn.Module
|
||||
model: models.DetectionModel
|
||||
|
||||
if params["model_name"] == "Net2DFast":
|
||||
model = models.Net2DFast(
|
||||
@ -162,7 +161,7 @@ def load_model(
|
||||
if load_weights:
|
||||
model.load_state_dict(net_params["state_dict"])
|
||||
|
||||
model = model.to(params["device"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
return model, params
|
||||
@ -285,30 +284,11 @@ class ResultParams(TypedDict):
|
||||
"""Class names."""
|
||||
|
||||
|
||||
def format_results(
|
||||
file_id: str,
|
||||
time_exp: float,
|
||||
duration: float,
|
||||
def get_annotations_from_preds(
|
||||
predictions,
|
||||
class_names: List[str],
|
||||
) -> FileAnnotations:
|
||||
"""Format results into the format expected by the annotation tool.
|
||||
|
||||
Args:
|
||||
file_id (str): File ID.
|
||||
time_exp (float): Time expansion factor.
|
||||
duration (float): Duration of audio file.
|
||||
predictions (dict): Predictions.
|
||||
|
||||
Returns:
|
||||
dict: Results in the format expected by the annotation tool.
|
||||
"""
|
||||
# Get a single class prediction for the file
|
||||
class_overall = pp.overall_class_pred(
|
||||
predictions["det_probs"],
|
||||
predictions["class_probs"],
|
||||
)
|
||||
|
||||
) -> List[Annotation]:
|
||||
"""Get list of annotations from predictions."""
|
||||
# Get the best class prediction probability and index for each detection
|
||||
class_prob_best = predictions["class_probs"].max(0)
|
||||
class_ind_best = predictions["class_probs"].argmax(0)
|
||||
@ -344,6 +324,32 @@ def format_results(
|
||||
predictions["det_probs"],
|
||||
)
|
||||
]
|
||||
return annotations
|
||||
|
||||
|
||||
def format_single_result(
|
||||
file_id: str,
|
||||
time_exp: float,
|
||||
duration: float,
|
||||
predictions,
|
||||
class_names: List[str],
|
||||
) -> FileAnnotations:
|
||||
"""Format results into the format expected by the annotation tool.
|
||||
|
||||
Args:
|
||||
file_id (str): File ID.
|
||||
time_exp (float): Time expansion factor.
|
||||
duration (float): Duration of audio file.
|
||||
predictions (dict): Predictions.
|
||||
|
||||
Returns:
|
||||
dict: Results in the format expected by the annotation tool.
|
||||
"""
|
||||
# Get a single class prediction for the file
|
||||
class_overall = pp.overall_class_pred(
|
||||
predictions["det_probs"],
|
||||
predictions["class_probs"],
|
||||
)
|
||||
|
||||
return {
|
||||
"id": file_id,
|
||||
@ -352,7 +358,7 @@ def format_results(
|
||||
"notes": "Automatically generated.",
|
||||
"time_exp": time_exp,
|
||||
"duration": round(float(duration), 4),
|
||||
"annotation": annotations,
|
||||
"annotation": get_annotations_from_preds(predictions, class_names),
|
||||
"class_name": class_names[np.argmax(class_overall)],
|
||||
}
|
||||
|
||||
@ -383,7 +389,7 @@ def convert_results(
|
||||
dict: Dictionary with results.
|
||||
|
||||
"""
|
||||
pred_dict = format_results(
|
||||
pred_dict = format_single_result(
|
||||
file_id,
|
||||
time_exp,
|
||||
duration,
|
||||
@ -490,47 +496,11 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True)
|
||||
|
||||
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Percentage of overlap between FFT windows."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
device: torch.device
|
||||
"""Device to store the spectrogram on."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to display in the spectrogram."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to display in the spectrogram."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
params: SpectrogramParameters,
|
||||
params: au.SpectrogramParameters,
|
||||
device: torch.device,
|
||||
return_np: bool = False,
|
||||
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
||||
"""Compute a spectrogram from an audio array.
|
||||
@ -578,7 +548,7 @@ def compute_spectrogram(
|
||||
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
||||
|
||||
# convert to pytorch
|
||||
spec = torch.from_numpy(spec).to(params["device"])
|
||||
spec = torch.from_numpy(spec).to(device)
|
||||
|
||||
# add batch and channel dimensions
|
||||
spec = spec.unsqueeze(0).unsqueeze(0)
|
||||
@ -672,9 +642,6 @@ class ProcessingConfiguration(TypedDict):
|
||||
scale_raw_audio: bool
|
||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||
|
||||
device: torch.device
|
||||
"""Device to run the model on."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Names of the classes the model can detect."""
|
||||
|
||||
@ -721,33 +688,12 @@ class ProcessingConfiguration(TypedDict):
|
||||
"""Whether to return spectrogram slices."""
|
||||
|
||||
|
||||
def process_spectrogram(
|
||||
def _process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: int,
|
||||
model: torch.nn.Module,
|
||||
model: models.DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
):
|
||||
"""Process a spectrogram with detection model.
|
||||
|
||||
Will run non-maximum suppression on the output of the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
|
||||
samplerate : int
|
||||
|
||||
model : torch.nn.Module
|
||||
Detection model.
|
||||
|
||||
config : pp.NonMaximumSuppressionConfig
|
||||
Parameters for non-maximum suppression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_nms : Dict[str, np.ndarray]
|
||||
features : Dict[str, np.ndarray]
|
||||
"""
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
# evaluate model
|
||||
with torch.no_grad():
|
||||
outputs = model(spec, return_feats=config["cnn_features"])
|
||||
@ -781,12 +727,96 @@ def process_spectrogram(
|
||||
return pred_nms, features
|
||||
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: int,
|
||||
model: models.DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
"""Process a spectrogram with detection model.
|
||||
|
||||
Will run non-maximum suppression on the output of the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
|
||||
samplerate : int
|
||||
|
||||
model : torch.nn.Module
|
||||
Detection model.
|
||||
|
||||
config : pp.NonMaximumSuppressionConfig
|
||||
Parameters for non-maximum suppression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
List of annotations predicted by the model.
|
||||
features : List[np.ndarray]
|
||||
List of CNN features associated with each annotation.
|
||||
Is empty if `config["cnn_features"]` is False.
|
||||
"""
|
||||
pred_nms, features = _process_spectrogram(
|
||||
spec,
|
||||
samplerate,
|
||||
model,
|
||||
config,
|
||||
)
|
||||
|
||||
annotations = get_annotations_from_preds(
|
||||
pred_nms,
|
||||
config["class_names"],
|
||||
)
|
||||
|
||||
return annotations, features
|
||||
|
||||
|
||||
def _process_audio_array(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
model: torch.nn.Module,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
# load audio file and compute spectrogram
|
||||
_, spec, _ = compute_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
{
|
||||
"fft_win_length": config["fft_win_length"],
|
||||
"fft_overlap": config["fft_overlap"],
|
||||
"spec_height": config["spec_height"],
|
||||
"resize_factor": config["resize_factor"],
|
||||
"spec_divide_factor": config["spec_divide_factor"],
|
||||
"max_freq": config["max_freq"],
|
||||
"min_freq": config["min_freq"],
|
||||
"spec_scale": config["spec_scale"],
|
||||
"denoise_spec_avg": config["denoise_spec_avg"],
|
||||
"max_scale_spec": config["max_scale_spec"],
|
||||
},
|
||||
device,
|
||||
return_np=False,
|
||||
)
|
||||
|
||||
# process spectrogram with model
|
||||
pred_nms, features = _process_spectrogram(
|
||||
spec,
|
||||
sampling_rate,
|
||||
model,
|
||||
config,
|
||||
)
|
||||
|
||||
return pred_nms, features, spec
|
||||
|
||||
|
||||
def process_audio_array(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
model: torch.nn.Module,
|
||||
config: ProcessingConfiguration,
|
||||
):
|
||||
device: torch.device,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
"""Process a single audio array with detection model.
|
||||
|
||||
Parameters
|
||||
@ -801,47 +831,42 @@ def process_audio_array(
|
||||
config : ProcessingConfiguration
|
||||
Configuration for processing.
|
||||
|
||||
device : torch.device
|
||||
Device to use for processing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_nms : Dict[str, np.ndarray]
|
||||
features : Dict[str, np.ndarray]
|
||||
spec_np : np.ndarray
|
||||
"""
|
||||
# load audio file and compute spectrogram
|
||||
_, spec, spec_np = compute_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
{
|
||||
"fft_win_length": config["fft_win_length"],
|
||||
"fft_overlap": config["fft_overlap"],
|
||||
"spec_height": config["spec_height"],
|
||||
"resize_factor": config["resize_factor"],
|
||||
"spec_divide_factor": config["spec_divide_factor"],
|
||||
"device": config["device"],
|
||||
"max_freq": config["max_freq"],
|
||||
"min_freq": config["min_freq"],
|
||||
"spec_scale": config["spec_scale"],
|
||||
"denoise_spec_avg": config["denoise_spec_avg"],
|
||||
"max_scale_spec": config["max_scale_spec"],
|
||||
},
|
||||
return_np=config["spec_features"] or config["spec_slices"],
|
||||
)
|
||||
annotations : List[Annotation]
|
||||
List of annotations predicted by the model.
|
||||
|
||||
# process spectrogram with model
|
||||
pred_nms, features = process_spectrogram(
|
||||
spec,
|
||||
features : List[np.ndarray]
|
||||
List of CNN features associated with each annotation.
|
||||
|
||||
spec : torch.Tensor
|
||||
Spectrogram of the audio used as input.
|
||||
|
||||
"""
|
||||
pred_nms, features, spec = _process_audio_array(
|
||||
audio,
|
||||
sampling_rate,
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
)
|
||||
|
||||
return pred_nms, features, spec_np
|
||||
annotations = get_annotations_from_preds(
|
||||
pred_nms,
|
||||
config["class_names"],
|
||||
)
|
||||
|
||||
return annotations, features, spec
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: torch.nn.Module,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Union[RunResults, Any]:
|
||||
"""Process a single audio file with detection model.
|
||||
|
||||
@ -872,7 +897,7 @@ def process_file(
|
||||
spec_slices = []
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio_file(
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
audio_file,
|
||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
@ -881,7 +906,7 @@ def process_file(
|
||||
)
|
||||
|
||||
# loop through larger file and split into chunks
|
||||
# TODO fix so that it overlaps correctly and takes care of
|
||||
# TODO: fix so that it overlaps correctly and takes care of
|
||||
# duplicate detections at borders
|
||||
for chunk_time, audio in iterate_over_chunks(
|
||||
audio_full,
|
||||
@ -889,11 +914,12 @@ def process_file(
|
||||
config["chunk_size"],
|
||||
):
|
||||
# Run detection model on chunk
|
||||
pred_nms, features, spec_np = process_audio_array(
|
||||
pred_nms, features, spec_np = _process_audio_array(
|
||||
audio,
|
||||
sampling_rate,
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
)
|
||||
|
||||
# add chunk time to start and end times
|
||||
@ -965,39 +991,30 @@ def summarize_results(results, predictions, config):
|
||||
)
|
||||
|
||||
|
||||
def get_default_config(**kwargs) -> ProcessingConfiguration:
|
||||
"""Get default configuration for running detection model."""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
args: ProcessingConfiguration = {
|
||||
"detection_threshold": DETECTION_THRESHOLD,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 3,
|
||||
"spec_features": False,
|
||||
"cnn_features": False,
|
||||
"quiet": True,
|
||||
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||
"device": device,
|
||||
"class_names": [],
|
||||
"time_expansion": 1,
|
||||
"top_n": 3,
|
||||
"return_raw_preds": False,
|
||||
"max_duration": None,
|
||||
"nms_kernel_size": NMS_KERNEL_SIZE,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
return {
|
||||
**args,
|
||||
**kwargs,
|
||||
}
|
||||
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
|
||||
"detection_threshold": DETECTION_THRESHOLD,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 3,
|
||||
"spec_features": False,
|
||||
"cnn_features": False,
|
||||
"quiet": True,
|
||||
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||
"class_names": [],
|
||||
"time_expansion": 1,
|
||||
"top_n": 3,
|
||||
"return_raw_preds": False,
|
||||
"max_duration": None,
|
||||
"nms_kernel_size": NMS_KERNEL_SIZE,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ if __name__ == "__main__":
|
||||
# load audio and crop
|
||||
print("\nProcessing: " + os.path.basename(args_cmd["audio_file"]))
|
||||
print("\nOutput directory: " + args_cmd["op_dir"])
|
||||
sampling_rate, audio = au.load_audio_file(
|
||||
sampling_rate, audio = au.load_audio(
|
||||
args_cmd["audio_file"],
|
||||
args_cmd["time_exp"],
|
||||
params_bd["target_samp_rate"],
|
||||
|
@ -96,7 +96,7 @@ if __name__ == "__main__":
|
||||
# load audio file
|
||||
print("\nProcessing: " + os.path.basename(audio_file))
|
||||
print("\nOutput directory: " + op_dir)
|
||||
sampling_rate, audio = au.load_audio_file(
|
||||
sampling_rate, audio = au.load_audio(
|
||||
audio_file, args["time_expansion_factor"], params["target_samp_rate"]
|
||||
)
|
||||
audio = audio[
|
||||
|
@ -72,7 +72,7 @@ def load_data(
|
||||
sampling_rates = []
|
||||
file_names = []
|
||||
for cur_file in anns:
|
||||
sampling_rate, audio_orig = au.load_audio_file(
|
||||
sampling_rate, audio_orig = au.load_audio(
|
||||
cur_file["file_path"],
|
||||
cur_file["time_exp"],
|
||||
params["target_samp_rate"],
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/test_api.py
Normal file
0
tests/test_api.py
Normal file
213
tests/test_bat_detect.py
Normal file
213
tests/test_bat_detect.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from bat_detect.api import (
|
||||
generate_spectrogram,
|
||||
get_config,
|
||||
list_audio_files,
|
||||
load_audio,
|
||||
load_model,
|
||||
process_audio,
|
||||
process_file,
|
||||
process_spectrogram,
|
||||
)
|
||||
|
||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||
|
||||
|
||||
def test_load_model_with_default_params():
|
||||
"""Test loading model with default parameters."""
|
||||
model, params = load_model()
|
||||
|
||||
assert model is not None
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
assert params is not None
|
||||
assert isinstance(params, dict)
|
||||
|
||||
assert "model_name" in params
|
||||
assert "num_filters" in params
|
||||
assert "emb_dim" in params
|
||||
assert "ip_height" in params
|
||||
assert "resize_factor" in params
|
||||
assert "class_names" in params
|
||||
|
||||
assert params["model_name"] == "Net2DFast"
|
||||
assert params["num_filters"] == 128
|
||||
assert params["emb_dim"] == 0
|
||||
assert params["ip_height"] == 128
|
||||
assert params["resize_factor"] == 0.5
|
||||
assert len(params["class_names"]) == 17
|
||||
|
||||
|
||||
def test_list_audio_files():
|
||||
"""Test listing audio files."""
|
||||
audio_files = list_audio_files(TEST_DATA_DIR)
|
||||
|
||||
assert len(audio_files) == 3
|
||||
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
||||
|
||||
|
||||
def test_load_audio():
|
||||
"""Test loading audio."""
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
|
||||
assert audio is not None
|
||||
assert samplerate == 256000
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert audio.shape == (128000,)
|
||||
|
||||
|
||||
def test_generate_spectrogram():
|
||||
"""Test generating spectrogram."""
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
spectrogram = generate_spectrogram(audio, samplerate)
|
||||
|
||||
assert spectrogram is not None
|
||||
assert isinstance(spectrogram, torch.Tensor)
|
||||
assert spectrogram.shape == (1, 1, 128, 512)
|
||||
|
||||
|
||||
def test_get_default_config():
|
||||
"""Test getting default configuration."""
|
||||
config = get_config()
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, dict)
|
||||
|
||||
assert config["target_samp_rate"] == 256000
|
||||
assert config["fft_win_length"] == 0.002
|
||||
assert config["fft_overlap"] == 0.75
|
||||
assert config["resize_factor"] == 0.5
|
||||
assert config["spec_divide_factor"] == 32
|
||||
assert config["spec_height"] == 256
|
||||
assert config["spec_scale"] == "pcen"
|
||||
assert config["denoise_spec_avg"] is True
|
||||
assert config["max_scale_spec"] is False
|
||||
assert config["scale_raw_audio"] is False
|
||||
assert len(config["class_names"]) == 0
|
||||
assert config["detection_threshold"] == 0.01
|
||||
assert config["time_expansion"] == 1
|
||||
assert config["top_n"] == 3
|
||||
assert config["return_raw_preds"] is False
|
||||
assert config["max_duration"] is None
|
||||
assert config["nms_kernel_size"] == 9
|
||||
assert config["max_freq"] == 120000
|
||||
assert config["min_freq"] == 10000
|
||||
assert config["nms_top_k_per_sec"] == 200
|
||||
assert config["quiet"] is True
|
||||
assert config["chunk_size"] == 3
|
||||
assert config["cnn_features"] is False
|
||||
assert config["spec_features"] is False
|
||||
assert config["spec_slices"] is False
|
||||
|
||||
|
||||
def test_process_file_with_model():
|
||||
"""Test processing file with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
predictions = process_file(TEST_DATA[0], model, config=config)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, dict)
|
||||
|
||||
assert "pred_dict" in predictions
|
||||
assert "spec_feats" in predictions
|
||||
assert "spec_feat_names" in predictions
|
||||
assert "cnn_feats" in predictions
|
||||
assert "cnn_feat_names" in predictions
|
||||
assert "spec_slices" in predictions
|
||||
|
||||
# By default will not return spectrogram features
|
||||
assert predictions["spec_feats"] is None
|
||||
assert predictions["spec_feat_names"] is None
|
||||
assert predictions["cnn_feats"] is None
|
||||
assert predictions["cnn_feat_names"] is None
|
||||
assert predictions["spec_slices"] is None
|
||||
|
||||
# Check that predictions are returned
|
||||
assert isinstance(predictions["pred_dict"], dict)
|
||||
pred_dict = predictions["pred_dict"]
|
||||
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
|
||||
assert pred_dict["annotated"] is False
|
||||
assert pred_dict["issues"] is False
|
||||
assert pred_dict["notes"] == "Automatically generated."
|
||||
assert pred_dict["time_exp"] == 1
|
||||
assert pred_dict["duration"] == 0.5
|
||||
assert pred_dict["class_name"] is not None
|
||||
assert len(pred_dict["annotation"]) > 0
|
||||
|
||||
|
||||
def test_process_spectrogram_with_model():
|
||||
"""Test processing spectrogram with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
spectrogram = generate_spectrogram(audio, samplerate)
|
||||
predictions, features = process_spectrogram(
|
||||
spectrogram,
|
||||
samplerate,
|
||||
model,
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
# By default will not return cnn features
|
||||
assert len(features) == 0
|
||||
|
||||
|
||||
def test_process_audio_with_model():
|
||||
"""Test processing audio with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
predictions, features, spec = process_audio(
|
||||
audio,
|
||||
samplerate,
|
||||
model,
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
# By default will not return cnn features
|
||||
assert len(features) == 0
|
||||
|
||||
assert spec is not None
|
||||
assert isinstance(spec, torch.Tensor)
|
||||
assert spec.shape == (1, 1, 128, 512)
|
0
tests/test_cli.py
Normal file
0
tests/test_cli.py
Normal file
Loading…
Reference in New Issue
Block a user