diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..bdaab28 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,39 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore index 252a12b..ad423c1 100644 --- a/.gitignore +++ b/.gitignore @@ -65,7 +65,7 @@ ipython_config.py # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide -.pdm.toml +.pdm-python # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ @@ -102,7 +102,11 @@ experiments/* .virtual_documents .ipynb_checkpoints *.ipynb -!batdetect2_notebook.ipynb -# Batdetect Models [Include] +# Bump2version +.bumpversion.cfg + +# DO Include +!batdetect2_notebook.ipynb !batdetect2/models/*.pth.tar +!tests/data/*.wav diff --git a/README.md b/README.md index f62a078..e128ec6 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ pip install batdetect2 ``` Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it. -Once unziped, run this from extracted folder. +Once unzipped, run this from extracted folder. ```bash pip install . @@ -98,7 +98,7 @@ You can integrate the detections or the extracted features to your custom analys ## Training the model on your own data -Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model. +Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model. ## Data and annotations diff --git a/batdetect2/__init__.py b/batdetect2/__init__.py index 1f356cc..887a342 100644 --- a/batdetect2/__init__.py +++ b/batdetect2/__init__.py @@ -1 +1 @@ -__version__ = '1.0.0' +__version__ = '1.0.7' diff --git a/batdetect2/cli.py b/batdetect2/cli.py index b5ef01a..66aada8 100644 --- a/batdetect2/cli.py +++ b/batdetect2/cli.py @@ -5,6 +5,7 @@ import click from batdetect2 import api from batdetect2.detector.parameters import DEFAULT_MODEL_PATH +from batdetect2.types import ProcessingConfiguration from batdetect2.utils.detector_utils import save_results_to_file CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -77,6 +78,7 @@ def detect( audio_dir: str, ann_dir: str, detection_threshold: float, + time_expansion_factor: int, **args, ): """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR. @@ -103,16 +105,23 @@ def detect( **{ **params, **args, + "time_expansion": time_expansion_factor, "spec_slices": False, "chunk_size": 2, "detection_threshold": detection_threshold, } ) + if not args["quiet"]: + print_config(config) + # process files error_files = [] - for audio_file in files: + for index, audio_file in enumerate(files): try: + if not args["quiet"]: + click.echo(f"\n{index} {audio_file}") + results = api.process_file(audio_file, model, config=config) if args["save_preds_if_empty"] or ( @@ -133,5 +142,12 @@ def detect( click.echo(f" {err}") +def print_config(config: ProcessingConfiguration): + """Print the processing configuration.""" + click.echo("\nProcessing Configuration:") + click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") + click.echo(f"Detection Threshold: {config.get('detection_threshold')}") + + if __name__ == "__main__": cli() diff --git a/batdetect2/detector/compute_features.py b/batdetect2/detector/compute_features.py index 368c2db..b53b0cb 100644 --- a/batdetect2/detector/compute_features.py +++ b/batdetect2/detector/compute_features.py @@ -1,22 +1,27 @@ +"""Functions to compute features from predictions.""" +from typing import Dict, Optional + import numpy as np +from batdetect2 import types +from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ + def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq): + """Convert spectrogram index to frequency in Hz.""" "" spec_ind = spec_height - spec_ind return round( (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 ) -def extract_spec_slices(spec, pred_nms, params): - """ - Extracts spectrogram slices from spectrogram based on detected call locations. - """ +def extract_spec_slices(spec, pred_nms): + """Extract spectrogram slices from spectrogram. + The slices are extracted based on detected call locations. + """ x_pos = pred_nms["x_pos"] - y_pos = pred_nms["y_pos"] bb_width = pred_nms["bb_width"] - bb_height = pred_nms["bb_height"] slices = [] # add 20% padding either side of call @@ -35,100 +40,273 @@ def extract_spec_slices(spec, pred_nms, params): return slices -def get_feature_names(): - feature_names = [ - "duration", - "low_freq_bb", - "high_freq_bb", - "bandwidth", - "max_power_bb", - "max_power", - "max_power_first", - "max_power_second", - "call_interval", - ] - return feature_names +def compute_duration( + prediction: types.Prediction, + **_, +) -> float: + """Compute duration of call in seconds.""" + return round(prediction["end_time"] - prediction["start_time"], 5) -def get_feats(spec, pred_nms, params): +def compute_low_freq( + prediction: types.Prediction, + **_, +) -> float: + """Compute lowest frequency in call in Hz.""" + return int(prediction["low_freq"]) + + +def compute_high_freq( + prediction: types.Prediction, + **_, +) -> float: + """Compute highest frequency in call in Hz.""" + return int(prediction["high_freq"]) + + +def compute_bandwidth( + prediction: types.Prediction, + **_, +) -> float: + """Compute bandwidth of call in Hz.""" + return int(prediction["high_freq"] - prediction["low_freq"]) + + +def compute_max_power_bb( + prediction: types.Prediction, + spec: Optional[np.ndarray] = None, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, + **_, +) -> float: + """Compute frequency with maximum power in call in Hz. + + This is the frequency with the maximum power in the bounding box of the + call. """ - Extracts features from spectrogram based on detected call locations. - Condsider re-extracting spectrogram for this to get better temporal resolution. + if spec is None: + return np.nan + + x_start = max(0, prediction["x_pos"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) + + # y low is the lowest freq but it will have a higher value due to array + # starting at 0 at top + y_low = min(spec.shape[0] - 1, prediction["y_pos"]) + y_high = max(0, prediction["y_pos"] - prediction["bb_height"]) + + spec_bb = spec[y_high:y_low, x_start:x_end] + power_per_freq_band = np.sum(spec_bb, axis=1) + + try: + max_power_ind = np.argmax(power_per_freq_band) + except ValueError: + # If the call is too short, the bounding box might be empty. + # In this case, return NaN. + return np.nan + + return int( + convert_int_to_freq( + y_high + max_power_ind, + spec.shape[0], + min_freq, + max_freq, + ) + ) + + +def compute_max_power( + prediction: types.Prediction, + spec: Optional[np.ndarray] = None, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, + **_, +) -> float: + """Compute frequency with maximum power in during the call in Hz.""" + if spec is None: + return np.nan + + x_start = max(0, prediction["x_pos"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) + spec_call = spec[:, x_start:x_end] + power_per_freq_band = np.sum(spec_call, axis=1) + max_power_ind = np.argmax(power_per_freq_band) + return int( + convert_int_to_freq( + max_power_ind, + spec.shape[0], + min_freq, + max_freq, + ) + ) + + +def compute_max_power_first( + prediction: types.Prediction, + spec: Optional[np.ndarray] = None, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, + **_, +) -> float: + """Compute frequency with maximum power in first half of call in Hz.""" + if spec is None: + return np.nan + + x_start = max(0, prediction["x_pos"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) + spec_call = spec[:, x_start:x_end] + first_half = spec_call[:, : int(spec_call.shape[1] / 2)] + power_per_freq_band = np.sum(first_half, axis=1) + max_power_ind = np.argmax(power_per_freq_band) + return int( + convert_int_to_freq( + max_power_ind, + spec.shape[0], + min_freq, + max_freq, + ) + ) + + +def compute_max_power_second( + prediction: types.Prediction, + spec: Optional[np.ndarray] = None, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, + **_, +) -> float: + """Compute frequency with maximum power in second half of call in Hz.""" + if spec is None: + return np.nan + + x_start = max(0, prediction["x_pos"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) + spec_call = spec[:, x_start:x_end] + second_half = spec_call[:, int(spec_call.shape[1] / 2) :] + power_per_freq_band = np.sum(second_half, axis=1) + max_power_ind = np.argmax(power_per_freq_band) + return int( + convert_int_to_freq( + max_power_ind, + spec.shape[0], + min_freq, + max_freq, + ) + ) + + +def compute_call_interval( + prediction: types.Prediction, + previous: Optional[types.Prediction] = None, + **_, +) -> float: + """Compute time between this call and the previous call in seconds.""" + if previous is None: + return np.nan + return round(prediction["start_time"] - previous["end_time"], 5) + + +# NOTE: The order of the features in this dictionary is important. The +# features are extracted in this order and the order of the columns in the +# output csv file is determined by this order. In order to avoid breaking +# changes in the output csv file, new features should be added to the end of +# this dictionary. +FEATURES: Dict[str, types.FeatureExtractor] = { + "duration": compute_duration, + "low_freq_bb": compute_low_freq, + "high_freq_bb": compute_high_freq, + "bandwidth": compute_bandwidth, + "max_power_bb": compute_max_power_bb, + "max_power": compute_max_power, + "max_power_first": compute_max_power_first, + "max_power_second": compute_max_power_second, + "call_interval": compute_call_interval, +} + + +def get_feats( + spec: np.ndarray, + pred_nms: types.PredictionResults, + params: types.FeatureExtractionParameters, +): + """Extract features from spectrogram based on detected call locations. + + The features extracted are: + + - duration: duration of call in seconds + - low_freq: lowest frequency in call in kHz + - high_freq: highest frequency in call in kHz + - bandwidth: high_freq - low_freq + - max_power_bb: frequency with maximum power in call in kHz + - max_power: frequency with maximum power in spectrogram in kHz + - max_power_first: frequency with maximum power in first half of call in + kHz. + - max_power_second: frequency with maximum power in second half of call in + kHz. + - call_interval: time between this call and the previous call in seconds + + Consider re-extracting spectrogram for this to get better temporal + resolution. For more possible features check out: https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt + + Parameters + ---------- + spec : np.ndarray + Spectrogram from which to extract features. + + pred_nms : types.PredictionResults + Information about detected calls from which to extract features. + + params : types.FeatureExtractionParameters + Parameters for feature extraction. + + Returns + ------- + features : np.ndarray + Extracted features for each detected call. Shape is + (num_detections, num_features). """ - - x_pos = pred_nms["x_pos"] - y_pos = pred_nms["y_pos"] - bb_width = pred_nms["bb_width"] - bb_height = pred_nms["bb_height"] - - feature_names = get_feature_names() num_detections = len(pred_nms["det_probs"]) - features = ( - np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1 - ) + features = np.empty((num_detections, len(FEATURES)), dtype=np.float32) + previous = None - for ff in range(num_detections): - x_start = int(np.maximum(0, x_pos[ff])) - x_end = int( - np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff])) - ) - # y low is the lowest freq but it will have a higher value due to array starting at 0 at top - y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff])) - y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff]))) - spec_slice = spec[:, x_start:x_end] + for row in range(num_detections): + prediction: types.Prediction = { + "det_prob": float(pred_nms["det_probs"][row]), + "class_prob": pred_nms["class_probs"][:, row], + "start_time": float(pred_nms["start_times"][row]), + "end_time": float(pred_nms["end_times"][row]), + "low_freq": float(pred_nms["low_freqs"][row]), + "high_freq": float(pred_nms["high_freqs"][row]), + "x_pos": int(pred_nms["x_pos"][row]), + "y_pos": int(pred_nms["y_pos"][row]), + "bb_width": int(pred_nms["bb_width"][row]), + "bb_height": int(pred_nms["bb_height"][row]), + } - if spec_slice.shape[1] > 1: - features[ff, 0] = round( - pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5 - ) - features[ff, 1] = int(pred_nms["low_freqs"][ff]) - features[ff, 2] = int(pred_nms["high_freqs"][ff]) - features[ff, 3] = int( - pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff] - ) - features[ff, 4] = int( - convert_int_to_freq( - y_high + spec_slice[y_high:y_low, :].sum(1).argmax(), - spec.shape[0], - params["min_freq"], - params["max_freq"], - ) - ) - features[ff, 5] = int( - convert_int_to_freq( - spec_slice.sum(1).argmax(), - spec.shape[0], - params["min_freq"], - params["max_freq"], - ) - ) - hlf_val = spec_slice.shape[1] // 2 - - features[ff, 6] = int( - convert_int_to_freq( - spec_slice[:, :hlf_val].sum(1).argmax(), - spec.shape[0], - params["min_freq"], - params["max_freq"], - ) - ) - features[ff, 7] = int( - convert_int_to_freq( - spec_slice[:, hlf_val:].sum(1).argmax(), - spec.shape[0], - params["min_freq"], - params["max_freq"], - ) + for col, feature in enumerate(FEATURES.values()): + features[row, col] = feature( + prediction, + previous=previous, + spec=spec, + **params, ) - if ff > 0: - features[ff, 8] = round( - pred_nms["start_times"][ff] - - pred_nms["start_times"][ff - 1], - 5, - ) + previous = prediction return features + + +def get_feature_names(): + """Get names of features in the order they are extracted.""" + return list(FEATURES.keys()) diff --git a/batdetect2/detector/post_process.py b/batdetect2/detector/post_process.py index a2ba353..b47eec6 100644 --- a/batdetect2/detector/post_process.py +++ b/batdetect2/detector/post_process.py @@ -19,13 +19,15 @@ def x_coords_to_time( ) -> float: """Convert x coordinates of spectrogram to time in seconds. - Args: + Parameters + ---------- x_pos: X position of the detection in pixels. sampling_rate: Sampling rate of the audio in Hz. fft_win_length: Length of the FFT window in seconds. fft_overlap: Overlap of the FFT windows in seconds. - Returns: + Returns + ------- Time in seconds. """ nfft = int(fft_win_length * sampling_rate) @@ -134,12 +136,12 @@ def run_nms( y_pos[num_detection, valid_inds], x_pos[num_detection, valid_inds], ].transpose(0, 1) - feat = feat.detach().numpy().astype(np.float32) + feat = feat.detach().cpu().numpy().astype(np.float32) feats.append(feat) # convert to numpy for key, value in pred.items(): - pred[key] = value.detach().numpy().astype(np.float32) + pred[key] = value.detach().cpu().numpy().astype(np.float32) preds.append(pred) # type: ignore diff --git a/batdetect2/plot.py b/batdetect2/plot.py index 1f1a343..e436dae 100644 --- a/batdetect2/plot.py +++ b/batdetect2/plot.py @@ -61,7 +61,7 @@ def spectrogram( """ # Convert to numpy array if needed if isinstance(spec, torch.Tensor): - spec = spec.numpy() + spec = spec.detach().cpu().numpy() # Remove batch and channel dimensions if present spec = spec.squeeze() @@ -265,7 +265,7 @@ def detection( # Add class label txt = " ".join([sp[:3] for sp in det["class"].split(" ")]) font_info = { - "color": "white", + "color": edgecolor, "size": 10, "weight": "bold", "alpha": rect.get_alpha(), diff --git a/batdetect2/train/train_model.py b/batdetect2/train/train_model.py index 759c2d7..e38de39 100644 --- a/batdetect2/train/train_model.py +++ b/batdetect2/train/train_model.py @@ -7,15 +7,14 @@ import numpy as np import torch from torch.optim.lr_scheduler import CosineAnnealingLR -from batdetect2.detector import models -from batdetect2.detector import parameters -from batdetect2.train import losses import batdetect2.detector.post_process as pp import batdetect2.train.audio_dataloader as adl import batdetect2.train.evaluate as evl import batdetect2.train.train_split as ts import batdetect2.train.train_utils as tu import batdetect2.utils.plot_utils as pu +from batdetect2.detector import models, parameters +from batdetect2.train import losses warnings.filterwarnings("ignore", category=UserWarning) @@ -84,7 +83,6 @@ def save_image( def loss_fun( outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq ): - # detection loss loss = params["det_loss_weight"] * det_criterion( outputs["pred_det"], gt_det @@ -108,7 +106,6 @@ def loss_fun( def train( model, epoch, data_loader, det_criterion, optimizer, scheduler, params ): - model.train() train_loss = tu.AverageMeter() @@ -119,7 +116,6 @@ def train( print("\nEpoch", epoch) for batch_idx, inputs in enumerate(data_loader): - data = inputs["spec"].to(params["device"]) gt_det = inputs["y_2d_det"].to(params["device"]) gt_size = inputs["y_2d_size"].to(params["device"]) @@ -172,7 +168,6 @@ def test(model, epoch, data_loader, det_criterion, params): with torch.no_grad(): for batch_idx, inputs in enumerate(data_loader): - data = inputs["spec"].to(params["device"]) gt_det = inputs["y_2d_det"].to(params["device"]) gt_size = inputs["y_2d_size"].to(params["device"]) @@ -279,7 +274,7 @@ def parse_gt_data(inputs): is_valid = inputs["is_valid"][ind] == 1 gt = {} for kk in keys: - gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32) + gt[kk] = inputs[kk][ind][is_valid].cpu().numpy().astype(np.float32) gt["duration"] = inputs["duration"][ind].item() gt["file_id"] = inputs["file_id"][ind].item() gt["class_id_file"] = inputs["class_id_file"][ind].item() @@ -318,8 +313,7 @@ def select_model(params): return model -if __name__ == "__main__": - +def main(): plt.close("all") params = parameters.get_params(True) @@ -501,7 +495,6 @@ if __name__ == "__main__": # # main train loop for epoch in range(0, params["num_epochs"] + 1): - train_loss = train( model, epoch, @@ -550,3 +543,7 @@ if __name__ == "__main__": # TODO: args variable does not exist # if not args["do_not_save_images"]: # save_images_batch(model, test_loader, params) + + +if __name__ == "__main__": + main() diff --git a/batdetect2/types.py b/batdetect2/types.py index 8a6437c..019564a 100644 --- a/batdetect2/types.py +++ b/batdetect2/types.py @@ -1,5 +1,5 @@ """Types used in the code base.""" -from typing import List, NamedTuple, Optional +from typing import List, NamedTuple, Optional, Union import numpy as np import torch @@ -25,10 +25,13 @@ except ImportError: __all__ = [ "Annotation", "DetectionModel", + "FeatureExtractionParameters", + "FeatureExtractor", "FileAnnotations", "ModelOutput", "ModelParameters", "NonMaximumSuppressionConfig", + "Prediction", "PredictionResults", "ProcessingConfiguration", "ResultParams", @@ -316,6 +319,40 @@ class ModelOutput(NamedTuple): """Tensor with intermediate features.""" +class Prediction(TypedDict): + """Singe prediction.""" + + det_prob: float + """Detection probability.""" + + x_pos: int + """X position of the detection in pixels.""" + + y_pos: int + """Y position of the detection in pixels.""" + + bb_width: int + """Width of the detection in pixels.""" + + bb_height: int + """Height of the detection in pixels.""" + + start_time: float + """Start time of the detection in seconds.""" + + end_time: float + """End time of the detection in seconds.""" + + low_freq: float + """Low frequency of the detection in Hz.""" + + high_freq: float + """High frequency of the detection in Hz.""" + + class_prob: np.ndarray + """Vector holding the probability of each class.""" + + class PredictionResults(TypedDict): """Results of the prediction. @@ -422,6 +459,16 @@ class NonMaximumSuppressionConfig(TypedDict): """Threshold for detection probability.""" +class FeatureExtractionParameters(TypedDict): + """Parameters that control the feature extraction function.""" + + min_freq: int + """Minimum frequency to consider in Hz.""" + + max_freq: int + """Maximum frequency to consider in Hz.""" + + class HeatmapParameters(TypedDict): """Parameters that control the heatmap generation function.""" diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index d6d2b13..8d6ca7f 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -2,6 +2,7 @@ import json import os from typing import Any, Iterator, List, Optional, Tuple, Union +import librosa import numpy as np import pandas as pd import torch @@ -66,7 +67,6 @@ def list_audio_files(ip_dir: str) -> List[str]: Raises: FileNotFoundError: Input directory not found. - """ matches = [] for root, _, filenames in os.walk(ip_dir): @@ -143,7 +143,19 @@ def load_model( def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): - predictions_m = {} + predictions_m = { + "det_probs": np.array([]), + "x_pos": np.array([]), + "y_pos": np.array([]), + "bb_widths": np.array([]), + "bb_heights": np.array([]), + "start_times": np.array([]), + "end_times": np.array([]), + "low_freqs": np.array([]), + "high_freqs": np.array([]), + "class_probs": np.array([]), + } + num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) if num_preds > 0: @@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): predictions_m[key] = np.hstack( [pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0] ) - else: - # hack in case where no detected calls as we need some of the key - # names in dict - predictions_m = predictions[0] if len(spec_feats) > 0: spec_feats = np.vstack(spec_feats) @@ -226,11 +234,19 @@ def format_single_result( Returns: dict: Results in the format expected by the annotation tool. """ - # Get a single class prediction for the file - class_overall = pp.overall_class_pred( - predictions["det_probs"], - predictions["class_probs"], - ) + try: + # Get a single class prediction for the file + class_overall = pp.overall_class_pred( + predictions["det_probs"], + predictions["class_probs"], + ) + class_name = class_names[np.argmax(class_overall)] + annotations = get_annotations_from_preds(predictions, class_names) + except (np.AxisError, ValueError): + # No detections + class_overall = np.zeros(len(class_names)) + class_name = "None" + annotations = [] return { "id": file_id, @@ -239,8 +255,8 @@ def format_single_result( "notes": "Automatically generated.", "time_exp": time_exp, "duration": round(float(duration), 4), - "annotation": get_annotations_from_preds(predictions, class_names), - "class_name": class_names[np.argmax(class_overall)], + "annotation": annotations, + "class_name": class_name, } @@ -253,6 +269,7 @@ def convert_results( spec_feats, cnn_feats, spec_slices, + nyquist_freq: Optional[float] = None, ) -> RunResults: """Convert results to dictionary as expected by the annotation tool. @@ -268,8 +285,8 @@ def convert_results( Returns: dict: Dictionary with results. - """ + pred_dict = format_single_result( file_id, time_exp, @@ -278,6 +295,14 @@ def convert_results( params["class_names"], ) + # Remove high frequency detections + if nyquist_freq is not None: + pred_dict["annotation"] = [ + pred + for pred in pred_dict["annotation"] + if pred["high_freq"] <= nyquist_freq + ] + # combine into final results dictionary results: RunResults = { "pred_dict": pred_dict, @@ -310,7 +335,6 @@ def save_results_to_file(results, op_path: str) -> None: Args: results (dict): Results. op_path (str): Output path. - """ # make directory if it does not exist if not os.path.isdir(os.path.dirname(op_path)): @@ -472,7 +496,6 @@ def iterate_over_chunks( chunk_start : float Start time of chunk in seconds. chunk : np.ndarray - """ nsamples = audio.shape[0] duration_full = nsamples / samplerate @@ -678,7 +701,6 @@ def process_audio_array( The array is of shape (num_detections, num_features). spec : torch.Tensor Spectrogram of the audio used as input. - """ pred_nms, features, spec = _process_audio_array( audio, @@ -730,6 +752,10 @@ def process_file( cnn_feats = [] spec_slices = [] + # Get original sampling rate + file_samp_rate = librosa.get_samplerate(audio_file) + orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1 + # load audio file sampling_rate, audio_full = au.load_audio( audio_file, @@ -757,7 +783,7 @@ def process_file( ) # convert to numpy - spec_np = spec.detach().cpu().numpy() + spec_np = spec.detach().cpu().numpy().squeeze() # add chunk time to start and end times pred_nms["start_times"] += chunk_time @@ -777,9 +803,7 @@ def process_file( if config["spec_slices"]: # FIX: This is not currently working. Returns empty slices - spec_slices.extend( - feats.extract_spec_slices(spec_np, pred_nms, config) - ) + spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms)) # Merge results from chunks predictions, spec_feats, cnn_feats, spec_slices = _merge_results( @@ -799,6 +823,7 @@ def process_file( spec_feats=spec_feats, cnn_feats=cnn_feats, spec_slices=spec_slices, + nyquist_freq=orig_samp_rate / 2, ) # summarize results diff --git a/pdm.lock b/pdm.lock index 68480c5..e625f95 100644 --- a/pdm.lock +++ b/pdm.lock @@ -453,7 +453,7 @@ summary = "Backport of pathlib-compatible object wrapper for zip files" [metadata] lock_version = "4.1" -content_hash = "sha256:2401b930c14b3b7e107372f0103cccebff74691b6bcd54148d832ce847df5673" +content_hash = "sha256:667d4d2891fb85565cb04d84d0970eaac799bf272e3c4d7e4e6fea0b33c241fb" [metadata.files] "appdirs 1.4.4" = [ diff --git a/pyproject.toml b/pyproject.toml index 2a60909..42cdebe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dev = [ [project] name = "batdetect2" -version = "1.0.0" +version = "1.0.7" description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings." authors = [ { "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" }, @@ -19,7 +19,7 @@ dependencies = [ "pandas", "scikit-learn", "scipy", - "torch<2", + "torch>=1.13.1,<2", "torchaudio", "torchvision", "click", @@ -53,10 +53,10 @@ requires = ["pdm-pep517>=1.0.0"] build-backend = "pdm.pep517.api" [project.scripts] -batdetect2 = "bat_detect.cli:cli" +batdetect2 = "batdetect2.cli:cli" [tool.black] -line-length = 80 +line-length = 79 [[tool.mypy.overrides]] module = [ diff --git a/tests/data/20230322_172000_selec2.wav b/tests/data/20230322_172000_selec2.wav new file mode 100644 index 0000000..2a2c1c5 Binary files /dev/null and b/tests/data/20230322_172000_selec2.wav differ diff --git a/tests/test_api.py b/tests/test_api.py index 942a1f1..d28c733 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,11 +1,14 @@ """Test bat detect module API.""" +from pathlib import Path + import os from glob import glob import numpy as np import torch from torch import nn +import soundfile as sf from batdetect2 import api @@ -262,3 +265,20 @@ def test_process_file_with_spec_slices(): assert "spec_slices" in results assert isinstance(results["spec_slices"], list) assert len(results["spec_slices"]) == len(detections) + + + +def test_process_file_with_empty_predictions_does_not_fail( + tmp_path: Path, +): + """Test process file with empty predictions does not fail.""" + # Create empty file + empty_file = tmp_path / "empty.wav" + empty_wav = np.zeros((0, 1), dtype=np.float32) + sf.write(empty_file, empty_wav, 256000) + + # Process file + results = api.process_file(str(empty_file)) + + assert results is not None + assert len(results["pred_dict"]["annotation"]) == 0 diff --git a/tests/test_cli.py b/tests/test_cli.py index ffad17e..4038533 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,7 @@ """Test the command line interface.""" +from pathlib import Path from click.testing import CliRunner +import pandas as pd from batdetect2.cli import cli @@ -42,3 +44,67 @@ def test_cli_detect_command_on_test_audio(tmp_path): assert results_dir.exists() assert len(list(results_dir.glob("*.csv"))) == 3 assert len(list(results_dir.glob("*.json"))) == 3 + + +def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path): + """Test the detect command with a non-trivial time expansion factor.""" + results_dir = tmp_path / "results" + + # Remove results dir if it exists + if results_dir.exists(): + results_dir.rmdir() + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "detect", + "example_data/audio", + str(results_dir), + "0.3", + "--time_expansion_factor", + "10", + ], + ) + + assert result.exit_code == 0 + assert 'Time Expansion Factor: 10' in result.stdout + + + +def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path): + """Test the detect command with the spec feature flag.""" + results_dir = tmp_path / "results" + + # Remove results dir if it exists + if results_dir.exists(): + results_dir.rmdir() + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "detect", + "example_data/audio", + str(results_dir), + "0.3", + "--spec_features", + ], + ) + assert result.exit_code == 0 + assert results_dir.exists() + + + csv_files = [path.name for path in results_dir.glob("*.csv")] + + expected_files = [ + "20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv", + "20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv", + "20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv" + ] + + for expected_file in expected_files: + assert expected_file in csv_files + + df = pd.read_csv(results_dir / expected_file) + assert not (df.duration == -1).any() diff --git a/tests/test_detections.py b/tests/test_detections.py new file mode 100644 index 0000000..7f17012 --- /dev/null +++ b/tests/test_detections.py @@ -0,0 +1,23 @@ +"""Test suite to ensure that model detections are not incorrect.""" + +import os + +from batdetect2 import api + +DATA_DIR = os.path.join(os.path.dirname(__file__), "data") + + +def test_no_detections_above_nyquist(): + """Test that no detections are made above the nyquist frequency.""" + # Recording donated by @@kdarras + path = os.path.join(DATA_DIR, "20230322_172000_selec2.wav") + + # This recording has a sampling rate of 192 kHz + nyquist = 192_000 / 2 + + output = api.process_file(path) + predictions = output["pred_dict"] + assert len(predictions["annotation"]) != 0 + assert all( + pred["high_freq"] < nyquist for pred in predictions["annotation"] + ) diff --git a/tests/test_features.py b/tests/test_features.py new file mode 100644 index 0000000..1271fda --- /dev/null +++ b/tests/test_features.py @@ -0,0 +1,291 @@ +"""Test suite for feature extraction functions.""" + +import logging + +import librosa +import numpy as np +import pytest + +import batdetect2.detector.compute_features as feats +from batdetect2 import api, types +from batdetect2.utils import audio_utils as au + +numba_logger = logging.getLogger("numba") +numba_logger.setLevel(logging.WARNING) + + +def index_to_freq( + index: int, + spec_height: int, + min_freq: int, + max_freq: int, +) -> float: + """Convert spectrogram index to frequency in Hz.""" + index = spec_height - index + return round( + (index / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 + ) + + +def index_to_time( + index: int, + spec_width: int, + spec_duration: float, +) -> float: + """Convert spectrogram index to time in seconds.""" + return round((index / float(spec_width)) * spec_duration, 2) + + +def test_get_feats_function_with_empty_spectrogram(): + """Test get_feats function with empty spectrogram. + + This tests that the overall flow of the function works, even if the + spectrogram is empty. + """ + spec_duration = 3 + spec_width = 100 + spec_height = 100 + min_freq = 10_000 + max_freq = 120_000 + spectrogram = np.zeros((spec_height, spec_width)) + + x_pos = 20 + y_pos = 80 + bb_width = 20 + bb_height = 20 + + start_time = index_to_time(x_pos, spec_width, spec_duration) + end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration) + low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq) + high_freq = index_to_freq( + y_pos - bb_height, spec_height, min_freq, max_freq + ) + + pred_nms: types.PredictionResults = { + "det_probs": np.array([1]), + "class_probs": np.array([[1]]), + "x_pos": np.array([x_pos]), + "y_pos": np.array([y_pos]), + "bb_width": np.array([bb_width]), + "bb_height": np.array([bb_height]), + "start_times": np.array([start_time]), + "end_times": np.array([end_time]), + "low_freqs": np.array([low_freq]), + "high_freqs": np.array([high_freq]), + } + + params: types.FeatureExtractionParameters = { + "min_freq": min_freq, + "max_freq": max_freq, + } + + features = feats.get_feats(spectrogram, pred_nms, params) + assert low_freq < high_freq + assert isinstance(features, np.ndarray) + assert features.shape == (len(pred_nms["det_probs"]), 9) + assert np.isclose( + features[0], + np.array( + [ + end_time - start_time, + low_freq, + high_freq, + high_freq - low_freq, + high_freq, + max_freq, + max_freq, + max_freq, + np.nan, + ] + ), + equal_nan=True, + ).all() + + +@pytest.mark.parametrize( + "max_power", + [ + 30_000, + 31_000, + 32_000, + 33_000, + 34_000, + 35_000, + 36_000, + 37_000, + 38_000, + 39_000, + 40_000, + ], +) +def test_compute_max_power_bb(max_power: int): + """Test compute_max_power_bb function.""" + duration = 1 + samplerate = 256_000 + min_freq = 0 + max_freq = 128_000 + + start_time = 0.3 + end_time = 0.6 + low_freq = 30_000 + high_freq = 40_000 + + audio = np.zeros((int(duration * samplerate),)) + + # Add a signal during the time and frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] = 0.5 * librosa.tone( + max_power, sr=samplerate, duration=end_time - start_time + ) + + # Add a more powerful signal outside frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] += 2 * librosa.tone( + 80_000, sr=samplerate, duration=end_time - start_time + ) + + params = api.get_config( + min_freq=min_freq, + max_freq=max_freq, + target_samp_rate=samplerate, + ) + + spec, _ = au.generate_spectrogram( + audio, + samplerate, + params, + ) + + x_start = int( + au.time_to_x_coords( + start_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + x_end = int( + au.time_to_x_coords( + end_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + num_freq_bins = spec.shape[0] + y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq) + y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq) + + prediction: types.Prediction = { + "det_prob": 1, + "class_prob": np.ones((1,)), + "x_pos": x_start, + "y_pos": int(y_low), + "bb_width": int(x_end - x_start), + "bb_height": int(y_low - y_high), + "start_time": start_time, + "end_time": end_time, + "low_freq": low_freq, + "high_freq": high_freq, + } + + print(prediction) + + max_power_bb = feats.compute_max_power_bb( + prediction, + spec, + min_freq=min_freq, + max_freq=max_freq, + ) + + assert abs(max_power_bb - max_power) <= 500 + + +def test_compute_max_power(): + """Test compute_max_power_bb function.""" + duration = 3 + samplerate = 16_000 + min_freq = 0 + max_freq = 8_000 + + start_time = 1 + end_time = 2 + low_freq = 3_000 + high_freq = 4_000 + max_power = 5_000 + + audio = np.zeros((int(duration * samplerate),)) + + # Add a signal during the time and frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] = 0.5 * librosa.tone( + 3_500, sr=samplerate, duration=end_time - start_time + ) + + # Add a more powerful signal outside frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] += 2 * librosa.tone( + max_power, sr=samplerate, duration=end_time - start_time + ) + + params = api.get_config( + min_freq=min_freq, + max_freq=max_freq, + target_samp_rate=samplerate, + ) + + spec, _ = au.generate_spectrogram( + audio, + samplerate, + params, + ) + + x_start = int( + au.time_to_x_coords( + start_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + x_end = int( + au.time_to_x_coords( + end_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + num_freq_bins = spec.shape[0] + y_low = int(num_freq_bins * low_freq / max_freq) + y_high = int(num_freq_bins * high_freq / max_freq) + + prediction: types.Prediction = { + "det_prob": 1, + "class_prob": np.ones((1,)), + "x_pos": x_start, + "y_pos": int(y_high), + "bb_width": int(x_end - x_start), + "bb_height": int(y_high - y_low), + "start_time": start_time, + "end_time": end_time, + "low_freq": low_freq, + "high_freq": high_freq, + } + + computed_max_power = feats.compute_max_power( + prediction, + spec, + min_freq=min_freq, + max_freq=max_freq, + ) + + assert abs(computed_max_power - max_power) < 100