mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Moved types to a single module
This commit is contained in:
parent
f0b0f28379
commit
6f2bb605d3
@ -3,10 +3,19 @@ from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import bat_detect.detector.models as md
|
||||
import bat_detect.utils.audio_utils as au
|
||||
import bat_detect.utils.detector_utils as du
|
||||
from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ
|
||||
from bat_detect.detector.parameters import (
|
||||
DEFAULT_PROCESSING_CONFIGURATIONS,
|
||||
DEFAULT_SPECTROGRAM_PARAMETERS,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
)
|
||||
from bat_detect.types import (
|
||||
Annotation,
|
||||
DetectionModel,
|
||||
ProcessingConfiguration,
|
||||
SpectrogramParameters,
|
||||
)
|
||||
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
||||
|
||||
# Use GPU if available
|
||||
@ -24,12 +33,12 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def get_config(**kwargs) -> du.ProcessingConfiguration:
|
||||
def get_config(**kwargs) -> ProcessingConfiguration:
|
||||
"""Get default processing configuration.
|
||||
|
||||
Can be used to override default parameters by passing keyword arguments.
|
||||
"""
|
||||
return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
|
||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
|
||||
|
||||
|
||||
def load_audio(
|
||||
@ -73,7 +82,7 @@ def load_audio(
|
||||
def generate_spectrogram(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int,
|
||||
config: Optional[au.SpectrogramParameters] = None,
|
||||
config: Optional[SpectrogramParameters] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> torch.Tensor:
|
||||
"""Generate spectrogram from audio array.
|
||||
@ -93,7 +102,7 @@ def generate_spectrogram(
|
||||
Spectrogram.
|
||||
"""
|
||||
if config is None:
|
||||
config = au.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
config = DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
_, spec, _ = du.compute_spectrogram(
|
||||
audio,
|
||||
@ -108,8 +117,8 @@ def generate_spectrogram(
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: md.DetectionModel,
|
||||
config: Optional[du.ProcessingConfiguration] = None,
|
||||
model: DetectionModel,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
@ -126,7 +135,7 @@ def process_file(
|
||||
Device to use, by default tries to use GPU if available.
|
||||
"""
|
||||
if config is None:
|
||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
|
||||
return du.process_file(
|
||||
audio_file,
|
||||
@ -139,9 +148,9 @@ def process_file(
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samp_rate: int,
|
||||
model: md.DetectionModel,
|
||||
config: Optional[du.ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[du.Annotation], List[np.ndarray]]:
|
||||
model: DetectionModel,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
"""Process spectrogram with model.
|
||||
|
||||
Parameters
|
||||
@ -160,7 +169,7 @@ def process_spectrogram(
|
||||
DetectionResult
|
||||
"""
|
||||
if config is None:
|
||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
|
||||
return du.process_spectrogram(
|
||||
spec,
|
||||
@ -173,10 +182,10 @@ def process_spectrogram(
|
||||
def process_audio(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int,
|
||||
model: md.DetectionModel,
|
||||
config: Optional[du.ProcessingConfiguration] = None,
|
||||
model: DetectionModel,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]:
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
"""Process audio array with model.
|
||||
|
||||
Parameters
|
||||
@ -204,7 +213,7 @@ def process_audio(
|
||||
Spectrogram of the audio used for prediction.
|
||||
"""
|
||||
if config is None:
|
||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
|
||||
return du.process_audio_array(
|
||||
audio,
|
||||
|
@ -7,7 +7,7 @@ Example usage:
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import bat_detect.utils.detector_utils as du
|
||||
from bat_detect import api
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
@ -12,86 +10,15 @@ from bat_detect.detector.model_helpers import (
|
||||
ConvBlockUpStandard,
|
||||
SelfAttention,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
from bat_detect.types import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"ModelOutput",
|
||||
"DetectionModel",
|
||||
]
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Output of the detection model."""
|
||||
|
||||
pred_det: torch.Tensor
|
||||
"""Tensor with predict detection probabilities."""
|
||||
|
||||
pred_size: torch.Tensor
|
||||
"""Tensor with predicted bounding box sizes."""
|
||||
|
||||
pred_class: torch.Tensor
|
||||
"""Tensor with predicted class probabilities."""
|
||||
|
||||
pred_class_un_norm: torch.Tensor
|
||||
"""Tensor with predicted class probabilities before softmax."""
|
||||
|
||||
features: torch.Tensor
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class DetectionModel(Protocol):
|
||||
"""Protocol for detection models.
|
||||
|
||||
This protocol is used to define the interface for the detection models.
|
||||
This allows us to use the same code for training and inference, even
|
||||
though the models are different.
|
||||
"""
|
||||
|
||||
num_classes: int
|
||||
"""Number of classes the model can classify."""
|
||||
|
||||
emb_dim: int
|
||||
"""Dimension of the embedding vector."""
|
||||
|
||||
num_filts: int
|
||||
"""Number of filters in the model."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input is resized."""
|
||||
|
||||
ip_height: int
|
||||
"""Height of the input image."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model.
|
||||
|
||||
When `return_feats` is `True`, the model should return the
|
||||
intermediate features of the model.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model.
|
||||
|
||||
When `return_feats` is `True`, the model should return the
|
||||
int
|
||||
"""
|
||||
|
||||
|
||||
class Net2DFast(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1,6 +1,11 @@
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from bat_detect.types import (
|
||||
ProcessingConfiguration,
|
||||
SpectrogramParameters,
|
||||
)
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
@ -18,6 +23,56 @@ DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
|
||||
"detection_threshold": DETECTION_THRESHOLD,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 3,
|
||||
"spec_features": False,
|
||||
"cnn_features": False,
|
||||
"quiet": True,
|
||||
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||
"class_names": [],
|
||||
"time_expansion": 1,
|
||||
"top_n": 3,
|
||||
"return_raw_preds": False,
|
||||
"max_duration": None,
|
||||
"nms_kernel_size": NMS_KERNEL_SIZE,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
||||
|
||||
def mk_dir(path):
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path)
|
||||
|
@ -1,24 +1,18 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
sys.path.append(os.path.join("..", ".."))
|
||||
|
||||
import warnings
|
||||
|
||||
import bat_detect.detector.models as models
|
||||
import bat_detect.detector.parameters as parameters
|
||||
from bat_detect.detector import models
|
||||
from bat_detect.detector import parameters
|
||||
from bat_detect.train import losses
|
||||
import bat_detect.detector.post_process as pp
|
||||
import bat_detect.train.audio_dataloader as adl
|
||||
import bat_detect.train.evaluate as evl
|
||||
import bat_detect.train.losses as losses
|
||||
import bat_detect.train.train_split as ts
|
||||
import bat_detect.train.train_utils as tu
|
||||
import bat_detect.utils.plot_utils as pu
|
||||
|
304
bat_detect/types.py
Normal file
304
bat_detect/types.py
Normal file
@ -0,0 +1,304 @@
|
||||
"""Types used in the code base."""
|
||||
from typing import List, Optional, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Percentage of overlap between FFT windows."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to display in the spectrogram."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to display in the spectrogram."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
|
||||
class ModelParameters(TypedDict):
|
||||
"""Model parameters."""
|
||||
|
||||
model_name: str
|
||||
"""Model name."""
|
||||
|
||||
num_filters: int
|
||||
"""Number of filters."""
|
||||
|
||||
emb_dim: int
|
||||
"""Embedding dimension."""
|
||||
|
||||
ip_height: int
|
||||
"""Input height in pixels."""
|
||||
|
||||
resize_factor: float
|
||||
"""Resize factor."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names. The model is trained to detect these classes."""
|
||||
|
||||
|
||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
|
||||
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
tool.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
"""Start time in seconds."""
|
||||
|
||||
end_time: float
|
||||
"""End time in seconds."""
|
||||
|
||||
low_freq: int
|
||||
"""Low frequency in Hz."""
|
||||
|
||||
high_freq: int
|
||||
"""High frequency in Hz."""
|
||||
|
||||
class_prob: float
|
||||
"""Probability of class assignment."""
|
||||
|
||||
det_prob: float
|
||||
"""Probability of detection."""
|
||||
|
||||
individual: str
|
||||
"""Individual ID."""
|
||||
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
|
||||
class FileAnnotations(TypedDict):
|
||||
"""Format of results.
|
||||
|
||||
This is the format of the results expected by the annotation tool.
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""File ID."""
|
||||
|
||||
annotated: bool
|
||||
"""Whether file has been annotated."""
|
||||
|
||||
duration: float
|
||||
"""Duration of audio file."""
|
||||
|
||||
issues: bool
|
||||
"""Whether file has issues."""
|
||||
|
||||
time_exp: float
|
||||
"""Time expansion factor."""
|
||||
|
||||
class_name: str
|
||||
"""Class predicted at file level"""
|
||||
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
|
||||
annotation: List[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
|
||||
class RunResults(TypedDict):
|
||||
"""Run results."""
|
||||
|
||||
pred_dict: FileAnnotations
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: Optional[List[np.ndarray]]
|
||||
"""Spectrogram features."""
|
||||
|
||||
spec_feat_names: Optional[List[str]]
|
||||
"""Spectrogram feature names."""
|
||||
|
||||
cnn_feats: Optional[List[np.ndarray]]
|
||||
"""CNN features."""
|
||||
|
||||
cnn_feat_names: Optional[List[str]]
|
||||
"""CNN feature names."""
|
||||
|
||||
spec_slices: Optional[List[np.ndarray]]
|
||||
"""Spectrogram slices."""
|
||||
|
||||
|
||||
class ResultParams(TypedDict):
|
||||
"""Result parameters."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names."""
|
||||
|
||||
|
||||
class ProcessingConfiguration(TypedDict):
|
||||
"""Parameters for processing audio files."""
|
||||
|
||||
# audio parameters
|
||||
target_samp_rate: int
|
||||
"""Target sampling rate of the audio."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Length of the FFT window in samples."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
scale_raw_audio: bool
|
||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Names of the classes the model can detect."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
time_expansion: Optional[float]
|
||||
"""Time expansion factor of the processed recordings."""
|
||||
|
||||
top_n: int
|
||||
"""Number of top detections to keep."""
|
||||
|
||||
return_raw_preds: bool
|
||||
"""Whether to return raw predictions."""
|
||||
|
||||
max_duration: Optional[float]
|
||||
"""Maximum duration of audio file to process in seconds."""
|
||||
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
quiet: bool
|
||||
"""Whether to suppress output."""
|
||||
|
||||
chunk_size: float
|
||||
"""Size of chunks to process in seconds."""
|
||||
|
||||
cnn_features: bool
|
||||
"""Whether to return CNN features."""
|
||||
|
||||
spec_features: bool
|
||||
"""Whether to return spectrogram features."""
|
||||
|
||||
spec_slices: bool
|
||||
"""Whether to return spectrogram slices."""
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Output of the detection model."""
|
||||
|
||||
pred_det: torch.Tensor
|
||||
"""Tensor with predict detection probabilities."""
|
||||
|
||||
pred_size: torch.Tensor
|
||||
"""Tensor with predicted bounding box sizes."""
|
||||
|
||||
pred_class: torch.Tensor
|
||||
"""Tensor with predicted class probabilities."""
|
||||
|
||||
pred_class_un_norm: torch.Tensor
|
||||
"""Tensor with predicted class probabilities before softmax."""
|
||||
|
||||
features: torch.Tensor
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class DetectionModel(Protocol):
|
||||
"""Protocol for detection models.
|
||||
|
||||
This protocol is used to define the interface for the detection models.
|
||||
This allows us to use the same code for training and inference, even
|
||||
though the models are different.
|
||||
"""
|
||||
|
||||
num_classes: int
|
||||
"""Number of classes the model can classify."""
|
||||
|
||||
emb_dim: int
|
||||
"""Dimension of the embedding vector."""
|
||||
|
||||
num_filts: int
|
||||
"""Number of filters in the model."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input is resized."""
|
||||
|
||||
ip_height: int
|
||||
"""Height of the input image."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
@ -38,54 +38,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Percentage of overlap between FFT windows."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to display in the spectrogram."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to display in the spectrogram."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
|
||||
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
||||
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
|
@ -11,34 +11,16 @@ import bat_detect.detector.compute_features as feats
|
||||
import bat_detect.detector.post_process as pp
|
||||
import bat_detect.utils.audio_utils as au
|
||||
from bat_detect.detector import models
|
||||
from bat_detect.detector.parameters import (
|
||||
DENOISE_SPEC_AVG,
|
||||
DETECTION_THRESHOLD,
|
||||
FFT_OVERLAP,
|
||||
FFT_WIN_LENGTH_S,
|
||||
MAX_FREQ_HZ,
|
||||
MAX_SCALE_SPEC,
|
||||
MIN_FREQ_HZ,
|
||||
NMS_KERNEL_SIZE,
|
||||
NMS_TOP_K_PER_SEC,
|
||||
RESIZE_FACTOR,
|
||||
SCALE_RAW_AUDIO,
|
||||
SPEC_DIVIDE_FACTOR,
|
||||
SPEC_HEIGHT,
|
||||
SPEC_SCALE,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from bat_detect.types import (
|
||||
Annotation,
|
||||
FileAnnotations,
|
||||
ModelParameters,
|
||||
ProcessingConfiguration,
|
||||
SpectrogramParameters,
|
||||
ResultParams,
|
||||
RunResults,
|
||||
DetectionModel
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -50,8 +32,6 @@ __all__ = [
|
||||
"process_spectrogram",
|
||||
"process_audio_array",
|
||||
"process_file",
|
||||
"DEFAULT_MODEL_PATH",
|
||||
"DEFAULT_PROCESSING_CONFIGURATIONS",
|
||||
]
|
||||
|
||||
|
||||
@ -77,33 +57,11 @@ def list_audio_files(ip_dir: str) -> List[str]:
|
||||
return matches
|
||||
|
||||
|
||||
class ModelParameters(TypedDict):
|
||||
"""Model parameters."""
|
||||
|
||||
model_name: str
|
||||
"""Model name."""
|
||||
|
||||
num_filters: int
|
||||
"""Number of filters."""
|
||||
|
||||
emb_dim: int
|
||||
"""Embedding dimension."""
|
||||
|
||||
ip_height: int
|
||||
"""Input height in pixels."""
|
||||
|
||||
resize_factor: float
|
||||
"""Resize factor."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names. The model is trained to detect these classes."""
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[models.DetectionModel, ModelParameters]:
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
@ -129,7 +87,7 @@ def load_model(
|
||||
|
||||
params = net_params["params"]
|
||||
|
||||
model: models.DetectionModel
|
||||
model: DetectionModel
|
||||
|
||||
if params["model_name"] == "Net2DFast":
|
||||
model = models.Net2DFast(
|
||||
@ -189,101 +147,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
return predictions_m, spec_feats, cnn_feats, spec_slices
|
||||
|
||||
|
||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
|
||||
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
tool.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
"""Start time in seconds."""
|
||||
|
||||
end_time: float
|
||||
"""End time in seconds."""
|
||||
|
||||
low_freq: int
|
||||
"""Low frequency in Hz."""
|
||||
|
||||
high_freq: int
|
||||
"""High frequency in Hz."""
|
||||
|
||||
class_prob: float
|
||||
"""Probability of class assignment."""
|
||||
|
||||
det_prob: float
|
||||
"""Probability of detection."""
|
||||
|
||||
individual: str
|
||||
"""Individual ID."""
|
||||
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
|
||||
class FileAnnotations(TypedDict):
|
||||
"""Format of results.
|
||||
|
||||
This is the format of the results expected by the annotation tool.
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""File ID."""
|
||||
|
||||
annotated: bool
|
||||
"""Whether file has been annotated."""
|
||||
|
||||
duration: float
|
||||
"""Duration of audio file."""
|
||||
|
||||
issues: bool
|
||||
"""Whether file has issues."""
|
||||
|
||||
time_exp: float
|
||||
"""Time expansion factor."""
|
||||
|
||||
class_name: str
|
||||
"""Class predicted at file level"""
|
||||
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
|
||||
annotation: List[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
|
||||
class RunResults(TypedDict):
|
||||
"""Run results."""
|
||||
|
||||
pred_dict: FileAnnotations
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: Optional[List[np.ndarray]]
|
||||
"""Spectrogram features."""
|
||||
|
||||
spec_feat_names: Optional[List[str]]
|
||||
"""Spectrogram feature names."""
|
||||
|
||||
cnn_feats: Optional[List[np.ndarray]]
|
||||
"""CNN features."""
|
||||
|
||||
cnn_feat_names: Optional[List[str]]
|
||||
"""CNN feature names."""
|
||||
|
||||
spec_slices: Optional[List[np.ndarray]]
|
||||
"""Spectrogram slices."""
|
||||
|
||||
|
||||
class ResultParams(TypedDict):
|
||||
"""Result parameters."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Class names."""
|
||||
|
||||
|
||||
def get_annotations_from_preds(
|
||||
predictions,
|
||||
class_names: List[str],
|
||||
@ -499,7 +362,7 @@ def save_results_to_file(results, op_path: str) -> None:
|
||||
def compute_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
params: au.SpectrogramParameters,
|
||||
params: SpectrogramParameters,
|
||||
device: torch.device,
|
||||
return_np: bool = False,
|
||||
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
||||
@ -608,90 +471,10 @@ def iterate_over_chunks(
|
||||
yield chunk_start, audio[start_sample:end_sample]
|
||||
|
||||
|
||||
class ProcessingConfiguration(TypedDict):
|
||||
"""Parameters for processing audio files."""
|
||||
|
||||
# audio parameters
|
||||
target_samp_rate: int
|
||||
"""Target sampling rate of the audio."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Length of the FFT window in samples."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor to resize the spectrogram by."""
|
||||
|
||||
spec_divide_factor: int
|
||||
"""Factor to divide the spectrogram by."""
|
||||
|
||||
spec_height: int
|
||||
"""Height of the spectrogram in pixels."""
|
||||
|
||||
spec_scale: str
|
||||
"""Scale to use for the spectrogram."""
|
||||
|
||||
denoise_spec_avg: bool
|
||||
"""Whether to denoise the spectrogram by averaging."""
|
||||
|
||||
max_scale_spec: bool
|
||||
"""Whether to scale the spectrogram so that its max is 1."""
|
||||
|
||||
scale_raw_audio: bool
|
||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||
|
||||
class_names: List[str]
|
||||
"""Names of the classes the model can detect."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
time_expansion: Optional[float]
|
||||
"""Time expansion factor of the processed recordings."""
|
||||
|
||||
top_n: int
|
||||
"""Number of top detections to keep."""
|
||||
|
||||
return_raw_preds: bool
|
||||
"""Whether to return raw predictions."""
|
||||
|
||||
max_duration: Optional[float]
|
||||
"""Maximum duration of audio file to process in seconds."""
|
||||
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
quiet: bool
|
||||
"""Whether to suppress output."""
|
||||
|
||||
chunk_size: float
|
||||
"""Size of chunks to process in seconds."""
|
||||
|
||||
cnn_features: bool
|
||||
"""Whether to return CNN features."""
|
||||
|
||||
spec_features: bool
|
||||
"""Whether to return spectrogram features."""
|
||||
|
||||
spec_slices: bool
|
||||
"""Whether to return spectrogram slices."""
|
||||
|
||||
|
||||
def _process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: int,
|
||||
model: models.DetectionModel,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
# evaluate model
|
||||
@ -730,7 +513,7 @@ def _process_spectrogram(
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: int,
|
||||
model: models.DetectionModel,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
"""Process a spectrogram with detection model.
|
||||
@ -775,7 +558,7 @@ def process_spectrogram(
|
||||
def _process_audio_array(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
model: torch.nn.Module,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
@ -813,7 +596,7 @@ def _process_audio_array(
|
||||
def process_audio_array(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
model: torch.nn.Module,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
@ -864,7 +647,7 @@ def process_audio_array(
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: torch.nn.Module,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Union[RunResults, Any]:
|
||||
@ -989,32 +772,3 @@ def summarize_results(results, predictions, config):
|
||||
config["class_names"][class_index].ljust(30)
|
||||
+ str(round(class_overall[class_index], 3))
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
|
||||
"detection_threshold": DETECTION_THRESHOLD,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 3,
|
||||
"spec_features": False,
|
||||
"cnn_features": False,
|
||||
"quiet": True,
|
||||
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||
"class_names": [],
|
||||
"time_expansion": 1,
|
||||
"top_n": 3,
|
||||
"return_raw_preds": False,
|
||||
"max_duration": None,
|
||||
"nms_kernel_size": NMS_KERNEL_SIZE,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
@ -0,0 +1,211 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from bat_detect.api import (
|
||||
generate_spectrogram,
|
||||
get_config,
|
||||
list_audio_files,
|
||||
load_audio,
|
||||
load_model,
|
||||
process_audio,
|
||||
process_file,
|
||||
process_spectrogram,
|
||||
)
|
||||
|
||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||
|
||||
|
||||
def test_load_model_with_default_params():
|
||||
"""Test loading model with default parameters."""
|
||||
model, params = load_model()
|
||||
|
||||
assert model is not None
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
assert params is not None
|
||||
assert isinstance(params, dict)
|
||||
|
||||
assert "model_name" in params
|
||||
assert "num_filters" in params
|
||||
assert "emb_dim" in params
|
||||
assert "ip_height" in params
|
||||
assert "resize_factor" in params
|
||||
assert "class_names" in params
|
||||
|
||||
assert params["model_name"] == "Net2DFast"
|
||||
assert params["num_filters"] == 128
|
||||
assert params["emb_dim"] == 0
|
||||
assert params["ip_height"] == 128
|
||||
assert params["resize_factor"] == 0.5
|
||||
assert len(params["class_names"]) == 17
|
||||
|
||||
|
||||
def test_list_audio_files():
|
||||
"""Test listing audio files."""
|
||||
audio_files = list_audio_files(TEST_DATA_DIR)
|
||||
|
||||
assert len(audio_files) == 3
|
||||
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
||||
|
||||
|
||||
def test_load_audio():
|
||||
"""Test loading audio."""
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
|
||||
assert audio is not None
|
||||
assert samplerate == 256000
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert audio.shape == (128000,)
|
||||
|
||||
|
||||
def test_generate_spectrogram():
|
||||
"""Test generating spectrogram."""
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
spectrogram = generate_spectrogram(audio, samplerate)
|
||||
|
||||
assert spectrogram is not None
|
||||
assert isinstance(spectrogram, torch.Tensor)
|
||||
assert spectrogram.shape == (1, 1, 128, 512)
|
||||
|
||||
|
||||
def test_get_default_config():
|
||||
"""Test getting default configuration."""
|
||||
config = get_config()
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, dict)
|
||||
|
||||
assert config["target_samp_rate"] == 256000
|
||||
assert config["fft_win_length"] == 0.002
|
||||
assert config["fft_overlap"] == 0.75
|
||||
assert config["resize_factor"] == 0.5
|
||||
assert config["spec_divide_factor"] == 32
|
||||
assert config["spec_height"] == 256
|
||||
assert config["spec_scale"] == "pcen"
|
||||
assert config["denoise_spec_avg"] is True
|
||||
assert config["max_scale_spec"] is False
|
||||
assert config["scale_raw_audio"] is False
|
||||
assert len(config["class_names"]) == 0
|
||||
assert config["detection_threshold"] == 0.01
|
||||
assert config["time_expansion"] == 1
|
||||
assert config["top_n"] == 3
|
||||
assert config["return_raw_preds"] is False
|
||||
assert config["max_duration"] is None
|
||||
assert config["nms_kernel_size"] == 9
|
||||
assert config["max_freq"] == 120000
|
||||
assert config["min_freq"] == 10000
|
||||
assert config["nms_top_k_per_sec"] == 200
|
||||
assert config["quiet"] is True
|
||||
assert config["chunk_size"] == 3
|
||||
assert config["cnn_features"] is False
|
||||
assert config["spec_features"] is False
|
||||
assert config["spec_slices"] is False
|
||||
|
||||
|
||||
def test_process_file_with_model():
|
||||
"""Test processing file with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
predictions = process_file(TEST_DATA[0], model, config=config)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, dict)
|
||||
|
||||
assert "pred_dict" in predictions
|
||||
assert "spec_feats" in predictions
|
||||
assert "spec_feat_names" in predictions
|
||||
assert "cnn_feats" in predictions
|
||||
assert "cnn_feat_names" in predictions
|
||||
assert "spec_slices" in predictions
|
||||
|
||||
# By default will not return spectrogram features
|
||||
assert predictions["spec_feats"] is None
|
||||
assert predictions["spec_feat_names"] is None
|
||||
assert predictions["cnn_feats"] is None
|
||||
assert predictions["cnn_feat_names"] is None
|
||||
assert predictions["spec_slices"] is None
|
||||
|
||||
# Check that predictions are returned
|
||||
assert isinstance(predictions["pred_dict"], dict)
|
||||
pred_dict = predictions["pred_dict"]
|
||||
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
|
||||
assert pred_dict["annotated"] is False
|
||||
assert pred_dict["issues"] is False
|
||||
assert pred_dict["notes"] == "Automatically generated."
|
||||
assert pred_dict["time_exp"] == 1
|
||||
assert pred_dict["duration"] == 0.5
|
||||
assert pred_dict["class_name"] is not None
|
||||
assert len(pred_dict["annotation"]) > 0
|
||||
|
||||
|
||||
def test_process_spectrogram_with_model():
|
||||
"""Test processing spectrogram with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
spectrogram = generate_spectrogram(audio, samplerate)
|
||||
predictions, features = process_spectrogram(
|
||||
spectrogram,
|
||||
samplerate,
|
||||
model,
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 1
|
||||
|
||||
|
||||
def test_process_audio_with_model():
|
||||
"""Test processing audio with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
predictions, features, spec = process_audio(
|
||||
audio,
|
||||
samplerate,
|
||||
model,
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 1
|
||||
|
||||
assert spec is not None
|
||||
assert isinstance(spec, torch.Tensor)
|
||||
assert spec.shape == (1, 1, 128, 512)
|
@ -1,211 +0,0 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from bat_detect.api import (
|
||||
generate_spectrogram,
|
||||
get_config,
|
||||
list_audio_files,
|
||||
load_audio,
|
||||
load_model,
|
||||
process_audio,
|
||||
process_file,
|
||||
process_spectrogram,
|
||||
)
|
||||
|
||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||
|
||||
|
||||
def test_load_model_with_default_params():
|
||||
"""Test loading model with default parameters."""
|
||||
model, params = load_model()
|
||||
|
||||
assert model is not None
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
assert params is not None
|
||||
assert isinstance(params, dict)
|
||||
|
||||
assert "model_name" in params
|
||||
assert "num_filters" in params
|
||||
assert "emb_dim" in params
|
||||
assert "ip_height" in params
|
||||
assert "resize_factor" in params
|
||||
assert "class_names" in params
|
||||
|
||||
assert params["model_name"] == "Net2DFast"
|
||||
assert params["num_filters"] == 128
|
||||
assert params["emb_dim"] == 0
|
||||
assert params["ip_height"] == 128
|
||||
assert params["resize_factor"] == 0.5
|
||||
assert len(params["class_names"]) == 17
|
||||
|
||||
|
||||
def test_list_audio_files():
|
||||
"""Test listing audio files."""
|
||||
audio_files = list_audio_files(TEST_DATA_DIR)
|
||||
|
||||
assert len(audio_files) == 3
|
||||
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
||||
|
||||
|
||||
def test_load_audio():
|
||||
"""Test loading audio."""
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
|
||||
assert audio is not None
|
||||
assert samplerate == 256000
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert audio.shape == (128000,)
|
||||
|
||||
|
||||
def test_generate_spectrogram():
|
||||
"""Test generating spectrogram."""
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
spectrogram = generate_spectrogram(audio, samplerate)
|
||||
|
||||
assert spectrogram is not None
|
||||
assert isinstance(spectrogram, torch.Tensor)
|
||||
assert spectrogram.shape == (1, 1, 128, 512)
|
||||
|
||||
|
||||
def test_get_default_config():
|
||||
"""Test getting default configuration."""
|
||||
config = get_config()
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, dict)
|
||||
|
||||
assert config["target_samp_rate"] == 256000
|
||||
assert config["fft_win_length"] == 0.002
|
||||
assert config["fft_overlap"] == 0.75
|
||||
assert config["resize_factor"] == 0.5
|
||||
assert config["spec_divide_factor"] == 32
|
||||
assert config["spec_height"] == 256
|
||||
assert config["spec_scale"] == "pcen"
|
||||
assert config["denoise_spec_avg"] is True
|
||||
assert config["max_scale_spec"] is False
|
||||
assert config["scale_raw_audio"] is False
|
||||
assert len(config["class_names"]) == 0
|
||||
assert config["detection_threshold"] == 0.01
|
||||
assert config["time_expansion"] == 1
|
||||
assert config["top_n"] == 3
|
||||
assert config["return_raw_preds"] is False
|
||||
assert config["max_duration"] is None
|
||||
assert config["nms_kernel_size"] == 9
|
||||
assert config["max_freq"] == 120000
|
||||
assert config["min_freq"] == 10000
|
||||
assert config["nms_top_k_per_sec"] == 200
|
||||
assert config["quiet"] is True
|
||||
assert config["chunk_size"] == 3
|
||||
assert config["cnn_features"] is False
|
||||
assert config["spec_features"] is False
|
||||
assert config["spec_slices"] is False
|
||||
|
||||
|
||||
def test_process_file_with_model():
|
||||
"""Test processing file with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
predictions = process_file(TEST_DATA[0], model, config=config)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, dict)
|
||||
|
||||
assert "pred_dict" in predictions
|
||||
assert "spec_feats" in predictions
|
||||
assert "spec_feat_names" in predictions
|
||||
assert "cnn_feats" in predictions
|
||||
assert "cnn_feat_names" in predictions
|
||||
assert "spec_slices" in predictions
|
||||
|
||||
# By default will not return spectrogram features
|
||||
assert predictions["spec_feats"] is None
|
||||
assert predictions["spec_feat_names"] is None
|
||||
assert predictions["cnn_feats"] is None
|
||||
assert predictions["cnn_feat_names"] is None
|
||||
assert predictions["spec_slices"] is None
|
||||
|
||||
# Check that predictions are returned
|
||||
assert isinstance(predictions["pred_dict"], dict)
|
||||
pred_dict = predictions["pred_dict"]
|
||||
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
|
||||
assert pred_dict["annotated"] is False
|
||||
assert pred_dict["issues"] is False
|
||||
assert pred_dict["notes"] == "Automatically generated."
|
||||
assert pred_dict["time_exp"] == 1
|
||||
assert pred_dict["duration"] == 0.5
|
||||
assert pred_dict["class_name"] is not None
|
||||
assert len(pred_dict["annotation"]) > 0
|
||||
|
||||
|
||||
def test_process_spectrogram_with_model():
|
||||
"""Test processing spectrogram with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
spectrogram = generate_spectrogram(audio, samplerate)
|
||||
predictions, features = process_spectrogram(
|
||||
spectrogram,
|
||||
samplerate,
|
||||
model,
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 1
|
||||
|
||||
|
||||
def test_process_audio_with_model():
|
||||
"""Test processing audio with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
predictions, features, spec = process_audio(
|
||||
audio,
|
||||
samplerate,
|
||||
model,
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == 1
|
||||
|
||||
assert spec is not None
|
||||
assert isinstance(spec, torch.Tensor)
|
||||
assert spec.shape == (1, 1, 128, 512)
|
Loading…
Reference in New Issue
Block a user