diff --git a/bat_detect/api.py b/bat_detect/api.py index 34fdd6c..b09d3b4 100644 --- a/bat_detect/api.py +++ b/bat_detect/api.py @@ -3,10 +3,19 @@ from typing import List, Optional, Tuple import numpy as np import torch -import bat_detect.detector.models as md import bat_detect.utils.audio_utils as au import bat_detect.utils.detector_utils as du -from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ +from bat_detect.detector.parameters import ( + DEFAULT_PROCESSING_CONFIGURATIONS, + DEFAULT_SPECTROGRAM_PARAMETERS, + TARGET_SAMPLERATE_HZ, +) +from bat_detect.types import ( + Annotation, + DetectionModel, + ProcessingConfiguration, + SpectrogramParameters, +) from bat_detect.utils.detector_utils import list_audio_files, load_model # Use GPU if available @@ -24,12 +33,12 @@ __all__ = [ ] -def get_config(**kwargs) -> du.ProcessingConfiguration: +def get_config(**kwargs) -> ProcessingConfiguration: """Get default processing configuration. Can be used to override default parameters by passing keyword arguments. """ - return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} + return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} def load_audio( @@ -73,7 +82,7 @@ def load_audio( def generate_spectrogram( audio: np.ndarray, samp_rate: int, - config: Optional[au.SpectrogramParameters] = None, + config: Optional[SpectrogramParameters] = None, device: torch.device = DEVICE, ) -> torch.Tensor: """Generate spectrogram from audio array. @@ -93,7 +102,7 @@ def generate_spectrogram( Spectrogram. """ if config is None: - config = au.DEFAULT_SPECTROGRAM_PARAMETERS + config = DEFAULT_SPECTROGRAM_PARAMETERS _, spec, _ = du.compute_spectrogram( audio, @@ -108,8 +117,8 @@ def generate_spectrogram( def process_file( audio_file: str, - model: md.DetectionModel, - config: Optional[du.ProcessingConfiguration] = None, + model: DetectionModel, + config: Optional[ProcessingConfiguration] = None, device: torch.device = DEVICE, ) -> du.RunResults: """Process audio file with model. @@ -126,7 +135,7 @@ def process_file( Device to use, by default tries to use GPU if available. """ if config is None: - config = du.DEFAULT_PROCESSING_CONFIGURATIONS + config = DEFAULT_PROCESSING_CONFIGURATIONS return du.process_file( audio_file, @@ -139,9 +148,9 @@ def process_file( def process_spectrogram( spec: torch.Tensor, samp_rate: int, - model: md.DetectionModel, - config: Optional[du.ProcessingConfiguration] = None, -) -> Tuple[List[du.Annotation], List[np.ndarray]]: + model: DetectionModel, + config: Optional[ProcessingConfiguration] = None, +) -> Tuple[List[Annotation], List[np.ndarray]]: """Process spectrogram with model. Parameters @@ -160,7 +169,7 @@ def process_spectrogram( DetectionResult """ if config is None: - config = du.DEFAULT_PROCESSING_CONFIGURATIONS + config = DEFAULT_PROCESSING_CONFIGURATIONS return du.process_spectrogram( spec, @@ -173,10 +182,10 @@ def process_spectrogram( def process_audio( audio: np.ndarray, samp_rate: int, - model: md.DetectionModel, - config: Optional[du.ProcessingConfiguration] = None, + model: DetectionModel, + config: Optional[ProcessingConfiguration] = None, device: torch.device = DEVICE, -) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]: +) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: """Process audio array with model. Parameters @@ -204,7 +213,7 @@ def process_audio( Spectrogram of the audio used for prediction. """ if config is None: - config = du.DEFAULT_PROCESSING_CONFIGURATIONS + config = DEFAULT_PROCESSING_CONFIGURATIONS return du.process_audio_array( audio, diff --git a/bat_detect/cli.py b/bat_detect/cli.py index 9ad4d32..9293745 100644 --- a/bat_detect/cli.py +++ b/bat_detect/cli.py @@ -7,7 +7,7 @@ Example usage: import argparse import os -import bat_detect.utils.detector_utils as du +from bat_detect import api CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/bat_detect/detector/models.py b/bat_detect/detector/models.py index 47ec728..99a48e1 100644 --- a/bat_detect/detector/models.py +++ b/bat_detect/detector/models.py @@ -1,5 +1,3 @@ -from typing import NamedTuple, Optional - import torch import torch.fft import torch.nn.functional as F @@ -12,86 +10,15 @@ from bat_detect.detector.model_helpers import ( ConvBlockUpStandard, SelfAttention, ) - -try: - from typing import Protocol -except ImportError: - from typing_extensions import Protocol +from bat_detect.types import ModelOutput __all__ = [ "Net2DFast", "Net2DFastNoAttn", "Net2DFastNoCoordConv", - "ModelOutput", - "DetectionModel", ] -class ModelOutput(NamedTuple): - """Output of the detection model.""" - - pred_det: torch.Tensor - """Tensor with predict detection probabilities.""" - - pred_size: torch.Tensor - """Tensor with predicted bounding box sizes.""" - - pred_class: torch.Tensor - """Tensor with predicted class probabilities.""" - - pred_class_un_norm: torch.Tensor - """Tensor with predicted class probabilities before softmax.""" - - features: torch.Tensor - """Tensor with intermediate features.""" - - -class DetectionModel(Protocol): - """Protocol for detection models. - - This protocol is used to define the interface for the detection models. - This allows us to use the same code for training and inference, even - though the models are different. - """ - - num_classes: int - """Number of classes the model can classify.""" - - emb_dim: int - """Dimension of the embedding vector.""" - - num_filts: int - """Number of filters in the model.""" - - resize_factor: float - """Factor by which the input is resized.""" - - ip_height: int - """Height of the input image.""" - - def forward( - self, - ip: torch.Tensor, - return_feats: bool = False, - ) -> ModelOutput: - """Forward pass of the model. - - When `return_feats` is `True`, the model should return the - intermediate features of the model. - """ - - def __call__( - self, - ip: torch.Tensor, - return_feats: bool = False, - ) -> ModelOutput: - """Forward pass of the model. - - When `return_feats` is `True`, the model should return the - int - """ - - class Net2DFast(nn.Module): def __init__( self, diff --git a/bat_detect/detector/parameters.py b/bat_detect/detector/parameters.py index b7d9244..f733062 100644 --- a/bat_detect/detector/parameters.py +++ b/bat_detect/detector/parameters.py @@ -1,6 +1,11 @@ import datetime import os +from bat_detect.types import ( + ProcessingConfiguration, + SpectrogramParameters, +) + TARGET_SAMPLERATE_HZ = 256000 FFT_WIN_LENGTH_S = 512 / 256000.0 FFT_OVERLAP = 0.75 @@ -18,6 +23,56 @@ DENOISE_SPEC_AVG = True MAX_SCALE_SPEC = False +DEFAULT_MODEL_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "models", + "Net2DFast_UK_same.pth.tar", +) + + +DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = { + "fft_win_length": FFT_WIN_LENGTH_S, + "fft_overlap": FFT_OVERLAP, + "spec_height": SPEC_HEIGHT, + "resize_factor": RESIZE_FACTOR, + "spec_divide_factor": SPEC_DIVIDE_FACTOR, + "max_freq": MAX_FREQ_HZ, + "min_freq": MIN_FREQ_HZ, + "spec_scale": SPEC_SCALE, + "denoise_spec_avg": DENOISE_SPEC_AVG, + "max_scale_spec": MAX_SCALE_SPEC, +} + + +DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = { + "detection_threshold": DETECTION_THRESHOLD, + "spec_slices": False, + "chunk_size": 3, + "spec_features": False, + "cnn_features": False, + "quiet": True, + "target_samp_rate": TARGET_SAMPLERATE_HZ, + "fft_win_length": FFT_WIN_LENGTH_S, + "fft_overlap": FFT_OVERLAP, + "resize_factor": RESIZE_FACTOR, + "spec_divide_factor": SPEC_DIVIDE_FACTOR, + "spec_height": SPEC_HEIGHT, + "scale_raw_audio": SCALE_RAW_AUDIO, + "class_names": [], + "time_expansion": 1, + "top_n": 3, + "return_raw_preds": False, + "max_duration": None, + "nms_kernel_size": NMS_KERNEL_SIZE, + "max_freq": MAX_FREQ_HZ, + "min_freq": MIN_FREQ_HZ, + "nms_top_k_per_sec": NMS_TOP_K_PER_SEC, + "spec_scale": SPEC_SCALE, + "denoise_spec_avg": DENOISE_SPEC_AVG, + "max_scale_spec": MAX_SCALE_SPEC, +} + + def mk_dir(path): if not os.path.isdir(path): os.makedirs(path) diff --git a/bat_detect/train/train_model.py b/bat_detect/train/train_model.py index 3619576..1f4ea5f 100644 --- a/bat_detect/train/train_model.py +++ b/bat_detect/train/train_model.py @@ -1,24 +1,18 @@ import argparse import json -import os -import sys +import warnings import matplotlib.pyplot as plt import numpy as np import torch -import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR -sys.path.append(os.path.join("..", "..")) - -import warnings - -import bat_detect.detector.models as models -import bat_detect.detector.parameters as parameters +from bat_detect.detector import models +from bat_detect.detector import parameters +from bat_detect.train import losses import bat_detect.detector.post_process as pp import bat_detect.train.audio_dataloader as adl import bat_detect.train.evaluate as evl -import bat_detect.train.losses as losses import bat_detect.train.train_split as ts import bat_detect.train.train_utils as tu import bat_detect.utils.plot_utils as pu diff --git a/bat_detect/types.py b/bat_detect/types.py new file mode 100644 index 0000000..e961a28 --- /dev/null +++ b/bat_detect/types.py @@ -0,0 +1,304 @@ +"""Types used in the code base.""" +from typing import List, Optional, NamedTuple + +import numpy as np +import torch + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + + +try: + from typing import Protocol +except ImportError: + from typing_extensions import Protocol + + +class SpectrogramParameters(TypedDict): + """Parameters for generating spectrograms.""" + + fft_win_length: float + """Length of the FFT window in seconds.""" + + fft_overlap: float + """Percentage of overlap between FFT windows.""" + + spec_height: int + """Height of the spectrogram in pixels.""" + + resize_factor: float + """Factor to resize the spectrogram by.""" + + spec_divide_factor: int + """Factor to divide the spectrogram by.""" + + max_freq: int + """Maximum frequency to display in the spectrogram.""" + + min_freq: int + """Minimum frequency to display in the spectrogram.""" + + spec_scale: str + """Scale to use for the spectrogram.""" + + denoise_spec_avg: bool + """Whether to denoise the spectrogram by averaging.""" + + max_scale_spec: bool + """Whether to scale the spectrogram so that its max is 1.""" + + +class ModelParameters(TypedDict): + """Model parameters.""" + + model_name: str + """Model name.""" + + num_filters: int + """Number of filters.""" + + emb_dim: int + """Embedding dimension.""" + + ip_height: int + """Input height in pixels.""" + + resize_factor: float + """Resize factor.""" + + class_names: List[str] + """Class names. The model is trained to detect these classes.""" + + +DictWithClass = TypedDict("DictWithClass", {"class": str}) + + +class Annotation(DictWithClass): + """Format of annotations. + + This is the format of a single annotation as expected by the annotation + tool. + """ + + start_time: float + """Start time in seconds.""" + + end_time: float + """End time in seconds.""" + + low_freq: int + """Low frequency in Hz.""" + + high_freq: int + """High frequency in Hz.""" + + class_prob: float + """Probability of class assignment.""" + + det_prob: float + """Probability of detection.""" + + individual: str + """Individual ID.""" + + event: str + """Type of detected event.""" + + +class FileAnnotations(TypedDict): + """Format of results. + + This is the format of the results expected by the annotation tool. + """ + + id: str + """File ID.""" + + annotated: bool + """Whether file has been annotated.""" + + duration: float + """Duration of audio file.""" + + issues: bool + """Whether file has issues.""" + + time_exp: float + """Time expansion factor.""" + + class_name: str + """Class predicted at file level""" + + notes: str + """Notes of file.""" + + annotation: List[Annotation] + """List of annotations.""" + + +class RunResults(TypedDict): + """Run results.""" + + pred_dict: FileAnnotations + """Predictions in the format expected by the annotation tool.""" + + spec_feats: Optional[List[np.ndarray]] + """Spectrogram features.""" + + spec_feat_names: Optional[List[str]] + """Spectrogram feature names.""" + + cnn_feats: Optional[List[np.ndarray]] + """CNN features.""" + + cnn_feat_names: Optional[List[str]] + """CNN feature names.""" + + spec_slices: Optional[List[np.ndarray]] + """Spectrogram slices.""" + + +class ResultParams(TypedDict): + """Result parameters.""" + + class_names: List[str] + """Class names.""" + + +class ProcessingConfiguration(TypedDict): + """Parameters for processing audio files.""" + + # audio parameters + target_samp_rate: int + """Target sampling rate of the audio.""" + + fft_win_length: float + """Length of the FFT window in seconds.""" + + fft_overlap: float + """Length of the FFT window in samples.""" + + resize_factor: float + """Factor to resize the spectrogram by.""" + + spec_divide_factor: int + """Factor to divide the spectrogram by.""" + + spec_height: int + """Height of the spectrogram in pixels.""" + + spec_scale: str + """Scale to use for the spectrogram.""" + + denoise_spec_avg: bool + """Whether to denoise the spectrogram by averaging.""" + + max_scale_spec: bool + """Whether to scale the spectrogram so that its max is 1.""" + + scale_raw_audio: bool + """Whether to scale the raw audio to be between -1 and 1.""" + + class_names: List[str] + """Names of the classes the model can detect.""" + + detection_threshold: float + """Threshold for detection probability.""" + + time_expansion: Optional[float] + """Time expansion factor of the processed recordings.""" + + top_n: int + """Number of top detections to keep.""" + + return_raw_preds: bool + """Whether to return raw predictions.""" + + max_duration: Optional[float] + """Maximum duration of audio file to process in seconds.""" + + nms_kernel_size: int + """Size of the kernel for non-maximum suppression.""" + + max_freq: int + """Maximum frequency to consider in Hz.""" + + min_freq: int + """Minimum frequency to consider in Hz.""" + + nms_top_k_per_sec: float + """Number of top detections to keep per second.""" + + quiet: bool + """Whether to suppress output.""" + + chunk_size: float + """Size of chunks to process in seconds.""" + + cnn_features: bool + """Whether to return CNN features.""" + + spec_features: bool + """Whether to return spectrogram features.""" + + spec_slices: bool + """Whether to return spectrogram slices.""" + + +class ModelOutput(NamedTuple): + """Output of the detection model.""" + + pred_det: torch.Tensor + """Tensor with predict detection probabilities.""" + + pred_size: torch.Tensor + """Tensor with predicted bounding box sizes.""" + + pred_class: torch.Tensor + """Tensor with predicted class probabilities.""" + + pred_class_un_norm: torch.Tensor + """Tensor with predicted class probabilities before softmax.""" + + features: torch.Tensor + """Tensor with intermediate features.""" + + +class DetectionModel(Protocol): + """Protocol for detection models. + + This protocol is used to define the interface for the detection models. + This allows us to use the same code for training and inference, even + though the models are different. + """ + + num_classes: int + """Number of classes the model can classify.""" + + emb_dim: int + """Dimension of the embedding vector.""" + + num_filts: int + """Number of filters in the model.""" + + resize_factor: float + """Factor by which the input is resized.""" + + ip_height: int + """Height of the input image.""" + + def forward( + self, + spec: torch.Tensor, + return_feats: bool = False, + ) -> ModelOutput: + """Forward pass of the model.""" + + def __call__( + self, + spec: torch.Tensor, + return_feats: bool = False, + ) -> ModelOutput: + """Forward pass of the model.""" diff --git a/bat_detect/utils/audio_utils.py b/bat_detect/utils/audio_utils.py index cd90a80..bcb262e 100644 --- a/bat_detect/utils/audio_utils.py +++ b/bat_detect/utils/audio_utils.py @@ -38,54 +38,6 @@ __all__ = [ ] -class SpectrogramParameters(TypedDict): - """Parameters for generating spectrograms.""" - - fft_win_length: float - """Length of the FFT window in seconds.""" - - fft_overlap: float - """Percentage of overlap between FFT windows.""" - - spec_height: int - """Height of the spectrogram in pixels.""" - - resize_factor: float - """Factor to resize the spectrogram by.""" - - spec_divide_factor: int - """Factor to divide the spectrogram by.""" - - max_freq: int - """Maximum frequency to display in the spectrogram.""" - - min_freq: int - """Minimum frequency to display in the spectrogram.""" - - spec_scale: str - """Scale to use for the spectrogram.""" - - denoise_spec_avg: bool - """Whether to denoise the spectrogram by averaging.""" - - max_scale_spec: bool - """Whether to scale the spectrogram so that its max is 1.""" - - -DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = { - "fft_win_length": FFT_WIN_LENGTH_S, - "fft_overlap": FFT_OVERLAP, - "spec_height": SPEC_HEIGHT, - "resize_factor": RESIZE_FACTOR, - "spec_divide_factor": SPEC_DIVIDE_FACTOR, - "max_freq": MAX_FREQ_HZ, - "min_freq": MIN_FREQ_HZ, - "spec_scale": SPEC_SCALE, - "denoise_spec_avg": DENOISE_SPEC_AVG, - "max_scale_spec": MAX_SCALE_SPEC, -} - - def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor noverlap = np.floor(fft_overlap * nfft) diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index f1a8c00..3ebd69c 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -11,34 +11,16 @@ import bat_detect.detector.compute_features as feats import bat_detect.detector.post_process as pp import bat_detect.utils.audio_utils as au from bat_detect.detector import models -from bat_detect.detector.parameters import ( - DENOISE_SPEC_AVG, - DETECTION_THRESHOLD, - FFT_OVERLAP, - FFT_WIN_LENGTH_S, - MAX_FREQ_HZ, - MAX_SCALE_SPEC, - MIN_FREQ_HZ, - NMS_KERNEL_SIZE, - NMS_TOP_K_PER_SEC, - RESIZE_FACTOR, - SCALE_RAW_AUDIO, - SPEC_DIVIDE_FACTOR, - SPEC_HEIGHT, - SPEC_SCALE, - TARGET_SAMPLERATE_HZ, -) - -try: - from typing import TypedDict -except ImportError: - from typing_extensions import TypedDict - - -DEFAULT_MODEL_PATH = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "models", - "Net2DFast_UK_same.pth.tar", +from bat_detect.detector.parameters import DEFAULT_MODEL_PATH +from bat_detect.types import ( + Annotation, + FileAnnotations, + ModelParameters, + ProcessingConfiguration, + SpectrogramParameters, + ResultParams, + RunResults, + DetectionModel ) __all__ = [ @@ -50,8 +32,6 @@ __all__ = [ "process_spectrogram", "process_audio_array", "process_file", - "DEFAULT_MODEL_PATH", - "DEFAULT_PROCESSING_CONFIGURATIONS", ] @@ -77,33 +57,11 @@ def list_audio_files(ip_dir: str) -> List[str]: return matches -class ModelParameters(TypedDict): - """Model parameters.""" - - model_name: str - """Model name.""" - - num_filters: int - """Number of filters.""" - - emb_dim: int - """Embedding dimension.""" - - ip_height: int - """Input height in pixels.""" - - resize_factor: float - """Resize factor.""" - - class_names: List[str] - """Class names. The model is trained to detect these classes.""" - - def load_model( model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True, device: Optional[torch.device] = None, -) -> Tuple[models.DetectionModel, ModelParameters]: +) -> Tuple[DetectionModel, ModelParameters]: """Load model from file. Args: @@ -129,7 +87,7 @@ def load_model( params = net_params["params"] - model: models.DetectionModel + model: DetectionModel if params["model_name"] == "Net2DFast": model = models.Net2DFast( @@ -189,101 +147,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): return predictions_m, spec_feats, cnn_feats, spec_slices -DictWithClass = TypedDict("DictWithClass", {"class": str}) - - -class Annotation(DictWithClass): - """Format of annotations. - - This is the format of a single annotation as expected by the annotation - tool. - """ - - start_time: float - """Start time in seconds.""" - - end_time: float - """End time in seconds.""" - - low_freq: int - """Low frequency in Hz.""" - - high_freq: int - """High frequency in Hz.""" - - class_prob: float - """Probability of class assignment.""" - - det_prob: float - """Probability of detection.""" - - individual: str - """Individual ID.""" - - event: str - """Type of detected event.""" - - -class FileAnnotations(TypedDict): - """Format of results. - - This is the format of the results expected by the annotation tool. - """ - - id: str - """File ID.""" - - annotated: bool - """Whether file has been annotated.""" - - duration: float - """Duration of audio file.""" - - issues: bool - """Whether file has issues.""" - - time_exp: float - """Time expansion factor.""" - - class_name: str - """Class predicted at file level""" - - notes: str - """Notes of file.""" - - annotation: List[Annotation] - """List of annotations.""" - - -class RunResults(TypedDict): - """Run results.""" - - pred_dict: FileAnnotations - """Predictions in the format expected by the annotation tool.""" - - spec_feats: Optional[List[np.ndarray]] - """Spectrogram features.""" - - spec_feat_names: Optional[List[str]] - """Spectrogram feature names.""" - - cnn_feats: Optional[List[np.ndarray]] - """CNN features.""" - - cnn_feat_names: Optional[List[str]] - """CNN feature names.""" - - spec_slices: Optional[List[np.ndarray]] - """Spectrogram slices.""" - - -class ResultParams(TypedDict): - """Result parameters.""" - - class_names: List[str] - """Class names.""" - - def get_annotations_from_preds( predictions, class_names: List[str], @@ -499,7 +362,7 @@ def save_results_to_file(results, op_path: str) -> None: def compute_spectrogram( audio: np.ndarray, sampling_rate: int, - params: au.SpectrogramParameters, + params: SpectrogramParameters, device: torch.device, return_np: bool = False, ) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]: @@ -608,90 +471,10 @@ def iterate_over_chunks( yield chunk_start, audio[start_sample:end_sample] -class ProcessingConfiguration(TypedDict): - """Parameters for processing audio files.""" - - # audio parameters - target_samp_rate: int - """Target sampling rate of the audio.""" - - fft_win_length: float - """Length of the FFT window in seconds.""" - - fft_overlap: float - """Length of the FFT window in samples.""" - - resize_factor: float - """Factor to resize the spectrogram by.""" - - spec_divide_factor: int - """Factor to divide the spectrogram by.""" - - spec_height: int - """Height of the spectrogram in pixels.""" - - spec_scale: str - """Scale to use for the spectrogram.""" - - denoise_spec_avg: bool - """Whether to denoise the spectrogram by averaging.""" - - max_scale_spec: bool - """Whether to scale the spectrogram so that its max is 1.""" - - scale_raw_audio: bool - """Whether to scale the raw audio to be between -1 and 1.""" - - class_names: List[str] - """Names of the classes the model can detect.""" - - detection_threshold: float - """Threshold for detection probability.""" - - time_expansion: Optional[float] - """Time expansion factor of the processed recordings.""" - - top_n: int - """Number of top detections to keep.""" - - return_raw_preds: bool - """Whether to return raw predictions.""" - - max_duration: Optional[float] - """Maximum duration of audio file to process in seconds.""" - - nms_kernel_size: int - """Size of the kernel for non-maximum suppression.""" - - max_freq: int - """Maximum frequency to consider in Hz.""" - - min_freq: int - """Minimum frequency to consider in Hz.""" - - nms_top_k_per_sec: float - """Number of top detections to keep per second.""" - - quiet: bool - """Whether to suppress output.""" - - chunk_size: float - """Size of chunks to process in seconds.""" - - cnn_features: bool - """Whether to return CNN features.""" - - spec_features: bool - """Whether to return spectrogram features.""" - - spec_slices: bool - """Whether to return spectrogram slices.""" - - def _process_spectrogram( spec: torch.Tensor, samplerate: int, - model: models.DetectionModel, + model: DetectionModel, config: ProcessingConfiguration, ) -> Tuple[List[Annotation], List[np.ndarray]]: # evaluate model @@ -730,7 +513,7 @@ def _process_spectrogram( def process_spectrogram( spec: torch.Tensor, samplerate: int, - model: models.DetectionModel, + model: DetectionModel, config: ProcessingConfiguration, ) -> Tuple[List[Annotation], List[np.ndarray]]: """Process a spectrogram with detection model. @@ -775,7 +558,7 @@ def process_spectrogram( def _process_audio_array( audio: np.ndarray, sampling_rate: int, - model: torch.nn.Module, + model: DetectionModel, config: ProcessingConfiguration, device: torch.device, ) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: @@ -813,7 +596,7 @@ def _process_audio_array( def process_audio_array( audio: np.ndarray, sampling_rate: int, - model: torch.nn.Module, + model: DetectionModel, config: ProcessingConfiguration, device: torch.device, ) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: @@ -864,7 +647,7 @@ def process_audio_array( def process_file( audio_file: str, - model: torch.nn.Module, + model: DetectionModel, config: ProcessingConfiguration, device: torch.device, ) -> Union[RunResults, Any]: @@ -989,32 +772,3 @@ def summarize_results(results, predictions, config): config["class_names"][class_index].ljust(30) + str(round(class_overall[class_index], 3)) ) - - -DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = { - "detection_threshold": DETECTION_THRESHOLD, - "spec_slices": False, - "chunk_size": 3, - "spec_features": False, - "cnn_features": False, - "quiet": True, - "target_samp_rate": TARGET_SAMPLERATE_HZ, - "fft_win_length": FFT_WIN_LENGTH_S, - "fft_overlap": FFT_OVERLAP, - "resize_factor": RESIZE_FACTOR, - "spec_divide_factor": SPEC_DIVIDE_FACTOR, - "spec_height": SPEC_HEIGHT, - "scale_raw_audio": SCALE_RAW_AUDIO, - "class_names": [], - "time_expansion": 1, - "top_n": 3, - "return_raw_preds": False, - "max_duration": None, - "nms_kernel_size": NMS_KERNEL_SIZE, - "max_freq": MAX_FREQ_HZ, - "min_freq": MIN_FREQ_HZ, - "nms_top_k_per_sec": NMS_TOP_K_PER_SEC, - "spec_scale": SPEC_SCALE, - "denoise_spec_avg": DENOISE_SPEC_AVG, - "max_scale_spec": MAX_SCALE_SPEC, -} diff --git a/tests/test_api.py b/tests/test_api.py index e69de29..902c2df 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -0,0 +1,211 @@ +"""Test bat detect module API.""" + +import os +from glob import glob + +import numpy as np +import torch +from torch import nn + +from bat_detect.api import ( + generate_spectrogram, + get_config, + list_audio_files, + load_audio, + load_model, + process_audio, + process_file, + process_spectrogram, +) + +PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") +TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav")) + + +def test_load_model_with_default_params(): + """Test loading model with default parameters.""" + model, params = load_model() + + assert model is not None + assert isinstance(model, nn.Module) + + assert params is not None + assert isinstance(params, dict) + + assert "model_name" in params + assert "num_filters" in params + assert "emb_dim" in params + assert "ip_height" in params + assert "resize_factor" in params + assert "class_names" in params + + assert params["model_name"] == "Net2DFast" + assert params["num_filters"] == 128 + assert params["emb_dim"] == 0 + assert params["ip_height"] == 128 + assert params["resize_factor"] == 0.5 + assert len(params["class_names"]) == 17 + + +def test_list_audio_files(): + """Test listing audio files.""" + audio_files = list_audio_files(TEST_DATA_DIR) + + assert len(audio_files) == 3 + assert all(path.endswith((".wav", ".WAV")) for path in audio_files) + + +def test_load_audio(): + """Test loading audio.""" + samplerate, audio = load_audio(TEST_DATA[0]) + + assert audio is not None + assert samplerate == 256000 + assert isinstance(audio, np.ndarray) + assert audio.shape == (128000,) + + +def test_generate_spectrogram(): + """Test generating spectrogram.""" + samplerate, audio = load_audio(TEST_DATA[0]) + spectrogram = generate_spectrogram(audio, samplerate) + + assert spectrogram is not None + assert isinstance(spectrogram, torch.Tensor) + assert spectrogram.shape == (1, 1, 128, 512) + + +def test_get_default_config(): + """Test getting default configuration.""" + config = get_config() + + assert config is not None + assert isinstance(config, dict) + + assert config["target_samp_rate"] == 256000 + assert config["fft_win_length"] == 0.002 + assert config["fft_overlap"] == 0.75 + assert config["resize_factor"] == 0.5 + assert config["spec_divide_factor"] == 32 + assert config["spec_height"] == 256 + assert config["spec_scale"] == "pcen" + assert config["denoise_spec_avg"] is True + assert config["max_scale_spec"] is False + assert config["scale_raw_audio"] is False + assert len(config["class_names"]) == 0 + assert config["detection_threshold"] == 0.01 + assert config["time_expansion"] == 1 + assert config["top_n"] == 3 + assert config["return_raw_preds"] is False + assert config["max_duration"] is None + assert config["nms_kernel_size"] == 9 + assert config["max_freq"] == 120000 + assert config["min_freq"] == 10000 + assert config["nms_top_k_per_sec"] == 200 + assert config["quiet"] is True + assert config["chunk_size"] == 3 + assert config["cnn_features"] is False + assert config["spec_features"] is False + assert config["spec_slices"] is False + + +def test_process_file_with_model(): + """Test processing file with model.""" + model, params = load_model() + config = get_config(**params) + predictions = process_file(TEST_DATA[0], model, config=config) + + assert predictions is not None + assert isinstance(predictions, dict) + + assert "pred_dict" in predictions + assert "spec_feats" in predictions + assert "spec_feat_names" in predictions + assert "cnn_feats" in predictions + assert "cnn_feat_names" in predictions + assert "spec_slices" in predictions + + # By default will not return spectrogram features + assert predictions["spec_feats"] is None + assert predictions["spec_feat_names"] is None + assert predictions["cnn_feats"] is None + assert predictions["cnn_feat_names"] is None + assert predictions["spec_slices"] is None + + # Check that predictions are returned + assert isinstance(predictions["pred_dict"], dict) + pred_dict = predictions["pred_dict"] + assert pred_dict["id"] == os.path.basename(TEST_DATA[0]) + assert pred_dict["annotated"] is False + assert pred_dict["issues"] is False + assert pred_dict["notes"] == "Automatically generated." + assert pred_dict["time_exp"] == 1 + assert pred_dict["duration"] == 0.5 + assert pred_dict["class_name"] is not None + assert len(pred_dict["annotation"]) > 0 + + +def test_process_spectrogram_with_model(): + """Test processing spectrogram with model.""" + model, params = load_model() + config = get_config(**params) + samplerate, audio = load_audio(TEST_DATA[0]) + spectrogram = generate_spectrogram(audio, samplerate) + predictions, features = process_spectrogram( + spectrogram, + samplerate, + model, + config=config, + ) + + assert predictions is not None + assert isinstance(predictions, list) + assert len(predictions) > 0 + sample_pred = predictions[0] + assert isinstance(sample_pred, dict) + assert "class" in sample_pred + assert "class_prob" in sample_pred + assert "det_prob" in sample_pred + assert "start_time" in sample_pred + assert "end_time" in sample_pred + assert "low_freq" in sample_pred + assert "high_freq" in sample_pred + + assert features is not None + assert isinstance(features, list) + assert len(features) == 1 + + +def test_process_audio_with_model(): + """Test processing audio with model.""" + model, params = load_model() + config = get_config(**params) + samplerate, audio = load_audio(TEST_DATA[0]) + predictions, features, spec = process_audio( + audio, + samplerate, + model, + config=config, + ) + + assert predictions is not None + assert isinstance(predictions, list) + assert len(predictions) > 0 + sample_pred = predictions[0] + assert isinstance(sample_pred, dict) + assert "class" in sample_pred + assert "class_prob" in sample_pred + assert "det_prob" in sample_pred + assert "start_time" in sample_pred + assert "end_time" in sample_pred + assert "low_freq" in sample_pred + assert "high_freq" in sample_pred + + assert features is not None + assert isinstance(features, list) + assert len(features) == 1 + + assert spec is not None + assert isinstance(spec, torch.Tensor) + assert spec.shape == (1, 1, 128, 512) diff --git a/tests/test_bat_detect.py b/tests/test_bat_detect.py deleted file mode 100644 index 902c2df..0000000 --- a/tests/test_bat_detect.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Test bat detect module API.""" - -import os -from glob import glob - -import numpy as np -import torch -from torch import nn - -from bat_detect.api import ( - generate_spectrogram, - get_config, - list_audio_files, - load_audio, - load_model, - process_audio, - process_file, - process_spectrogram, -) - -PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") -TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav")) - - -def test_load_model_with_default_params(): - """Test loading model with default parameters.""" - model, params = load_model() - - assert model is not None - assert isinstance(model, nn.Module) - - assert params is not None - assert isinstance(params, dict) - - assert "model_name" in params - assert "num_filters" in params - assert "emb_dim" in params - assert "ip_height" in params - assert "resize_factor" in params - assert "class_names" in params - - assert params["model_name"] == "Net2DFast" - assert params["num_filters"] == 128 - assert params["emb_dim"] == 0 - assert params["ip_height"] == 128 - assert params["resize_factor"] == 0.5 - assert len(params["class_names"]) == 17 - - -def test_list_audio_files(): - """Test listing audio files.""" - audio_files = list_audio_files(TEST_DATA_DIR) - - assert len(audio_files) == 3 - assert all(path.endswith((".wav", ".WAV")) for path in audio_files) - - -def test_load_audio(): - """Test loading audio.""" - samplerate, audio = load_audio(TEST_DATA[0]) - - assert audio is not None - assert samplerate == 256000 - assert isinstance(audio, np.ndarray) - assert audio.shape == (128000,) - - -def test_generate_spectrogram(): - """Test generating spectrogram.""" - samplerate, audio = load_audio(TEST_DATA[0]) - spectrogram = generate_spectrogram(audio, samplerate) - - assert spectrogram is not None - assert isinstance(spectrogram, torch.Tensor) - assert spectrogram.shape == (1, 1, 128, 512) - - -def test_get_default_config(): - """Test getting default configuration.""" - config = get_config() - - assert config is not None - assert isinstance(config, dict) - - assert config["target_samp_rate"] == 256000 - assert config["fft_win_length"] == 0.002 - assert config["fft_overlap"] == 0.75 - assert config["resize_factor"] == 0.5 - assert config["spec_divide_factor"] == 32 - assert config["spec_height"] == 256 - assert config["spec_scale"] == "pcen" - assert config["denoise_spec_avg"] is True - assert config["max_scale_spec"] is False - assert config["scale_raw_audio"] is False - assert len(config["class_names"]) == 0 - assert config["detection_threshold"] == 0.01 - assert config["time_expansion"] == 1 - assert config["top_n"] == 3 - assert config["return_raw_preds"] is False - assert config["max_duration"] is None - assert config["nms_kernel_size"] == 9 - assert config["max_freq"] == 120000 - assert config["min_freq"] == 10000 - assert config["nms_top_k_per_sec"] == 200 - assert config["quiet"] is True - assert config["chunk_size"] == 3 - assert config["cnn_features"] is False - assert config["spec_features"] is False - assert config["spec_slices"] is False - - -def test_process_file_with_model(): - """Test processing file with model.""" - model, params = load_model() - config = get_config(**params) - predictions = process_file(TEST_DATA[0], model, config=config) - - assert predictions is not None - assert isinstance(predictions, dict) - - assert "pred_dict" in predictions - assert "spec_feats" in predictions - assert "spec_feat_names" in predictions - assert "cnn_feats" in predictions - assert "cnn_feat_names" in predictions - assert "spec_slices" in predictions - - # By default will not return spectrogram features - assert predictions["spec_feats"] is None - assert predictions["spec_feat_names"] is None - assert predictions["cnn_feats"] is None - assert predictions["cnn_feat_names"] is None - assert predictions["spec_slices"] is None - - # Check that predictions are returned - assert isinstance(predictions["pred_dict"], dict) - pred_dict = predictions["pred_dict"] - assert pred_dict["id"] == os.path.basename(TEST_DATA[0]) - assert pred_dict["annotated"] is False - assert pred_dict["issues"] is False - assert pred_dict["notes"] == "Automatically generated." - assert pred_dict["time_exp"] == 1 - assert pred_dict["duration"] == 0.5 - assert pred_dict["class_name"] is not None - assert len(pred_dict["annotation"]) > 0 - - -def test_process_spectrogram_with_model(): - """Test processing spectrogram with model.""" - model, params = load_model() - config = get_config(**params) - samplerate, audio = load_audio(TEST_DATA[0]) - spectrogram = generate_spectrogram(audio, samplerate) - predictions, features = process_spectrogram( - spectrogram, - samplerate, - model, - config=config, - ) - - assert predictions is not None - assert isinstance(predictions, list) - assert len(predictions) > 0 - sample_pred = predictions[0] - assert isinstance(sample_pred, dict) - assert "class" in sample_pred - assert "class_prob" in sample_pred - assert "det_prob" in sample_pred - assert "start_time" in sample_pred - assert "end_time" in sample_pred - assert "low_freq" in sample_pred - assert "high_freq" in sample_pred - - assert features is not None - assert isinstance(features, list) - assert len(features) == 1 - - -def test_process_audio_with_model(): - """Test processing audio with model.""" - model, params = load_model() - config = get_config(**params) - samplerate, audio = load_audio(TEST_DATA[0]) - predictions, features, spec = process_audio( - audio, - samplerate, - model, - config=config, - ) - - assert predictions is not None - assert isinstance(predictions, list) - assert len(predictions) > 0 - sample_pred = predictions[0] - assert isinstance(sample_pred, dict) - assert "class" in sample_pred - assert "class_prob" in sample_pred - assert "det_prob" in sample_pred - assert "start_time" in sample_pred - assert "end_time" in sample_pred - assert "low_freq" in sample_pred - assert "high_freq" in sample_pred - - assert features is not None - assert isinstance(features, list) - assert len(features) == 1 - - assert spec is not None - assert isinstance(spec, torch.Tensor) - assert spec.shape == (1, 1, 128, 512)