mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added click to dependencies and made cli tests
This commit is contained in:
parent
6f2bb605d3
commit
a2deab9f3f
@ -38,7 +38,7 @@ def get_config(**kwargs) -> ProcessingConfiguration:
|
||||
|
||||
Can be used to override default parameters by passing keyword arguments.
|
||||
"""
|
||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
|
||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
||||
|
||||
|
||||
def load_audio(
|
||||
|
@ -1,136 +1,141 @@
|
||||
"""Main script for running BatDetect2 on audio files.
|
||||
|
||||
Example usage:
|
||||
python command.py /path/to/audio/ /path/to/ann/ 0.1
|
||||
|
||||
"""
|
||||
import argparse
|
||||
"""BatDetect2 command line interface."""
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from bat_detect import api
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
import click # noqa: E402
|
||||
|
||||
from bat_detect import api # noqa: E402
|
||||
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402
|
||||
from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def parse_args():
|
||||
info_str = (
|
||||
"\nBatDetect2 - Detection and Classification\n"
|
||||
+ " Assumes audio files are mono, not stereo.\n"
|
||||
+ ' Spaces in the input paths will throw an error. Wrap in quotes "".\n'
|
||||
+ " Input files should be short in duration e.g. < 30 seconds.\n"
|
||||
INFO_STR = """
|
||||
BatDetect2 - Detection and Classification
|
||||
Assumes audio files are mono, not stereo.
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
"audio_dir",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.argument(
|
||||
"ann_dir",
|
||||
type=click.Path(exists=False),
|
||||
)
|
||||
@click.argument(
|
||||
"detection_threshold",
|
||||
type=float,
|
||||
)
|
||||
@click.option(
|
||||
"--cnn_features",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Extracts CNN call features",
|
||||
)
|
||||
@click.option(
|
||||
"--spec_features",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Extracts low level call features",
|
||||
)
|
||||
@click.option(
|
||||
"--time_expansion_factor",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The time expansion factor used for all files (default is 1)",
|
||||
)
|
||||
@click.option(
|
||||
"--quiet",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Minimize output printing",
|
||||
)
|
||||
@click.option(
|
||||
"--save_preds_if_empty",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Save empty annotation file if no detections made.",
|
||||
)
|
||||
@click.option(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default=DEFAULT_MODEL_PATH,
|
||||
help="Path to trained BatDetect2 model",
|
||||
)
|
||||
def detect(
|
||||
audio_dir: str,
|
||||
ann_dir: str,
|
||||
detection_threshold: float,
|
||||
**args,
|
||||
):
|
||||
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
||||
|
||||
DETECTION_THRESHOLD is the detection threshold. All predictions with a
|
||||
score below this threshold will be discarded. Values between 0 and 1.
|
||||
|
||||
Assumes audio files are mono, not stereo.
|
||||
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
click.echo(f"Loading model: {args['model_path']}")
|
||||
model, params = api.load_model(args["model_path"])
|
||||
|
||||
click.echo(f"\nInput directory: {audio_dir}")
|
||||
files = api.list_audio_files(audio_dir)
|
||||
|
||||
click.echo(f"Number of audio files: {len(files)}")
|
||||
click.echo(f"\nSaving results to: {ann_dir}")
|
||||
|
||||
config = api.get_config(
|
||||
**{
|
||||
**params,
|
||||
**args,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 2,
|
||||
"detection_threshold": detection_threshold,
|
||||
}
|
||||
)
|
||||
|
||||
print(info_str)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
||||
parser.add_argument(
|
||||
"ann_dir",
|
||||
type=str,
|
||||
help="Output directory for where the predictions will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"detection_threshold",
|
||||
type=float,
|
||||
help="Cut-off probability for detector e.g. 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cnn_features",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="cnn_features",
|
||||
help="Extracts CNN call features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spec_features",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="spec_features",
|
||||
help="Extracts low level call features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time_expansion_factor",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="time_expansion_factor",
|
||||
help="The time expansion factor used for all files (default is 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="quiet",
|
||||
help="Minimize output printing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_preds_if_empty",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="save_preds_if_empty",
|
||||
help="Save empty annotation file if no detections made.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default=du.DEFAULT_MODEL_PATH,
|
||||
help="Path to trained BatDetect2 model",
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["spec_slices"] = False # used for visualization
|
||||
# if files greater than this amount (seconds) they will be broken down into small chunks
|
||||
args["chunk_size"] = 2
|
||||
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
||||
args["quiet"] = True
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
print("Loading model: " + args["model_path"])
|
||||
model, params = du.load_model(args["model_path"])
|
||||
|
||||
print("\nInput directory: " + args["audio_dir"])
|
||||
files = du.list_audio_files(args["audio_dir"])
|
||||
|
||||
print(f"Number of audio files: {len(files)}")
|
||||
print("\nSaving results to: " + args["ann_dir"])
|
||||
|
||||
default_config = du.get_default_config()
|
||||
|
||||
# set up run config
|
||||
run_config = {
|
||||
**default_config,
|
||||
**args,
|
||||
**params,
|
||||
}
|
||||
|
||||
# process files
|
||||
error_files = []
|
||||
for audio_file in files:
|
||||
try:
|
||||
results = du.process_file(audio_file, model, run_config)
|
||||
results = api.process_file(audio_file, model, config=config)
|
||||
|
||||
if args["save_preds_if_empty"] or (
|
||||
len(results["pred_dict"]["annotation"]) > 0
|
||||
):
|
||||
results_path = audio_file.replace(
|
||||
args["audio_dir"], args["ann_dir"]
|
||||
)
|
||||
du.save_results_to_file(results, results_path)
|
||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||
save_results_to_file(results, results_path)
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
# TODO: Check what other errors can be thrown
|
||||
error_files.append(audio_file)
|
||||
print(f"Error processing file!: {err}")
|
||||
click.echo(f"Error processing file!: {err}")
|
||||
raise err
|
||||
|
||||
print("\nResults saved to: " + args["ann_dir"])
|
||||
click.echo(f"\nResults saved to: {ann_dir}")
|
||||
|
||||
if len(error_files) > 0:
|
||||
print("\nUnable to process the follow files:")
|
||||
click.echo("\nUnable to process the follow files:")
|
||||
for err in error_files:
|
||||
print(" " + err)
|
||||
click.echo(f" {err}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
cli()
|
||||
|
@ -1,16 +1,12 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from bat_detect.detector.models import ModelOutput
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults
|
||||
|
||||
np.seterr(divide="ignore", invalid="ignore")
|
||||
|
||||
@ -42,72 +38,6 @@ def overall_class_pred(det_prob, class_prob):
|
||||
return weighted_pred / weighted_pred.sum()
|
||||
|
||||
|
||||
class NonMaximumSuppressionConfig(TypedDict):
|
||||
"""Configuration for non-maximum suppression."""
|
||||
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Overlap of the FFT windows in seconds."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input was resized."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
|
||||
class PredictionResults(TypedDict):
|
||||
"""Results of the prediction.
|
||||
|
||||
Each key is a list of length `num_detections` containing the
|
||||
corresponding values for each detection.
|
||||
"""
|
||||
|
||||
det_probs: np.ndarray
|
||||
"""Detection probabilities."""
|
||||
|
||||
x_pos: np.ndarray
|
||||
"""X position of the detection in pixels."""
|
||||
|
||||
y_pos: np.ndarray
|
||||
"""Y position of the detection in pixels."""
|
||||
|
||||
bb_width: np.ndarray
|
||||
"""Width of the detection in pixels."""
|
||||
|
||||
bb_height: np.ndarray
|
||||
"""Height of the detection in pixels."""
|
||||
|
||||
start_times: np.ndarray
|
||||
"""Start times of the detections in seconds."""
|
||||
|
||||
end_times: np.ndarray
|
||||
"""End times of the detections in seconds."""
|
||||
|
||||
low_freqs: np.ndarray
|
||||
"""Low frequencies of the detections in Hz."""
|
||||
|
||||
high_freqs: np.ndarray
|
||||
"""High frequencies of the detections in Hz."""
|
||||
|
||||
class_probs: Optional[np.ndarray]
|
||||
"""Class probabilities."""
|
||||
|
||||
|
||||
def run_nms(
|
||||
outputs: ModelOutput,
|
||||
params: NonMaximumSuppressionConfig,
|
||||
@ -128,8 +58,8 @@ def run_nms(
|
||||
-2
|
||||
]
|
||||
|
||||
# NOTE: there will be small differences depending on which sampling rate is chosen
|
||||
# as we are choosing the same sampling rate for the entire batch
|
||||
# NOTE: there will be small differences depending on which sampling rate
|
||||
# is chosen as we are choosing the same sampling rate for the entire batch
|
||||
duration = x_coords_to_time(
|
||||
pred_det.shape[-1],
|
||||
int(sampling_rate[0].item()),
|
||||
@ -211,16 +141,21 @@ def run_nms(
|
||||
for key, value in pred.items():
|
||||
pred[key] = value.cpu().numpy().astype(np.float32)
|
||||
|
||||
preds.append(pred)
|
||||
preds.append(pred) # type: ignore
|
||||
|
||||
return preds, feats
|
||||
|
||||
|
||||
def non_max_suppression(heat, kernel_size):
|
||||
def non_max_suppression(
|
||||
heat: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
):
|
||||
# kernel can be an int or list/tuple
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
else:
|
||||
kernel_size_h, kernel_size_w = kernel_size
|
||||
|
||||
pad_h = (kernel_size_h - 1) // 2
|
||||
pad_w = (kernel_size_w - 1) // 2
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Types used in the code base."""
|
||||
from typing import List, Optional, NamedTuple
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -16,6 +16,27 @@ except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Annotation",
|
||||
"DetectionModel",
|
||||
"FileAnnotations",
|
||||
"ModelOutput",
|
||||
"ModelParameters",
|
||||
"NonMaximumSuppressionConfig",
|
||||
"PredictionResults",
|
||||
"ProcessingConfiguration",
|
||||
"ResultParams",
|
||||
"RunResults",
|
||||
"SpectrogramParameters",
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
||||
@ -144,19 +165,19 @@ class RunResults(TypedDict):
|
||||
pred_dict: FileAnnotations
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: Optional[List[np.ndarray]]
|
||||
spec_feats: NotRequired[List[np.ndarray]]
|
||||
"""Spectrogram features."""
|
||||
|
||||
spec_feat_names: Optional[List[str]]
|
||||
spec_feat_names: NotRequired[List[str]]
|
||||
"""Spectrogram feature names."""
|
||||
|
||||
cnn_feats: Optional[List[np.ndarray]]
|
||||
cnn_feats: NotRequired[List[np.ndarray]]
|
||||
"""CNN features."""
|
||||
|
||||
cnn_feat_names: Optional[List[str]]
|
||||
cnn_feat_names: NotRequired[List[str]]
|
||||
"""CNN feature names."""
|
||||
|
||||
spec_slices: Optional[List[np.ndarray]]
|
||||
spec_slices: NotRequired[List[np.ndarray]]
|
||||
"""Spectrogram slices."""
|
||||
|
||||
|
||||
@ -166,6 +187,15 @@ class ResultParams(TypedDict):
|
||||
class_names: List[str]
|
||||
"""Class names."""
|
||||
|
||||
spec_features: bool
|
||||
"""Whether to return spectrogram features."""
|
||||
|
||||
cnn_features: bool
|
||||
"""Whether to return CNN features."""
|
||||
|
||||
spec_slices: bool
|
||||
"""Whether to return spectrogram slices."""
|
||||
|
||||
|
||||
class ProcessingConfiguration(TypedDict):
|
||||
"""Parameters for processing audio files."""
|
||||
@ -266,6 +296,44 @@ class ModelOutput(NamedTuple):
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class PredictionResults(TypedDict):
|
||||
"""Results of the prediction.
|
||||
|
||||
Each key is a list of length `num_detections` containing the
|
||||
corresponding values for each detection.
|
||||
"""
|
||||
|
||||
det_probs: np.ndarray
|
||||
"""Detection probabilities."""
|
||||
|
||||
x_pos: np.ndarray
|
||||
"""X position of the detection in pixels."""
|
||||
|
||||
y_pos: np.ndarray
|
||||
"""Y position of the detection in pixels."""
|
||||
|
||||
bb_width: np.ndarray
|
||||
"""Width of the detection in pixels."""
|
||||
|
||||
bb_height: np.ndarray
|
||||
"""Height of the detection in pixels."""
|
||||
|
||||
start_times: np.ndarray
|
||||
"""Start times of the detections in seconds."""
|
||||
|
||||
end_times: np.ndarray
|
||||
"""End times of the detections in seconds."""
|
||||
|
||||
low_freqs: np.ndarray
|
||||
"""Low frequencies of the detections in Hz."""
|
||||
|
||||
high_freqs: np.ndarray
|
||||
"""High frequencies of the detections in Hz."""
|
||||
|
||||
class_probs: Optional[np.ndarray]
|
||||
"""Class probabilities."""
|
||||
|
||||
|
||||
class DetectionModel(Protocol):
|
||||
"""Protocol for detection models.
|
||||
|
||||
@ -286,19 +354,49 @@ class DetectionModel(Protocol):
|
||||
resize_factor: float
|
||||
"""Factor by which the input is resized."""
|
||||
|
||||
ip_height: int
|
||||
ip_height_rs: int
|
||||
"""Height of the input image."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
...
|
||||
|
||||
|
||||
class NonMaximumSuppressionConfig(TypedDict):
|
||||
"""Configuration for non-maximum suppression."""
|
||||
|
||||
nms_kernel_size: int
|
||||
"""Size of the kernel for non-maximum suppression."""
|
||||
|
||||
max_freq: int
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
min_freq: int
|
||||
"""Minimum frequency to consider in Hz."""
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
|
||||
fft_overlap: float
|
||||
"""Overlap of the FFT windows in seconds."""
|
||||
|
||||
resize_factor: float
|
||||
"""Factor by which the input was resized."""
|
||||
|
||||
nms_top_k_per_sec: float
|
||||
"""Number of top detections to keep per second."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
@ -14,13 +14,14 @@ from bat_detect.detector import models
|
||||
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from bat_detect.types import (
|
||||
Annotation,
|
||||
DetectionModel,
|
||||
FileAnnotations,
|
||||
ModelParameters,
|
||||
PredictionResults,
|
||||
ProcessingConfiguration,
|
||||
SpectrogramParameters,
|
||||
ResultParams,
|
||||
RunResults,
|
||||
DetectionModel
|
||||
SpectrogramParameters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -73,10 +74,8 @@ def load_model(
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Model file not found.
|
||||
ValueError: Unknown model.
|
||||
ValueError: Unknown model name.
|
||||
"""
|
||||
|
||||
# load model
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@ -135,7 +134,8 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
||||
)
|
||||
else:
|
||||
# hack in case where no detected calls as we need some of the key names in dict
|
||||
# hack in case where no detected calls as we need some of the key
|
||||
# names in dict
|
||||
predictions_m = predictions[0]
|
||||
|
||||
if len(spec_feats) > 0:
|
||||
@ -263,27 +263,22 @@ def convert_results(
|
||||
# combine into final results dictionary
|
||||
results: RunResults = {
|
||||
"pred_dict": pred_dict,
|
||||
"spec_feats": None,
|
||||
"spec_feat_names": None,
|
||||
"cnn_feats": None,
|
||||
"cnn_feat_names": None,
|
||||
"spec_slices": None,
|
||||
}
|
||||
|
||||
# add spectrogram features if they exist
|
||||
if len(spec_feats) > 0:
|
||||
if len(spec_feats) > 0 and params["spec_features"]:
|
||||
results["spec_feats"] = spec_feats
|
||||
results["spec_feat_names"] = feats.get_feature_names()
|
||||
|
||||
# add CNN features if they exist
|
||||
if len(cnn_feats) > 0:
|
||||
if len(cnn_feats) > 0 and params["cnn_features"]:
|
||||
results["cnn_feats"] = cnn_feats
|
||||
results["cnn_feat_names"] = [
|
||||
str(ii) for ii in range(cnn_feats.shape[1])
|
||||
]
|
||||
|
||||
# add spectrogram slices if they exist
|
||||
if len(spec_slices) > 0:
|
||||
if len(spec_slices) > 0 and params["spec_slices"]:
|
||||
results["spec_slices"] = spec_slices
|
||||
|
||||
return results
|
||||
@ -292,6 +287,8 @@ def convert_results(
|
||||
def save_results_to_file(results, op_path: str) -> None:
|
||||
"""Save results to file.
|
||||
|
||||
Will create the output directory if it does not exist.
|
||||
|
||||
Args:
|
||||
results (dict): Results.
|
||||
op_path (str): Output path.
|
||||
@ -476,7 +473,7 @@ def _process_spectrogram(
|
||||
samplerate: int,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
) -> Tuple[PredictionResults, List[np.ndarray]]:
|
||||
# evaluate model
|
||||
with torch.no_grad():
|
||||
outputs = model(spec)
|
||||
@ -493,7 +490,6 @@ def _process_spectrogram(
|
||||
"resize_factor": config["resize_factor"],
|
||||
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
||||
"detection_threshold": config["detection_threshold"],
|
||||
"max_scale_spec": config["max_scale_spec"],
|
||||
},
|
||||
np.array([float(samplerate)]),
|
||||
)
|
||||
@ -561,7 +557,7 @@ def _process_audio_array(
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
) -> Tuple[PredictionResults, List[np.ndarray], torch.Tensor]:
|
||||
# load audio file and compute spectrogram
|
||||
_, spec, _ = compute_spectrogram(
|
||||
audio,
|
||||
@ -712,17 +708,19 @@ def process_file(
|
||||
predictions.append(pred_nms)
|
||||
|
||||
# extract features - if there are any calls detected
|
||||
if pred_nms["det_probs"].shape[0] > 0:
|
||||
if config["spec_features"]:
|
||||
spec_feats.append(feats.get_feats(spec_np, pred_nms, config))
|
||||
if pred_nms["det_probs"].shape[0] == 0:
|
||||
continue
|
||||
|
||||
if config["cnn_features"]:
|
||||
cnn_feats.append(features[0])
|
||||
if config["spec_features"]:
|
||||
spec_feats.append(feats.get_feats(spec_np, pred_nms, config))
|
||||
|
||||
if config["spec_slices"]:
|
||||
spec_slices.extend(
|
||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
||||
)
|
||||
if config["cnn_features"]:
|
||||
cnn_feats.append(features[0])
|
||||
|
||||
if config["spec_slices"]:
|
||||
spec_slices.extend(
|
||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
||||
)
|
||||
|
||||
# Merge results from chunks
|
||||
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
||||
|
@ -15,6 +15,7 @@ dependencies = [
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"click",
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
readme = "README.md"
|
||||
@ -56,3 +57,8 @@ module = [
|
||||
"pandas",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pylsp-mypy]
|
||||
enabled = false
|
||||
live_mode = true
|
||||
strict = true
|
||||
|
@ -7,3 +7,4 @@ scipy==1.9.3
|
||||
torch==1.13.0
|
||||
torchaudio==0.13.0
|
||||
torchvision==0.14.0
|
||||
click
|
||||
|
@ -120,18 +120,13 @@ def test_process_file_with_model():
|
||||
assert isinstance(predictions, dict)
|
||||
|
||||
assert "pred_dict" in predictions
|
||||
assert "spec_feats" in predictions
|
||||
assert "spec_feat_names" in predictions
|
||||
assert "cnn_feats" in predictions
|
||||
assert "cnn_feat_names" in predictions
|
||||
assert "spec_slices" in predictions
|
||||
|
||||
# By default will not return spectrogram features
|
||||
assert predictions["spec_feats"] is None
|
||||
assert predictions["spec_feat_names"] is None
|
||||
assert predictions["cnn_feats"] is None
|
||||
assert predictions["cnn_feat_names"] is None
|
||||
assert predictions["spec_slices"] is None
|
||||
# By default will not return other features
|
||||
assert "spec_feats" not in predictions
|
||||
assert "spec_feat_names" not in predictions
|
||||
assert "cnn_feats" not in predictions
|
||||
assert "cnn_feat_names" not in predictions
|
||||
assert "spec_slices" not in predictions
|
||||
|
||||
# Check that predictions are returned
|
||||
assert isinstance(predictions["pred_dict"], dict)
|
||||
|
@ -0,0 +1,41 @@
|
||||
"""Test the command line interface."""
|
||||
from click.testing import CliRunner
|
||||
|
||||
from bat_detect.cli import cli
|
||||
|
||||
|
||||
def test_cli_base_command():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||
|
||||
|
||||
def test_cli_detect_command_help():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["detect", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||
|
||||
|
||||
def test_cli_detect_command_on_test_audio(tmp_path):
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
# Remove results dir if it exists
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||
assert len(list(results_dir.glob("*.json"))) == 3
|
Loading…
Reference in New Issue
Block a user