Moved types to a single module

This commit is contained in:
Santiago Martinez 2023-02-26 14:27:03 +00:00
parent f0b0f28379
commit 6f2bb605d3
10 changed files with 620 additions and 625 deletions

View File

@ -3,10 +3,19 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import bat_detect.detector.models as md
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du import bat_detect.utils.detector_utils as du
from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ from bat_detect.detector.parameters import (
DEFAULT_PROCESSING_CONFIGURATIONS,
DEFAULT_SPECTROGRAM_PARAMETERS,
TARGET_SAMPLERATE_HZ,
)
from bat_detect.types import (
Annotation,
DetectionModel,
ProcessingConfiguration,
SpectrogramParameters,
)
from bat_detect.utils.detector_utils import list_audio_files, load_model from bat_detect.utils.detector_utils import list_audio_files, load_model
# Use GPU if available # Use GPU if available
@ -24,12 +33,12 @@ __all__ = [
] ]
def get_config(**kwargs) -> du.ProcessingConfiguration: def get_config(**kwargs) -> ProcessingConfiguration:
"""Get default processing configuration. """Get default processing configuration.
Can be used to override default parameters by passing keyword arguments. Can be used to override default parameters by passing keyword arguments.
""" """
return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
def load_audio( def load_audio(
@ -73,7 +82,7 @@ def load_audio(
def generate_spectrogram( def generate_spectrogram(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int, samp_rate: int,
config: Optional[au.SpectrogramParameters] = None, config: Optional[SpectrogramParameters] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> torch.Tensor: ) -> torch.Tensor:
"""Generate spectrogram from audio array. """Generate spectrogram from audio array.
@ -93,7 +102,7 @@ def generate_spectrogram(
Spectrogram. Spectrogram.
""" """
if config is None: if config is None:
config = au.DEFAULT_SPECTROGRAM_PARAMETERS config = DEFAULT_SPECTROGRAM_PARAMETERS
_, spec, _ = du.compute_spectrogram( _, spec, _ = du.compute_spectrogram(
audio, audio,
@ -108,8 +117,8 @@ def generate_spectrogram(
def process_file( def process_file(
audio_file: str, audio_file: str,
model: md.DetectionModel, model: DetectionModel,
config: Optional[du.ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> du.RunResults: ) -> du.RunResults:
"""Process audio file with model. """Process audio file with model.
@ -126,7 +135,7 @@ def process_file(
Device to use, by default tries to use GPU if available. Device to use, by default tries to use GPU if available.
""" """
if config is None: if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS config = DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_file( return du.process_file(
audio_file, audio_file,
@ -139,9 +148,9 @@ def process_file(
def process_spectrogram( def process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samp_rate: int, samp_rate: int,
model: md.DetectionModel, model: DetectionModel,
config: Optional[du.ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[du.Annotation], List[np.ndarray]]: ) -> Tuple[List[Annotation], List[np.ndarray]]:
"""Process spectrogram with model. """Process spectrogram with model.
Parameters Parameters
@ -160,7 +169,7 @@ def process_spectrogram(
DetectionResult DetectionResult
""" """
if config is None: if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS config = DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_spectrogram( return du.process_spectrogram(
spec, spec,
@ -173,10 +182,10 @@ def process_spectrogram(
def process_audio( def process_audio(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int, samp_rate: int,
model: md.DetectionModel, model: DetectionModel,
config: Optional[du.ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]: ) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
"""Process audio array with model. """Process audio array with model.
Parameters Parameters
@ -204,7 +213,7 @@ def process_audio(
Spectrogram of the audio used for prediction. Spectrogram of the audio used for prediction.
""" """
if config is None: if config is None:
config = du.DEFAULT_PROCESSING_CONFIGURATIONS config = DEFAULT_PROCESSING_CONFIGURATIONS
return du.process_audio_array( return du.process_audio_array(
audio, audio,

View File

@ -7,7 +7,7 @@ Example usage:
import argparse import argparse
import os import os
import bat_detect.utils.detector_utils as du from bat_detect import api
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

View File

@ -1,5 +1,3 @@
from typing import NamedTuple, Optional
import torch import torch
import torch.fft import torch.fft
import torch.nn.functional as F import torch.nn.functional as F
@ -12,86 +10,15 @@ from bat_detect.detector.model_helpers import (
ConvBlockUpStandard, ConvBlockUpStandard,
SelfAttention, SelfAttention,
) )
from bat_detect.types import ModelOutput
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
__all__ = [ __all__ = [
"Net2DFast", "Net2DFast",
"Net2DFastNoAttn", "Net2DFastNoAttn",
"Net2DFastNoCoordConv", "Net2DFastNoCoordConv",
"ModelOutput",
"DetectionModel",
] ]
class ModelOutput(NamedTuple):
"""Output of the detection model."""
pred_det: torch.Tensor
"""Tensor with predict detection probabilities."""
pred_size: torch.Tensor
"""Tensor with predicted bounding box sizes."""
pred_class: torch.Tensor
"""Tensor with predicted class probabilities."""
pred_class_un_norm: torch.Tensor
"""Tensor with predicted class probabilities before softmax."""
features: torch.Tensor
"""Tensor with intermediate features."""
class DetectionModel(Protocol):
"""Protocol for detection models.
This protocol is used to define the interface for the detection models.
This allows us to use the same code for training and inference, even
though the models are different.
"""
num_classes: int
"""Number of classes the model can classify."""
emb_dim: int
"""Dimension of the embedding vector."""
num_filts: int
"""Number of filters in the model."""
resize_factor: float
"""Factor by which the input is resized."""
ip_height: int
"""Height of the input image."""
def forward(
self,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model.
When `return_feats` is `True`, the model should return the
intermediate features of the model.
"""
def __call__(
self,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model.
When `return_feats` is `True`, the model should return the
int
"""
class Net2DFast(nn.Module): class Net2DFast(nn.Module):
def __init__( def __init__(
self, self,

View File

@ -1,6 +1,11 @@
import datetime import datetime
import os import os
from bat_detect.types import (
ProcessingConfiguration,
SpectrogramParameters,
)
TARGET_SAMPLERATE_HZ = 256000 TARGET_SAMPLERATE_HZ = 256000
FFT_WIN_LENGTH_S = 512 / 256000.0 FFT_WIN_LENGTH_S = 512 / 256000.0
FFT_OVERLAP = 0.75 FFT_OVERLAP = 0.75
@ -18,6 +23,56 @@ DENOISE_SPEC_AVG = True
MAX_SCALE_SPEC = False MAX_SCALE_SPEC = False
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"Net2DFast_UK_same.pth.tar",
)
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"spec_height": SPEC_HEIGHT,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False,
"chunk_size": 3,
"spec_features": False,
"cnn_features": False,
"quiet": True,
"target_samp_rate": TARGET_SAMPLERATE_HZ,
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO,
"class_names": [],
"time_expansion": 1,
"top_n": 3,
"return_raw_preds": False,
"max_duration": None,
"nms_kernel_size": NMS_KERNEL_SIZE,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
def mk_dir(path): def mk_dir(path):
if not os.path.isdir(path): if not os.path.isdir(path):
os.makedirs(path) os.makedirs(path)

View File

@ -1,24 +1,18 @@
import argparse import argparse
import json import json
import os import warnings
import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
sys.path.append(os.path.join("..", "..")) from bat_detect.detector import models
from bat_detect.detector import parameters
import warnings from bat_detect.train import losses
import bat_detect.detector.models as models
import bat_detect.detector.parameters as parameters
import bat_detect.detector.post_process as pp import bat_detect.detector.post_process as pp
import bat_detect.train.audio_dataloader as adl import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl import bat_detect.train.evaluate as evl
import bat_detect.train.losses as losses
import bat_detect.train.train_split as ts import bat_detect.train.train_split as ts
import bat_detect.train.train_utils as tu import bat_detect.train.train_utils as tu
import bat_detect.utils.plot_utils as pu import bat_detect.utils.plot_utils as pu

304
bat_detect/types.py Normal file
View File

@ -0,0 +1,304 @@
"""Types used in the code base."""
from typing import List, Optional, NamedTuple
import numpy as np
import torch
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
max_freq: int
"""Maximum frequency to display in the spectrogram."""
min_freq: int
"""Minimum frequency to display in the spectrogram."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
class ModelParameters(TypedDict):
"""Model parameters."""
model_name: str
"""Model name."""
num_filters: int
"""Number of filters."""
emb_dim: int
"""Embedding dimension."""
ip_height: int
"""Input height in pixels."""
resize_factor: float
"""Resize factor."""
class_names: List[str]
"""Class names. The model is trained to detect these classes."""
DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the annotation
tool.
"""
start_time: float
"""Start time in seconds."""
end_time: float
"""End time in seconds."""
low_freq: int
"""Low frequency in Hz."""
high_freq: int
"""High frequency in Hz."""
class_prob: float
"""Probability of class assignment."""
det_prob: float
"""Probability of detection."""
individual: str
"""Individual ID."""
event: str
"""Type of detected event."""
class FileAnnotations(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
"""
id: str
"""File ID."""
annotated: bool
"""Whether file has been annotated."""
duration: float
"""Duration of audio file."""
issues: bool
"""Whether file has issues."""
time_exp: float
"""Time expansion factor."""
class_name: str
"""Class predicted at file level"""
notes: str
"""Notes of file."""
annotation: List[Annotation]
"""List of annotations."""
class RunResults(TypedDict):
"""Run results."""
pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool."""
spec_feats: Optional[List[np.ndarray]]
"""Spectrogram features."""
spec_feat_names: Optional[List[str]]
"""Spectrogram feature names."""
cnn_feats: Optional[List[np.ndarray]]
"""CNN features."""
cnn_feat_names: Optional[List[str]]
"""CNN feature names."""
spec_slices: Optional[List[np.ndarray]]
"""Spectrogram slices."""
class ResultParams(TypedDict):
"""Result parameters."""
class_names: List[str]
"""Class names."""
class ProcessingConfiguration(TypedDict):
"""Parameters for processing audio files."""
# audio parameters
target_samp_rate: int
"""Target sampling rate of the audio."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Length of the FFT window in samples."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
spec_height: int
"""Height of the spectrogram in pixels."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1."""
class_names: List[str]
"""Names of the classes the model can detect."""
detection_threshold: float
"""Threshold for detection probability."""
time_expansion: Optional[float]
"""Time expansion factor of the processed recordings."""
top_n: int
"""Number of top detections to keep."""
return_raw_preds: bool
"""Whether to return raw predictions."""
max_duration: Optional[float]
"""Maximum duration of audio file to process in seconds."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: int
"""Maximum frequency to consider in Hz."""
min_freq: int
"""Minimum frequency to consider in Hz."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
quiet: bool
"""Whether to suppress output."""
chunk_size: float
"""Size of chunks to process in seconds."""
cnn_features: bool
"""Whether to return CNN features."""
spec_features: bool
"""Whether to return spectrogram features."""
spec_slices: bool
"""Whether to return spectrogram slices."""
class ModelOutput(NamedTuple):
"""Output of the detection model."""
pred_det: torch.Tensor
"""Tensor with predict detection probabilities."""
pred_size: torch.Tensor
"""Tensor with predicted bounding box sizes."""
pred_class: torch.Tensor
"""Tensor with predicted class probabilities."""
pred_class_un_norm: torch.Tensor
"""Tensor with predicted class probabilities before softmax."""
features: torch.Tensor
"""Tensor with intermediate features."""
class DetectionModel(Protocol):
"""Protocol for detection models.
This protocol is used to define the interface for the detection models.
This allows us to use the same code for training and inference, even
though the models are different.
"""
num_classes: int
"""Number of classes the model can classify."""
emb_dim: int
"""Dimension of the embedding vector."""
num_filts: int
"""Number of filters in the model."""
resize_factor: float
"""Factor by which the input is resized."""
ip_height: int
"""Height of the input image."""
def forward(
self,
spec: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model."""
def __call__(
self,
spec: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model."""

View File

@ -38,54 +38,6 @@ __all__ = [
] ]
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Percentage of overlap between FFT windows."""
spec_height: int
"""Height of the spectrogram in pixels."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
max_freq: int
"""Maximum frequency to display in the spectrogram."""
min_freq: int
"""Minimum frequency to display in the spectrogram."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"spec_height": SPEC_HEIGHT,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap * nfft) noverlap = np.floor(fft_overlap * nfft)

View File

@ -11,34 +11,16 @@ import bat_detect.detector.compute_features as feats
import bat_detect.detector.post_process as pp import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
from bat_detect.detector import models from bat_detect.detector import models
from bat_detect.detector.parameters import ( from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
DENOISE_SPEC_AVG, from bat_detect.types import (
DETECTION_THRESHOLD, Annotation,
FFT_OVERLAP, FileAnnotations,
FFT_WIN_LENGTH_S, ModelParameters,
MAX_FREQ_HZ, ProcessingConfiguration,
MAX_SCALE_SPEC, SpectrogramParameters,
MIN_FREQ_HZ, ResultParams,
NMS_KERNEL_SIZE, RunResults,
NMS_TOP_K_PER_SEC, DetectionModel
RESIZE_FACTOR,
SCALE_RAW_AUDIO,
SPEC_DIVIDE_FACTOR,
SPEC_HEIGHT,
SPEC_SCALE,
TARGET_SAMPLERATE_HZ,
)
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"Net2DFast_UK_same.pth.tar",
) )
__all__ = [ __all__ = [
@ -50,8 +32,6 @@ __all__ = [
"process_spectrogram", "process_spectrogram",
"process_audio_array", "process_audio_array",
"process_file", "process_file",
"DEFAULT_MODEL_PATH",
"DEFAULT_PROCESSING_CONFIGURATIONS",
] ]
@ -77,33 +57,11 @@ def list_audio_files(ip_dir: str) -> List[str]:
return matches return matches
class ModelParameters(TypedDict):
"""Model parameters."""
model_name: str
"""Model name."""
num_filters: int
"""Number of filters."""
emb_dim: int
"""Embedding dimension."""
ip_height: int
"""Input height in pixels."""
resize_factor: float
"""Resize factor."""
class_names: List[str]
"""Class names. The model is trained to detect these classes."""
def load_model( def load_model(
model_path: str = DEFAULT_MODEL_PATH, model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True, load_weights: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> Tuple[models.DetectionModel, ModelParameters]: ) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file. """Load model from file.
Args: Args:
@ -129,7 +87,7 @@ def load_model(
params = net_params["params"] params = net_params["params"]
model: models.DetectionModel model: DetectionModel
if params["model_name"] == "Net2DFast": if params["model_name"] == "Net2DFast":
model = models.Net2DFast( model = models.Net2DFast(
@ -189,101 +147,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
return predictions_m, spec_feats, cnn_feats, spec_slices return predictions_m, spec_feats, cnn_feats, spec_slices
DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the annotation
tool.
"""
start_time: float
"""Start time in seconds."""
end_time: float
"""End time in seconds."""
low_freq: int
"""Low frequency in Hz."""
high_freq: int
"""High frequency in Hz."""
class_prob: float
"""Probability of class assignment."""
det_prob: float
"""Probability of detection."""
individual: str
"""Individual ID."""
event: str
"""Type of detected event."""
class FileAnnotations(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
"""
id: str
"""File ID."""
annotated: bool
"""Whether file has been annotated."""
duration: float
"""Duration of audio file."""
issues: bool
"""Whether file has issues."""
time_exp: float
"""Time expansion factor."""
class_name: str
"""Class predicted at file level"""
notes: str
"""Notes of file."""
annotation: List[Annotation]
"""List of annotations."""
class RunResults(TypedDict):
"""Run results."""
pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool."""
spec_feats: Optional[List[np.ndarray]]
"""Spectrogram features."""
spec_feat_names: Optional[List[str]]
"""Spectrogram feature names."""
cnn_feats: Optional[List[np.ndarray]]
"""CNN features."""
cnn_feat_names: Optional[List[str]]
"""CNN feature names."""
spec_slices: Optional[List[np.ndarray]]
"""Spectrogram slices."""
class ResultParams(TypedDict):
"""Result parameters."""
class_names: List[str]
"""Class names."""
def get_annotations_from_preds( def get_annotations_from_preds(
predictions, predictions,
class_names: List[str], class_names: List[str],
@ -499,7 +362,7 @@ def save_results_to_file(results, op_path: str) -> None:
def compute_spectrogram( def compute_spectrogram(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: int,
params: au.SpectrogramParameters, params: SpectrogramParameters,
device: torch.device, device: torch.device,
return_np: bool = False, return_np: bool = False,
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]: ) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
@ -608,90 +471,10 @@ def iterate_over_chunks(
yield chunk_start, audio[start_sample:end_sample] yield chunk_start, audio[start_sample:end_sample]
class ProcessingConfiguration(TypedDict):
"""Parameters for processing audio files."""
# audio parameters
target_samp_rate: int
"""Target sampling rate of the audio."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Length of the FFT window in samples."""
resize_factor: float
"""Factor to resize the spectrogram by."""
spec_divide_factor: int
"""Factor to divide the spectrogram by."""
spec_height: int
"""Height of the spectrogram in pixels."""
spec_scale: str
"""Scale to use for the spectrogram."""
denoise_spec_avg: bool
"""Whether to denoise the spectrogram by averaging."""
max_scale_spec: bool
"""Whether to scale the spectrogram so that its max is 1."""
scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1."""
class_names: List[str]
"""Names of the classes the model can detect."""
detection_threshold: float
"""Threshold for detection probability."""
time_expansion: Optional[float]
"""Time expansion factor of the processed recordings."""
top_n: int
"""Number of top detections to keep."""
return_raw_preds: bool
"""Whether to return raw predictions."""
max_duration: Optional[float]
"""Maximum duration of audio file to process in seconds."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: int
"""Maximum frequency to consider in Hz."""
min_freq: int
"""Minimum frequency to consider in Hz."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
quiet: bool
"""Whether to suppress output."""
chunk_size: float
"""Size of chunks to process in seconds."""
cnn_features: bool
"""Whether to return CNN features."""
spec_features: bool
"""Whether to return spectrogram features."""
spec_slices: bool
"""Whether to return spectrogram slices."""
def _process_spectrogram( def _process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samplerate: int, samplerate: int,
model: models.DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Tuple[List[Annotation], List[np.ndarray]]: ) -> Tuple[List[Annotation], List[np.ndarray]]:
# evaluate model # evaluate model
@ -730,7 +513,7 @@ def _process_spectrogram(
def process_spectrogram( def process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samplerate: int, samplerate: int,
model: models.DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Tuple[List[Annotation], List[np.ndarray]]: ) -> Tuple[List[Annotation], List[np.ndarray]]:
"""Process a spectrogram with detection model. """Process a spectrogram with detection model.
@ -775,7 +558,7 @@ def process_spectrogram(
def _process_audio_array( def _process_audio_array(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: int,
model: torch.nn.Module, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: ) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
@ -813,7 +596,7 @@ def _process_audio_array(
def process_audio_array( def process_audio_array(
audio: np.ndarray, audio: np.ndarray,
sampling_rate: int, sampling_rate: int,
model: torch.nn.Module, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: ) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
@ -864,7 +647,7 @@ def process_audio_array(
def process_file( def process_file(
audio_file: str, audio_file: str,
model: torch.nn.Module, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> Union[RunResults, Any]: ) -> Union[RunResults, Any]:
@ -989,32 +772,3 @@ def summarize_results(results, predictions, config):
config["class_names"][class_index].ljust(30) config["class_names"][class_index].ljust(30)
+ str(round(class_overall[class_index], 3)) + str(round(class_overall[class_index], 3))
) )
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False,
"chunk_size": 3,
"spec_features": False,
"cnn_features": False,
"quiet": True,
"target_samp_rate": TARGET_SAMPLERATE_HZ,
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO,
"class_names": [],
"time_expansion": 1,
"top_n": 3,
"return_raw_preds": False,
"max_duration": None,
"nms_kernel_size": NMS_KERNEL_SIZE,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}

View File

@ -0,0 +1,211 @@
"""Test bat detect module API."""
import os
from glob import glob
import numpy as np
import torch
from torch import nn
from bat_detect.api import (
generate_spectrogram,
get_config,
list_audio_files,
load_audio,
load_model,
process_audio,
process_file,
process_spectrogram,
)
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
def test_load_model_with_default_params():
"""Test loading model with default parameters."""
model, params = load_model()
assert model is not None
assert isinstance(model, nn.Module)
assert params is not None
assert isinstance(params, dict)
assert "model_name" in params
assert "num_filters" in params
assert "emb_dim" in params
assert "ip_height" in params
assert "resize_factor" in params
assert "class_names" in params
assert params["model_name"] == "Net2DFast"
assert params["num_filters"] == 128
assert params["emb_dim"] == 0
assert params["ip_height"] == 128
assert params["resize_factor"] == 0.5
assert len(params["class_names"]) == 17
def test_list_audio_files():
"""Test listing audio files."""
audio_files = list_audio_files(TEST_DATA_DIR)
assert len(audio_files) == 3
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
def test_load_audio():
"""Test loading audio."""
samplerate, audio = load_audio(TEST_DATA[0])
assert audio is not None
assert samplerate == 256000
assert isinstance(audio, np.ndarray)
assert audio.shape == (128000,)
def test_generate_spectrogram():
"""Test generating spectrogram."""
samplerate, audio = load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate)
assert spectrogram is not None
assert isinstance(spectrogram, torch.Tensor)
assert spectrogram.shape == (1, 1, 128, 512)
def test_get_default_config():
"""Test getting default configuration."""
config = get_config()
assert config is not None
assert isinstance(config, dict)
assert config["target_samp_rate"] == 256000
assert config["fft_win_length"] == 0.002
assert config["fft_overlap"] == 0.75
assert config["resize_factor"] == 0.5
assert config["spec_divide_factor"] == 32
assert config["spec_height"] == 256
assert config["spec_scale"] == "pcen"
assert config["denoise_spec_avg"] is True
assert config["max_scale_spec"] is False
assert config["scale_raw_audio"] is False
assert len(config["class_names"]) == 0
assert config["detection_threshold"] == 0.01
assert config["time_expansion"] == 1
assert config["top_n"] == 3
assert config["return_raw_preds"] is False
assert config["max_duration"] is None
assert config["nms_kernel_size"] == 9
assert config["max_freq"] == 120000
assert config["min_freq"] == 10000
assert config["nms_top_k_per_sec"] == 200
assert config["quiet"] is True
assert config["chunk_size"] == 3
assert config["cnn_features"] is False
assert config["spec_features"] is False
assert config["spec_slices"] is False
def test_process_file_with_model():
"""Test processing file with model."""
model, params = load_model()
config = get_config(**params)
predictions = process_file(TEST_DATA[0], model, config=config)
assert predictions is not None
assert isinstance(predictions, dict)
assert "pred_dict" in predictions
assert "spec_feats" in predictions
assert "spec_feat_names" in predictions
assert "cnn_feats" in predictions
assert "cnn_feat_names" in predictions
assert "spec_slices" in predictions
# By default will not return spectrogram features
assert predictions["spec_feats"] is None
assert predictions["spec_feat_names"] is None
assert predictions["cnn_feats"] is None
assert predictions["cnn_feat_names"] is None
assert predictions["spec_slices"] is None
# Check that predictions are returned
assert isinstance(predictions["pred_dict"], dict)
pred_dict = predictions["pred_dict"]
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
assert pred_dict["annotated"] is False
assert pred_dict["issues"] is False
assert pred_dict["notes"] == "Automatically generated."
assert pred_dict["time_exp"] == 1
assert pred_dict["duration"] == 0.5
assert pred_dict["class_name"] is not None
assert len(pred_dict["annotation"]) > 0
def test_process_spectrogram_with_model():
"""Test processing spectrogram with model."""
model, params = load_model()
config = get_config(**params)
samplerate, audio = load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate)
predictions, features = process_spectrogram(
spectrogram,
samplerate,
model,
config=config,
)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
assert len(features) == 1
def test_process_audio_with_model():
"""Test processing audio with model."""
model, params = load_model()
config = get_config(**params)
samplerate, audio = load_audio(TEST_DATA[0])
predictions, features, spec = process_audio(
audio,
samplerate,
model,
config=config,
)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
assert len(features) == 1
assert spec is not None
assert isinstance(spec, torch.Tensor)
assert spec.shape == (1, 1, 128, 512)

View File

@ -1,211 +0,0 @@
"""Test bat detect module API."""
import os
from glob import glob
import numpy as np
import torch
from torch import nn
from bat_detect.api import (
generate_spectrogram,
get_config,
list_audio_files,
load_audio,
load_model,
process_audio,
process_file,
process_spectrogram,
)
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
def test_load_model_with_default_params():
"""Test loading model with default parameters."""
model, params = load_model()
assert model is not None
assert isinstance(model, nn.Module)
assert params is not None
assert isinstance(params, dict)
assert "model_name" in params
assert "num_filters" in params
assert "emb_dim" in params
assert "ip_height" in params
assert "resize_factor" in params
assert "class_names" in params
assert params["model_name"] == "Net2DFast"
assert params["num_filters"] == 128
assert params["emb_dim"] == 0
assert params["ip_height"] == 128
assert params["resize_factor"] == 0.5
assert len(params["class_names"]) == 17
def test_list_audio_files():
"""Test listing audio files."""
audio_files = list_audio_files(TEST_DATA_DIR)
assert len(audio_files) == 3
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
def test_load_audio():
"""Test loading audio."""
samplerate, audio = load_audio(TEST_DATA[0])
assert audio is not None
assert samplerate == 256000
assert isinstance(audio, np.ndarray)
assert audio.shape == (128000,)
def test_generate_spectrogram():
"""Test generating spectrogram."""
samplerate, audio = load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate)
assert spectrogram is not None
assert isinstance(spectrogram, torch.Tensor)
assert spectrogram.shape == (1, 1, 128, 512)
def test_get_default_config():
"""Test getting default configuration."""
config = get_config()
assert config is not None
assert isinstance(config, dict)
assert config["target_samp_rate"] == 256000
assert config["fft_win_length"] == 0.002
assert config["fft_overlap"] == 0.75
assert config["resize_factor"] == 0.5
assert config["spec_divide_factor"] == 32
assert config["spec_height"] == 256
assert config["spec_scale"] == "pcen"
assert config["denoise_spec_avg"] is True
assert config["max_scale_spec"] is False
assert config["scale_raw_audio"] is False
assert len(config["class_names"]) == 0
assert config["detection_threshold"] == 0.01
assert config["time_expansion"] == 1
assert config["top_n"] == 3
assert config["return_raw_preds"] is False
assert config["max_duration"] is None
assert config["nms_kernel_size"] == 9
assert config["max_freq"] == 120000
assert config["min_freq"] == 10000
assert config["nms_top_k_per_sec"] == 200
assert config["quiet"] is True
assert config["chunk_size"] == 3
assert config["cnn_features"] is False
assert config["spec_features"] is False
assert config["spec_slices"] is False
def test_process_file_with_model():
"""Test processing file with model."""
model, params = load_model()
config = get_config(**params)
predictions = process_file(TEST_DATA[0], model, config=config)
assert predictions is not None
assert isinstance(predictions, dict)
assert "pred_dict" in predictions
assert "spec_feats" in predictions
assert "spec_feat_names" in predictions
assert "cnn_feats" in predictions
assert "cnn_feat_names" in predictions
assert "spec_slices" in predictions
# By default will not return spectrogram features
assert predictions["spec_feats"] is None
assert predictions["spec_feat_names"] is None
assert predictions["cnn_feats"] is None
assert predictions["cnn_feat_names"] is None
assert predictions["spec_slices"] is None
# Check that predictions are returned
assert isinstance(predictions["pred_dict"], dict)
pred_dict = predictions["pred_dict"]
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
assert pred_dict["annotated"] is False
assert pred_dict["issues"] is False
assert pred_dict["notes"] == "Automatically generated."
assert pred_dict["time_exp"] == 1
assert pred_dict["duration"] == 0.5
assert pred_dict["class_name"] is not None
assert len(pred_dict["annotation"]) > 0
def test_process_spectrogram_with_model():
"""Test processing spectrogram with model."""
model, params = load_model()
config = get_config(**params)
samplerate, audio = load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate)
predictions, features = process_spectrogram(
spectrogram,
samplerate,
model,
config=config,
)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
assert len(features) == 1
def test_process_audio_with_model():
"""Test processing audio with model."""
model, params = load_model()
config = get_config(**params)
samplerate, audio = load_audio(TEST_DATA[0])
predictions, features, spec = process_audio(
audio,
samplerate,
model,
config=config,
)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
assert len(features) == 1
assert spec is not None
assert isinstance(spec, torch.Tensor)
assert spec.shape == (1, 1, 128, 512)