diff --git a/app.py b/app.py index dd35265..1b820b2 100644 --- a/app.py +++ b/app.py @@ -77,7 +77,7 @@ def make_prediction(file_name=None, detection_threshold=0.3): def generate_results_image(audio_file, anns): # load audio - sampling_rate, audio = au.load_audio_file( + sampling_rate, audio = au.load_audio( audio_file, args["time_expansion_factor"], params["target_samp_rate"], diff --git a/bat_detect/api.py b/bat_detect/api.py new file mode 100644 index 0000000..34fdd6c --- /dev/null +++ b/bat_detect/api.py @@ -0,0 +1,215 @@ +from typing import List, Optional, Tuple + +import numpy as np +import torch + +import bat_detect.detector.models as md +import bat_detect.utils.audio_utils as au +import bat_detect.utils.detector_utils as du +from bat_detect.detector.parameters import TARGET_SAMPLERATE_HZ +from bat_detect.utils.detector_utils import list_audio_files, load_model + +# Use GPU if available +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +__all__ = [ + "load_model", + "load_audio", + "list_audio_files", + "generate_spectrogram", + "get_config", + "process_file", + "process_spectrogram", + "process_audio", +] + + +def get_config(**kwargs) -> du.ProcessingConfiguration: + """Get default processing configuration. + + Can be used to override default parameters by passing keyword arguments. + """ + return {**du.DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} + + +def load_audio( + path: str, + time_exp_fact: float = 1, + target_samp_rate: int = TARGET_SAMPLERATE_HZ, + scale: bool = False, + max_duration: Optional[float] = None, +) -> Tuple[int, np.ndarray]: + """Load audio from file. + + Parameters + ---------- + path : str + Path to audio file. + time_exp_fact : float, optional + Time expansion factor, by default 1 + target_samp_rate : int, optional + Target sample rate, by default 256000 + scale : bool, optional + Scale audio to [-1, 1], by default False + max_duration : Optional[float], optional + Maximum duration of audio in seconds, by default None + + Returns + ------- + np.ndarray + Audio data. + int + Sample rate. + """ + return au.load_audio( + path, + time_exp_fact, + target_samp_rate, + scale, + max_duration, + ) + + +def generate_spectrogram( + audio: np.ndarray, + samp_rate: int, + config: Optional[au.SpectrogramParameters] = None, + device: torch.device = DEVICE, +) -> torch.Tensor: + """Generate spectrogram from audio array. + + Parameters + ---------- + audio : np.ndarray + Audio data. + samp_rate : int + Sample rate. + config : Optional[SpectrogramParameters], optional + Spectrogram parameters, by default None (uses default parameters). + + Returns + ------- + torch.Tensor + Spectrogram. + """ + if config is None: + config = au.DEFAULT_SPECTROGRAM_PARAMETERS + + _, spec, _ = du.compute_spectrogram( + audio, + samp_rate, + config, + return_np=False, + device=device, + ) + + return spec + + +def process_file( + audio_file: str, + model: md.DetectionModel, + config: Optional[du.ProcessingConfiguration] = None, + device: torch.device = DEVICE, +) -> du.RunResults: + """Process audio file with model. + + Parameters + ---------- + audio_file : str + Path to audio file. + model : DetectionModel + Detection model. + config : Optional[ProcessingConfiguration], optional + Processing configuration, by default None (uses default parameters). + device : torch.device, optional + Device to use, by default tries to use GPU if available. + """ + if config is None: + config = du.DEFAULT_PROCESSING_CONFIGURATIONS + + return du.process_file( + audio_file, + model, + config, + device, + ) + + +def process_spectrogram( + spec: torch.Tensor, + samp_rate: int, + model: md.DetectionModel, + config: Optional[du.ProcessingConfiguration] = None, +) -> Tuple[List[du.Annotation], List[np.ndarray]]: + """Process spectrogram with model. + + Parameters + ---------- + spec : torch.Tensor + Spectrogram. + samp_rate : int + Sample rate of the audio from which the spectrogram was generated. + model : DetectionModel + Detection model. + config : Optional[ProcessingConfiguration], optional + Processing configuration, by default None (uses default parameters). + + Returns + ------- + DetectionResult + """ + if config is None: + config = du.DEFAULT_PROCESSING_CONFIGURATIONS + + return du.process_spectrogram( + spec, + samp_rate, + model, + config, + ) + + +def process_audio( + audio: np.ndarray, + samp_rate: int, + model: md.DetectionModel, + config: Optional[du.ProcessingConfiguration] = None, + device: torch.device = DEVICE, +) -> Tuple[List[du.Annotation], List[np.ndarray], torch.Tensor]: + """Process audio array with model. + + Parameters + ---------- + audio : np.ndarray + Audio data. + samp_rate : int + Sample rate. + model : DetectionModel + Detection model. + config : Optional[ProcessingConfiguration], optional + Processing configuration, by default None (uses default parameters). + device : torch.device, optional + Device to use, by default tries to use GPU if available. + + Returns + ------- + annotations : List[Annotation] + List of predicted annotations. + + features: List[np.ndarray] + List of extracted features for each annotation. + + spec : torch.Tensor + Spectrogram of the audio used for prediction. + """ + if config is None: + config = du.DEFAULT_PROCESSING_CONFIGURATIONS + + return du.process_audio_array( + audio, + samp_rate, + model, + config, + device, + ) diff --git a/bat_detect/command.py b/bat_detect/cli.py similarity index 98% rename from bat_detect/command.py rename to bat_detect/cli.py index b680b5d..9ad4d32 100644 --- a/bat_detect/command.py +++ b/bat_detect/cli.py @@ -92,7 +92,7 @@ def main(): model, params = du.load_model(args["model_path"]) print("\nInput directory: " + args["audio_dir"]) - files = du.get_audio_files(args["audio_dir"]) + files = du.list_audio_files(args["audio_dir"]) print(f"Number of audio files: {len(files)}") print("\nSaving results to: " + args["ann_dir"]) diff --git a/bat_detect/detector/models.py b/bat_detect/detector/models.py index 94b98ad..b9a18dd 100644 --- a/bat_detect/detector/models.py +++ b/bat_detect/detector/models.py @@ -1,9 +1,11 @@ +from typing import NamedTuple, Optional + import torch import torch.fft import torch.nn.functional as F from torch import nn -from .model_helpers import ( +from bat_detect.detector.model_helpers import ( ConvBlockDownCoordF, ConvBlockDownStandard, ConvBlockUpF, @@ -11,13 +13,88 @@ from .model_helpers import ( SelfAttention, ) +try: + from typing import Protocol +except ImportError: + from typing_extensions import Protocol + __all__ = [ "Net2DFast", "Net2DFastNoAttn", "Net2DFastNoCoordConv", + "ModelOutput", + "DetectionModel", ] +class ModelOutput(NamedTuple): + """Output of the detection model.""" + + pred_det: torch.Tensor + """Tensor with predict detection probabilities.""" + + pred_size: torch.Tensor + """Tensor with predicted bounding box sizes.""" + + pred_class: torch.Tensor + """Tensor with predicted class probabilities.""" + + pred_class_un_norm: torch.Tensor + """Tensor with predicted class probabilities before softmax.""" + + pred_emb: Optional[torch.Tensor] + """Tensor with embeddings.""" + + features: Optional[torch.Tensor] + """Tensor with intermediate features.""" + + +class DetectionModel(Protocol): + """Protocol for detection models. + + This protocol is used to define the interface for the detection models. + This allows us to use the same code for training and inference, even + though the models are different. + """ + + num_classes: int + """Number of classes the model can classify.""" + + emb_dim: int + """Dimension of the embedding vector.""" + + num_filts: int + """Number of filters in the model.""" + + resize_factor: float + """Factor by which the input is resized.""" + + ip_height: int + """Height of the input image.""" + + def forward( + self, + ip: torch.Tensor, + return_feats: bool = False, + ) -> ModelOutput: + """Forward pass of the model. + + When `return_feats` is `True`, the model should return the + intermediate features of the model. + """ + + def __call__( + self, + ip: torch.Tensor, + return_feats: bool = False, + ) -> ModelOutput: + """Forward pass of the model. + + When `return_feats` is `True`, the model should return the + int + """ + + class Net2DFast(nn.Module): def __init__( self, @@ -27,7 +104,7 @@ class Net2DFast(nn.Module): ip_height=128, resize_factor=0.5, ): - super(Net2DFast, self).__init__() + super().__init__() self.num_classes = num_classes self.emb_dim = emb_dim self.num_filts = num_filts @@ -102,7 +179,7 @@ class Net2DFast(nn.Module): num_filts, self.emb_dim, kernel_size=1, padding=0 ) - def forward(self, ip, return_feats=False): + def forward(self, ip, return_feats=False) -> ModelOutput: # encoder x1 = self.conv_dn_0(ip) @@ -125,17 +202,14 @@ class Net2DFast(nn.Module): cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) - op = {} - op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) - op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) - op["pred_class"] = comb - op["pred_class_un_norm"] = cls - if self.emb_dim > 0: - op["pred_emb"] = self.conv_emb(x) - if return_feats: - op["features"] = x - - return op + return ModelOutput( + pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), + pred_size=F.relu(self.conv_size_op(x), inplace=True), + pred_class=comb, + pred_class_un_norm=cls, + pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, + features=x if return_feats else None, + ) class Net2DFastNoAttn(nn.Module): @@ -147,7 +221,7 @@ class Net2DFastNoAttn(nn.Module): ip_height=128, resize_factor=0.5, ): - super(Net2DFastNoAttn, self).__init__() + super().__init__() self.num_classes = num_classes self.emb_dim = emb_dim @@ -219,8 +293,7 @@ class Net2DFastNoAttn(nn.Module): num_filts, self.emb_dim, kernel_size=1, padding=0 ) - def forward(self, ip, return_feats=False): - + def forward(self, ip, return_feats=False) -> ModelOutput: x1 = self.conv_dn_0(ip) x2 = self.conv_dn_1(x1) x3 = self.conv_dn_2(x2) @@ -237,17 +310,14 @@ class Net2DFastNoAttn(nn.Module): cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) - op = {} - op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) - op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) - op["pred_class"] = comb - op["pred_class_un_norm"] = cls - if self.emb_dim > 0: - op["pred_emb"] = self.conv_emb(x) - if return_feats: - op["features"] = x - - return op + return ModelOutput( + pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), + pred_size=F.relu(self.conv_size_op(x), inplace=True), + pred_class=comb, + pred_class_un_norm=cls, + pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, + features=x if return_feats else None, + ) class Net2DFastNoCoordConv(nn.Module): @@ -259,7 +329,7 @@ class Net2DFastNoCoordConv(nn.Module): ip_height=128, resize_factor=0.5, ): - super(Net2DFastNoCoordConv, self).__init__() + super().__init__() self.num_classes = num_classes self.emb_dim = emb_dim @@ -333,7 +403,7 @@ class Net2DFastNoCoordConv(nn.Module): num_filts, self.emb_dim, kernel_size=1, padding=0 ) - def forward(self, ip, return_feats=False): + def forward(self, ip, return_feats=False) -> ModelOutput: x1 = self.conv_dn_0(ip) x2 = self.conv_dn_1(x1) @@ -352,14 +422,11 @@ class Net2DFastNoCoordConv(nn.Module): cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) - op = {} - op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) - op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) - op["pred_class"] = comb - op["pred_class_un_norm"] = cls - if self.emb_dim > 0: - op["pred_emb"] = self.conv_emb(x) - if return_feats: - op["features"] = x - - return op + return ModelOutput( + pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), + pred_size=F.relu(self.conv_size_op(x), inplace=True), + pred_class=comb, + pred_class_un_norm=cls, + pred_emb=self.conv_emb(x) if self.emb_dim > 0 else None, + features=x if return_feats else None, + ) diff --git a/bat_detect/detector/post_process.py b/bat_detect/detector/post_process.py index fbbd410..eb8cfbc 100644 --- a/bat_detect/detector/post_process.py +++ b/bat_detect/detector/post_process.py @@ -5,6 +5,8 @@ import numpy as np import torch from torch import nn +from bat_detect.detector.models import ModelOutput + try: from typing import TypedDict except ImportError: @@ -106,24 +108,8 @@ class PredictionResults(TypedDict): """Class probabilities.""" -class ModelOutputs(TypedDict): - """Outputs of the model.""" - - pred_det: torch.Tensor - """Detection probabilities.""" - - pred_size: torch.Tensor - """Box sizes.""" - - pred_class: Optional[torch.Tensor] - """Class probabilities.""" - - features: Optional[torch.Tensor] - """Features extracted by the model.""" - - def run_nms( - outputs: ModelOutputs, + outputs: ModelOutput, params: NonMaximumSuppressionConfig, sampling_rate: np.ndarray, ) -> Tuple[List[PredictionResults], List[np.ndarray]]: @@ -135,16 +121,14 @@ def run_nms( the features. Each element of the lists corresponds to one element of the batch. """ - - pred_det = outputs["pred_det"] # probability of box - pred_size = outputs["pred_size"] # box size + pred_det, pred_size, pred_class, _, _, features = outputs pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"]) freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[ -2 ] - # NOTE there will be small differences depending on which sampling rate is chosen + # NOTE: there will be small differences depending on which sampling rate is chosen # as we are choosing the same sampling rate for the entire batch duration = x_coords_to_time( pred_det.shape[-1], @@ -172,10 +156,16 @@ def run_nms( pred["x_pos"] = x_pos[num_detection, valid_inds] pred["y_pos"] = y_pos[num_detection, valid_inds] pred["bb_width"] = pred_size[ - num_detection, 0, pred["y_pos"], pred["x_pos"] + num_detection, + 0, + pred["y_pos"], + pred["x_pos"], ] pred["bb_height"] = pred_size[ - num_detection, 1, pred["y_pos"], pred["x_pos"] + num_detection, + 1, + pred["y_pos"], + pred["x_pos"], ] pred["start_times"] = x_coords_to_time( pred["x_pos"].float() / params["resize_factor"], @@ -198,7 +188,6 @@ def run_nms( ) # extract the per class votes - pred_class = outputs.get("pred_class") if pred_class is not None: pred["class_probs"] = pred_class[ num_detection, @@ -208,7 +197,6 @@ def run_nms( ] # extract the model features - features = outputs.get("features") if features is not None: feat = features[ num_detection, diff --git a/bat_detect/train/audio_dataloader.py b/bat_detect/train/audio_dataloader.py index f7790a6..cce8255 100644 --- a/bat_detect/train/audio_dataloader.py +++ b/bat_detect/train/audio_dataloader.py @@ -373,7 +373,7 @@ class AudioLoader(torch.utils.data.Dataset): index = np.random.randint(0, len(self.data_anns)) audio_file = self.data_anns[index]["file_path"] - sampling_rate, audio_raw = au.load_audio_file( + sampling_rate, audio_raw = au.load_audio( audio_file, self.data_anns[index]["time_exp"], self.params["target_samp_rate"], diff --git a/bat_detect/utils/audio_utils.py b/bat_detect/utils/audio_utils.py index 23bfc2c..cd90a80 100644 --- a/bat_detect/utils/audio_utils.py +++ b/bat_detect/utils/audio_utils.py @@ -5,13 +5,87 @@ import librosa import numpy as np import torch +from bat_detect.detector.parameters import ( + DENOISE_SPEC_AVG, + DETECTION_THRESHOLD, + FFT_OVERLAP, + FFT_WIN_LENGTH_S, + MAX_FREQ_HZ, + MAX_SCALE_SPEC, + MIN_FREQ_HZ, + NMS_KERNEL_SIZE, + NMS_TOP_K_PER_SEC, + RESIZE_FACTOR, + SCALE_RAW_AUDIO, + SPEC_DIVIDE_FACTOR, + SPEC_HEIGHT, + SPEC_SCALE, +) + from . import wavfile +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + __all__ = [ - "load_audio_file", + "load_audio", + "generate_spectrogram", + "pad_audio", + "SpectrogramParameters", + "DEFAULT_SPECTROGRAM_PARAMETERS", ] +class SpectrogramParameters(TypedDict): + """Parameters for generating spectrograms.""" + + fft_win_length: float + """Length of the FFT window in seconds.""" + + fft_overlap: float + """Percentage of overlap between FFT windows.""" + + spec_height: int + """Height of the spectrogram in pixels.""" + + resize_factor: float + """Factor to resize the spectrogram by.""" + + spec_divide_factor: int + """Factor to divide the spectrogram by.""" + + max_freq: int + """Maximum frequency to display in the spectrogram.""" + + min_freq: int + """Minimum frequency to display in the spectrogram.""" + + spec_scale: str + """Scale to use for the spectrogram.""" + + denoise_spec_avg: bool + """Whether to denoise the spectrogram by averaging.""" + + max_scale_spec: bool + """Whether to scale the spectrogram so that its max is 1.""" + + +DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = { + "fft_win_length": FFT_WIN_LENGTH_S, + "fft_overlap": FFT_OVERLAP, + "spec_height": SPEC_HEIGHT, + "resize_factor": RESIZE_FACTOR, + "spec_divide_factor": SPEC_DIVIDE_FACTOR, + "max_freq": MAX_FREQ_HZ, + "min_freq": MIN_FREQ_HZ, + "spec_scale": SPEC_SCALE, + "denoise_spec_avg": DENOISE_SPEC_AVG, + "max_scale_spec": MAX_SCALE_SPEC, +} + + def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor noverlap = np.floor(fft_overlap * nfft) @@ -36,7 +110,10 @@ def generate_spectrogram( # generate spectrogram spec = gen_mag_spectrogram( - audio, sampling_rate, params["fft_win_length"], params["fft_overlap"] + audio, + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], ) # crop to min/max freq @@ -70,6 +147,7 @@ def generate_spectrogram( spec = np.log1p(log_scaling * spec_cropped) elif params["spec_scale"] == "pcen": spec = pcen(spec_cropped, sampling_rate) + elif params["spec_scale"] == "none": pass @@ -109,13 +187,13 @@ def generate_spectrogram( return spec, spec_for_viz -def load_audio_file( +def load_audio( audio_file: str, time_exp_fact: float, target_samp_rate: int, scale: bool = False, max_duration: Optional[float] = None, -): +) -> Tuple[int, np.ndarray]: """Load an audio file and resample it to the target sampling rate. The audio is also scaled to [-1, 1] and clipped to the maximum duration. diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index f748396..9e34ea2 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -43,19 +43,19 @@ DEFAULT_MODEL_PATH = os.path.join( __all__ = [ "load_model", - "get_audio_files", - "get_default_config", - "format_results", + "list_audio_files", + "format_single_result", "save_results_to_file", "iterate_over_chunks", "process_spectrogram", "process_audio_array", "process_file", "DEFAULT_MODEL_PATH", + "DEFAULT_PROCESSING_CONFIGURATIONS", ] -def get_audio_files(ip_dir: str) -> List[str]: +def list_audio_files(ip_dir: str) -> List[str]: """Get all audio files in directory. Args: @@ -98,13 +98,12 @@ class ModelParameters(TypedDict): class_names: List[str] """Class names. The model is trained to detect these classes.""" - device: torch.device - def load_model( model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True, -) -> Tuple[torch.nn.Module, ModelParameters]: + device: Optional[torch.device] = None, +) -> Tuple[models.DetectionModel, ModelParameters]: """Load model from file. Args: @@ -120,7 +119,8 @@ def load_model( """ # load model - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not os.path.isfile(model_path): raise FileNotFoundError("Model file not found.") @@ -128,9 +128,8 @@ def load_model( net_params = torch.load(model_path, map_location=device) params = net_params["params"] - params["device"] = device - model: torch.nn.Module + model: models.DetectionModel if params["model_name"] == "Net2DFast": model = models.Net2DFast( @@ -162,7 +161,7 @@ def load_model( if load_weights: model.load_state_dict(net_params["state_dict"]) - model = model.to(params["device"]) + model = model.to(device) model.eval() return model, params @@ -285,30 +284,11 @@ class ResultParams(TypedDict): """Class names.""" -def format_results( - file_id: str, - time_exp: float, - duration: float, +def get_annotations_from_preds( predictions, class_names: List[str], -) -> FileAnnotations: - """Format results into the format expected by the annotation tool. - - Args: - file_id (str): File ID. - time_exp (float): Time expansion factor. - duration (float): Duration of audio file. - predictions (dict): Predictions. - - Returns: - dict: Results in the format expected by the annotation tool. - """ - # Get a single class prediction for the file - class_overall = pp.overall_class_pred( - predictions["det_probs"], - predictions["class_probs"], - ) - +) -> List[Annotation]: + """Get list of annotations from predictions.""" # Get the best class prediction probability and index for each detection class_prob_best = predictions["class_probs"].max(0) class_ind_best = predictions["class_probs"].argmax(0) @@ -344,6 +324,32 @@ def format_results( predictions["det_probs"], ) ] + return annotations + + +def format_single_result( + file_id: str, + time_exp: float, + duration: float, + predictions, + class_names: List[str], +) -> FileAnnotations: + """Format results into the format expected by the annotation tool. + + Args: + file_id (str): File ID. + time_exp (float): Time expansion factor. + duration (float): Duration of audio file. + predictions (dict): Predictions. + + Returns: + dict: Results in the format expected by the annotation tool. + """ + # Get a single class prediction for the file + class_overall = pp.overall_class_pred( + predictions["det_probs"], + predictions["class_probs"], + ) return { "id": file_id, @@ -352,7 +358,7 @@ def format_results( "notes": "Automatically generated.", "time_exp": time_exp, "duration": round(float(duration), 4), - "annotation": annotations, + "annotation": get_annotations_from_preds(predictions, class_names), "class_name": class_names[np.argmax(class_overall)], } @@ -383,7 +389,7 @@ def convert_results( dict: Dictionary with results. """ - pred_dict = format_results( + pred_dict = format_single_result( file_id, time_exp, duration, @@ -490,47 +496,11 @@ def save_results_to_file(results, op_path: str) -> None: json.dump(results["pred_dict"], jsonfile, indent=2, sort_keys=True) -class SpectrogramParameters(TypedDict): - """Parameters for generating spectrograms.""" - - fft_win_length: float - """Length of the FFT window in seconds.""" - - fft_overlap: float - """Percentage of overlap between FFT windows.""" - - spec_height: int - """Height of the spectrogram in pixels.""" - - resize_factor: float - """Factor to resize the spectrogram by.""" - - spec_divide_factor: int - """Factor to divide the spectrogram by.""" - - device: torch.device - """Device to store the spectrogram on.""" - - max_freq: int - """Maximum frequency to display in the spectrogram.""" - - min_freq: int - """Minimum frequency to display in the spectrogram.""" - - spec_scale: str - """Scale to use for the spectrogram.""" - - denoise_spec_avg: bool - """Whether to denoise the spectrogram by averaging.""" - - max_scale_spec: bool - """Whether to scale the spectrogram so that its max is 1.""" - - def compute_spectrogram( audio: np.ndarray, sampling_rate: int, - params: SpectrogramParameters, + params: au.SpectrogramParameters, + device: torch.device, return_np: bool = False, ) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]: """Compute a spectrogram from an audio array. @@ -578,7 +548,7 @@ def compute_spectrogram( spec, _ = au.generate_spectrogram(audio, sampling_rate, params) # convert to pytorch - spec = torch.from_numpy(spec).to(params["device"]) + spec = torch.from_numpy(spec).to(device) # add batch and channel dimensions spec = spec.unsqueeze(0).unsqueeze(0) @@ -672,9 +642,6 @@ class ProcessingConfiguration(TypedDict): scale_raw_audio: bool """Whether to scale the raw audio to be between -1 and 1.""" - device: torch.device - """Device to run the model on.""" - class_names: List[str] """Names of the classes the model can detect.""" @@ -721,33 +688,12 @@ class ProcessingConfiguration(TypedDict): """Whether to return spectrogram slices.""" -def process_spectrogram( +def _process_spectrogram( spec: torch.Tensor, samplerate: int, - model: torch.nn.Module, + model: models.DetectionModel, config: ProcessingConfiguration, -): - """Process a spectrogram with detection model. - - Will run non-maximum suppression on the output of the model. - - Parameters - ---------- - spec : torch.Tensor - - samplerate : int - - model : torch.nn.Module - Detection model. - - config : pp.NonMaximumSuppressionConfig - Parameters for non-maximum suppression. - - Returns - ------- - pred_nms : Dict[str, np.ndarray] - features : Dict[str, np.ndarray] - """ +) -> Tuple[List[Annotation], List[np.ndarray]]: # evaluate model with torch.no_grad(): outputs = model(spec, return_feats=config["cnn_features"]) @@ -781,12 +727,96 @@ def process_spectrogram( return pred_nms, features +def process_spectrogram( + spec: torch.Tensor, + samplerate: int, + model: models.DetectionModel, + config: ProcessingConfiguration, +) -> Tuple[List[Annotation], List[np.ndarray]]: + """Process a spectrogram with detection model. + + Will run non-maximum suppression on the output of the model. + + Parameters + ---------- + spec : torch.Tensor + + samplerate : int + + model : torch.nn.Module + Detection model. + + config : pp.NonMaximumSuppressionConfig + Parameters for non-maximum suppression. + + Returns + ------- + annotations : List[Annotation] + List of annotations predicted by the model. + features : List[np.ndarray] + List of CNN features associated with each annotation. + Is empty if `config["cnn_features"]` is False. + """ + pred_nms, features = _process_spectrogram( + spec, + samplerate, + model, + config, + ) + + annotations = get_annotations_from_preds( + pred_nms, + config["class_names"], + ) + + return annotations, features + + +def _process_audio_array( + audio: np.ndarray, + sampling_rate: int, + model: torch.nn.Module, + config: ProcessingConfiguration, + device: torch.device, +) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: + # load audio file and compute spectrogram + _, spec, _ = compute_spectrogram( + audio, + sampling_rate, + { + "fft_win_length": config["fft_win_length"], + "fft_overlap": config["fft_overlap"], + "spec_height": config["spec_height"], + "resize_factor": config["resize_factor"], + "spec_divide_factor": config["spec_divide_factor"], + "max_freq": config["max_freq"], + "min_freq": config["min_freq"], + "spec_scale": config["spec_scale"], + "denoise_spec_avg": config["denoise_spec_avg"], + "max_scale_spec": config["max_scale_spec"], + }, + device, + return_np=False, + ) + + # process spectrogram with model + pred_nms, features = _process_spectrogram( + spec, + sampling_rate, + model, + config, + ) + + return pred_nms, features, spec + + def process_audio_array( audio: np.ndarray, sampling_rate: int, model: torch.nn.Module, config: ProcessingConfiguration, -): + device: torch.device, +) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: """Process a single audio array with detection model. Parameters @@ -801,47 +831,42 @@ def process_audio_array( config : ProcessingConfiguration Configuration for processing. + device : torch.device + Device to use for processing. + Returns ------- - pred_nms : Dict[str, np.ndarray] - features : Dict[str, np.ndarray] - spec_np : np.ndarray - """ - # load audio file and compute spectrogram - _, spec, spec_np = compute_spectrogram( - audio, - sampling_rate, - { - "fft_win_length": config["fft_win_length"], - "fft_overlap": config["fft_overlap"], - "spec_height": config["spec_height"], - "resize_factor": config["resize_factor"], - "spec_divide_factor": config["spec_divide_factor"], - "device": config["device"], - "max_freq": config["max_freq"], - "min_freq": config["min_freq"], - "spec_scale": config["spec_scale"], - "denoise_spec_avg": config["denoise_spec_avg"], - "max_scale_spec": config["max_scale_spec"], - }, - return_np=config["spec_features"] or config["spec_slices"], - ) + annotations : List[Annotation] + List of annotations predicted by the model. - # process spectrogram with model - pred_nms, features = process_spectrogram( - spec, + features : List[np.ndarray] + List of CNN features associated with each annotation. + + spec : torch.Tensor + Spectrogram of the audio used as input. + + """ + pred_nms, features, spec = _process_audio_array( + audio, sampling_rate, model, config, + device, ) - return pred_nms, features, spec_np + annotations = get_annotations_from_preds( + pred_nms, + config["class_names"], + ) + + return annotations, features, spec def process_file( audio_file: str, model: torch.nn.Module, config: ProcessingConfiguration, + device: torch.device, ) -> Union[RunResults, Any]: """Process a single audio file with detection model. @@ -872,7 +897,7 @@ def process_file( spec_slices = [] # load audio file - sampling_rate, audio_full = au.load_audio_file( + sampling_rate, audio_full = au.load_audio( audio_file, time_exp_fact=config.get("time_expansion", 1) or 1, target_samp_rate=config["target_samp_rate"], @@ -881,7 +906,7 @@ def process_file( ) # loop through larger file and split into chunks - # TODO fix so that it overlaps correctly and takes care of + # TODO: fix so that it overlaps correctly and takes care of # duplicate detections at borders for chunk_time, audio in iterate_over_chunks( audio_full, @@ -889,11 +914,12 @@ def process_file( config["chunk_size"], ): # Run detection model on chunk - pred_nms, features, spec_np = process_audio_array( + pred_nms, features, spec_np = _process_audio_array( audio, sampling_rate, model, config, + device, ) # add chunk time to start and end times @@ -965,39 +991,30 @@ def summarize_results(results, predictions, config): ) -def get_default_config(**kwargs) -> ProcessingConfiguration: - """Get default configuration for running detection model.""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - args: ProcessingConfiguration = { - "detection_threshold": DETECTION_THRESHOLD, - "spec_slices": False, - "chunk_size": 3, - "spec_features": False, - "cnn_features": False, - "quiet": True, - "target_samp_rate": TARGET_SAMPLERATE_HZ, - "fft_win_length": FFT_WIN_LENGTH_S, - "fft_overlap": FFT_OVERLAP, - "resize_factor": RESIZE_FACTOR, - "spec_divide_factor": SPEC_DIVIDE_FACTOR, - "spec_height": SPEC_HEIGHT, - "scale_raw_audio": SCALE_RAW_AUDIO, - "device": device, - "class_names": [], - "time_expansion": 1, - "top_n": 3, - "return_raw_preds": False, - "max_duration": None, - "nms_kernel_size": NMS_KERNEL_SIZE, - "max_freq": MAX_FREQ_HZ, - "min_freq": MIN_FREQ_HZ, - "nms_top_k_per_sec": NMS_TOP_K_PER_SEC, - "spec_scale": SPEC_SCALE, - "denoise_spec_avg": DENOISE_SPEC_AVG, - "max_scale_spec": MAX_SCALE_SPEC, - } - return { - **args, - **kwargs, - } +DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = { + "detection_threshold": DETECTION_THRESHOLD, + "spec_slices": False, + "chunk_size": 3, + "spec_features": False, + "cnn_features": False, + "quiet": True, + "target_samp_rate": TARGET_SAMPLERATE_HZ, + "fft_win_length": FFT_WIN_LENGTH_S, + "fft_overlap": FFT_OVERLAP, + "resize_factor": RESIZE_FACTOR, + "spec_divide_factor": SPEC_DIVIDE_FACTOR, + "spec_height": SPEC_HEIGHT, + "scale_raw_audio": SCALE_RAW_AUDIO, + "class_names": [], + "time_expansion": 1, + "top_n": 3, + "return_raw_preds": False, + "max_duration": None, + "nms_kernel_size": NMS_KERNEL_SIZE, + "max_freq": MAX_FREQ_HZ, + "min_freq": MIN_FREQ_HZ, + "nms_top_k_per_sec": NMS_TOP_K_PER_SEC, + "spec_scale": SPEC_SCALE, + "denoise_spec_avg": DENOISE_SPEC_AVG, + "max_scale_spec": MAX_SCALE_SPEC, +} diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py index a7cb0a6..c8f8639 100644 --- a/scripts/gen_spec_image.py +++ b/scripts/gen_spec_image.py @@ -114,7 +114,7 @@ if __name__ == "__main__": # load audio and crop print("\nProcessing: " + os.path.basename(args_cmd["audio_file"])) print("\nOutput directory: " + args_cmd["op_dir"]) - sampling_rate, audio = au.load_audio_file( + sampling_rate, audio = au.load_audio( args_cmd["audio_file"], args_cmd["time_exp"], params_bd["target_samp_rate"], diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py index f7185a0..2588ede 100644 --- a/scripts/gen_spec_video.py +++ b/scripts/gen_spec_video.py @@ -96,7 +96,7 @@ if __name__ == "__main__": # load audio file print("\nProcessing: " + os.path.basename(audio_file)) print("\nOutput directory: " + op_dir) - sampling_rate, audio = au.load_audio_file( + sampling_rate, audio = au.load_audio( audio_file, args["time_expansion_factor"], params["target_samp_rate"] ) audio = audio[ diff --git a/scripts/viz_helpers.py b/scripts/viz_helpers.py index a286037..5044b8e 100644 --- a/scripts/viz_helpers.py +++ b/scripts/viz_helpers.py @@ -72,7 +72,7 @@ def load_data( sampling_rates = [] file_names = [] for cur_file in anns: - sampling_rate, audio_orig = au.load_audio_file( + sampling_rate, audio_orig = au.load_audio( cur_file["file_path"], cur_file["time_exp"], params["target_samp_rate"], diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_bat_detect.py b/tests/test_bat_detect.py new file mode 100644 index 0000000..6440261 --- /dev/null +++ b/tests/test_bat_detect.py @@ -0,0 +1,213 @@ +"""Test bat detect module API.""" + +import os +from glob import glob + +import numpy as np +import torch +from torch import nn + +from bat_detect.api import ( + generate_spectrogram, + get_config, + list_audio_files, + load_audio, + load_model, + process_audio, + process_file, + process_spectrogram, +) + +PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") +TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav")) + + +def test_load_model_with_default_params(): + """Test loading model with default parameters.""" + model, params = load_model() + + assert model is not None + assert isinstance(model, nn.Module) + + assert params is not None + assert isinstance(params, dict) + + assert "model_name" in params + assert "num_filters" in params + assert "emb_dim" in params + assert "ip_height" in params + assert "resize_factor" in params + assert "class_names" in params + + assert params["model_name"] == "Net2DFast" + assert params["num_filters"] == 128 + assert params["emb_dim"] == 0 + assert params["ip_height"] == 128 + assert params["resize_factor"] == 0.5 + assert len(params["class_names"]) == 17 + + +def test_list_audio_files(): + """Test listing audio files.""" + audio_files = list_audio_files(TEST_DATA_DIR) + + assert len(audio_files) == 3 + assert all(path.endswith((".wav", ".WAV")) for path in audio_files) + + +def test_load_audio(): + """Test loading audio.""" + samplerate, audio = load_audio(TEST_DATA[0]) + + assert audio is not None + assert samplerate == 256000 + assert isinstance(audio, np.ndarray) + assert audio.shape == (128000,) + + +def test_generate_spectrogram(): + """Test generating spectrogram.""" + samplerate, audio = load_audio(TEST_DATA[0]) + spectrogram = generate_spectrogram(audio, samplerate) + + assert spectrogram is not None + assert isinstance(spectrogram, torch.Tensor) + assert spectrogram.shape == (1, 1, 128, 512) + + +def test_get_default_config(): + """Test getting default configuration.""" + config = get_config() + + assert config is not None + assert isinstance(config, dict) + + assert config["target_samp_rate"] == 256000 + assert config["fft_win_length"] == 0.002 + assert config["fft_overlap"] == 0.75 + assert config["resize_factor"] == 0.5 + assert config["spec_divide_factor"] == 32 + assert config["spec_height"] == 256 + assert config["spec_scale"] == "pcen" + assert config["denoise_spec_avg"] is True + assert config["max_scale_spec"] is False + assert config["scale_raw_audio"] is False + assert len(config["class_names"]) == 0 + assert config["detection_threshold"] == 0.01 + assert config["time_expansion"] == 1 + assert config["top_n"] == 3 + assert config["return_raw_preds"] is False + assert config["max_duration"] is None + assert config["nms_kernel_size"] == 9 + assert config["max_freq"] == 120000 + assert config["min_freq"] == 10000 + assert config["nms_top_k_per_sec"] == 200 + assert config["quiet"] is True + assert config["chunk_size"] == 3 + assert config["cnn_features"] is False + assert config["spec_features"] is False + assert config["spec_slices"] is False + + +def test_process_file_with_model(): + """Test processing file with model.""" + model, params = load_model() + config = get_config(**params) + predictions = process_file(TEST_DATA[0], model, config=config) + + assert predictions is not None + assert isinstance(predictions, dict) + + assert "pred_dict" in predictions + assert "spec_feats" in predictions + assert "spec_feat_names" in predictions + assert "cnn_feats" in predictions + assert "cnn_feat_names" in predictions + assert "spec_slices" in predictions + + # By default will not return spectrogram features + assert predictions["spec_feats"] is None + assert predictions["spec_feat_names"] is None + assert predictions["cnn_feats"] is None + assert predictions["cnn_feat_names"] is None + assert predictions["spec_slices"] is None + + # Check that predictions are returned + assert isinstance(predictions["pred_dict"], dict) + pred_dict = predictions["pred_dict"] + assert pred_dict["id"] == os.path.basename(TEST_DATA[0]) + assert pred_dict["annotated"] is False + assert pred_dict["issues"] is False + assert pred_dict["notes"] == "Automatically generated." + assert pred_dict["time_exp"] == 1 + assert pred_dict["duration"] == 0.5 + assert pred_dict["class_name"] is not None + assert len(pred_dict["annotation"]) > 0 + + +def test_process_spectrogram_with_model(): + """Test processing spectrogram with model.""" + model, params = load_model() + config = get_config(**params) + samplerate, audio = load_audio(TEST_DATA[0]) + spectrogram = generate_spectrogram(audio, samplerate) + predictions, features = process_spectrogram( + spectrogram, + samplerate, + model, + config=config, + ) + + assert predictions is not None + assert isinstance(predictions, list) + assert len(predictions) > 0 + sample_pred = predictions[0] + assert isinstance(sample_pred, dict) + assert "class" in sample_pred + assert "class_prob" in sample_pred + assert "det_prob" in sample_pred + assert "start_time" in sample_pred + assert "end_time" in sample_pred + assert "low_freq" in sample_pred + assert "high_freq" in sample_pred + + assert features is not None + assert isinstance(features, list) + # By default will not return cnn features + assert len(features) == 0 + + +def test_process_audio_with_model(): + """Test processing audio with model.""" + model, params = load_model() + config = get_config(**params) + samplerate, audio = load_audio(TEST_DATA[0]) + predictions, features, spec = process_audio( + audio, + samplerate, + model, + config=config, + ) + + assert predictions is not None + assert isinstance(predictions, list) + assert len(predictions) > 0 + sample_pred = predictions[0] + assert isinstance(sample_pred, dict) + assert "class" in sample_pred + assert "class_prob" in sample_pred + assert "det_prob" in sample_pred + assert "start_time" in sample_pred + assert "end_time" in sample_pred + assert "low_freq" in sample_pred + assert "high_freq" in sample_pred + + assert features is not None + assert isinstance(features, list) + # By default will not return cnn features + assert len(features) == 0 + + assert spec is not None + assert isinstance(spec, torch.Tensor) + assert spec.shape == (1, 1, 128, 512) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..e69de29