Added click to dependencies and made cli tests

This commit is contained in:
Santiago Martinez 2023-02-26 18:40:35 +00:00
parent 6f2bb605d3
commit a2deab9f3f
9 changed files with 312 additions and 233 deletions

View File

@ -38,7 +38,7 @@ def get_config(**kwargs) -> ProcessingConfiguration:
Can be used to override default parameters by passing keyword arguments.
"""
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
def load_audio(

View File

@ -1,136 +1,141 @@
"""Main script for running BatDetect2 on audio files.
Example usage:
python command.py /path/to/audio/ /path/to/ann/ 0.1
"""
import argparse
"""BatDetect2 command line interface."""
import os
import warnings
from bat_detect import api
warnings.filterwarnings("ignore", category=UserWarning)
import click # noqa: E402
from bat_detect import api # noqa: E402
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402
from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
def parse_args():
info_str = (
"\nBatDetect2 - Detection and Classification\n"
+ " Assumes audio files are mono, not stereo.\n"
+ ' Spaces in the input paths will throw an error. Wrap in quotes "".\n'
+ " Input files should be short in duration e.g. < 30 seconds.\n"
)
INFO_STR = """
BatDetect2 - Detection and Classification
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
"""
print(info_str)
parser = argparse.ArgumentParser()
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
parser.add_argument(
"ann_dir",
type=str,
help="Output directory for where the predictions will be stored",
@click.group()
def cli():
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
@cli.command()
@click.argument(
"audio_dir",
type=click.Path(exists=True),
)
parser.add_argument(
@click.argument(
"ann_dir",
type=click.Path(exists=False),
)
@click.argument(
"detection_threshold",
type=float,
help="Cut-off probability for detector e.g. 0.1",
)
parser.add_argument(
@click.option(
"--cnn_features",
action="store_true",
is_flag=True,
default=False,
dest="cnn_features",
help="Extracts CNN call features",
)
parser.add_argument(
@click.option(
"--spec_features",
action="store_true",
is_flag=True,
default=False,
dest="spec_features",
help="Extracts low level call features",
)
parser.add_argument(
@click.option(
"--time_expansion_factor",
type=int,
default=1,
dest="time_expansion_factor",
help="The time expansion factor used for all files (default is 1)",
)
parser.add_argument(
@click.option(
"--quiet",
action="store_true",
is_flag=True,
default=False,
dest="quiet",
help="Minimize output printing",
)
parser.add_argument(
@click.option(
"--save_preds_if_empty",
action="store_true",
is_flag=True,
default=False,
dest="save_preds_if_empty",
help="Save empty annotation file if no detections made.",
)
parser.add_argument(
@click.option(
"--model_path",
type=str,
default=du.DEFAULT_MODEL_PATH,
default=DEFAULT_MODEL_PATH,
help="Path to trained BatDetect2 model",
)
args = vars(parser.parse_args())
args["spec_slices"] = False # used for visualization
# if files greater than this amount (seconds) they will be broken down into small chunks
args["chunk_size"] = 2
args["ann_dir"] = os.path.join(args["ann_dir"], "")
args["quiet"] = True
return args
def main():
args = parse_args()
print("Loading model: " + args["model_path"])
model, params = du.load_model(args["model_path"])
print("\nInput directory: " + args["audio_dir"])
files = du.list_audio_files(args["audio_dir"])
print(f"Number of audio files: {len(files)}")
print("\nSaving results to: " + args["ann_dir"])
default_config = du.get_default_config()
# set up run config
run_config = {
**default_config,
def detect(
audio_dir: str,
ann_dir: str,
detection_threshold: float,
**args,
):
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
DETECTION_THRESHOLD is the detection threshold. All predictions with a
score below this threshold will be discarded. Values between 0 and 1.
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
"""
click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"])
click.echo(f"\nInput directory: {audio_dir}")
files = api.list_audio_files(audio_dir)
click.echo(f"Number of audio files: {len(files)}")
click.echo(f"\nSaving results to: {ann_dir}")
config = api.get_config(
**{
**params,
**args,
"spec_slices": False,
"chunk_size": 2,
"detection_threshold": detection_threshold,
}
)
# process files
error_files = []
for audio_file in files:
try:
results = du.process_file(audio_file, model, run_config)
results = api.process_file(audio_file, model, config=config)
if args["save_preds_if_empty"] or (
len(results["pred_dict"]["annotation"]) > 0
):
results_path = audio_file.replace(
args["audio_dir"], args["ann_dir"]
)
du.save_results_to_file(results, results_path)
results_path = audio_file.replace(audio_dir, ann_dir)
save_results_to_file(results, results_path)
except (RuntimeError, ValueError, LookupError) as err:
# TODO: Check what other errors can be thrown
error_files.append(audio_file)
print(f"Error processing file!: {err}")
click.echo(f"Error processing file!: {err}")
raise err
print("\nResults saved to: " + args["ann_dir"])
click.echo(f"\nResults saved to: {ann_dir}")
if len(error_files) > 0:
print("\nUnable to process the follow files:")
click.echo("\nUnable to process the follow files:")
for err in error_files:
print(" " + err)
click.echo(f" {err}")
if __name__ == "__main__":
main()
cli()

