diff --git a/bat_detect/api.py b/bat_detect/api.py index b09d3b4..bf44670 100644 --- a/bat_detect/api.py +++ b/bat_detect/api.py @@ -38,7 +38,7 @@ def get_config(**kwargs) -> ProcessingConfiguration: Can be used to override default parameters by passing keyword arguments. """ - return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} + return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore def load_audio( diff --git a/bat_detect/cli.py b/bat_detect/cli.py index 9293745..c9c34db 100644 --- a/bat_detect/cli.py +++ b/bat_detect/cli.py @@ -1,136 +1,141 @@ -"""Main script for running BatDetect2 on audio files. - -Example usage: - python command.py /path/to/audio/ /path/to/ann/ 0.1 - -""" -import argparse +"""BatDetect2 command line interface.""" import os +import warnings -from bat_detect import api +warnings.filterwarnings("ignore", category=UserWarning) + +import click # noqa: E402 + +from bat_detect import api # noqa: E402 +from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402 +from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402 CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -def parse_args(): - info_str = ( - "\nBatDetect2 - Detection and Classification\n" - + " Assumes audio files are mono, not stereo.\n" - + ' Spaces in the input paths will throw an error. Wrap in quotes "".\n' - + " Input files should be short in duration e.g. < 30 seconds.\n" +INFO_STR = """ +BatDetect2 - Detection and Classification + Assumes audio files are mono, not stereo. + Spaces in the input paths will throw an error. Wrap in quotes. + Input files should be short in duration e.g. < 30 seconds. +""" + + +@click.group() +def cli(): + """BatDetect2 - Bat Call Detection and Classification.""" + click.echo(INFO_STR) + + +@cli.command() +@click.argument( + "audio_dir", + type=click.Path(exists=True), +) +@click.argument( + "ann_dir", + type=click.Path(exists=False), +) +@click.argument( + "detection_threshold", + type=float, +) +@click.option( + "--cnn_features", + is_flag=True, + default=False, + help="Extracts CNN call features", +) +@click.option( + "--spec_features", + is_flag=True, + default=False, + help="Extracts low level call features", +) +@click.option( + "--time_expansion_factor", + type=int, + default=1, + help="The time expansion factor used for all files (default is 1)", +) +@click.option( + "--quiet", + is_flag=True, + default=False, + help="Minimize output printing", +) +@click.option( + "--save_preds_if_empty", + is_flag=True, + default=False, + help="Save empty annotation file if no detections made.", +) +@click.option( + "--model_path", + type=str, + default=DEFAULT_MODEL_PATH, + help="Path to trained BatDetect2 model", +) +def detect( + audio_dir: str, + ann_dir: str, + detection_threshold: float, + **args, +): + """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR. + + DETECTION_THRESHOLD is the detection threshold. All predictions with a + score below this threshold will be discarded. Values between 0 and 1. + + Assumes audio files are mono, not stereo. + + Spaces in the input paths will throw an error. Wrap in quotes. + + Input files should be short in duration e.g. < 30 seconds. + """ + click.echo(f"Loading model: {args['model_path']}") + model, params = api.load_model(args["model_path"]) + + click.echo(f"\nInput directory: {audio_dir}") + files = api.list_audio_files(audio_dir) + + click.echo(f"Number of audio files: {len(files)}") + click.echo(f"\nSaving results to: {ann_dir}") + + config = api.get_config( + **{ + **params, + **args, + "spec_slices": False, + "chunk_size": 2, + "detection_threshold": detection_threshold, + } ) - print(info_str) - parser = argparse.ArgumentParser() - parser.add_argument("audio_dir", type=str, help="Input directory for audio") - parser.add_argument( - "ann_dir", - type=str, - help="Output directory for where the predictions will be stored", - ) - parser.add_argument( - "detection_threshold", - type=float, - help="Cut-off probability for detector e.g. 0.1", - ) - parser.add_argument( - "--cnn_features", - action="store_true", - default=False, - dest="cnn_features", - help="Extracts CNN call features", - ) - parser.add_argument( - "--spec_features", - action="store_true", - default=False, - dest="spec_features", - help="Extracts low level call features", - ) - parser.add_argument( - "--time_expansion_factor", - type=int, - default=1, - dest="time_expansion_factor", - help="The time expansion factor used for all files (default is 1)", - ) - parser.add_argument( - "--quiet", - action="store_true", - default=False, - dest="quiet", - help="Minimize output printing", - ) - parser.add_argument( - "--save_preds_if_empty", - action="store_true", - default=False, - dest="save_preds_if_empty", - help="Save empty annotation file if no detections made.", - ) - parser.add_argument( - "--model_path", - type=str, - default=du.DEFAULT_MODEL_PATH, - help="Path to trained BatDetect2 model", - ) - args = vars(parser.parse_args()) - - args["spec_slices"] = False # used for visualization - # if files greater than this amount (seconds) they will be broken down into small chunks - args["chunk_size"] = 2 - args["ann_dir"] = os.path.join(args["ann_dir"], "") - args["quiet"] = True - - return args - - -def main(): - args = parse_args() - - print("Loading model: " + args["model_path"]) - model, params = du.load_model(args["model_path"]) - - print("\nInput directory: " + args["audio_dir"]) - files = du.list_audio_files(args["audio_dir"]) - - print(f"Number of audio files: {len(files)}") - print("\nSaving results to: " + args["ann_dir"]) - - default_config = du.get_default_config() - - # set up run config - run_config = { - **default_config, - **args, - **params, - } - # process files error_files = [] for audio_file in files: try: - results = du.process_file(audio_file, model, run_config) + results = api.process_file(audio_file, model, config=config) if args["save_preds_if_empty"] or ( len(results["pred_dict"]["annotation"]) > 0 ): - results_path = audio_file.replace( - args["audio_dir"], args["ann_dir"] - ) - du.save_results_to_file(results, results_path) + results_path = audio_file.replace(audio_dir, ann_dir) + save_results_to_file(results, results_path) except (RuntimeError, ValueError, LookupError) as err: + # TODO: Check what other errors can be thrown error_files.append(audio_file) - print(f"Error processing file!: {err}") + click.echo(f"Error processing file!: {err}") raise err - print("\nResults saved to: " + args["ann_dir"]) + click.echo(f"\nResults saved to: {ann_dir}") if len(error_files) > 0: - print("\nUnable to process the follow files:") + click.echo("\nUnable to process the follow files:") for err in error_files: - print(" " + err) + click.echo(f" {err}") if __name__ == "__main__": - main() + cli() diff --git a/bat_detect/detector/post_process.py b/bat_detect/detector/post_process.py index 763aeca..ae56f80 100644 --- a/bat_detect/detector/post_process.py +++ b/bat_detect/detector/post_process.py @@ -1,16 +1,12 @@ """Post-processing of the output of the model.""" -from typing import List, Optional, Tuple +from typing import List, Tuple, Union import numpy as np import torch from torch import nn from bat_detect.detector.models import ModelOutput - -try: - from typing import TypedDict -except ImportError: - from typing_extensions import TypedDict +from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults np.seterr(divide="ignore", invalid="ignore") @@ -42,72 +38,6 @@ def overall_class_pred(det_prob, class_prob): return weighted_pred / weighted_pred.sum() -class NonMaximumSuppressionConfig(TypedDict): - """Configuration for non-maximum suppression.""" - - nms_kernel_size: int - """Size of the kernel for non-maximum suppression.""" - - max_freq: int - """Maximum frequency to consider in Hz.""" - - min_freq: int - """Minimum frequency to consider in Hz.""" - - fft_win_length: float - """Length of the FFT window in seconds.""" - - fft_overlap: float - """Overlap of the FFT windows in seconds.""" - - resize_factor: float - """Factor by which the input was resized.""" - - nms_top_k_per_sec: float - """Number of top detections to keep per second.""" - - detection_threshold: float - """Threshold for detection probability.""" - - -class PredictionResults(TypedDict): - """Results of the prediction. - - Each key is a list of length `num_detections` containing the - corresponding values for each detection. - """ - - det_probs: np.ndarray - """Detection probabilities.""" - - x_pos: np.ndarray - """X position of the detection in pixels.""" - - y_pos: np.ndarray - """Y position of the detection in pixels.""" - - bb_width: np.ndarray - """Width of the detection in pixels.""" - - bb_height: np.ndarray - """Height of the detection in pixels.""" - - start_times: np.ndarray - """Start times of the detections in seconds.""" - - end_times: np.ndarray - """End times of the detections in seconds.""" - - low_freqs: np.ndarray - """Low frequencies of the detections in Hz.""" - - high_freqs: np.ndarray - """High frequencies of the detections in Hz.""" - - class_probs: Optional[np.ndarray] - """Class probabilities.""" - - def run_nms( outputs: ModelOutput, params: NonMaximumSuppressionConfig, @@ -128,8 +58,8 @@ def run_nms( -2 ] - # NOTE: there will be small differences depending on which sampling rate is chosen - # as we are choosing the same sampling rate for the entire batch + # NOTE: there will be small differences depending on which sampling rate + # is chosen as we are choosing the same sampling rate for the entire batch duration = x_coords_to_time( pred_det.shape[-1], int(sampling_rate[0].item()), @@ -211,16 +141,21 @@ def run_nms( for key, value in pred.items(): pred[key] = value.cpu().numpy().astype(np.float32) - preds.append(pred) + preds.append(pred) # type: ignore return preds, feats -def non_max_suppression(heat, kernel_size): +def non_max_suppression( + heat: torch.Tensor, + kernel_size: Union[int, Tuple[int, int]], +): # kernel can be an int or list/tuple if isinstance(kernel_size, int): kernel_size_h = kernel_size kernel_size_w = kernel_size + else: + kernel_size_h, kernel_size_w = kernel_size pad_h = (kernel_size_h - 1) // 2 pad_w = (kernel_size_w - 1) // 2 diff --git a/bat_detect/types.py b/bat_detect/types.py index e961a28..c4b6297 100644 --- a/bat_detect/types.py +++ b/bat_detect/types.py @@ -1,5 +1,5 @@ """Types used in the code base.""" -from typing import List, Optional, NamedTuple +from typing import List, NamedTuple, Optional import numpy as np import torch @@ -16,6 +16,27 @@ except ImportError: from typing_extensions import Protocol +try: + from typing import NotRequired +except ImportError: + from typing_extensions import NotRequired + + +__all__ = [ + "Annotation", + "DetectionModel", + "FileAnnotations", + "ModelOutput", + "ModelParameters", + "NonMaximumSuppressionConfig", + "PredictionResults", + "ProcessingConfiguration", + "ResultParams", + "RunResults", + "SpectrogramParameters", +] + + class SpectrogramParameters(TypedDict): """Parameters for generating spectrograms.""" @@ -144,19 +165,19 @@ class RunResults(TypedDict): pred_dict: FileAnnotations """Predictions in the format expected by the annotation tool.""" - spec_feats: Optional[List[np.ndarray]] + spec_feats: NotRequired[List[np.ndarray]] """Spectrogram features.""" - spec_feat_names: Optional[List[str]] + spec_feat_names: NotRequired[List[str]] """Spectrogram feature names.""" - cnn_feats: Optional[List[np.ndarray]] + cnn_feats: NotRequired[List[np.ndarray]] """CNN features.""" - cnn_feat_names: Optional[List[str]] + cnn_feat_names: NotRequired[List[str]] """CNN feature names.""" - spec_slices: Optional[List[np.ndarray]] + spec_slices: NotRequired[List[np.ndarray]] """Spectrogram slices.""" @@ -166,6 +187,15 @@ class ResultParams(TypedDict): class_names: List[str] """Class names.""" + spec_features: bool + """Whether to return spectrogram features.""" + + cnn_features: bool + """Whether to return CNN features.""" + + spec_slices: bool + """Whether to return spectrogram slices.""" + class ProcessingConfiguration(TypedDict): """Parameters for processing audio files.""" @@ -266,6 +296,44 @@ class ModelOutput(NamedTuple): """Tensor with intermediate features.""" +class PredictionResults(TypedDict): + """Results of the prediction. + + Each key is a list of length `num_detections` containing the + corresponding values for each detection. + """ + + det_probs: np.ndarray + """Detection probabilities.""" + + x_pos: np.ndarray + """X position of the detection in pixels.""" + + y_pos: np.ndarray + """Y position of the detection in pixels.""" + + bb_width: np.ndarray + """Width of the detection in pixels.""" + + bb_height: np.ndarray + """Height of the detection in pixels.""" + + start_times: np.ndarray + """Start times of the detections in seconds.""" + + end_times: np.ndarray + """End times of the detections in seconds.""" + + low_freqs: np.ndarray + """Low frequencies of the detections in Hz.""" + + high_freqs: np.ndarray + """High frequencies of the detections in Hz.""" + + class_probs: Optional[np.ndarray] + """Class probabilities.""" + + class DetectionModel(Protocol): """Protocol for detection models. @@ -286,19 +354,49 @@ class DetectionModel(Protocol): resize_factor: float """Factor by which the input is resized.""" - ip_height: int + ip_height_rs: int """Height of the input image.""" def forward( self, - spec: torch.Tensor, + ip: torch.Tensor, return_feats: bool = False, ) -> ModelOutput: """Forward pass of the model.""" + ... def __call__( self, - spec: torch.Tensor, + ip: torch.Tensor, return_feats: bool = False, ) -> ModelOutput: """Forward pass of the model.""" + ... + + +class NonMaximumSuppressionConfig(TypedDict): + """Configuration for non-maximum suppression.""" + + nms_kernel_size: int + """Size of the kernel for non-maximum suppression.""" + + max_freq: int + """Maximum frequency to consider in Hz.""" + + min_freq: int + """Minimum frequency to consider in Hz.""" + + fft_win_length: float + """Length of the FFT window in seconds.""" + + fft_overlap: float + """Overlap of the FFT windows in seconds.""" + + resize_factor: float + """Factor by which the input was resized.""" + + nms_top_k_per_sec: float + """Number of top detections to keep per second.""" + + detection_threshold: float + """Threshold for detection probability.""" diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index 3ebd69c..4194ddf 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -14,13 +14,14 @@ from bat_detect.detector import models from bat_detect.detector.parameters import DEFAULT_MODEL_PATH from bat_detect.types import ( Annotation, + DetectionModel, FileAnnotations, ModelParameters, + PredictionResults, ProcessingConfiguration, - SpectrogramParameters, ResultParams, RunResults, - DetectionModel + SpectrogramParameters, ) __all__ = [ @@ -73,10 +74,8 @@ def load_model( Raises: FileNotFoundError: Model file not found. - ValueError: Unknown model. + ValueError: Unknown model name. """ - - # load model if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -135,7 +134,8 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): [pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0] ) else: - # hack in case where no detected calls as we need some of the key names in dict + # hack in case where no detected calls as we need some of the key + # names in dict predictions_m = predictions[0] if len(spec_feats) > 0: @@ -263,27 +263,22 @@ def convert_results( # combine into final results dictionary results: RunResults = { "pred_dict": pred_dict, - "spec_feats": None, - "spec_feat_names": None, - "cnn_feats": None, - "cnn_feat_names": None, - "spec_slices": None, } # add spectrogram features if they exist - if len(spec_feats) > 0: + if len(spec_feats) > 0 and params["spec_features"]: results["spec_feats"] = spec_feats results["spec_feat_names"] = feats.get_feature_names() # add CNN features if they exist - if len(cnn_feats) > 0: + if len(cnn_feats) > 0 and params["cnn_features"]: results["cnn_feats"] = cnn_feats results["cnn_feat_names"] = [ str(ii) for ii in range(cnn_feats.shape[1]) ] # add spectrogram slices if they exist - if len(spec_slices) > 0: + if len(spec_slices) > 0 and params["spec_slices"]: results["spec_slices"] = spec_slices return results @@ -292,6 +287,8 @@ def convert_results( def save_results_to_file(results, op_path: str) -> None: """Save results to file. + Will create the output directory if it does not exist. + Args: results (dict): Results. op_path (str): Output path. @@ -476,7 +473,7 @@ def _process_spectrogram( samplerate: int, model: DetectionModel, config: ProcessingConfiguration, -) -> Tuple[List[Annotation], List[np.ndarray]]: +) -> Tuple[PredictionResults, List[np.ndarray]]: # evaluate model with torch.no_grad(): outputs = model(spec) @@ -493,7 +490,6 @@ def _process_spectrogram( "resize_factor": config["resize_factor"], "nms_top_k_per_sec": config["nms_top_k_per_sec"], "detection_threshold": config["detection_threshold"], - "max_scale_spec": config["max_scale_spec"], }, np.array([float(samplerate)]), ) @@ -561,7 +557,7 @@ def _process_audio_array( model: DetectionModel, config: ProcessingConfiguration, device: torch.device, -) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: +) -> Tuple[PredictionResults, List[np.ndarray], torch.Tensor]: # load audio file and compute spectrogram _, spec, _ = compute_spectrogram( audio, @@ -712,17 +708,19 @@ def process_file( predictions.append(pred_nms) # extract features - if there are any calls detected - if pred_nms["det_probs"].shape[0] > 0: - if config["spec_features"]: - spec_feats.append(feats.get_feats(spec_np, pred_nms, config)) + if pred_nms["det_probs"].shape[0] == 0: + continue - if config["cnn_features"]: - cnn_feats.append(features[0]) + if config["spec_features"]: + spec_feats.append(feats.get_feats(spec_np, pred_nms, config)) - if config["spec_slices"]: - spec_slices.extend( - feats.extract_spec_slices(spec_np, pred_nms, config) - ) + if config["cnn_features"]: + cnn_feats.append(features[0]) + + if config["spec_slices"]: + spec_slices.extend( + feats.extract_spec_slices(spec_np, pred_nms, config) + ) # Merge results from chunks predictions, spec_feats, cnn_feats, spec_slices = _merge_results( diff --git a/pyproject.toml b/pyproject.toml index 3be8e8c..d560fc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "torch", "torchaudio", "torchvision", + "click", ] requires-python = ">=3.8" readme = "README.md" @@ -56,3 +57,8 @@ module = [ "pandas", ] ignore_missing_imports = true + +[tool.pylsp-mypy] +enabled = false +live_mode = true +strict = true diff --git a/requirements.txt b/requirements.txt index 5bb8e16..cac4479 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ scipy==1.9.3 torch==1.13.0 torchaudio==0.13.0 torchvision==0.14.0 +click diff --git a/tests/test_api.py b/tests/test_api.py index 902c2df..1ee3231 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -120,18 +120,13 @@ def test_process_file_with_model(): assert isinstance(predictions, dict) assert "pred_dict" in predictions - assert "spec_feats" in predictions - assert "spec_feat_names" in predictions - assert "cnn_feats" in predictions - assert "cnn_feat_names" in predictions - assert "spec_slices" in predictions - # By default will not return spectrogram features - assert predictions["spec_feats"] is None - assert predictions["spec_feat_names"] is None - assert predictions["cnn_feats"] is None - assert predictions["cnn_feat_names"] is None - assert predictions["spec_slices"] is None + # By default will not return other features + assert "spec_feats" not in predictions + assert "spec_feat_names" not in predictions + assert "cnn_feats" not in predictions + assert "cnn_feat_names" not in predictions + assert "spec_slices" not in predictions # Check that predictions are returned assert isinstance(predictions["pred_dict"], dict) diff --git a/tests/test_cli.py b/tests/test_cli.py index e69de29..4570cf5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -0,0 +1,41 @@ +"""Test the command line interface.""" +from click.testing import CliRunner + +from bat_detect.cli import cli + + +def test_cli_base_command(): + runner = CliRunner() + result = runner.invoke(cli, ["--help"]) + assert result.exit_code == 0 + assert "BatDetect2 - Bat Call Detection and Classification" in result.output + + +def test_cli_detect_command_help(): + runner = CliRunner() + result = runner.invoke(cli, ["detect", "--help"]) + assert result.exit_code == 0 + assert "Detect bat calls in files in AUDIO_DIR" in result.output + + +def test_cli_detect_command_on_test_audio(tmp_path): + results_dir = tmp_path / "results" + + # Remove results dir if it exists + if results_dir.exists(): + results_dir.rmdir() + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "detect", + "example_data/audio", + str(results_dir), + "0.3", + ], + ) + assert result.exit_code == 0 + assert results_dir.exists() + assert len(list(results_dir.glob("*.csv"))) == 3 + assert len(list(results_dir.glob("*.json"))) == 3