mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added an API file with tests to check basic functionality
This commit is contained in:
parent
40222d8233
commit
0eecf54a94
2
app.py
2
app.py
@ -77,7 +77,7 @@ def make_prediction(file_name=None, detection_threshold=0.3):
|
|||||||
def generate_results_image(audio_file, anns):
|
def generate_results_image(audio_file, anns):
|
||||||
|
|
||||||
# load audio
|
# load audio
|
||||||
sampling_rate, audio = au.load_audio_file(
|
sampling_rate, audio = au.load_audio(
|
||||||
audio_file,
|
audio_file,
|
||||||
args["time_expansion_factor"],
|
args["time_expansion_factor"],
|
||||||
params["target_samp_rate"],
|
params["target_samp_rate"],
|
||||||
|
215
bat_detect/api.py
Normal file
215
bat_detect/api.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import bat_detect.detector.models as md
|
||||||
|
import bat_detect.utils.audio_utils as au
|
||||||
|
import bat_detect.utils.detector_utils as du
|
||||||
|
from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ
|
||||||
|
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
||||||
|
|
||||||
|
# Use GPU if available
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_model",
|
||||||
|
"load_audio",
|
||||||
|
"list_audio_files",
|
||||||
|
"generate_spectrogram",
|
||||||
|
"get_config",
|
||||||
|
"process_file",
|
||||||
|
"process_spectrogram",
|
||||||
|
"process_audio",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(**kwargs) -> du.ProcessingConfiguration:
|
||||||
|
"""Get default processing configuration.
|
||||||
|
|
||||||
|
Can be used to override default parameters by passing keyword arguments.
|
||||||
|
"""
|
||||||
|
return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(
|
||||||
|
path: str,
|
||||||
|
time_exp_fact: float = 1,
|
||||||
|
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
scale: bool = False,
|
||||||
|
max_duration: Optional[float] = None,
|
||||||
|
) -> Tuple[int, np.ndarray]:
|
||||||
|
"""Load audio from file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : str
|
||||||
|
Path to audio file.
|
||||||
|
time_exp_fact : float, optional
|
||||||
|
Time expansion factor, by default 1
|
||||||
|
target_samp_rate : int, optional
|
||||||
|
Target sample rate, by default 256000
|
||||||
|
scale : bool, optional
|
||||||
|
Scale audio to [-1, 1], by default False
|
||||||
|
max_duration : Optional[float], optional
|
||||||
|
Maximum duration of audio in seconds, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
Audio data.
|
||||||
|
int
|
||||||
|
Sample rate.
|
||||||
|
"""
|
||||||
|
return au.load_audio(
|
||||||
|
path,
|
||||||
|
time_exp_fact,
|
||||||
|
target_samp_rate,
|
||||||
|
scale,
|
||||||
|
max_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_spectrogram(
|
||||||
|
audio: np.ndarray,
|
||||||
|
samp_rate: int,
|
||||||
|
config: Optional[au.SpectrogramParameters] = None,
|
||||||
|
device: torch.device = DEVICE,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Generate spectrogram from audio array.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio : np.ndarray
|
||||||
|
Audio data.
|
||||||
|
samp_rate : int
|
||||||
|
Sample rate.
|
||||||
|
config : Optional[SpectrogramParameters], optional
|
||||||
|
Spectrogram parameters, by default None (uses default parameters).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Spectrogram.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = au.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
|
|
||||||
|
_, spec, _ = du.compute_spectrogram(
|
||||||
|
audio,
|
||||||
|
samp_rate,
|
||||||
|
config,
|
||||||
|
return_np=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def process_file(
|
||||||
|
audio_file: str,
|
||||||
|
model: md.DetectionModel,
|
||||||
|
config: Optional[du.ProcessingConfiguration] = None,
|
||||||
|
device: torch.device = DEVICE,
|
||||||
|
) -> du.RunResults:
|
||||||
|
"""Process audio file with model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_file : str
|
||||||
|
Path to audio file.
|
||||||
|
model : DetectionModel
|
||||||
|
Detection model.
|
||||||
|
config : Optional[ProcessingConfiguration], optional
|
||||||
|
Processing configuration, by default None (uses default parameters).
|
||||||
|
device : torch.device, optional
|
||||||
|
Device to use, by default tries to use GPU if available.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||||
|
|
||||||
|
return du.process_file(
|
||||||
|
audio_file,
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_spectrogram(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
samp_rate: int,
|
||||||
|
model: md.DetectionModel,
|
||||||
|
config: Optional[du.ProcessingConfiguration] = None,
|
||||||
|
) -> Tuple[List[du.Annotation], List[np.ndarray]]:
|
||||||
|
"""Process spectrogram with model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Spectrogram.
|
||||||
|
samp_rate : int
|
||||||
|
Sample rate of the audio from which the spectrogram was generated.
|
||||||
|
model : DetectionModel
|
||||||
|
Detection model.
|
||||||
|
config : Optional[ProcessingConfiguration], optional
|
||||||
|
Processing configuration, by default None (uses default parameters).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
DetectionResult
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||||
|
|
||||||
|
return du.process_spectrogram(
|
||||||
|
spec,
|
||||||
|
samp_rate,
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_audio(
|
||||||
|
audio: np.ndarray,
|
||||||
|
samp_rate: int,
|
||||||
|
model: md.DetectionModel,
|
||||||
|
config: Optional[du.ProcessingConfiguration] = None,
|
||||||
|
device: torch.device = DEVICE,
|
||||||
|
) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]:
|
||||||
|
"""Process audio array with model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio : np.ndarray
|
||||||
|
Audio data.
|
||||||
|
samp_rate : int
|
||||||
|
Sample rate.
|
||||||
|
model : DetectionModel
|
||||||
|
Detection model.
|
||||||
|
config : Optional[ProcessingConfiguration], optional
|
||||||
|
Processing configuration, by default None (uses default parameters).
|
||||||
|
device : torch.device, optional
|
||||||
|
Device to use, by default tries to use GPU if available.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
annotations : List[Annotation]
|
||||||
|
List of predicted annotations.
|
||||||
|
|
||||||
|
features: List[np.ndarray]
|
||||||
|
List of extracted features for each annotation.
|
||||||
|
|
||||||
|
spec : torch.Tensor
|
||||||
|
Spectrogram of the audio used for prediction.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||||
|
|
||||||
|
return du.process_audio_array(
|
||||||
|
audio,
|
||||||
|
samp_rate,
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
device,
|
||||||
|
)
|
@ -92,7 +92,7 @@ def main():
|
|||||||
model, params = du.load_model(args["model_path"])
|
model, params = du.load_model(args["model_path"])
|
||||||
|
|
||||||
print("\nInput directory: " + args["audio_dir"])
|
print("\nInput directory: " + args["audio_dir"])
|
||||||
files = du.get_audio_files(args["audio_dir"])
|
files = du.list_audio_files(args["audio_dir"])
|
||||||
|
|
||||||
print(f"Number of audio files: {len(files)}")
|
print(f"Number of audio files: {len(files)}")
|
||||||
print("\nSaving results to: " + args["ann_dir"])
|
print("\nSaving results to: " + args["ann_dir"])
|
@ -1,9 +1,11 @@
|
|||||||
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fft
|
import torch.fft
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .model_helpers import (
|
from bat_detect.detector.model_helpers import (
|
||||||
ConvBlockDownCoordF,
|
ConvBlockDownCoordF,
|
||||||
ConvBlockDownStandard,
|
ConvBlockDownStandard,
|
||||||
ConvBlockUpF,
|
ConvBlockUpF,
|
||||||
@ -11,13 +13,88 @@ from .model_helpers import (
|
|||||||
SelfAttention,
|
SelfAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Protocol
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
"Net2DFastNoAttn",
|
"Net2DFastNoAttn",
|
||||||
"Net2DFastNoCoordConv",
|
"Net2DFastNoCoordConv",
|
||||||
|
"ModelOutput",
|
||||||
|
"DetectionModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutput(NamedTuple):
|
||||||
|
"""Output of the detection model."""
|
||||||
|
|
||||||
|
pred_det: torch.Tensor
|
||||||
|
"""Tensor with predict detection probabilities."""
|
||||||
|
|
||||||
|
pred_size: torch.Tensor
|
||||||
|
"""Tensor with predicted bounding box sizes."""
|
||||||
|
|
||||||
|
pred_class: torch.Tensor
|
||||||
|
"""Tensor with predicted class probabilities."""
|
||||||
|
|
||||||
|
pred_class_un_norm: torch.Tensor
|
||||||
|
"""Tensor with predicted class probabilities before softmax."""
|
||||||
|
|
||||||
|
pred_emb: Optional[torch.Tensor]
|
||||||
|
"""Tensor with embeddings."""
|
||||||
|
|
||||||
|
features: Optional[torch.Tensor]
|
||||||
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionModel(Protocol):
|
||||||
|
"""Protocol for detection models.
|
||||||
|
|
||||||
|
This protocol is used to define the interface for the detection models.
|
||||||
|
This allows us to use the same code for training and inference, even
|
||||||
|
though the models are different.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_classes: int
|
||||||
|
"""Number of classes the model can classify."""
|
||||||
|
|
||||||
|
emb_dim: int
|
||||||
|
"""Dimension of the embedding vector."""
|
||||||
|
|
||||||
|
num_filts: int
|
||||||
|
"""Number of filters in the model."""
|
||||||
|
|
||||||
|
resize_factor: float
|
||||||
|
"""Factor by which the input is resized."""
|
||||||
|
|
||||||
|
ip_height: int
|
||||||
|
"""Height of the input image."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
ip: torch.Tensor,
|
||||||
|
return_feats: bool = False,
|
||||||
|
) -> ModelOutput:
|
||||||
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
When `return_feats` is `True`, the model should return the
|
||||||
|
intermediate features of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ip: torch.Tensor,
|
||||||
|
return_feats: bool = False,
|
||||||
|
) -> ModelOutput:
|
||||||
|
"""Forward pass of the model.
|
||||||
|
|
||||||
|
When `return_feats` is `True`, the model should return the
|
||||||
|
int
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Net2DFast(nn.Module):
|
class Net2DFast(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -27,7 +104,7 @@ class Net2DFast(nn.Module):
|
|||||||
ip_height=128,
|
ip_height=128,
|
||||||
resize_factor=0.5,
|
resize_factor=0.5,
|
||||||
):
|
):
|
||||||
super(Net2DFast, self).__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.emb_dim = emb_dim
|
self.emb_dim = emb_dim
|
||||||
self.num_filts = num_filts
|
self.num_filts = num_filts
|
||||||
@ -102,7 +179,7 @@ class Net2DFast(nn.Module):
|
|||||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, ip, return_feats=False):
|
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
x1 = self.conv_dn_0(ip)
|
x1 = self.conv_dn_0(ip)
|
||||||
@ -125,17 +202,14 @@ class Net2DFast(nn.Module):
|
|||||||
cls = self.conv_classes_op(x)
|
cls = self.conv_classes_op(x)
|
||||||
comb = torch.softmax(cls, 1)
|
comb = torch.softmax(cls, 1)
|
||||||
|
|
||||||
op = {}
|
return ModelOutput(
|
||||||
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
|
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||||
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
|
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||||
op["pred_class"] = comb
|
pred_class=comb,
|
||||||
op["pred_class_un_norm"] = cls
|
pred_class_un_norm=cls,
|
||||||
if self.emb_dim > 0:
|
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||||
op["pred_emb"] = self.conv_emb(x)
|
features=x if return_feats else None,
|
||||||
if return_feats:
|
)
|
||||||
op["features"] = x
|
|
||||||
|
|
||||||
return op
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoAttn(nn.Module):
|
class Net2DFastNoAttn(nn.Module):
|
||||||
@ -147,7 +221,7 @@ class Net2DFastNoAttn(nn.Module):
|
|||||||
ip_height=128,
|
ip_height=128,
|
||||||
resize_factor=0.5,
|
resize_factor=0.5,
|
||||||
):
|
):
|
||||||
super(Net2DFastNoAttn, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.emb_dim = emb_dim
|
self.emb_dim = emb_dim
|
||||||
@ -219,8 +293,7 @@ class Net2DFastNoAttn(nn.Module):
|
|||||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, ip, return_feats=False):
|
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||||
|
|
||||||
x1 = self.conv_dn_0(ip)
|
x1 = self.conv_dn_0(ip)
|
||||||
x2 = self.conv_dn_1(x1)
|
x2 = self.conv_dn_1(x1)
|
||||||
x3 = self.conv_dn_2(x2)
|
x3 = self.conv_dn_2(x2)
|
||||||
@ -237,17 +310,14 @@ class Net2DFastNoAttn(nn.Module):
|
|||||||
cls = self.conv_classes_op(x)
|
cls = self.conv_classes_op(x)
|
||||||
comb = torch.softmax(cls, 1)
|
comb = torch.softmax(cls, 1)
|
||||||
|
|
||||||
op = {}
|
return ModelOutput(
|
||||||
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
|
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||||
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
|
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||||
op["pred_class"] = comb
|
pred_class=comb,
|
||||||
op["pred_class_un_norm"] = cls
|
pred_class_un_norm=cls,
|
||||||
if self.emb_dim > 0:
|
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||||
op["pred_emb"] = self.conv_emb(x)
|
features=x if return_feats else None,
|
||||||
if return_feats:
|
)
|
||||||
op["features"] = x
|
|
||||||
|
|
||||||
return op
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoCoordConv(nn.Module):
|
class Net2DFastNoCoordConv(nn.Module):
|
||||||
@ -259,7 +329,7 @@ class Net2DFastNoCoordConv(nn.Module):
|
|||||||
ip_height=128,
|
ip_height=128,
|
||||||
resize_factor=0.5,
|
resize_factor=0.5,
|
||||||
):
|
):
|
||||||
super(Net2DFastNoCoordConv, self).__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.emb_dim = emb_dim
|
self.emb_dim = emb_dim
|
||||||
@ -333,7 +403,7 @@ class Net2DFastNoCoordConv(nn.Module):
|
|||||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, ip, return_feats=False):
|
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||||
|
|
||||||
x1 = self.conv_dn_0(ip)
|
x1 = self.conv_dn_0(ip)
|
||||||
x2 = self.conv_dn_1(x1)
|
x2 = self.conv_dn_1(x1)
|
||||||
@ -352,14 +422,11 @@ class Net2DFastNoCoordConv(nn.Module):
|
|||||||
cls = self.conv_classes_op(x)
|
cls = self.conv_classes_op(x)
|
||||||
comb = torch.softmax(cls, 1)
|
comb = torch.softmax(cls, 1)
|
||||||
|
|
||||||
op = {}
|
return ModelOutput(
|
||||||
op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
|
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||||
op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
|
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||||
op["pred_class"] = comb
|
pred_class=comb,
|
||||||
op["pred_class_un_norm"] = cls
|
pred_class_un_norm=cls,
|
||||||
if self.emb_dim > 0:
|
pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None,
|
||||||
op["pred_emb"] = self.conv_emb(x)
|
features=x if return_feats else None,
|
||||||
if return_feats:
|
)
|
||||||
op["features"] = x
|
|
||||||
|
|
||||||
return op
|
|
||||||
|
@ -5,6 +5,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from bat_detect.detector.models import ModelOutput
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -106,24 +108,8 @@ class PredictionResults(TypedDict):
|
|||||||
"""Class probabilities."""
|
"""Class probabilities."""
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputs(TypedDict):
|
|
||||||
"""Outputs of the model."""
|
|
||||||
|
|
||||||
pred_det: torch.Tensor
|
|
||||||
"""Detection probabilities."""
|
|
||||||
|
|
||||||
pred_size: torch.Tensor
|
|
||||||
"""Box sizes."""
|
|
||||||
|
|
||||||
pred_class: Optional[torch.Tensor]
|
|
||||||
"""Class probabilities."""
|
|
||||||
|
|
||||||
features: Optional[torch.Tensor]
|
|
||||||
"""Features extracted by the model."""
|
|
||||||
|
|
||||||
|
|
||||||
def run_nms(
|
def run_nms(
|
||||||
outputs: ModelOutputs,
|
outputs: ModelOutput,
|
||||||
params: NonMaximumSuppressionConfig,
|
params: NonMaximumSuppressionConfig,
|
||||||
sampling_rate: np.ndarray,
|
sampling_rate: np.ndarray,
|
||||||
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
||||||
@ -135,16 +121,14 @@ def run_nms(
|
|||||||
the features. Each element of the lists corresponds to one
|
the features. Each element of the lists corresponds to one
|
||||||
element of the batch.
|
element of the batch.
|
||||||
"""
|
"""
|
||||||
|
pred_det, pred_size, pred_class, _, _, features = outputs
|
||||||
pred_det = outputs["pred_det"] # probability of box
|
|
||||||
pred_size = outputs["pred_size"] # box size
|
|
||||||
|
|
||||||
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
|
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
|
||||||
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
|
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
|
||||||
-2
|
-2
|
||||||
]
|
]
|
||||||
|
|
||||||
# NOTE there will be small differences depending on which sampling rate is chosen
|
# NOTE: there will be small differences depending on which sampling rate is chosen
|
||||||
# as we are choosing the same sampling rate for the entire batch
|
# as we are choosing the same sampling rate for the entire batch
|
||||||
duration = x_coords_to_time(
|
duration = x_coords_to_time(
|
||||||
pred_det.shape[-1],
|
pred_det.shape[-1],
|
||||||
@ -172,10 +156,16 @@ def run_nms(
|
|||||||
pred["x_pos"] = x_pos[num_detection, valid_inds]
|
pred["x_pos"] = x_pos[num_detection, valid_inds]
|
||||||
pred["y_pos"] = y_pos[num_detection, valid_inds]
|
pred["y_pos"] = y_pos[num_detection, valid_inds]
|
||||||
pred["bb_width"] = pred_size[
|
pred["bb_width"] = pred_size[
|
||||||
num_detection, 0, pred["y_pos"], pred["x_pos"]
|
num_detection,
|
||||||
|
0,
|
||||||
|
pred["y_pos"],
|
||||||
|
pred["x_pos"],
|
||||||
]
|
]
|
||||||
pred["bb_height"] = pred_size[
|
pred["bb_height"] = pred_size[
|
||||||
num_detection, 1, pred["y_pos"], pred["x_pos"]
|
num_detection,
|
||||||
|
1,
|
||||||
|
pred["y_pos"],
|
||||||
|
pred["x_pos"],
|
||||||
]
|
]
|
||||||
pred["start_times"] = x_coords_to_time(
|
pred["start_times"] = x_coords_to_time(
|
||||||
pred["x_pos"].float() / params["resize_factor"],
|
pred["x_pos"].float() / params["resize_factor"],
|
||||||
@ -198,7 +188,6 @@ def run_nms(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# extract the per class votes
|
# extract the per class votes
|
||||||
pred_class = outputs.get("pred_class")
|
|
||||||
if pred_class is not None:
|
if pred_class is not None:
|
||||||
pred["class_probs"] = pred_class[
|
pred["class_probs"] = pred_class[
|
||||||
num_detection,
|
num_detection,
|
||||||
@ -208,7 +197,6 @@ def run_nms(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# extract the model features
|
# extract the model features
|
||||||
features = outputs.get("features")
|
|
||||||
if features is not None:
|
if features is not None:
|
||||||
feat = features[
|
feat = features[
|
||||||
num_detection,
|
num_detection,
|
||||||
|
@ -373,7 +373,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
index = np.random.randint(0, len(self.data_anns))
|
index = np.random.randint(0, len(self.data_anns))
|
||||||
|
|
||||||
audio_file = self.data_anns[index]["file_path"]
|
audio_file = self.data_anns[index]["file_path"]
|
||||||
sampling_rate, audio_raw = au.load_audio_file(
|
sampling_rate, audio_raw = au.load_audio(
|
||||||
audio_file,
|
audio_file,
|
||||||
self.data_anns[index]["time_exp"],
|
self.data_anns[index]["time_exp"],
|
||||||
self.params["target_samp_rate"],
|
self.params["target_samp_rate"],
|
||||||
|
@ -5,13 +5,87 @@ import librosa
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from bat_detect.detector.parameters import (
|
||||||
|
DENOISE_SPEC_AVG,
|
||||||
|
DETECTION_THRESHOLD,
|
||||||
|
FFT_OVERLAP,
|
||||||
|
FFT_WIN_LENGTH_S,
|
||||||
|
MAX_FREQ_HZ,
|
||||||
|
MAX_SCALE_SPEC,
|
||||||
|
MIN_FREQ_HZ,
|
||||||
|
NMS_KERNEL_SIZE,
|
||||||
|
NMS_TOP_K_PER_SEC,
|
||||||
|
RESIZE_FACTOR,
|
||||||
|
SCALE_RAW_AUDIO,
|
||||||
|
SPEC_DIVIDE_FACTOR,
|
||||||
|
SPEC_HEIGHT,
|
||||||
|
SPEC_SCALE,
|
||||||
|
)
|
||||||
|
|
||||||
from . import wavfile
|
from . import wavfile
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import TypedDict
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_audio_file",
|
"load_audio",
|
||||||
|
"generate_spectrogram",
|
||||||
|
"pad_audio",
|
||||||
|
"SpectrogramParameters",
|
||||||
|
"DEFAULT_SPECTROGRAM_PARAMETERS",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramParameters(TypedDict):
|
||||||
|
"""Parameters for generating spectrograms."""
|
||||||
|
|
||||||
|
fft_win_length: float
|
||||||
|
"""Length of the FFT window in seconds."""
|
||||||
|
|
||||||
|
fft_overlap: float
|
||||||
|
"""Percentage of overlap between FFT windows."""
|
||||||
|
|
||||||
|
spec_height: int
|
||||||
|
"""Height of the spectrogram in pixels."""
|
||||||
|
|
||||||
|
resize_factor: float
|
||||||
|
"""Factor to resize the spectrogram by."""
|
||||||
|
|
||||||
|
spec_divide_factor: int
|
||||||
|
"""Factor to divide the spectrogram by."""
|
||||||
|
|
||||||
|
max_freq: int
|
||||||
|
"""Maximum frequency to display in the spectrogram."""
|
||||||
|
|
||||||
|
min_freq: int
|
||||||
|
"""Minimum frequency to display in the spectrogram."""
|
||||||
|
|
||||||
|
spec_scale: str
|
||||||
|
"""Scale to use for the spectrogram."""
|
||||||
|
|
||||||
|
denoise_spec_avg: bool
|
||||||
|
"""Whether to denoise the spectrogram by averaging."""
|
||||||
|
|
||||||
|
max_scale_spec: bool
|
||||||
|
"""Whether to scale the spectrogram so that its max is 1."""
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
|
||||||
|
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||||
|
"fft_overlap": FFT_OVERLAP,
|
||||||
|
"spec_height": SPEC_HEIGHT,
|
||||||
|
"resize_factor": RESIZE_FACTOR,
|
||||||
|
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||||
|
"max_freq": MAX_FREQ_HZ,
|
||||||
|
"min_freq": MIN_FREQ_HZ,
|
||||||
|
"spec_scale": SPEC_SCALE,
|
||||||
|
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||||
|
"max_scale_spec": MAX_SCALE_SPEC,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
noverlap = np.floor(fft_overlap * nfft)
|
||||||
@ -36,7 +110,10 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
# generate spectrogram
|
# generate spectrogram
|
||||||
spec = gen_mag_spectrogram(
|
spec = gen_mag_spectrogram(
|
||||||
audio, sampling_rate, params["fft_win_length"], params["fft_overlap"]
|
audio,
|
||||||
|
sampling_rate,
|
||||||
|
params["fft_win_length"],
|
||||||
|
params["fft_overlap"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# crop to min/max freq
|
# crop to min/max freq
|
||||||
@ -70,6 +147,7 @@ def generate_spectrogram(
|
|||||||
spec = np.log1p(log_scaling * spec_cropped)
|
spec = np.log1p(log_scaling * spec_cropped)
|
||||||
elif params["spec_scale"] == "pcen":
|
elif params["spec_scale"] == "pcen":
|
||||||
spec = pcen(spec_cropped, sampling_rate)
|
spec = pcen(spec_cropped, sampling_rate)
|
||||||
|
|
||||||
elif params["spec_scale"] == "none":
|
elif params["spec_scale"] == "none":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -109,13 +187,13 @@ def generate_spectrogram(
|
|||||||
return spec, spec_for_viz
|
return spec, spec_for_viz
|
||||||
|
|
||||||
|
|
||||||
def load_audio_file(
|
def load_audio(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
time_exp_fact: float,
|
time_exp_fact: float,
|
||||||
target_samp_rate: int,
|
target_samp_rate: int,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: Optional[float] = None,
|
||||||
):
|
) -> Tuple[int, np.ndarray]:
|
||||||
"""Load an audio file and resample it to the target sampling rate.
|
"""Load an audio file and resample it to the target sampling rate.
|
||||||
|
|
||||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||||
|
@ -43,19 +43,19 @@ DEFAULT_MODEL_PATH = os.path.join(
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
"get_audio_files",
|
"list_audio_files",
|
||||||
"get_default_config",
|
"format_single_result",
|
||||||
"format_results",
|
|
||||||
"save_results_to_file",
|
"save_results_to_file",
|
||||||
"iterate_over_chunks",
|
"iterate_over_chunks",
|
||||||
"process_spectrogram",
|
"process_spectrogram",
|
||||||
"process_audio_array",
|
"process_audio_array",
|
||||||
"process_file",
|
"process_file",
|
||||||
"DEFAULT_MODEL_PATH",
|
"DEFAULT_MODEL_PATH",
|
||||||
|
"DEFAULT_PROCESSING_CONFIGURATIONS",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_audio_files(ip_dir: str) -> List[str]:
|
def list_audio_files(ip_dir: str) -> List[str]:
|
||||||
"""Get all audio files in directory.
|
"""Get all audio files in directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -98,13 +98,12 @@ class ModelParameters(TypedDict):
|
|||||||
class_names: List[str]
|
class_names: List[str]
|
||||||
"""Class names. The model is trained to detect these classes."""
|
"""Class names. The model is trained to detect these classes."""
|
||||||
|
|
||||||
device: torch.device
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
model_path: str = DEFAULT_MODEL_PATH,
|
model_path: str = DEFAULT_MODEL_PATH,
|
||||||
load_weights: bool = True,
|
load_weights: bool = True,
|
||||||
) -> Tuple[torch.nn.Module, ModelParameters]:
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tuple[models.DetectionModel, ModelParameters]:
|
||||||
"""Load model from file.
|
"""Load model from file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -120,6 +119,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
|
if device is None:
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
@ -128,9 +128,8 @@ def load_model(
|
|||||||
net_params = torch.load(model_path, map_location=device)
|
net_params = torch.load(model_path, map_location=device)
|
||||||
|
|
||||||
params = net_params["params"]
|
params = net_params["params"]
|
||||||
params["device"] = device
|
|
||||||
|
|
||||||
model: torch.nn.Module
|
model: models.DetectionModel
|
||||||
|
|
||||||
if params["model_name"] == "Net2DFast":
|
if params["model_name"] == "Net2DFast":
|
||||||
model = models.Net2DFast(
|
model = models.Net2DFast(
|
||||||
@ -162,7 +161,7 @@ def load_model(
|
|||||||
if load_weights:
|
if load_weights:
|
||||||
model.load_state_dict(net_params["state_dict"])
|
model.load_state_dict(net_params["state_dict"])
|
||||||
|
|
||||||
model = model.to(params["device"])
|
model = model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
return model, params
|
return model, params
|
||||||
@ -285,30 +284,11 @@ class ResultParams(TypedDict):
|
|||||||
"""Class names."""
|
"""Class names."""
|
||||||
|
|
||||||
|
|
||||||
def format_results(
|
def get_annotations_from_preds(
|
||||||
file_id: str,
|
|
||||||
time_exp: float,
|
|
||||||
duration: float,
|
|
||||||
predictions,
|
predictions,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
) -> FileAnnotations:
|
) -> List[Annotation]:
|
||||||
"""Format results into the format expected by the annotation tool.
|
"""Get list of annotations from predictions."""
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id (str): File ID.
|
|
||||||
time_exp (float): Time expansion factor.
|
|
||||||
duration (float): Duration of audio file.
|
|
||||||
predictions (dict): Predictions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Results in the format expected by the annotation tool.
|
|
||||||
"""
|
|
||||||
# Get a single class prediction for the file
|
|
||||||
class_overall = pp.overall_class_pred(
|
|
||||||
predictions["det_probs"],
|
|
||||||
predictions["class_probs"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the best class prediction probability and index for each detection
|
# Get the best class prediction probability and index for each detection
|
||||||
class_prob_best = predictions["class_probs"].max(0)
|
class_prob_best = predictions["class_probs"].max(0)
|
||||||
class_ind_best = predictions["class_probs"].argmax(0)
|
class_ind_best = predictions["class_probs"].argmax(0)
|
||||||
@ -344,6 +324,32 @@ def format_results(
|
|||||||
predictions["det_probs"],
|
predictions["det_probs"],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
|
||||||
|
def format_single_result(
|
||||||
|
file_id: str,
|
||||||
|
time_exp: float,
|
||||||
|
duration: float,
|
||||||
|
predictions,
|
||||||
|
class_names: List[str],
|
||||||
|
) -> FileAnnotations:
|
||||||
|
"""Format results into the format expected by the annotation tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id (str): File ID.
|
||||||
|
time_exp (float): Time expansion factor.
|
||||||
|
duration (float): Duration of audio file.
|
||||||
|
predictions (dict): Predictions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Results in the format expected by the annotation tool.
|
||||||
|
"""
|
||||||
|
# Get a single class prediction for the file
|
||||||
|
class_overall = pp.overall_class_pred(
|
||||||
|
predictions["det_probs"],
|
||||||
|
predictions["class_probs"],
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": file_id,
|
"id": file_id,
|
||||||
@ -352,7 +358,7 @@ def format_results(
|
|||||||
"notes": "Automatically generated.",
|
"notes": "Automatically generated.",
|
||||||
"time_exp": time_exp,
|
"time_exp": time_exp,
|
||||||
"duration": round(float(duration), 4),
|
"duration": round(float(duration), 4),
|
||||||
"annotation": annotations,
|
"annotation": get_annotations_from_preds(predictions, class_names),
|
||||||
"class_name": class_names[np.argmax(class_overall)],
|
"class_name": class_names[np.argmax(class_overall)],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,7 +389,7 @@ def convert_results(
|
|||||||
dict: Dictionary with results.
|
dict: Dictionary with results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pred_dict = format_results(
|
pred_dict = format_single_result(
|
||||||
file_id,
|
file_id,
|
||||||
time_exp,
|
time_exp,
|
||||||
duration,
|
duration,
|
||||||
@ -490,47 +496,11 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True)
|
json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramParameters(TypedDict):
|
|
||||||
"""Parameters for generating spectrograms."""
|
|
||||||
|
|
||||||
fft_win_length: float
|
|
||||||
"""Length of the FFT window in seconds."""
|
|
||||||
|
|
||||||
fft_overlap: float
|
|
||||||
"""Percentage of overlap between FFT windows."""
|
|
||||||
|
|
||||||
spec_height: int
|
|
||||||
"""Height of the spectrogram in pixels."""
|
|
||||||
|
|
||||||
resize_factor: float
|
|
||||||
"""Factor to resize the spectrogram by."""
|
|
||||||
|
|
||||||
spec_divide_factor: int
|
|
||||||
"""Factor to divide the spectrogram by."""
|
|
||||||
|
|
||||||
device: torch.device
|
|
||||||
"""Device to store the spectrogram on."""
|
|
||||||
|
|
||||||
max_freq: int
|
|
||||||
"""Maximum frequency to display in the spectrogram."""
|
|
||||||
|
|
||||||
min_freq: int
|
|
||||||
"""Minimum frequency to display in the spectrogram."""
|
|
||||||
|
|
||||||
spec_scale: str
|
|
||||||
"""Scale to use for the spectrogram."""
|
|
||||||
|
|
||||||
denoise_spec_avg: bool
|
|
||||||
"""Whether to denoise the spectrogram by averaging."""
|
|
||||||
|
|
||||||
max_scale_spec: bool
|
|
||||||
"""Whether to scale the spectrogram so that its max is 1."""
|
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: int,
|
||||||
params: SpectrogramParameters,
|
params: au.SpectrogramParameters,
|
||||||
|
device: torch.device,
|
||||||
return_np: bool = False,
|
return_np: bool = False,
|
||||||
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
||||||
"""Compute a spectrogram from an audio array.
|
"""Compute a spectrogram from an audio array.
|
||||||
@ -578,7 +548,7 @@ def compute_spectrogram(
|
|||||||
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
||||||
|
|
||||||
# convert to pytorch
|
# convert to pytorch
|
||||||
spec = torch.from_numpy(spec).to(params["device"])
|
spec = torch.from_numpy(spec).to(device)
|
||||||
|
|
||||||
# add batch and channel dimensions
|
# add batch and channel dimensions
|
||||||
spec = spec.unsqueeze(0).unsqueeze(0)
|
spec = spec.unsqueeze(0).unsqueeze(0)
|
||||||
@ -672,9 +642,6 @@ class ProcessingConfiguration(TypedDict):
|
|||||||
scale_raw_audio: bool
|
scale_raw_audio: bool
|
||||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||||
|
|
||||||
device: torch.device
|
|
||||||
"""Device to run the model on."""
|
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: List[str]
|
||||||
"""Names of the classes the model can detect."""
|
"""Names of the classes the model can detect."""
|
||||||
|
|
||||||
@ -721,33 +688,12 @@ class ProcessingConfiguration(TypedDict):
|
|||||||
"""Whether to return spectrogram slices."""
|
"""Whether to return spectrogram slices."""
|
||||||
|
|
||||||
|
|
||||||
def process_spectrogram(
|
def _process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
model: torch.nn.Module,
|
model: models.DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
):
|
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||||
"""Process a spectrogram with detection model.
|
|
||||||
|
|
||||||
Will run non-maximum suppression on the output of the model.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec : torch.Tensor
|
|
||||||
|
|
||||||
samplerate : int
|
|
||||||
|
|
||||||
model : torch.nn.Module
|
|
||||||
Detection model.
|
|
||||||
|
|
||||||
config : pp.NonMaximumSuppressionConfig
|
|
||||||
Parameters for non-maximum suppression.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
pred_nms : Dict[str, np.ndarray]
|
|
||||||
features : Dict[str, np.ndarray]
|
|
||||||
"""
|
|
||||||
# evaluate model
|
# evaluate model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(spec, return_feats=config["cnn_features"])
|
outputs = model(spec, return_feats=config["cnn_features"])
|
||||||
@ -781,12 +727,96 @@ def process_spectrogram(
|
|||||||
return pred_nms, features
|
return pred_nms, features
|
||||||
|
|
||||||
|
|
||||||
|
def process_spectrogram(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
samplerate: int,
|
||||||
|
model: models.DetectionModel,
|
||||||
|
config: ProcessingConfiguration,
|
||||||
|
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||||
|
"""Process a spectrogram with detection model.
|
||||||
|
|
||||||
|
Will run non-maximum suppression on the output of the model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
|
||||||
|
samplerate : int
|
||||||
|
|
||||||
|
model : torch.nn.Module
|
||||||
|
Detection model.
|
||||||
|
|
||||||
|
config : pp.NonMaximumSuppressionConfig
|
||||||
|
Parameters for non-maximum suppression.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
annotations : List[Annotation]
|
||||||
|
List of annotations predicted by the model.
|
||||||
|
features : List[np.ndarray]
|
||||||
|
List of CNN features associated with each annotation.
|
||||||
|
Is empty if `config["cnn_features"]` is False.
|
||||||
|
"""
|
||||||
|
pred_nms, features = _process_spectrogram(
|
||||||
|
spec,
|
||||||
|
samplerate,
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
annotations = get_annotations_from_preds(
|
||||||
|
pred_nms,
|
||||||
|
config["class_names"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return annotations, features
|
||||||
|
|
||||||
|
|
||||||
|
def _process_audio_array(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sampling_rate: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
config: ProcessingConfiguration,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||||
|
# load audio file and compute spectrogram
|
||||||
|
_, spec, _ = compute_spectrogram(
|
||||||
|
audio,
|
||||||
|
sampling_rate,
|
||||||
|
{
|
||||||
|
"fft_win_length": config["fft_win_length"],
|
||||||
|
"fft_overlap": config["fft_overlap"],
|
||||||
|
"spec_height": config["spec_height"],
|
||||||
|
"resize_factor": config["resize_factor"],
|
||||||
|
"spec_divide_factor": config["spec_divide_factor"],
|
||||||
|
"max_freq": config["max_freq"],
|
||||||
|
"min_freq": config["min_freq"],
|
||||||
|
"spec_scale": config["spec_scale"],
|
||||||
|
"denoise_spec_avg": config["denoise_spec_avg"],
|
||||||
|
"max_scale_spec": config["max_scale_spec"],
|
||||||
|
},
|
||||||
|
device,
|
||||||
|
return_np=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# process spectrogram with model
|
||||||
|
pred_nms, features = _process_spectrogram(
|
||||||
|
spec,
|
||||||
|
sampling_rate,
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pred_nms, features, spec
|
||||||
|
|
||||||
|
|
||||||
def process_audio_array(
|
def process_audio_array(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: int,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
):
|
device: torch.device,
|
||||||
|
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||||
"""Process a single audio array with detection model.
|
"""Process a single audio array with detection model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -801,47 +831,42 @@ def process_audio_array(
|
|||||||
config : ProcessingConfiguration
|
config : ProcessingConfiguration
|
||||||
Configuration for processing.
|
Configuration for processing.
|
||||||
|
|
||||||
|
device : torch.device
|
||||||
|
Device to use for processing.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
pred_nms : Dict[str, np.ndarray]
|
annotations : List[Annotation]
|
||||||
features : Dict[str, np.ndarray]
|
List of annotations predicted by the model.
|
||||||
spec_np : np.ndarray
|
|
||||||
"""
|
|
||||||
# load audio file and compute spectrogram
|
|
||||||
_, spec, spec_np = compute_spectrogram(
|
|
||||||
audio,
|
|
||||||
sampling_rate,
|
|
||||||
{
|
|
||||||
"fft_win_length": config["fft_win_length"],
|
|
||||||
"fft_overlap": config["fft_overlap"],
|
|
||||||
"spec_height": config["spec_height"],
|
|
||||||
"resize_factor": config["resize_factor"],
|
|
||||||
"spec_divide_factor": config["spec_divide_factor"],
|
|
||||||
"device": config["device"],
|
|
||||||
"max_freq": config["max_freq"],
|
|
||||||
"min_freq": config["min_freq"],
|
|
||||||
"spec_scale": config["spec_scale"],
|
|
||||||
"denoise_spec_avg": config["denoise_spec_avg"],
|
|
||||||
"max_scale_spec": config["max_scale_spec"],
|
|
||||||
},
|
|
||||||
return_np=config["spec_features"] or config["spec_slices"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# process spectrogram with model
|
features : List[np.ndarray]
|
||||||
pred_nms, features = process_spectrogram(
|
List of CNN features associated with each annotation.
|
||||||
spec,
|
|
||||||
|
spec : torch.Tensor
|
||||||
|
Spectrogram of the audio used as input.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pred_nms, features, spec = _process_audio_array(
|
||||||
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
model,
|
model,
|
||||||
config,
|
config,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return pred_nms, features, spec_np
|
annotations = get_annotations_from_preds(
|
||||||
|
pred_nms,
|
||||||
|
config["class_names"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return annotations, features, spec
|
||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
|
device: torch.device,
|
||||||
) -> Union[RunResults, Any]:
|
) -> Union[RunResults, Any]:
|
||||||
"""Process a single audio file with detection model.
|
"""Process a single audio file with detection model.
|
||||||
|
|
||||||
@ -872,7 +897,7 @@ def process_file(
|
|||||||
spec_slices = []
|
spec_slices = []
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full = au.load_audio_file(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
audio_file,
|
audio_file,
|
||||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||||
target_samp_rate=config["target_samp_rate"],
|
target_samp_rate=config["target_samp_rate"],
|
||||||
@ -881,7 +906,7 @@ def process_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# loop through larger file and split into chunks
|
# loop through larger file and split into chunks
|
||||||
# TODO fix so that it overlaps correctly and takes care of
|
# TODO: fix so that it overlaps correctly and takes care of
|
||||||
# duplicate detections at borders
|
# duplicate detections at borders
|
||||||
for chunk_time, audio in iterate_over_chunks(
|
for chunk_time, audio in iterate_over_chunks(
|
||||||
audio_full,
|
audio_full,
|
||||||
@ -889,11 +914,12 @@ def process_file(
|
|||||||
config["chunk_size"],
|
config["chunk_size"],
|
||||||
):
|
):
|
||||||
# Run detection model on chunk
|
# Run detection model on chunk
|
||||||
pred_nms, features, spec_np = process_audio_array(
|
pred_nms, features, spec_np = _process_audio_array(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
model,
|
model,
|
||||||
config,
|
config,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# add chunk time to start and end times
|
# add chunk time to start and end times
|
||||||
@ -965,11 +991,7 @@ def summarize_results(results, predictions, config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(**kwargs) -> ProcessingConfiguration:
|
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
|
||||||
"""Get default configuration for running detection model."""
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
args: ProcessingConfiguration = {
|
|
||||||
"detection_threshold": DETECTION_THRESHOLD,
|
"detection_threshold": DETECTION_THRESHOLD,
|
||||||
"spec_slices": False,
|
"spec_slices": False,
|
||||||
"chunk_size": 3,
|
"chunk_size": 3,
|
||||||
@ -983,7 +1005,6 @@ def get_default_config(**kwargs) -> ProcessingConfiguration:
|
|||||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||||
"spec_height": SPEC_HEIGHT,
|
"spec_height": SPEC_HEIGHT,
|
||||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||||
"device": device,
|
|
||||||
"class_names": [],
|
"class_names": [],
|
||||||
"time_expansion": 1,
|
"time_expansion": 1,
|
||||||
"top_n": 3,
|
"top_n": 3,
|
||||||
@ -997,7 +1018,3 @@ def get_default_config(**kwargs) -> ProcessingConfiguration:
|
|||||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||||
"max_scale_spec": MAX_SCALE_SPEC,
|
"max_scale_spec": MAX_SCALE_SPEC,
|
||||||
}
|
}
|
||||||
return {
|
|
||||||
**args,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
@ -114,7 +114,7 @@ if __name__ == "__main__":
|
|||||||
# load audio and crop
|
# load audio and crop
|
||||||
print("\nProcessing: " + os.path.basename(args_cmd["audio_file"]))
|
print("\nProcessing: " + os.path.basename(args_cmd["audio_file"]))
|
||||||
print("\nOutput directory: " + args_cmd["op_dir"])
|
print("\nOutput directory: " + args_cmd["op_dir"])
|
||||||
sampling_rate, audio = au.load_audio_file(
|
sampling_rate, audio = au.load_audio(
|
||||||
args_cmd["audio_file"],
|
args_cmd["audio_file"],
|
||||||
args_cmd["time_exp"],
|
args_cmd["time_exp"],
|
||||||
params_bd["target_samp_rate"],
|
params_bd["target_samp_rate"],
|
||||||
|
@ -96,7 +96,7 @@ if __name__ == "__main__":
|
|||||||
# load audio file
|
# load audio file
|
||||||
print("\nProcessing: " + os.path.basename(audio_file))
|
print("\nProcessing: " + os.path.basename(audio_file))
|
||||||
print("\nOutput directory: " + op_dir)
|
print("\nOutput directory: " + op_dir)
|
||||||
sampling_rate, audio = au.load_audio_file(
|
sampling_rate, audio = au.load_audio(
|
||||||
audio_file, args["time_expansion_factor"], params["target_samp_rate"]
|
audio_file, args["time_expansion_factor"], params["target_samp_rate"]
|
||||||
)
|
)
|
||||||
audio = audio[
|
audio = audio[
|
||||||
|
@ -72,7 +72,7 @@ def load_data(
|
|||||||
sampling_rates = []
|
sampling_rates = []
|
||||||
file_names = []
|
file_names = []
|
||||||
for cur_file in anns:
|
for cur_file in anns:
|
||||||
sampling_rate, audio_orig = au.load_audio_file(
|
sampling_rate, audio_orig = au.load_audio(
|
||||||
cur_file["file_path"],
|
cur_file["file_path"],
|
||||||
cur_file["time_exp"],
|
cur_file["time_exp"],
|
||||||
params["target_samp_rate"],
|
params["target_samp_rate"],
|
||||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/test_api.py
Normal file
0
tests/test_api.py
Normal file
213
tests/test_bat_detect.py
Normal file
213
tests/test_bat_detect.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
"""Test bat detect module API."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from bat_detect.api import (
|
||||||
|
generate_spectrogram,
|
||||||
|
get_config,
|
||||||
|
list_audio_files,
|
||||||
|
load_audio,
|
||||||
|
load_model,
|
||||||
|
process_audio,
|
||||||
|
process_file,
|
||||||
|
process_spectrogram,
|
||||||
|
)
|
||||||
|
|
||||||
|
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||||
|
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_with_default_params():
|
||||||
|
"""Test loading model with default parameters."""
|
||||||
|
model, params = load_model()
|
||||||
|
|
||||||
|
assert model is not None
|
||||||
|
assert isinstance(model, nn.Module)
|
||||||
|
|
||||||
|
assert params is not None
|
||||||
|
assert isinstance(params, dict)
|
||||||
|
|
||||||
|
assert "model_name" in params
|
||||||
|
assert "num_filters" in params
|
||||||
|
assert "emb_dim" in params
|
||||||
|
assert "ip_height" in params
|
||||||
|
assert "resize_factor" in params
|
||||||
|
assert "class_names" in params
|
||||||
|
|
||||||
|
assert params["model_name"] == "Net2DFast"
|
||||||
|
assert params["num_filters"] == 128
|
||||||
|
assert params["emb_dim"] == 0
|
||||||
|
assert params["ip_height"] == 128
|
||||||
|
assert params["resize_factor"] == 0.5
|
||||||
|
assert len(params["class_names"]) == 17
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_audio_files():
|
||||||
|
"""Test listing audio files."""
|
||||||
|
audio_files = list_audio_files(TEST_DATA_DIR)
|
||||||
|
|
||||||
|
assert len(audio_files) == 3
|
||||||
|
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_audio():
|
||||||
|
"""Test loading audio."""
|
||||||
|
samplerate, audio = load_audio(TEST_DATA[0])
|
||||||
|
|
||||||
|
assert audio is not None
|
||||||
|
assert samplerate == 256000
|
||||||
|
assert isinstance(audio, np.ndarray)
|
||||||
|
assert audio.shape == (128000,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_spectrogram():
|
||||||
|
"""Test generating spectrogram."""
|
||||||
|
samplerate, audio = load_audio(TEST_DATA[0])
|
||||||
|
spectrogram = generate_spectrogram(audio, samplerate)
|
||||||
|
|
||||||
|
assert spectrogram is not None
|
||||||
|
assert isinstance(spectrogram, torch.Tensor)
|
||||||
|
assert spectrogram.shape == (1, 1, 128, 512)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_config():
|
||||||
|
"""Test getting default configuration."""
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
|
assert isinstance(config, dict)
|
||||||
|
|
||||||
|
assert config["target_samp_rate"] == 256000
|
||||||
|
assert config["fft_win_length"] == 0.002
|
||||||
|
assert config["fft_overlap"] == 0.75
|
||||||
|
assert config["resize_factor"] == 0.5
|
||||||
|
assert config["spec_divide_factor"] == 32
|
||||||
|
assert config["spec_height"] == 256
|
||||||
|
assert config["spec_scale"] == "pcen"
|
||||||
|
assert config["denoise_spec_avg"] is True
|
||||||
|
assert config["max_scale_spec"] is False
|
||||||
|
assert config["scale_raw_audio"] is False
|
||||||
|
assert len(config["class_names"]) == 0
|
||||||
|
assert config["detection_threshold"] == 0.01
|
||||||
|
assert config["time_expansion"] == 1
|
||||||
|
assert config["top_n"] == 3
|
||||||
|
assert config["return_raw_preds"] is False
|
||||||
|
assert config["max_duration"] is None
|
||||||
|
assert config["nms_kernel_size"] == 9
|
||||||
|
assert config["max_freq"] == 120000
|
||||||
|
assert config["min_freq"] == 10000
|
||||||
|
assert config["nms_top_k_per_sec"] == 200
|
||||||
|
assert config["quiet"] is True
|
||||||
|
assert config["chunk_size"] == 3
|
||||||
|
assert config["cnn_features"] is False
|
||||||
|
assert config["spec_features"] is False
|
||||||
|
assert config["spec_slices"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_file_with_model():
|
||||||
|
"""Test processing file with model."""
|
||||||
|
model, params = load_model()
|
||||||
|
config = get_config(**params)
|
||||||
|
predictions = process_file(TEST_DATA[0], model, config=config)
|
||||||
|
|
||||||
|
assert predictions is not None
|
||||||
|
assert isinstance(predictions, dict)
|
||||||
|
|
||||||
|
assert "pred_dict" in predictions
|
||||||
|
assert "spec_feats" in predictions
|
||||||
|
assert "spec_feat_names" in predictions
|
||||||
|
assert "cnn_feats" in predictions
|
||||||
|
assert "cnn_feat_names" in predictions
|
||||||
|
assert "spec_slices" in predictions
|
||||||
|
|
||||||
|
# By default will not return spectrogram features
|
||||||
|
assert predictions["spec_feats"] is None
|
||||||
|
assert predictions["spec_feat_names"] is None
|
||||||
|
assert predictions["cnn_feats"] is None
|
||||||
|
assert predictions["cnn_feat_names"] is None
|
||||||
|
assert predictions["spec_slices"] is None
|
||||||
|
|
||||||
|
# Check that predictions are returned
|
||||||
|
assert isinstance(predictions["pred_dict"], dict)
|
||||||
|
pred_dict = predictions["pred_dict"]
|
||||||
|
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
|
||||||
|
assert pred_dict["annotated"] is False
|
||||||
|
assert pred_dict["issues"] is False
|
||||||
|
assert pred_dict["notes"] == "Automatically generated."
|
||||||
|
assert pred_dict["time_exp"] == 1
|
||||||
|
assert pred_dict["duration"] == 0.5
|
||||||
|
assert pred_dict["class_name"] is not None
|
||||||
|
assert len(pred_dict["annotation"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_spectrogram_with_model():
|
||||||
|
"""Test processing spectrogram with model."""
|
||||||
|
model, params = load_model()
|
||||||
|
config = get_config(**params)
|
||||||
|
samplerate, audio = load_audio(TEST_DATA[0])
|
||||||
|
spectrogram = generate_spectrogram(audio, samplerate)
|
||||||
|
predictions, features = process_spectrogram(
|
||||||
|
spectrogram,
|
||||||
|
samplerate,
|
||||||
|
model,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert predictions is not None
|
||||||
|
assert isinstance(predictions, list)
|
||||||
|
assert len(predictions) > 0
|
||||||
|
sample_pred = predictions[0]
|
||||||
|
assert isinstance(sample_pred, dict)
|
||||||
|
assert "class" in sample_pred
|
||||||
|
assert "class_prob" in sample_pred
|
||||||
|
assert "det_prob" in sample_pred
|
||||||
|
assert "start_time" in sample_pred
|
||||||
|
assert "end_time" in sample_pred
|
||||||
|
assert "low_freq" in sample_pred
|
||||||
|
assert "high_freq" in sample_pred
|
||||||
|
|
||||||
|
assert features is not None
|
||||||
|
assert isinstance(features, list)
|
||||||
|
# By default will not return cnn features
|
||||||
|
assert len(features) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_audio_with_model():
|
||||||
|
"""Test processing audio with model."""
|
||||||
|
model, params = load_model()
|
||||||
|
config = get_config(**params)
|
||||||
|
samplerate, audio = load_audio(TEST_DATA[0])
|
||||||
|
predictions, features, spec = process_audio(
|
||||||
|
audio,
|
||||||
|
samplerate,
|
||||||
|
model,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert predictions is not None
|
||||||
|
assert isinstance(predictions, list)
|
||||||
|
assert len(predictions) > 0
|
||||||
|
sample_pred = predictions[0]
|
||||||
|
assert isinstance(sample_pred, dict)
|
||||||
|
assert "class" in sample_pred
|
||||||
|
assert "class_prob" in sample_pred
|
||||||
|
assert "det_prob" in sample_pred
|
||||||
|
assert "start_time" in sample_pred
|
||||||
|
assert "end_time" in sample_pred
|
||||||
|
assert "low_freq" in sample_pred
|
||||||
|
assert "high_freq" in sample_pred
|
||||||
|
|
||||||
|
assert features is not None
|
||||||
|
assert isinstance(features, list)
|
||||||
|
# By default will not return cnn features
|
||||||
|
assert len(features) == 0
|
||||||
|
|
||||||
|
assert spec is not None
|
||||||
|
assert isinstance(spec, torch.Tensor)
|
||||||
|
assert spec.shape == (1, 1, 128, 512)
|
0
tests/test_cli.py
Normal file
0
tests/test_cli.py
Normal file
Loading…
Reference in New Issue
Block a user