View File

@ -1,16 +1,12 @@
"""Post-processing of the output of the model."""
from typing import List, Optional, Tuple
from typing import List, Tuple, Union
import numpy as np
import torch
from torch import nn
from bat_detect.detector.models import ModelOutput
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults
np.seterr(divide="ignore", invalid="ignore")
@ -42,72 +38,6 @@ def overall_class_pred(det_prob, class_prob):
return weighted_pred / weighted_pred.sum()
class NonMaximumSuppressionConfig(TypedDict):
"""Configuration for non-maximum suppression."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: int
"""Maximum frequency to consider in Hz."""
min_freq: int
"""Minimum frequency to consider in Hz."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Overlap of the FFT windows in seconds."""
resize_factor: float
"""Factor by which the input was resized."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""
class PredictionResults(TypedDict):
"""Results of the prediction.
Each key is a list of length `num_detections` containing the
corresponding values for each detection.
"""
det_probs: np.ndarray
"""Detection probabilities."""
x_pos: np.ndarray
"""X position of the detection in pixels."""
y_pos: np.ndarray
"""Y position of the detection in pixels."""
bb_width: np.ndarray
"""Width of the detection in pixels."""
bb_height: np.ndarray
"""Height of the detection in pixels."""
start_times: np.ndarray
"""Start times of the detections in seconds."""
end_times: np.ndarray
"""End times of the detections in seconds."""
low_freqs: np.ndarray
"""Low frequencies of the detections in Hz."""
high_freqs: np.ndarray
"""High frequencies of the detections in Hz."""
class_probs: Optional[np.ndarray]
"""Class probabilities."""
def run_nms(
outputs: ModelOutput,
params: NonMaximumSuppressionConfig,
@ -128,8 +58,8 @@ def run_nms(
-2
]
# NOTE: there will be small differences depending on which sampling rate is chosen
# as we are choosing the same sampling rate for the entire batch
# NOTE: there will be small differences depending on which sampling rate
# is chosen as we are choosing the same sampling rate for the entire batch
duration = x_coords_to_time(
pred_det.shape[-1],
int(sampling_rate[0].item()),
@ -211,16 +141,21 @@ def run_nms(
for key, value in pred.items():
pred[key] = value.cpu().numpy().astype(np.float32)
preds.append(pred)
preds.append(pred) # type: ignore
return preds, feats
def non_max_suppression(heat, kernel_size):
def non_max_suppression(
heat: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]],
):
# kernel can be an int or list/tuple
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2

View File

@ -1,5 +1,5 @@
"""Types used in the code base."""
from typing import List, Optional, NamedTuple
from typing import List, NamedTuple, Optional
import numpy as np
import torch
@ -16,6 +16,27 @@ except ImportError:
from typing_extensions import Protocol
try:
from typing import NotRequired
except ImportError:
from typing_extensions import NotRequired
__all__ = [
"Annotation",
"DetectionModel",
"FileAnnotations",
"ModelOutput",
"ModelParameters",
"NonMaximumSuppressionConfig",
"PredictionResults",
"ProcessingConfiguration",
"ResultParams",
"RunResults",
"SpectrogramParameters",
]
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""
@ -144,19 +165,19 @@ class RunResults(TypedDict):
pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool."""
spec_feats: Optional[List[np.ndarray]]
spec_feats: NotRequired[List[np.ndarray]]
"""Spectrogram features."""
spec_feat_names: Optional[List[str]]
spec_feat_names: NotRequired[List[str]]
"""Spectrogram feature names."""
cnn_feats: Optional[List[np.ndarray]]
cnn_feats: NotRequired[List[np.ndarray]]
"""CNN features."""
cnn_feat_names: Optional[List[str]]
cnn_feat_names: NotRequired[List[str]]
"""CNN feature names."""
spec_slices: Optional[List[np.ndarray]]
spec_slices: NotRequired[List[np.ndarray]]
"""Spectrogram slices."""
@ -166,6 +187,15 @@ class ResultParams(TypedDict):
class_names: List[str]
"""Class names."""
spec_features: bool
"""Whether to return spectrogram features."""
cnn_features: bool
"""Whether to return CNN features."""
spec_slices: bool
"""Whether to return spectrogram slices."""
class ProcessingConfiguration(TypedDict):
"""Parameters for processing audio files."""
@ -266,6 +296,44 @@ class ModelOutput(NamedTuple):
"""Tensor with intermediate features."""
class PredictionResults(TypedDict):
"""Results of the prediction.
Each key is a list of length `num_detections` containing the
corresponding values for each detection.
"""
det_probs: np.ndarray
"""Detection probabilities."""
x_pos: np.ndarray
"""X position of the detection in pixels."""
y_pos: np.ndarray
"""Y position of the detection in pixels."""
bb_width: np.ndarray
"""Width of the detection in pixels."""
bb_height: np.ndarray
"""Height of the detection in pixels."""
start_times: np.ndarray
"""Start times of the detections in seconds."""
end_times: np.ndarray
"""End times of the detections in seconds."""
low_freqs: np.ndarray
"""Low frequencies of the detections in Hz."""
high_freqs: np.ndarray
"""High frequencies of the detections in Hz."""
class_probs: Optional[np.ndarray]
"""Class probabilities."""
class DetectionModel(Protocol):
"""Protocol for detection models.
@ -286,19 +354,49 @@ class DetectionModel(Protocol):
resize_factor: float
"""Factor by which the input is resized."""
ip_height: int
ip_height_rs: int
"""Height of the input image."""
def forward(
self,
spec: torch.Tensor,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model."""
...
def __call__(
self,
spec: torch.Tensor,
ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput:
"""Forward pass of the model."""
...
class NonMaximumSuppressionConfig(TypedDict):
"""Configuration for non-maximum suppression."""
nms_kernel_size: int
"""Size of the kernel for non-maximum suppression."""
max_freq: int
"""Maximum frequency to consider in Hz."""
min_freq: int
"""Minimum frequency to consider in Hz."""
fft_win_length: float
"""Length of the FFT window in seconds."""
fft_overlap: float
"""Overlap of the FFT windows in seconds."""
resize_factor: float
"""Factor by which the input was resized."""
nms_top_k_per_sec: float
"""Number of top detections to keep per second."""
detection_threshold: float
"""Threshold for detection probability."""

View File

@ -14,13 +14,14 @@ from bat_detect.detector import models
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
from bat_detect.types import (
Annotation,
DetectionModel,
FileAnnotations,
ModelParameters,
PredictionResults,
ProcessingConfiguration,
SpectrogramParameters,
ResultParams,
RunResults,
DetectionModel
SpectrogramParameters,
)
__all__ = [
@ -73,10 +74,8 @@ def load_model(
Raises:
FileNotFoundError: Model file not found.
ValueError: Unknown model.
ValueError: Unknown model name.
"""
# load model
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -135,7 +134,8 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
)
else:
# hack in case where no detected calls as we need some of the key names in dict
# hack in case where no detected calls as we need some of the key
# names in dict
predictions_m = predictions[0]
if len(spec_feats) > 0:
@ -263,27 +263,22 @@ def convert_results(
# combine into final results dictionary
results: RunResults = {
"pred_dict": pred_dict,
"spec_feats": None,
"spec_feat_names": None,
"cnn_feats": None,
"cnn_feat_names": None,
"spec_slices": None,
}
# add spectrogram features if they exist
if len(spec_feats) > 0:
if len(spec_feats) > 0 and params["spec_features"]:
results["spec_feats"] = spec_feats
results["spec_feat_names"] = feats.get_feature_names()
# add CNN features if they exist
if len(cnn_feats) > 0:
if len(cnn_feats) > 0 and params["cnn_features"]:
results["cnn_feats"] = cnn_feats
results["cnn_feat_names"] = [
str(ii) for ii in range(cnn_feats.shape[1])
]
# add spectrogram slices if they exist
if len(spec_slices) > 0:
if len(spec_slices) > 0 and params["spec_slices"]:
results["spec_slices"] = spec_slices
return results
@ -292,6 +287,8 @@ def convert_results(
def save_results_to_file(results, op_path: str) -> None:
"""Save results to file.
Will create the output directory if it does not exist.
Args:
results (dict): Results.
op_path (str): Output path.
@ -476,7 +473,7 @@ def _process_spectrogram(
samplerate: int,
model: DetectionModel,
config: ProcessingConfiguration,
) -> Tuple[List[Annotation], List[np.ndarray]]:
) -> Tuple[PredictionResults, List[np.ndarray]]:
# evaluate model
with torch.no_grad():
outputs = model(spec)
@ -493,7 +490,6 @@ def _process_spectrogram(
"resize_factor": config["resize_factor"],
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
"detection_threshold": config["detection_threshold"],
"max_scale_spec": config["max_scale_spec"],
},
np.array([float(samplerate)]),
)
@ -561,7 +557,7 @@ def _process_audio_array(
model: DetectionModel,
config: ProcessingConfiguration,
device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
) -> Tuple[PredictionResults, List[np.ndarray], torch.Tensor]:
# load audio file and compute spectrogram
_, spec, _ = compute_spectrogram(
audio,
@ -712,7 +708,9 @@ def process_file(
predictions.append(pred_nms)
# extract features - if there are any calls detected
if pred_nms["det_probs"].shape[0] > 0:
if pred_nms["det_probs"].shape[0] == 0:
continue
if config["spec_features"]:
spec_feats.append(feats.get_feats(spec_np, pred_nms, config))

View File

@ -15,6 +15,7 @@ dependencies = [
"torch",
"torchaudio",
"torchvision",
"click",
]
requires-python = ">=3.8"
readme = "README.md"
@ -56,3 +57,8 @@ module = [
"pandas",
]
ignore_missing_imports = true
[tool.pylsp-mypy]
enabled = false
live_mode = true
strict = true

View File

@ -7,3 +7,4 @@ scipy==1.9.3
torch==1.13.0
torchaudio==0.13.0
torchvision==0.14.0
click

View File

@ -120,18 +120,13 @@ def test_process_file_with_model():
assert isinstance(predictions, dict)
assert "pred_dict" in predictions
assert "spec_feats" in predictions
assert "spec_feat_names" in predictions
assert "cnn_feats" in predictions
assert "cnn_feat_names" in predictions
assert "spec_slices" in predictions
# By default will not return spectrogram features
assert predictions["spec_feats"] is None
assert predictions["spec_feat_names"] is None
assert predictions["cnn_feats"] is None
assert predictions["cnn_feat_names"] is None
assert predictions["spec_slices"] is None
# By default will not return other features
assert "spec_feats" not in predictions
assert "spec_feat_names" not in predictions
assert "cnn_feats" not in predictions
assert "cnn_feat_names" not in predictions
assert "spec_slices" not in predictions
# Check that predictions are returned
assert isinstance(predictions["pred_dict"], dict)

View File

@ -0,0 +1,41 @@
"""Test the command line interface."""
from click.testing import CliRunner
from bat_detect.cli import cli
def test_cli_base_command():
runner = CliRunner()
result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
def test_cli_detect_command_help():
runner = CliRunner()
result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
def test_cli_detect_command_on_test_audio(tmp_path):
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3