mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Moved types to a single module
This commit is contained in:
parent
f0b0f28379
commit
6f2bb605d3
@ -3,10 +3,19 @@ from typing import List, Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import bat_detect.detector.models as md
|
|
||||||
import bat_detect.utils.audio_utils as au
|
import bat_detect.utils.audio_utils as au
|
||||||
import bat_detect.utils.detector_utils as du
|
import bat_detect.utils.detector_utils as du
|
||||||
from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ
|
from bat_detect.detector.parameters import (
|
||||||
|
DEFAULT_PROCESSING_CONFIGURATIONS,
|
||||||
|
DEFAULT_SPECTROGRAM_PARAMETERS,
|
||||||
|
TARGET_SAMPLERATE_HZ,
|
||||||
|
)
|
||||||
|
from bat_detect.types import (
|
||||||
|
Annotation,
|
||||||
|
DetectionModel,
|
||||||
|
ProcessingConfiguration,
|
||||||
|
SpectrogramParameters,
|
||||||
|
)
|
||||||
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
||||||
|
|
||||||
# Use GPU if available
|
# Use GPU if available
|
||||||
@ -24,12 +33,12 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_config(**kwargs) -> du.ProcessingConfiguration:
|
def get_config(**kwargs) -> ProcessingConfiguration:
|
||||||
"""Get default processing configuration.
|
"""Get default processing configuration.
|
||||||
|
|
||||||
Can be used to override default parameters by passing keyword arguments.
|
Can be used to override default parameters by passing keyword arguments.
|
||||||
"""
|
"""
|
||||||
return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
|
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
@ -73,7 +82,7 @@ def load_audio(
|
|||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samp_rate: int,
|
samp_rate: int,
|
||||||
config: Optional[au.SpectrogramParameters] = None,
|
config: Optional[SpectrogramParameters] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Generate spectrogram from audio array.
|
"""Generate spectrogram from audio array.
|
||||||
@ -93,7 +102,7 @@ def generate_spectrogram(
|
|||||||
Spectrogram.
|
Spectrogram.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = au.DEFAULT_SPECTROGRAM_PARAMETERS
|
config = DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
|
|
||||||
_, spec, _ = du.compute_spectrogram(
|
_, spec, _ = du.compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
@ -108,8 +117,8 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
model: md.DetectionModel,
|
model: DetectionModel,
|
||||||
config: Optional[du.ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> du.RunResults:
|
) -> du.RunResults:
|
||||||
"""Process audio file with model.
|
"""Process audio file with model.
|
||||||
@ -126,7 +135,7 @@ def process_file(
|
|||||||
Device to use, by default tries to use GPU if available.
|
Device to use, by default tries to use GPU if available.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||||
|
|
||||||
return du.process_file(
|
return du.process_file(
|
||||||
audio_file,
|
audio_file,
|
||||||
@ -139,9 +148,9 @@ def process_file(
|
|||||||
def process_spectrogram(
|
def process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samp_rate: int,
|
samp_rate: int,
|
||||||
model: md.DetectionModel,
|
model: DetectionModel,
|
||||||
config: Optional[du.ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
) -> Tuple[List[du.Annotation], List[np.ndarray]]:
|
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||||
"""Process spectrogram with model.
|
"""Process spectrogram with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -160,7 +169,7 @@ def process_spectrogram(
|
|||||||
DetectionResult
|
DetectionResult
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||||
|
|
||||||
return du.process_spectrogram(
|
return du.process_spectrogram(
|
||||||
spec,
|
spec,
|
||||||
@ -173,10 +182,10 @@ def process_spectrogram(
|
|||||||
def process_audio(
|
def process_audio(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samp_rate: int,
|
samp_rate: int,
|
||||||
model: md.DetectionModel,
|
model: DetectionModel,
|
||||||
config: Optional[du.ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]:
|
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||||
"""Process audio array with model.
|
"""Process audio array with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -204,7 +213,7 @@ def process_audio(
|
|||||||
Spectrogram of the audio used for prediction.
|
Spectrogram of the audio used for prediction.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = du.DEFAULT_PROCESSING_CONFIGURATIONS
|
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||||
|
|
||||||
return du.process_audio_array(
|
return du.process_audio_array(
|
||||||
audio,
|
audio,
|
||||||
|
@ -7,7 +7,7 @@ Example usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import bat_detect.utils.detector_utils as du
|
from bat_detect import api
|
||||||
|
|
||||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from typing import NamedTuple, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fft
|
import torch.fft
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -12,86 +10,15 @@ from bat_detect.detector.model_helpers import (
|
|||||||
ConvBlockUpStandard,
|
ConvBlockUpStandard,
|
||||||
SelfAttention,
|
SelfAttention,
|
||||||
)
|
)
|
||||||
|
from bat_detect.types import ModelOutput
|
||||||
try:
|
|
||||||
from typing import Protocol
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import Protocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
"Net2DFastNoAttn",
|
"Net2DFastNoAttn",
|
||||||
"Net2DFastNoCoordConv",
|
"Net2DFastNoCoordConv",
|
||||||
"ModelOutput",
|
|
||||||
"DetectionModel",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelOutput(NamedTuple):
|
|
||||||
"""Output of the detection model."""
|
|
||||||
|
|
||||||
pred_det: torch.Tensor
|
|
||||||
"""Tensor with predict detection probabilities."""
|
|
||||||
|
|
||||||
pred_size: torch.Tensor
|
|
||||||
"""Tensor with predicted bounding box sizes."""
|
|
||||||
|
|
||||||
pred_class: torch.Tensor
|
|
||||||
"""Tensor with predicted class probabilities."""
|
|
||||||
|
|
||||||
pred_class_un_norm: torch.Tensor
|
|
||||||
"""Tensor with predicted class probabilities before softmax."""
|
|
||||||
|
|
||||||
features: torch.Tensor
|
|
||||||
"""Tensor with intermediate features."""
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(Protocol):
|
|
||||||
"""Protocol for detection models.
|
|
||||||
|
|
||||||
This protocol is used to define the interface for the detection models.
|
|
||||||
This allows us to use the same code for training and inference, even
|
|
||||||
though the models are different.
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_classes: int
|
|
||||||
"""Number of classes the model can classify."""
|
|
||||||
|
|
||||||
emb_dim: int
|
|
||||||
"""Dimension of the embedding vector."""
|
|
||||||
|
|
||||||
num_filts: int
|
|
||||||
"""Number of filters in the model."""
|
|
||||||
|
|
||||||
resize_factor: float
|
|
||||||
"""Factor by which the input is resized."""
|
|
||||||
|
|
||||||
ip_height: int
|
|
||||||
"""Height of the input image."""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
ip: torch.Tensor,
|
|
||||||
return_feats: bool = False,
|
|
||||||
) -> ModelOutput:
|
|
||||||
"""Forward pass of the model.
|
|
||||||
|
|
||||||
When `return_feats` is `True`, the model should return the
|
|
||||||
intermediate features of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
ip: torch.Tensor,
|
|
||||||
return_feats: bool = False,
|
|
||||||
) -> ModelOutput:
|
|
||||||
"""Forward pass of the model.
|
|
||||||
|
|
||||||
When `return_feats` is `True`, the model should return the
|
|
||||||
int
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFast(nn.Module):
|
class Net2DFast(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from bat_detect.types import (
|
||||||
|
ProcessingConfiguration,
|
||||||
|
SpectrogramParameters,
|
||||||
|
)
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256000
|
TARGET_SAMPLERATE_HZ = 256000
|
||||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||||
FFT_OVERLAP = 0.75
|
FFT_OVERLAP = 0.75
|
||||||
@ -18,6 +23,56 @@ DENOISE_SPEC_AVG = True
|
|||||||
MAX_SCALE_SPEC = False
|
MAX_SCALE_SPEC = False
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MODEL_PATH = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"models",
|
||||||
|
"Net2DFast_UK_same.pth.tar",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
|
||||||
|
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||||
|
"fft_overlap": FFT_OVERLAP,
|
||||||
|
"spec_height": SPEC_HEIGHT,
|
||||||
|
"resize_factor": RESIZE_FACTOR,
|
||||||
|
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||||
|
"max_freq": MAX_FREQ_HZ,
|
||||||
|
"min_freq": MIN_FREQ_HZ,
|
||||||
|
"spec_scale": SPEC_SCALE,
|
||||||
|
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||||
|
"max_scale_spec": MAX_SCALE_SPEC,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
|
||||||
|
"detection_threshold": DETECTION_THRESHOLD,
|
||||||
|
"spec_slices": False,
|
||||||
|
"chunk_size": 3,
|
||||||
|
"spec_features": False,
|
||||||
|
"cnn_features": False,
|
||||||
|
"quiet": True,
|
||||||
|
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
||||||
|
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||||
|
"fft_overlap": FFT_OVERLAP,
|
||||||
|
"resize_factor": RESIZE_FACTOR,
|
||||||
|
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||||
|
"spec_height": SPEC_HEIGHT,
|
||||||
|
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||||
|
"class_names": [],
|
||||||
|
"time_expansion": 1,
|
||||||
|
"top_n": 3,
|
||||||
|
"return_raw_preds": False,
|
||||||
|
"max_duration": None,
|
||||||
|
"nms_kernel_size": NMS_KERNEL_SIZE,
|
||||||
|
"max_freq": MAX_FREQ_HZ,
|
||||||
|
"min_freq": MIN_FREQ_HZ,
|
||||||
|
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||||
|
"spec_scale": SPEC_SCALE,
|
||||||
|
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||||
|
"max_scale_spec": MAX_SCALE_SPEC,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def mk_dir(path):
|
def mk_dir(path):
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
@ -1,24 +1,18 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import warnings
|
||||||
import sys
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
sys.path.append(os.path.join("..", ".."))
|
from bat_detect.detector import models
|
||||||
|
from bat_detect.detector import parameters
|
||||||
import warnings
|
from bat_detect.train import losses
|
||||||
|
|
||||||
import bat_detect.detector.models as models
|
|
||||||
import bat_detect.detector.parameters as parameters
|
|
||||||
import bat_detect.detector.post_process as pp
|
import bat_detect.detector.post_process as pp
|
||||||
import bat_detect.train.audio_dataloader as adl
|
import bat_detect.train.audio_dataloader as adl
|
||||||
import bat_detect.train.evaluate as evl
|
import bat_detect.train.evaluate as evl
|
||||||
import bat_detect.train.losses as losses
|
|
||||||
import bat_detect.train.train_split as ts
|
import bat_detect.train.train_split as ts
|
||||||
import bat_detect.train.train_utils as tu
|
import bat_detect.train.train_utils as tu
|
||||||
import bat_detect.utils.plot_utils as pu
|
import bat_detect.utils.plot_utils as pu
|
||||||
|
304
bat_detect/types.py
Normal file
304
bat_detect/types.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
"""Types used in the code base."""
|
||||||
|
from typing import List, Optional, NamedTuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import TypedDict
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import Protocol
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramParameters(TypedDict):
|
||||||
|
"""Parameters for generating spectrograms."""
|
||||||
|
|
||||||
|
fft_win_length: float
|
||||||
|
"""Length of the FFT window in seconds."""
|
||||||
|
|
||||||
|
fft_overlap: float
|
||||||
|
"""Percentage of overlap between FFT windows."""
|
||||||
|
|
||||||
|
spec_height: int
|
||||||
|
"""Height of the spectrogram in pixels."""
|
||||||
|
|
||||||
|
resize_factor: float
|
||||||
|
"""Factor to resize the spectrogram by."""
|
||||||
|
|
||||||
|
spec_divide_factor: int
|
||||||
|
"""Factor to divide the spectrogram by."""
|
||||||
|
|
||||||
|
max_freq: int
|
||||||
|
"""Maximum frequency to display in the spectrogram."""
|
||||||
|
|
||||||
|
min_freq: int
|
||||||
|
"""Minimum frequency to display in the spectrogram."""
|
||||||
|
|
||||||
|
spec_scale: str
|
||||||
|
"""Scale to use for the spectrogram."""
|
||||||
|
|
||||||
|
denoise_spec_avg: bool
|
||||||
|
"""Whether to denoise the spectrogram by averaging."""
|
||||||
|
|
||||||
|
max_scale_spec: bool
|
||||||
|
"""Whether to scale the spectrogram so that its max is 1."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelParameters(TypedDict):
|
||||||
|
"""Model parameters."""
|
||||||
|
|
||||||
|
model_name: str
|
||||||
|
"""Model name."""
|
||||||
|
|
||||||
|
num_filters: int
|
||||||
|
"""Number of filters."""
|
||||||
|
|
||||||
|
emb_dim: int
|
||||||
|
"""Embedding dimension."""
|
||||||
|
|
||||||
|
ip_height: int
|
||||||
|
"""Input height in pixels."""
|
||||||
|
|
||||||
|
resize_factor: float
|
||||||
|
"""Resize factor."""
|
||||||
|
|
||||||
|
class_names: List[str]
|
||||||
|
"""Class names. The model is trained to detect these classes."""
|
||||||
|
|
||||||
|
|
||||||
|
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||||
|
|
||||||
|
|
||||||
|
class Annotation(DictWithClass):
|
||||||
|
"""Format of annotations.
|
||||||
|
|
||||||
|
This is the format of a single annotation as expected by the annotation
|
||||||
|
tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_time: float
|
||||||
|
"""Start time in seconds."""
|
||||||
|
|
||||||
|
end_time: float
|
||||||
|
"""End time in seconds."""
|
||||||
|
|
||||||
|
low_freq: int
|
||||||
|
"""Low frequency in Hz."""
|
||||||
|
|
||||||
|
high_freq: int
|
||||||
|
"""High frequency in Hz."""
|
||||||
|
|
||||||
|
class_prob: float
|
||||||
|
"""Probability of class assignment."""
|
||||||
|
|
||||||
|
det_prob: float
|
||||||
|
"""Probability of detection."""
|
||||||
|
|
||||||
|
individual: str
|
||||||
|
"""Individual ID."""
|
||||||
|
|
||||||
|
event: str
|
||||||
|
"""Type of detected event."""
|
||||||
|
|
||||||
|
|
||||||
|
class FileAnnotations(TypedDict):
|
||||||
|
"""Format of results.
|
||||||
|
|
||||||
|
This is the format of the results expected by the annotation tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""File ID."""
|
||||||
|
|
||||||
|
annotated: bool
|
||||||
|
"""Whether file has been annotated."""
|
||||||
|
|
||||||
|
duration: float
|
||||||
|
"""Duration of audio file."""
|
||||||
|
|
||||||
|
issues: bool
|
||||||
|
"""Whether file has issues."""
|
||||||
|
|
||||||
|
time_exp: float
|
||||||
|
"""Time expansion factor."""
|
||||||
|
|
||||||
|
class_name: str
|
||||||
|
"""Class predicted at file level"""
|
||||||
|
|
||||||
|
notes: str
|
||||||
|
"""Notes of file."""
|
||||||
|
|
||||||
|
annotation: List[Annotation]
|
||||||
|
"""List of annotations."""
|
||||||
|
|
||||||
|
|
||||||
|
class RunResults(TypedDict):
|
||||||
|
"""Run results."""
|
||||||
|
|
||||||
|
pred_dict: FileAnnotations
|
||||||
|
"""Predictions in the format expected by the annotation tool."""
|
||||||
|
|
||||||
|
spec_feats: Optional[List[np.ndarray]]
|
||||||
|
"""Spectrogram features."""
|
||||||
|
|
||||||
|
spec_feat_names: Optional[List[str]]
|
||||||
|
"""Spectrogram feature names."""
|
||||||
|
|
||||||
|
cnn_feats: Optional[List[np.ndarray]]
|
||||||
|
"""CNN features."""
|
||||||
|
|
||||||
|
cnn_feat_names: Optional[List[str]]
|
||||||
|
"""CNN feature names."""
|
||||||
|
|
||||||
|
spec_slices: Optional[List[np.ndarray]]
|
||||||
|
"""Spectrogram slices."""
|
||||||
|
|
||||||
|
|
||||||
|
class ResultParams(TypedDict):
|
||||||
|
"""Result parameters."""
|
||||||
|
|
||||||
|
class_names: List[str]
|
||||||
|
"""Class names."""
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingConfiguration(TypedDict):
|
||||||
|
"""Parameters for processing audio files."""
|
||||||
|
|
||||||
|
# audio parameters
|
||||||
|
target_samp_rate: int
|
||||||
|
"""Target sampling rate of the audio."""
|
||||||
|
|
||||||
|
fft_win_length: float
|
||||||
|
"""Length of the FFT window in seconds."""
|
||||||
|
|
||||||
|
fft_overlap: float
|
||||||
|
"""Length of the FFT window in samples."""
|
||||||
|
|
||||||
|
resize_factor: float
|
||||||
|
"""Factor to resize the spectrogram by."""
|
||||||
|
|
||||||
|
spec_divide_factor: int
|
||||||
|
"""Factor to divide the spectrogram by."""
|
||||||
|
|
||||||
|
spec_height: int
|
||||||
|
"""Height of the spectrogram in pixels."""
|
||||||
|
|
||||||
|
spec_scale: str
|
||||||
|
"""Scale to use for the spectrogram."""
|
||||||
|
|
||||||
|
denoise_spec_avg: bool
|
||||||
|
"""Whether to denoise the spectrogram by averaging."""
|
||||||
|
|
||||||
|
max_scale_spec: bool
|
||||||
|
"""Whether to scale the spectrogram so that its max is 1."""
|
||||||
|
|
||||||
|
scale_raw_audio: bool
|
||||||
|
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||||
|
|
||||||
|
class_names: List[str]
|
||||||
|
"""Names of the classes the model can detect."""
|
||||||
|
|
||||||
|
detection_threshold: float
|
||||||
|
"""Threshold for detection probability."""
|
||||||
|
|
||||||
|
time_expansion: Optional[float]
|
||||||
|
"""Time expansion factor of the processed recordings."""
|
||||||
|
|
||||||
|
top_n: int
|
||||||
|
"""Number of top detections to keep."""
|
||||||
|
|
||||||
|
return_raw_preds: bool
|
||||||
|
"""Whether to return raw predictions."""
|
||||||
|
|
||||||
|
max_duration: Optional[float]
|
||||||
|
"""Maximum duration of audio file to process in seconds."""
|
||||||
|
|
||||||
|
nms_kernel_size: int
|
||||||
|
"""Size of the kernel for non-maximum suppression."""
|
||||||
|
|
||||||
|
max_freq: int
|
||||||
|
"""Maximum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
min_freq: int
|
||||||
|
"""Minimum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
nms_top_k_per_sec: float
|
||||||
|
"""Number of top detections to keep per second."""
|
||||||
|
|
||||||
|
quiet: bool
|
||||||
|
"""Whether to suppress output."""
|
||||||
|
|
||||||
|
chunk_size: float
|
||||||
|
"""Size of chunks to process in seconds."""
|
||||||
|
|
||||||
|
cnn_features: bool
|
||||||
|
"""Whether to return CNN features."""
|
||||||
|
|
||||||
|
spec_features: bool
|
||||||
|
"""Whether to return spectrogram features."""
|
||||||
|
|
||||||
|
spec_slices: bool
|
||||||
|
"""Whether to return spectrogram slices."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutput(NamedTuple):
|
||||||
|
"""Output of the detection model."""
|
||||||
|
|
||||||
|
pred_det: torch.Tensor
|
||||||
|
"""Tensor with predict detection probabilities."""
|
||||||
|
|
||||||
|
pred_size: torch.Tensor
|
||||||
|
"""Tensor with predicted bounding box sizes."""
|
||||||
|
|
||||||
|
pred_class: torch.Tensor
|
||||||
|
"""Tensor with predicted class probabilities."""
|
||||||
|
|
||||||
|
pred_class_un_norm: torch.Tensor
|
||||||
|
"""Tensor with predicted class probabilities before softmax."""
|
||||||
|
|
||||||
|
features: torch.Tensor
|
||||||
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionModel(Protocol):
|
||||||
|
"""Protocol for detection models.
|
||||||
|
|
||||||
|
This protocol is used to define the interface for the detection models.
|
||||||
|
This allows us to use the same code for training and inference, even
|
||||||
|
though the models are different.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_classes: int
|
||||||
|
"""Number of classes the model can classify."""
|
||||||
|
|
||||||
|
emb_dim: int
|
||||||
|
"""Dimension of the embedding vector."""
|
||||||
|
|
||||||
|
num_filts: int
|
||||||
|
"""Number of filters in the model."""
|
||||||
|
|
||||||
|
resize_factor: float
|
||||||
|
"""Factor by which the input is resized."""
|
||||||
|
|
||||||
|
ip_height: int
|
||||||
|
"""Height of the input image."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
spec: torch.Tensor,
|
||||||
|
return_feats: bool = False,
|
||||||
|
) -> ModelOutput:
|
||||||
|
"""Forward pass of the model."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
spec: torch.Tensor,
|
||||||
|
return_feats: bool = False,
|
||||||
|
) -> ModelOutput:
|
||||||
|
"""Forward pass of the model."""
|
@ -38,54 +38,6 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramParameters(TypedDict):
|
|
||||||
"""Parameters for generating spectrograms."""
|
|
||||||
|
|
||||||
fft_win_length: float
|
|
||||||
"""Length of the FFT window in seconds."""
|
|
||||||
|
|
||||||
fft_overlap: float
|
|
||||||
"""Percentage of overlap between FFT windows."""
|
|
||||||
|
|
||||||
spec_height: int
|
|
||||||
"""Height of the spectrogram in pixels."""
|
|
||||||
|
|
||||||
resize_factor: float
|
|
||||||
"""Factor to resize the spectrogram by."""
|
|
||||||
|
|
||||||
spec_divide_factor: int
|
|
||||||
"""Factor to divide the spectrogram by."""
|
|
||||||
|
|
||||||
max_freq: int
|
|
||||||
"""Maximum frequency to display in the spectrogram."""
|
|
||||||
|
|
||||||
min_freq: int
|
|
||||||
"""Minimum frequency to display in the spectrogram."""
|
|
||||||
|
|
||||||
spec_scale: str
|
|
||||||
"""Scale to use for the spectrogram."""
|
|
||||||
|
|
||||||
denoise_spec_avg: bool
|
|
||||||
"""Whether to denoise the spectrogram by averaging."""
|
|
||||||
|
|
||||||
max_scale_spec: bool
|
|
||||||
"""Whether to scale the spectrogram so that its max is 1."""
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
|
|
||||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
|
||||||
"fft_overlap": FFT_OVERLAP,
|
|
||||||
"spec_height": SPEC_HEIGHT,
|
|
||||||
"resize_factor": RESIZE_FACTOR,
|
|
||||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
|
||||||
"max_freq": MAX_FREQ_HZ,
|
|
||||||
"min_freq": MIN_FREQ_HZ,
|
|
||||||
"spec_scale": SPEC_SCALE,
|
|
||||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
|
||||||
"max_scale_spec": MAX_SCALE_SPEC,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
noverlap = np.floor(fft_overlap * nfft)
|
||||||
|
@ -11,34 +11,16 @@ import bat_detect.detector.compute_features as feats
|
|||||||
import bat_detect.detector.post_process as pp
|
import bat_detect.detector.post_process as pp
|
||||||
import bat_detect.utils.audio_utils as au
|
import bat_detect.utils.audio_utils as au
|
||||||
from bat_detect.detector import models
|
from bat_detect.detector import models
|
||||||
from bat_detect.detector.parameters import (
|
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
|
||||||
DENOISE_SPEC_AVG,
|
from bat_detect.types import (
|
||||||
DETECTION_THRESHOLD,
|
Annotation,
|
||||||
FFT_OVERLAP,
|
FileAnnotations,
|
||||||
FFT_WIN_LENGTH_S,
|
ModelParameters,
|
||||||
MAX_FREQ_HZ,
|
ProcessingConfiguration,
|
||||||
MAX_SCALE_SPEC,
|
SpectrogramParameters,
|
||||||
MIN_FREQ_HZ,
|
ResultParams,
|
||||||
NMS_KERNEL_SIZE,
|
RunResults,
|
||||||
NMS_TOP_K_PER_SEC,
|
DetectionModel
|
||||||
RESIZE_FACTOR,
|
|
||||||
SCALE_RAW_AUDIO,
|
|
||||||
SPEC_DIVIDE_FACTOR,
|
|
||||||
SPEC_HEIGHT,
|
|
||||||
SPEC_SCALE,
|
|
||||||
TARGET_SAMPLERATE_HZ,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from typing import TypedDict
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MODEL_PATH = os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(__file__)),
|
|
||||||
"models",
|
|
||||||
"Net2DFast_UK_same.pth.tar",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -50,8 +32,6 @@ __all__ = [
|
|||||||
"process_spectrogram",
|
"process_spectrogram",
|
||||||
"process_audio_array",
|
"process_audio_array",
|
||||||
"process_file",
|
"process_file",
|
||||||
"DEFAULT_MODEL_PATH",
|
|
||||||
"DEFAULT_PROCESSING_CONFIGURATIONS",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -77,33 +57,11 @@ def list_audio_files(ip_dir: str) -> List[str]:
|
|||||||
return matches
|
return matches
|
||||||
|
|
||||||
|
|
||||||
class ModelParameters(TypedDict):
|
|
||||||
"""Model parameters."""
|
|
||||||
|
|
||||||
model_name: str
|
|
||||||
"""Model name."""
|
|
||||||
|
|
||||||
num_filters: int
|
|
||||||
"""Number of filters."""
|
|
||||||
|
|
||||||
emb_dim: int
|
|
||||||
"""Embedding dimension."""
|
|
||||||
|
|
||||||
ip_height: int
|
|
||||||
"""Input height in pixels."""
|
|
||||||
|
|
||||||
resize_factor: float
|
|
||||||
"""Resize factor."""
|
|
||||||
|
|
||||||
class_names: List[str]
|
|
||||||
"""Class names. The model is trained to detect these classes."""
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
model_path: str = DEFAULT_MODEL_PATH,
|
model_path: str = DEFAULT_MODEL_PATH,
|
||||||
load_weights: bool = True,
|
load_weights: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
) -> Tuple[models.DetectionModel, ModelParameters]:
|
) -> Tuple[DetectionModel, ModelParameters]:
|
||||||
"""Load model from file.
|
"""Load model from file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -129,7 +87,7 @@ def load_model(
|
|||||||
|
|
||||||
params = net_params["params"]
|
params = net_params["params"]
|
||||||
|
|
||||||
model: models.DetectionModel
|
model: DetectionModel
|
||||||
|
|
||||||
if params["model_name"] == "Net2DFast":
|
if params["model_name"] == "Net2DFast":
|
||||||
model = models.Net2DFast(
|
model = models.Net2DFast(
|
||||||
@ -189,101 +147,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
|||||||
return predictions_m, spec_feats, cnn_feats, spec_slices
|
return predictions_m, spec_feats, cnn_feats, spec_slices
|
||||||
|
|
||||||
|
|
||||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
|
||||||
|
|
||||||
|
|
||||||
class Annotation(DictWithClass):
|
|
||||||
"""Format of annotations.
|
|
||||||
|
|
||||||
This is the format of a single annotation as expected by the annotation
|
|
||||||
tool.
|
|
||||||
"""
|
|
||||||
|
|
||||||
start_time: float
|
|
||||||
"""Start time in seconds."""
|
|
||||||
|
|
||||||
end_time: float
|
|
||||||
"""End time in seconds."""
|
|
||||||
|
|
||||||
low_freq: int
|
|
||||||
"""Low frequency in Hz."""
|
|
||||||
|
|
||||||
high_freq: int
|
|
||||||
"""High frequency in Hz."""
|
|
||||||
|
|
||||||
class_prob: float
|
|
||||||
"""Probability of class assignment."""
|
|
||||||
|
|
||||||
det_prob: float
|
|
||||||
"""Probability of detection."""
|
|
||||||
|
|
||||||
individual: str
|
|
||||||
"""Individual ID."""
|
|
||||||
|
|
||||||
event: str
|
|
||||||
"""Type of detected event."""
|
|
||||||
|
|
||||||
|
|
||||||
class FileAnnotations(TypedDict):
|
|
||||||
"""Format of results.
|
|
||||||
|
|
||||||
This is the format of the results expected by the annotation tool.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
"""File ID."""
|
|
||||||
|
|
||||||
annotated: bool
|
|
||||||
"""Whether file has been annotated."""
|
|
||||||
|
|
||||||
duration: float
|
|
||||||
"""Duration of audio file."""
|
|
||||||
|
|
||||||
issues: bool
|
|
||||||
"""Whether file has issues."""
|
|
||||||
|
|
||||||
time_exp: float
|
|
||||||
"""Time expansion factor."""
|
|
||||||
|
|
||||||
class_name: str
|
|
||||||
"""Class predicted at file level"""
|
|
||||||
|
|
||||||
notes: str
|
|
||||||
"""Notes of file."""
|
|
||||||
|
|
||||||
annotation: List[Annotation]
|
|
||||||
"""List of annotations."""
|
|
||||||
|
|
||||||
|
|
||||||
class RunResults(TypedDict):
|
|
||||||
"""Run results."""
|
|
||||||
|
|
||||||
pred_dict: FileAnnotations
|
|
||||||
"""Predictions in the format expected by the annotation tool."""
|
|
||||||
|
|
||||||
spec_feats: Optional[List[np.ndarray]]
|
|
||||||
"""Spectrogram features."""
|
|
||||||
|
|
||||||
spec_feat_names: Optional[List[str]]
|
|
||||||
"""Spectrogram feature names."""
|
|
||||||
|
|
||||||
cnn_feats: Optional[List[np.ndarray]]
|
|
||||||
"""CNN features."""
|
|
||||||
|
|
||||||
cnn_feat_names: Optional[List[str]]
|
|
||||||
"""CNN feature names."""
|
|
||||||
|
|
||||||
spec_slices: Optional[List[np.ndarray]]
|
|
||||||
"""Spectrogram slices."""
|
|
||||||
|
|
||||||
|
|
||||||
class ResultParams(TypedDict):
|
|
||||||
"""Result parameters."""
|
|
||||||
|
|
||||||
class_names: List[str]
|
|
||||||
"""Class names."""
|
|
||||||
|
|
||||||
|
|
||||||
def get_annotations_from_preds(
|
def get_annotations_from_preds(
|
||||||
predictions,
|
predictions,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
@ -499,7 +362,7 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: int,
|
||||||
params: au.SpectrogramParameters,
|
params: SpectrogramParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
return_np: bool = False,
|
return_np: bool = False,
|
||||||
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
||||||
@ -608,90 +471,10 @@ def iterate_over_chunks(
|
|||||||
yield chunk_start, audio[start_sample:end_sample]
|
yield chunk_start, audio[start_sample:end_sample]
|
||||||
|
|
||||||
|
|
||||||
class ProcessingConfiguration(TypedDict):
|
|
||||||
"""Parameters for processing audio files."""
|
|
||||||
|
|
||||||
# audio parameters
|
|
||||||
target_samp_rate: int
|
|
||||||
"""Target sampling rate of the audio."""
|
|
||||||
|
|
||||||
fft_win_length: float
|
|
||||||
"""Length of the FFT window in seconds."""
|
|
||||||
|
|
||||||
fft_overlap: float
|
|
||||||
"""Length of the FFT window in samples."""
|
|
||||||
|
|
||||||
resize_factor: float
|
|
||||||
"""Factor to resize the spectrogram by."""
|
|
||||||
|
|
||||||
spec_divide_factor: int
|
|
||||||
"""Factor to divide the spectrogram by."""
|
|
||||||
|
|
||||||
spec_height: int
|
|
||||||
"""Height of the spectrogram in pixels."""
|
|
||||||
|
|
||||||
spec_scale: str
|
|
||||||
"""Scale to use for the spectrogram."""
|
|
||||||
|
|
||||||
denoise_spec_avg: bool
|
|
||||||
"""Whether to denoise the spectrogram by averaging."""
|
|
||||||
|
|
||||||
max_scale_spec: bool
|
|
||||||
"""Whether to scale the spectrogram so that its max is 1."""
|
|
||||||
|
|
||||||
scale_raw_audio: bool
|
|
||||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
|
||||||
|
|
||||||
class_names: List[str]
|
|
||||||
"""Names of the classes the model can detect."""
|
|
||||||
|
|
||||||
detection_threshold: float
|
|
||||||
"""Threshold for detection probability."""
|
|
||||||
|
|
||||||
time_expansion: Optional[float]
|
|
||||||
"""Time expansion factor of the processed recordings."""
|
|
||||||
|
|
||||||
top_n: int
|
|
||||||
"""Number of top detections to keep."""
|
|
||||||
|
|
||||||
return_raw_preds: bool
|
|
||||||
"""Whether to return raw predictions."""
|
|
||||||
|
|
||||||
max_duration: Optional[float]
|
|
||||||
"""Maximum duration of audio file to process in seconds."""
|
|
||||||
|
|
||||||
nms_kernel_size: int
|
|
||||||
"""Size of the kernel for non-maximum suppression."""
|
|
||||||
|
|
||||||
max_freq: int
|
|
||||||
"""Maximum frequency to consider in Hz."""
|
|
||||||
|
|
||||||
min_freq: int
|
|
||||||
"""Minimum frequency to consider in Hz."""
|
|
||||||
|
|
||||||
nms_top_k_per_sec: float
|
|
||||||
"""Number of top detections to keep per second."""
|
|
||||||
|
|
||||||
quiet: bool
|
|
||||||
"""Whether to suppress output."""
|
|
||||||
|
|
||||||
chunk_size: float
|
|
||||||
"""Size of chunks to process in seconds."""
|
|
||||||
|
|
||||||
cnn_features: bool
|
|
||||||
"""Whether to return CNN features."""
|
|
||||||
|
|
||||||
spec_features: bool
|
|
||||||
"""Whether to return spectrogram features."""
|
|
||||||
|
|
||||||
spec_slices: bool
|
|
||||||
"""Whether to return spectrogram slices."""
|
|
||||||
|
|
||||||
|
|
||||||
def _process_spectrogram(
|
def _process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
model: models.DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||||
# evaluate model
|
# evaluate model
|
||||||
@ -730,7 +513,7 @@ def _process_spectrogram(
|
|||||||
def process_spectrogram(
|
def process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
model: models.DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||||
"""Process a spectrogram with detection model.
|
"""Process a spectrogram with detection model.
|
||||||
@ -775,7 +558,7 @@ def process_spectrogram(
|
|||||||
def _process_audio_array(
|
def _process_audio_array(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: int,
|
||||||
model: torch.nn.Module,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||||
@ -813,7 +596,7 @@ def _process_audio_array(
|
|||||||
def process_audio_array(
|
def process_audio_array(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: int,
|
sampling_rate: int,
|
||||||
model: torch.nn.Module,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||||
@ -864,7 +647,7 @@ def process_audio_array(
|
|||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
model: torch.nn.Module,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Union[RunResults, Any]:
|
) -> Union[RunResults, Any]:
|
||||||
@ -989,32 +772,3 @@ def summarize_results(results, predictions, config):
|
|||||||
config["class_names"][class_index].ljust(30)
|
config["class_names"][class_index].ljust(30)
|
||||||
+ str(round(class_overall[class_index], 3))
|
+ str(round(class_overall[class_index], 3))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
|
|
||||||
"detection_threshold": DETECTION_THRESHOLD,
|
|
||||||
"spec_slices": False,
|
|
||||||
"chunk_size": 3,
|
|
||||||
"spec_features": False,
|
|
||||||
"cnn_features": False,
|
|
||||||
"quiet": True,
|
|
||||||
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
|
||||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
|
||||||
"fft_overlap": FFT_OVERLAP,
|
|
||||||
"resize_factor": RESIZE_FACTOR,
|
|
||||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
|
||||||
"spec_height": SPEC_HEIGHT,
|
|
||||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
|
||||||
"class_names": [],
|
|
||||||
"time_expansion": 1,
|
|
||||||
"top_n": 3,
|
|
||||||
"return_raw_preds": False,
|
|
||||||
"max_duration": None,
|
|
||||||
"nms_kernel_size": NMS_KERNEL_SIZE,
|
|
||||||
"max_freq": MAX_FREQ_HZ,
|
|
||||||
"min_freq": MIN_FREQ_HZ,
|
|
||||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
|
||||||
"spec_scale": SPEC_SCALE,
|
|
||||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
|
||||||
"max_scale_spec": MAX_SCALE_SPEC,
|
|
||||||
}
|
|
||||||
|
@ -0,0 +1,211 @@
|
|||||||
|
"""Test bat detect module API."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from bat_detect.api import (
|
||||||
|
generate_spectrogram,
|
||||||
|
get_config,
|
||||||
|
list_audio_files,
|
||||||
|
load_audio,
|
||||||
|
load_model,
|
||||||
|
process_audio,
|
||||||
|
process_file,
|
||||||
|
process_spectrogram,
|
||||||
|
)
|
||||||
|
|
||||||
|
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||||
|
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_with_default_params():
|
||||||
|
"""Test loading model with default parameters."""
|
||||||
|
model, params = load_model()
|
||||||
|
|
||||||
|
assert model is not None
|
||||||
|
assert isinstance(model, nn.Module)
|
||||||
|
|
||||||
|
assert params is not None
|
||||||
|
assert isinstance(params, dict)
|
||||||
|
|
||||||
|
assert "model_name" in params
|
||||||
|
assert "num_filters" in params
|
||||||
|
assert "emb_dim" in params
|
||||||
|
assert "ip_height" in params
|
||||||
|
assert "resize_factor" in params
|
||||||
|
assert "class_names" in params
|
||||||
|
|
||||||
|
assert params["model_name"] == "Net2DFast"
|
||||||
|
assert params["num_filters"] == 128
|
||||||
|
assert params["emb_dim"] == 0
|
||||||
|
assert params["ip_height"] == 128
|
||||||
|
assert params["resize_factor"] == 0.5
|
||||||
|
assert len(params["class_names"]) == 17
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_audio_files():
|
||||||
|
"""Test listing audio files."""
|
||||||
|
audio_files = list_audio_files(TEST_DATA_DIR)
|
||||||
|
|
||||||
|
assert len(audio_files) == 3
|
||||||
|
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_audio():
|
||||||
|
"""Test loading audio."""
|
||||||
|
samplerate, audio = load_audio(TEST_DATA[0])
|
||||||
|
|
||||||
|
assert audio is not None
|
||||||
|
assert samplerate == 256000
|
||||||
|
assert isinstance(audio, np.ndarray)
|
||||||
|
assert audio.shape == (128000,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_spectrogram():
|
||||||
|
"""Test generating spectrogram."""
|
||||||
|
samplerate, audio = load_audio(TEST_DATA[0])
|
||||||
|
spectrogram = generate_spectrogram(audio, samplerate)
|
||||||
|
|
||||||
|
assert spectrogram is not None
|
||||||
|
assert isinstance(spectrogram, torch.Tensor)
|
||||||
|
assert spectrogram.shape == (1, 1, 128, 512)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_config():
|
||||||
|
"""Test getting default configuration."""
|
||||||
|
config = get_config()
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
|
assert isinstance(config, dict)
|
||||||
|
|
||||||
|
assert config["target_samp_rate"] == 256000
|
||||||
|
assert config["fft_win_length"] == 0.002
|
||||||
|
assert config["fft_overlap"] == 0.75
|
||||||
|
assert config["resize_factor"] == 0.5
|
||||||
|
assert config["spec_divide_factor"] == 32
|
||||||
|
assert config["spec_height"] == 256
|
||||||
|
assert config["spec_scale"] == "pcen"
|
||||||
|
assert config["denoise_spec_avg"] is True
|
||||||
|
assert config["max_scale_spec"] is False
|
||||||
|
assert config["scale_raw_audio"] is False
|
||||||
|
assert len(config["class_names"]) == 0
|
||||||
|
assert config["detection_threshold"] == 0.01
|
||||||
|
assert config["time_expansion"] == 1
|
||||||
|
assert config["top_n"] == 3
|
||||||
|
assert config["return_raw_preds"] is False
|
||||||
|
assert config["max_duration"] is None
|
||||||
|
assert config["nms_kernel_size"] == 9
|
||||||
|
assert config["max_freq"] == 120000
|
||||||
|
assert config["min_freq"] == 10000
|
||||||
|
assert config["nms_top_k_per_sec"] == 200
|
||||||
|
assert config["quiet"] is True
|
||||||
|
assert config["chunk_size"] == 3
|
||||||
|
assert config["cnn_features"] is False
|
||||||
|
assert config["spec_features"] is False
|
||||||
|
assert config["spec_slices"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_file_with_model():
|
||||||
|
"""Test processing file with model."""
|
||||||
|
model, params = load_model()
|
||||||
|
config = get_config(**params)
|
||||||
|
predictions = process_file(TEST_DATA[0], model, config=config)
|
||||||
|
|
||||||
|
assert predictions is not None
|
||||||
|
assert isinstance(predictions, dict)
|
||||||
|
|
||||||
|
assert "pred_dict" in predictions
|
||||||
|
assert "spec_feats" in predictions
|
||||||
|
assert "spec_feat_names" in predictions
|
||||||
|
assert "cnn_feats" in predictions
|
||||||
|
assert "cnn_feat_names" in predictions
|
||||||
|
assert "spec_slices" in predictions
|
||||||
|
|
||||||
|
# By default will not return spectrogram features
|
||||||
|
assert predictions["spec_feats"] is None
|
||||||
|
assert predictions["spec_feat_names"] is None
|
||||||
|
assert predictions["cnn_feats"] is None
|
||||||
|
assert predictions["cnn_feat_names"] is None
|
||||||
|
assert predictions["spec_slices"] is None
|
||||||
|
|
||||||
|
# Check that predictions are returned
|
||||||
|
assert isinstance(predictions["pred_dict"], dict)
|
||||||
|
pred_dict = predictions["pred_dict"]
|
||||||
|
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
|
||||||
|
assert pred_dict["annotated"] is False
|
||||||
|
assert pred_dict["issues"] is False
|
||||||
|
assert pred_dict["notes"] == "Automatically generated."
|
||||||
|
assert pred_dict["time_exp"] == 1
|
||||||
|
assert pred_dict["duration"] == 0.5
|
||||||
|
assert pred_dict["class_name"] is not None
|
||||||
|
assert len(pred_dict["annotation"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_spectrogram_with_model():
|
||||||
|
"""Test processing spectrogram with model."""
|
||||||
|
model, params = load_model()
|
||||||
|
config = get_config(**params)
|
||||||
|
samplerate, audio = load_audio(TEST_DATA[0])
|
||||||
|
spectrogram = generate_spectrogram(audio, samplerate)
|
||||||
|
predictions, features = process_spectrogram(
|
||||||
|
spectrogram,
|
||||||
|
samplerate,
|
||||||
|
model,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert predictions is not None
|
||||||
|
assert isinstance(predictions, list)
|
||||||
|
assert len(predictions) > 0
|
||||||
|
sample_pred = predictions[0]
|
||||||
|
assert isinstance(sample_pred, dict)
|
||||||
|
assert "class" in sample_pred
|
||||||
|
assert "class_prob" in sample_pred
|
||||||
|
assert "det_prob" in sample_pred
|
||||||
|
assert "start_time" in sample_pred
|
||||||
|
assert "end_time" in sample_pred
|
||||||
|
assert "low_freq" in sample_pred
|
||||||
|
assert "high_freq" in sample_pred
|
||||||
|
|
||||||
|
assert features is not None
|
||||||
|
assert isinstance(features, list)
|
||||||
|
assert len(features) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_audio_with_model():
|
||||||
|
"""Test processing audio with model."""
|
||||||
|
model, params = load_model()
|
||||||
|
config = get_config(**params)
|
||||||
|
samplerate, audio = load_audio(TEST_DATA[0])
|
||||||
|
predictions, features, spec = process_audio(
|
||||||
|
audio,
|
||||||
|
samplerate,
|
||||||
|
model,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert predictions is not None
|
||||||
|
assert isinstance(predictions, list)
|
||||||
|
assert len(predictions) > 0
|
||||||
|
sample_pred = predictions[0]
|
||||||
|
assert isinstance(sample_pred, dict)
|
||||||
|
assert "class" in sample_pred
|
||||||
|
assert "class_prob" in sample_pred
|
||||||
|
assert "det_prob" in sample_pred
|
||||||
|
assert "start_time" in sample_pred
|
||||||
|
assert "end_time" in sample_pred
|
||||||
|
assert "low_freq" in sample_pred
|
||||||
|
assert "high_freq" in sample_pred
|
||||||
|
|
||||||
|
assert features is not None
|
||||||
|
assert isinstance(features, list)
|
||||||
|
assert len(features) == 1
|
||||||
|
|
||||||
|
assert spec is not None
|
||||||
|
assert isinstance(spec, torch.Tensor)
|
||||||
|
assert spec.shape == (1, 1, 128, 512)
|
@ -1,211 +0,0 @@
|
|||||||
"""Test bat detect module API."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from glob import glob
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from bat_detect.api import (
|
|
||||||
generate_spectrogram,
|
|
||||||
get_config,
|
|
||||||
list_audio_files,
|
|
||||||
load_audio,
|
|
||||||
load_model,
|
|
||||||
process_audio,
|
|
||||||
process_file,
|
|
||||||
process_spectrogram,
|
|
||||||
)
|
|
||||||
|
|
||||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
|
||||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_model_with_default_params():
|
|
||||||
"""Test loading model with default parameters."""
|
|
||||||
model, params = load_model()
|
|
||||||
|
|
||||||
assert model is not None
|
|
||||||
assert isinstance(model, nn.Module)
|
|
||||||
|
|
||||||
assert params is not None
|
|
||||||
assert isinstance(params, dict)
|
|
||||||
|
|
||||||
assert "model_name" in params
|
|
||||||
assert "num_filters" in params
|
|
||||||
assert "emb_dim" in params
|
|
||||||
assert "ip_height" in params
|
|
||||||
assert "resize_factor" in params
|
|
||||||
assert "class_names" in params
|
|
||||||
|
|
||||||
assert params["model_name"] == "Net2DFast"
|
|
||||||
assert params["num_filters"] == 128
|
|
||||||
assert params["emb_dim"] == 0
|
|
||||||
assert params["ip_height"] == 128
|
|
||||||
assert params["resize_factor"] == 0.5
|
|
||||||
assert len(params["class_names"]) == 17
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_audio_files():
|
|
||||||
"""Test listing audio files."""
|
|
||||||
audio_files = list_audio_files(TEST_DATA_DIR)
|
|
||||||
|
|
||||||
assert len(audio_files) == 3
|
|
||||||
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_audio():
|
|
||||||
"""Test loading audio."""
|
|
||||||
samplerate, audio = load_audio(TEST_DATA[0])
|
|
||||||
|
|
||||||
assert audio is not None
|
|
||||||
assert samplerate == 256000
|
|
||||||
assert isinstance(audio, np.ndarray)
|
|
||||||
assert audio.shape == (128000,)
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_spectrogram():
|
|
||||||
"""Test generating spectrogram."""
|
|
||||||
samplerate, audio = load_audio(TEST_DATA[0])
|
|
||||||
spectrogram = generate_spectrogram(audio, samplerate)
|
|
||||||
|
|
||||||
assert spectrogram is not None
|
|
||||||
assert isinstance(spectrogram, torch.Tensor)
|
|
||||||
assert spectrogram.shape == (1, 1, 128, 512)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_default_config():
|
|
||||||
"""Test getting default configuration."""
|
|
||||||
config = get_config()
|
|
||||||
|
|
||||||
assert config is not None
|
|
||||||
assert isinstance(config, dict)
|
|
||||||
|
|
||||||
assert config["target_samp_rate"] == 256000
|
|
||||||
assert config["fft_win_length"] == 0.002
|
|
||||||
assert config["fft_overlap"] == 0.75
|
|
||||||
assert config["resize_factor"] == 0.5
|
|
||||||
assert config["spec_divide_factor"] == 32
|
|
||||||
assert config["spec_height"] == 256
|
|
||||||
assert config["spec_scale"] == "pcen"
|
|
||||||
assert config["denoise_spec_avg"] is True
|
|
||||||
assert config["max_scale_spec"] is False
|
|
||||||
assert config["scale_raw_audio"] is False
|
|
||||||
assert len(config["class_names"]) == 0
|
|
||||||
assert config["detection_threshold"] == 0.01
|
|
||||||
assert config["time_expansion"] == 1
|
|
||||||
assert config["top_n"] == 3
|
|
||||||
assert config["return_raw_preds"] is False
|
|
||||||
assert config["max_duration"] is None
|
|
||||||
assert config["nms_kernel_size"] == 9
|
|
||||||
assert config["max_freq"] == 120000
|
|
||||||
assert config["min_freq"] == 10000
|
|
||||||
assert config["nms_top_k_per_sec"] == 200
|
|
||||||
assert config["quiet"] is True
|
|
||||||
assert config["chunk_size"] == 3
|
|
||||||
assert config["cnn_features"] is False
|
|
||||||
assert config["spec_features"] is False
|
|
||||||
assert config["spec_slices"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_with_model():
|
|
||||||
"""Test processing file with model."""
|
|
||||||
model, params = load_model()
|
|
||||||
config = get_config(**params)
|
|
||||||
predictions = process_file(TEST_DATA[0], model, config=config)
|
|
||||||
|
|
||||||
assert predictions is not None
|
|
||||||
assert isinstance(predictions, dict)
|
|
||||||
|
|
||||||
assert "pred_dict" in predictions
|
|
||||||
assert "spec_feats" in predictions
|
|
||||||
assert "spec_feat_names" in predictions
|
|
||||||
assert "cnn_feats" in predictions
|
|
||||||
assert "cnn_feat_names" in predictions
|
|
||||||
assert "spec_slices" in predictions
|
|
||||||
|
|
||||||
# By default will not return spectrogram features
|
|
||||||
assert predictions["spec_feats"] is None
|
|
||||||
assert predictions["spec_feat_names"] is None
|
|
||||||
assert predictions["cnn_feats"] is None
|
|
||||||
assert predictions["cnn_feat_names"] is None
|
|
||||||
assert predictions["spec_slices"] is None
|
|
||||||
|
|
||||||
# Check that predictions are returned
|
|
||||||
assert isinstance(predictions["pred_dict"], dict)
|
|
||||||
pred_dict = predictions["pred_dict"]
|
|
||||||
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
|
|
||||||
assert pred_dict["annotated"] is False
|
|
||||||
assert pred_dict["issues"] is False
|
|
||||||
assert pred_dict["notes"] == "Automatically generated."
|
|
||||||
assert pred_dict["time_exp"] == 1
|
|
||||||
assert pred_dict["duration"] == 0.5
|
|
||||||
assert pred_dict["class_name"] is not None
|
|
||||||
assert len(pred_dict["annotation"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_spectrogram_with_model():
|
|
||||||
"""Test processing spectrogram with model."""
|
|
||||||
model, params = load_model()
|
|
||||||
config = get_config(**params)
|
|
||||||
samplerate, audio = load_audio(TEST_DATA[0])
|
|
||||||
spectrogram = generate_spectrogram(audio, samplerate)
|
|
||||||
predictions, features = process_spectrogram(
|
|
||||||
spectrogram,
|
|
||||||
samplerate,
|
|
||||||
model,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert predictions is not None
|
|
||||||
assert isinstance(predictions, list)
|
|
||||||
assert len(predictions) > 0
|
|
||||||
sample_pred = predictions[0]
|
|
||||||
assert isinstance(sample_pred, dict)
|
|
||||||
assert "class" in sample_pred
|
|
||||||
assert "class_prob" in sample_pred
|
|
||||||
assert "det_prob" in sample_pred
|
|
||||||
assert "start_time" in sample_pred
|
|
||||||
assert "end_time" in sample_pred
|
|
||||||
assert "low_freq" in sample_pred
|
|
||||||
assert "high_freq" in sample_pred
|
|
||||||
|
|
||||||
assert features is not None
|
|
||||||
assert isinstance(features, list)
|
|
||||||
assert len(features) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_audio_with_model():
|
|
||||||
"""Test processing audio with model."""
|
|
||||||
model, params = load_model()
|
|
||||||
config = get_config(**params)
|
|
||||||
samplerate, audio = load_audio(TEST_DATA[0])
|
|
||||||
predictions, features, spec = process_audio(
|
|
||||||
audio,
|
|
||||||
samplerate,
|
|
||||||
model,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert predictions is not None
|
|
||||||
assert isinstance(predictions, list)
|
|
||||||
assert len(predictions) > 0
|
|
||||||
sample_pred = predictions[0]
|
|
||||||
assert isinstance(sample_pred, dict)
|
|
||||||
assert "class" in sample_pred
|
|
||||||
assert "class_prob" in sample_pred
|
|
||||||
assert "det_prob" in sample_pred
|
|
||||||
assert "start_time" in sample_pred
|
|
||||||
assert "end_time" in sample_pred
|
|
||||||
assert "low_freq" in sample_pred
|
|
||||||
assert "high_freq" in sample_pred
|
|
||||||
|
|
||||||
assert features is not None
|
|
||||||
assert isinstance(features, list)
|
|
||||||
assert len(features) == 1
|
|
||||||
|
|
||||||
assert spec is not None
|
|
||||||
assert isinstance(spec, torch.Tensor)
|
|
||||||
assert spec.shape == (1, 1, 128, 512)
|
|
Loading…
Reference in New Issue
Block a user