mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added click to dependencies and made cli tests
This commit is contained in:
parent
6f2bb605d3
commit
a2deab9f3f
@ -38,7 +38,7 @@ def get_config(**kwargs) -> ProcessingConfiguration:
|
|||||||
|
|
||||||
Can be used to override default parameters by passing keyword arguments.
|
Can be used to override default parameters by passing keyword arguments.
|
||||||
"""
|
"""
|
||||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs}
|
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
|
@ -1,136 +1,141 @@
|
|||||||
"""Main script for running BatDetect2 on audio files.
|
"""BatDetect2 command line interface."""
|
||||||
|
|
||||||
Example usage:
|
|
||||||
python command.py /path/to/audio/ /path/to/ann/ 0.1
|
|
||||||
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
from bat_detect import api
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
import click # noqa: E402
|
||||||
|
|
||||||
|
from bat_detect import api # noqa: E402
|
||||||
|
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402
|
||||||
|
from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402
|
||||||
|
|
||||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
INFO_STR = """
|
||||||
info_str = (
|
BatDetect2 - Detection and Classification
|
||||||
"\nBatDetect2 - Detection and Classification\n"
|
Assumes audio files are mono, not stereo.
|
||||||
+ " Assumes audio files are mono, not stereo.\n"
|
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||||
+ ' Spaces in the input paths will throw an error. Wrap in quotes "".\n'
|
Input files should be short in duration e.g. < 30 seconds.
|
||||||
+ " Input files should be short in duration e.g. < 30 seconds.\n"
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def cli():
|
||||||
|
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||||
|
click.echo(INFO_STR)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument(
|
||||||
|
"audio_dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"ann_dir",
|
||||||
|
type=click.Path(exists=False),
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"detection_threshold",
|
||||||
|
type=float,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--cnn_features",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Extracts CNN call features",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--spec_features",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Extracts low level call features",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--time_expansion_factor",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The time expansion factor used for all files (default is 1)",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--quiet",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Minimize output printing",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--save_preds_if_empty",
|
||||||
|
is_flag=True,
|
||||||
|
default=False,
|
||||||
|
help="Save empty annotation file if no detections made.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model_path",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODEL_PATH,
|
||||||
|
help="Path to trained BatDetect2 model",
|
||||||
|
)
|
||||||
|
def detect(
|
||||||
|
audio_dir: str,
|
||||||
|
ann_dir: str,
|
||||||
|
detection_threshold: float,
|
||||||
|
**args,
|
||||||
|
):
|
||||||
|
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
||||||
|
|
||||||
|
DETECTION_THRESHOLD is the detection threshold. All predictions with a
|
||||||
|
score below this threshold will be discarded. Values between 0 and 1.
|
||||||
|
|
||||||
|
Assumes audio files are mono, not stereo.
|
||||||
|
|
||||||
|
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||||
|
|
||||||
|
Input files should be short in duration e.g. < 30 seconds.
|
||||||
|
"""
|
||||||
|
click.echo(f"Loading model: {args['model_path']}")
|
||||||
|
model, params = api.load_model(args["model_path"])
|
||||||
|
|
||||||
|
click.echo(f"\nInput directory: {audio_dir}")
|
||||||
|
files = api.list_audio_files(audio_dir)
|
||||||
|
|
||||||
|
click.echo(f"Number of audio files: {len(files)}")
|
||||||
|
click.echo(f"\nSaving results to: {ann_dir}")
|
||||||
|
|
||||||
|
config = api.get_config(
|
||||||
|
**{
|
||||||
|
**params,
|
||||||
|
**args,
|
||||||
|
"spec_slices": False,
|
||||||
|
"chunk_size": 2,
|
||||||
|
"detection_threshold": detection_threshold,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(info_str)
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
|
||||||
parser.add_argument(
|
|
||||||
"ann_dir",
|
|
||||||
type=str,
|
|
||||||
help="Output directory for where the predictions will be stored",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"detection_threshold",
|
|
||||||
type=float,
|
|
||||||
help="Cut-off probability for detector e.g. 0.1",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cnn_features",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
dest="cnn_features",
|
|
||||||
help="Extracts CNN call features",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--spec_features",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
dest="spec_features",
|
|
||||||
help="Extracts low level call features",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--time_expansion_factor",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
dest="time_expansion_factor",
|
|
||||||
help="The time expansion factor used for all files (default is 1)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--quiet",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
dest="quiet",
|
|
||||||
help="Minimize output printing",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save_preds_if_empty",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
dest="save_preds_if_empty",
|
|
||||||
help="Save empty annotation file if no detections made.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
default=du.DEFAULT_MODEL_PATH,
|
|
||||||
help="Path to trained BatDetect2 model",
|
|
||||||
)
|
|
||||||
args = vars(parser.parse_args())
|
|
||||||
|
|
||||||
args["spec_slices"] = False # used for visualization
|
|
||||||
# if files greater than this amount (seconds) they will be broken down into small chunks
|
|
||||||
args["chunk_size"] = 2
|
|
||||||
args["ann_dir"] = os.path.join(args["ann_dir"], "")
|
|
||||||
args["quiet"] = True
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
print("Loading model: " + args["model_path"])
|
|
||||||
model, params = du.load_model(args["model_path"])
|
|
||||||
|
|
||||||
print("\nInput directory: " + args["audio_dir"])
|
|
||||||
files = du.list_audio_files(args["audio_dir"])
|
|
||||||
|
|
||||||
print(f"Number of audio files: {len(files)}")
|
|
||||||
print("\nSaving results to: " + args["ann_dir"])
|
|
||||||
|
|
||||||
default_config = du.get_default_config()
|
|
||||||
|
|
||||||
# set up run config
|
|
||||||
run_config = {
|
|
||||||
**default_config,
|
|
||||||
**args,
|
|
||||||
**params,
|
|
||||||
}
|
|
||||||
|
|
||||||
# process files
|
# process files
|
||||||
error_files = []
|
error_files = []
|
||||||
for audio_file in files:
|
for audio_file in files:
|
||||||
try:
|
try:
|
||||||
results = du.process_file(audio_file, model, run_config)
|
results = api.process_file(audio_file, model, config=config)
|
||||||
|
|
||||||
if args["save_preds_if_empty"] or (
|
if args["save_preds_if_empty"] or (
|
||||||
len(results["pred_dict"]["annotation"]) > 0
|
len(results["pred_dict"]["annotation"]) > 0
|
||||||
):
|
):
|
||||||
results_path = audio_file.replace(
|
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||||
args["audio_dir"], args["ann_dir"]
|
save_results_to_file(results, results_path)
|
||||||
)
|
|
||||||
du.save_results_to_file(results, results_path)
|
|
||||||
except (RuntimeError, ValueError, LookupError) as err:
|
except (RuntimeError, ValueError, LookupError) as err:
|
||||||
|
# TODO: Check what other errors can be thrown
|
||||||
error_files.append(audio_file)
|
error_files.append(audio_file)
|
||||||
print(f"Error processing file!: {err}")
|
click.echo(f"Error processing file!: {err}")
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
print("\nResults saved to: " + args["ann_dir"])
|
click.echo(f"\nResults saved to: {ann_dir}")
|
||||||
|
|
||||||
if len(error_files) > 0:
|
if len(error_files) > 0:
|
||||||
print("\nUnable to process the follow files:")
|
click.echo("\nUnable to process the follow files:")
|
||||||
for err in error_files:
|
for err in error_files:
|
||||||
print(" " + err)
|
click.echo(f" {err}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
cli()
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
"""Post-processing of the output of the model."""
|
"""Post-processing of the output of the model."""
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from bat_detect.detector.models import ModelOutput
|
from bat_detect.detector.models import ModelOutput
|
||||||
|
from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults
|
||||||
try:
|
|
||||||
from typing import TypedDict
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
np.seterr(divide="ignore", invalid="ignore")
|
np.seterr(divide="ignore", invalid="ignore")
|
||||||
|
|
||||||
@ -42,72 +38,6 @@ def overall_class_pred(det_prob, class_prob):
|
|||||||
return weighted_pred / weighted_pred.sum()
|
return weighted_pred / weighted_pred.sum()
|
||||||
|
|
||||||
|
|
||||||
class NonMaximumSuppressionConfig(TypedDict):
|
|
||||||
"""Configuration for non-maximum suppression."""
|
|
||||||
|
|
||||||
nms_kernel_size: int
|
|
||||||
"""Size of the kernel for non-maximum suppression."""
|
|
||||||
|
|
||||||
max_freq: int
|
|
||||||
"""Maximum frequency to consider in Hz."""
|
|
||||||
|
|
||||||
min_freq: int
|
|
||||||
"""Minimum frequency to consider in Hz."""
|
|
||||||
|
|
||||||
fft_win_length: float
|
|
||||||
"""Length of the FFT window in seconds."""
|
|
||||||
|
|
||||||
fft_overlap: float
|
|
||||||
"""Overlap of the FFT windows in seconds."""
|
|
||||||
|
|
||||||
resize_factor: float
|
|
||||||
"""Factor by which the input was resized."""
|
|
||||||
|
|
||||||
nms_top_k_per_sec: float
|
|
||||||
"""Number of top detections to keep per second."""
|
|
||||||
|
|
||||||
detection_threshold: float
|
|
||||||
"""Threshold for detection probability."""
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionResults(TypedDict):
|
|
||||||
"""Results of the prediction.
|
|
||||||
|
|
||||||
Each key is a list of length `num_detections` containing the
|
|
||||||
corresponding values for each detection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
det_probs: np.ndarray
|
|
||||||
"""Detection probabilities."""
|
|
||||||
|
|
||||||
x_pos: np.ndarray
|
|
||||||
"""X position of the detection in pixels."""
|
|
||||||
|
|
||||||
y_pos: np.ndarray
|
|
||||||
"""Y position of the detection in pixels."""
|
|
||||||
|
|
||||||
bb_width: np.ndarray
|
|
||||||
"""Width of the detection in pixels."""
|
|
||||||
|
|
||||||
bb_height: np.ndarray
|
|
||||||
"""Height of the detection in pixels."""
|
|
||||||
|
|
||||||
start_times: np.ndarray
|
|
||||||
"""Start times of the detections in seconds."""
|
|
||||||
|
|
||||||
end_times: np.ndarray
|
|
||||||
"""End times of the detections in seconds."""
|
|
||||||
|
|
||||||
low_freqs: np.ndarray
|
|
||||||
"""Low frequencies of the detections in Hz."""
|
|
||||||
|
|
||||||
high_freqs: np.ndarray
|
|
||||||
"""High frequencies of the detections in Hz."""
|
|
||||||
|
|
||||||
class_probs: Optional[np.ndarray]
|
|
||||||
"""Class probabilities."""
|
|
||||||
|
|
||||||
|
|
||||||
def run_nms(
|
def run_nms(
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
params: NonMaximumSuppressionConfig,
|
params: NonMaximumSuppressionConfig,
|
||||||
@ -128,8 +58,8 @@ def run_nms(
|
|||||||
-2
|
-2
|
||||||
]
|
]
|
||||||
|
|
||||||
# NOTE: there will be small differences depending on which sampling rate is chosen
|
# NOTE: there will be small differences depending on which sampling rate
|
||||||
# as we are choosing the same sampling rate for the entire batch
|
# is chosen as we are choosing the same sampling rate for the entire batch
|
||||||
duration = x_coords_to_time(
|
duration = x_coords_to_time(
|
||||||
pred_det.shape[-1],
|
pred_det.shape[-1],
|
||||||
int(sampling_rate[0].item()),
|
int(sampling_rate[0].item()),
|
||||||
@ -211,16 +141,21 @@ def run_nms(
|
|||||||
for key, value in pred.items():
|
for key, value in pred.items():
|
||||||
pred[key] = value.cpu().numpy().astype(np.float32)
|
pred[key] = value.cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
preds.append(pred)
|
preds.append(pred) # type: ignore
|
||||||
|
|
||||||
return preds, feats
|
return preds, feats
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(heat, kernel_size):
|
def non_max_suppression(
|
||||||
|
heat: torch.Tensor,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
|
):
|
||||||
# kernel can be an int or list/tuple
|
# kernel can be an int or list/tuple
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
kernel_size_h = kernel_size
|
kernel_size_h = kernel_size
|
||||||
kernel_size_w = kernel_size
|
kernel_size_w = kernel_size
|
||||||
|
else:
|
||||||
|
kernel_size_h, kernel_size_w = kernel_size
|
||||||
|
|
||||||
pad_h = (kernel_size_h - 1) // 2
|
pad_h = (kernel_size_h - 1) // 2
|
||||||
pad_w = (kernel_size_w - 1) // 2
|
pad_w = (kernel_size_w - 1) // 2
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
from typing import List, Optional, NamedTuple
|
from typing import List, NamedTuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -16,6 +16,27 @@ except ImportError:
|
|||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import NotRequired
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Annotation",
|
||||||
|
"DetectionModel",
|
||||||
|
"FileAnnotations",
|
||||||
|
"ModelOutput",
|
||||||
|
"ModelParameters",
|
||||||
|
"NonMaximumSuppressionConfig",
|
||||||
|
"PredictionResults",
|
||||||
|
"ProcessingConfiguration",
|
||||||
|
"ResultParams",
|
||||||
|
"RunResults",
|
||||||
|
"SpectrogramParameters",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramParameters(TypedDict):
|
class SpectrogramParameters(TypedDict):
|
||||||
"""Parameters for generating spectrograms."""
|
"""Parameters for generating spectrograms."""
|
||||||
|
|
||||||
@ -144,19 +165,19 @@ class RunResults(TypedDict):
|
|||||||
pred_dict: FileAnnotations
|
pred_dict: FileAnnotations
|
||||||
"""Predictions in the format expected by the annotation tool."""
|
"""Predictions in the format expected by the annotation tool."""
|
||||||
|
|
||||||
spec_feats: Optional[List[np.ndarray]]
|
spec_feats: NotRequired[List[np.ndarray]]
|
||||||
"""Spectrogram features."""
|
"""Spectrogram features."""
|
||||||
|
|
||||||
spec_feat_names: Optional[List[str]]
|
spec_feat_names: NotRequired[List[str]]
|
||||||
"""Spectrogram feature names."""
|
"""Spectrogram feature names."""
|
||||||
|
|
||||||
cnn_feats: Optional[List[np.ndarray]]
|
cnn_feats: NotRequired[List[np.ndarray]]
|
||||||
"""CNN features."""
|
"""CNN features."""
|
||||||
|
|
||||||
cnn_feat_names: Optional[List[str]]
|
cnn_feat_names: NotRequired[List[str]]
|
||||||
"""CNN feature names."""
|
"""CNN feature names."""
|
||||||
|
|
||||||
spec_slices: Optional[List[np.ndarray]]
|
spec_slices: NotRequired[List[np.ndarray]]
|
||||||
"""Spectrogram slices."""
|
"""Spectrogram slices."""
|
||||||
|
|
||||||
|
|
||||||
@ -166,6 +187,15 @@ class ResultParams(TypedDict):
|
|||||||
class_names: List[str]
|
class_names: List[str]
|
||||||
"""Class names."""
|
"""Class names."""
|
||||||
|
|
||||||
|
spec_features: bool
|
||||||
|
"""Whether to return spectrogram features."""
|
||||||
|
|
||||||
|
cnn_features: bool
|
||||||
|
"""Whether to return CNN features."""
|
||||||
|
|
||||||
|
spec_slices: bool
|
||||||
|
"""Whether to return spectrogram slices."""
|
||||||
|
|
||||||
|
|
||||||
class ProcessingConfiguration(TypedDict):
|
class ProcessingConfiguration(TypedDict):
|
||||||
"""Parameters for processing audio files."""
|
"""Parameters for processing audio files."""
|
||||||
@ -266,6 +296,44 @@ class ModelOutput(NamedTuple):
|
|||||||
"""Tensor with intermediate features."""
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionResults(TypedDict):
|
||||||
|
"""Results of the prediction.
|
||||||
|
|
||||||
|
Each key is a list of length `num_detections` containing the
|
||||||
|
corresponding values for each detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
det_probs: np.ndarray
|
||||||
|
"""Detection probabilities."""
|
||||||
|
|
||||||
|
x_pos: np.ndarray
|
||||||
|
"""X position of the detection in pixels."""
|
||||||
|
|
||||||
|
y_pos: np.ndarray
|
||||||
|
"""Y position of the detection in pixels."""
|
||||||
|
|
||||||
|
bb_width: np.ndarray
|
||||||
|
"""Width of the detection in pixels."""
|
||||||
|
|
||||||
|
bb_height: np.ndarray
|
||||||
|
"""Height of the detection in pixels."""
|
||||||
|
|
||||||
|
start_times: np.ndarray
|
||||||
|
"""Start times of the detections in seconds."""
|
||||||
|
|
||||||
|
end_times: np.ndarray
|
||||||
|
"""End times of the detections in seconds."""
|
||||||
|
|
||||||
|
low_freqs: np.ndarray
|
||||||
|
"""Low frequencies of the detections in Hz."""
|
||||||
|
|
||||||
|
high_freqs: np.ndarray
|
||||||
|
"""High frequencies of the detections in Hz."""
|
||||||
|
|
||||||
|
class_probs: Optional[np.ndarray]
|
||||||
|
"""Class probabilities."""
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(Protocol):
|
class DetectionModel(Protocol):
|
||||||
"""Protocol for detection models.
|
"""Protocol for detection models.
|
||||||
|
|
||||||
@ -286,19 +354,49 @@ class DetectionModel(Protocol):
|
|||||||
resize_factor: float
|
resize_factor: float
|
||||||
"""Factor by which the input is resized."""
|
"""Factor by which the input is resized."""
|
||||||
|
|
||||||
ip_height: int
|
ip_height_rs: int
|
||||||
"""Height of the input image."""
|
"""Height of the input image."""
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
spec: torch.Tensor,
|
ip: torch.Tensor,
|
||||||
return_feats: bool = False,
|
return_feats: bool = False,
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
"""Forward pass of the model."""
|
"""Forward pass of the model."""
|
||||||
|
...
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
spec: torch.Tensor,
|
ip: torch.Tensor,
|
||||||
return_feats: bool = False,
|
return_feats: bool = False,
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
"""Forward pass of the model."""
|
"""Forward pass of the model."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class NonMaximumSuppressionConfig(TypedDict):
|
||||||
|
"""Configuration for non-maximum suppression."""
|
||||||
|
|
||||||
|
nms_kernel_size: int
|
||||||
|
"""Size of the kernel for non-maximum suppression."""
|
||||||
|
|
||||||
|
max_freq: int
|
||||||
|
"""Maximum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
min_freq: int
|
||||||
|
"""Minimum frequency to consider in Hz."""
|
||||||
|
|
||||||
|
fft_win_length: float
|
||||||
|
"""Length of the FFT window in seconds."""
|
||||||
|
|
||||||
|
fft_overlap: float
|
||||||
|
"""Overlap of the FFT windows in seconds."""
|
||||||
|
|
||||||
|
resize_factor: float
|
||||||
|
"""Factor by which the input was resized."""
|
||||||
|
|
||||||
|
nms_top_k_per_sec: float
|
||||||
|
"""Number of top detections to keep per second."""
|
||||||
|
|
||||||
|
detection_threshold: float
|
||||||
|
"""Threshold for detection probability."""
|
||||||
|
@ -14,13 +14,14 @@ from bat_detect.detector import models
|
|||||||
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
|
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
|
||||||
from bat_detect.types import (
|
from bat_detect.types import (
|
||||||
Annotation,
|
Annotation,
|
||||||
|
DetectionModel,
|
||||||
FileAnnotations,
|
FileAnnotations,
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
|
PredictionResults,
|
||||||
ProcessingConfiguration,
|
ProcessingConfiguration,
|
||||||
SpectrogramParameters,
|
|
||||||
ResultParams,
|
ResultParams,
|
||||||
RunResults,
|
RunResults,
|
||||||
DetectionModel
|
SpectrogramParameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -73,10 +74,8 @@ def load_model(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: Model file not found.
|
FileNotFoundError: Model file not found.
|
||||||
ValueError: Unknown model.
|
ValueError: Unknown model name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# load model
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
@ -135,7 +134,8 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
|||||||
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# hack in case where no detected calls as we need some of the key names in dict
|
# hack in case where no detected calls as we need some of the key
|
||||||
|
# names in dict
|
||||||
predictions_m = predictions[0]
|
predictions_m = predictions[0]
|
||||||
|
|
||||||
if len(spec_feats) > 0:
|
if len(spec_feats) > 0:
|
||||||
@ -263,27 +263,22 @@ def convert_results(
|
|||||||
# combine into final results dictionary
|
# combine into final results dictionary
|
||||||
results: RunResults = {
|
results: RunResults = {
|
||||||
"pred_dict": pred_dict,
|
"pred_dict": pred_dict,
|
||||||
"spec_feats": None,
|
|
||||||
"spec_feat_names": None,
|
|
||||||
"cnn_feats": None,
|
|
||||||
"cnn_feat_names": None,
|
|
||||||
"spec_slices": None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# add spectrogram features if they exist
|
# add spectrogram features if they exist
|
||||||
if len(spec_feats) > 0:
|
if len(spec_feats) > 0 and params["spec_features"]:
|
||||||
results["spec_feats"] = spec_feats
|
results["spec_feats"] = spec_feats
|
||||||
results["spec_feat_names"] = feats.get_feature_names()
|
results["spec_feat_names"] = feats.get_feature_names()
|
||||||
|
|
||||||
# add CNN features if they exist
|
# add CNN features if they exist
|
||||||
if len(cnn_feats) > 0:
|
if len(cnn_feats) > 0 and params["cnn_features"]:
|
||||||
results["cnn_feats"] = cnn_feats
|
results["cnn_feats"] = cnn_feats
|
||||||
results["cnn_feat_names"] = [
|
results["cnn_feat_names"] = [
|
||||||
str(ii) for ii in range(cnn_feats.shape[1])
|
str(ii) for ii in range(cnn_feats.shape[1])
|
||||||
]
|
]
|
||||||
|
|
||||||
# add spectrogram slices if they exist
|
# add spectrogram slices if they exist
|
||||||
if len(spec_slices) > 0:
|
if len(spec_slices) > 0 and params["spec_slices"]:
|
||||||
results["spec_slices"] = spec_slices
|
results["spec_slices"] = spec_slices
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -292,6 +287,8 @@ def convert_results(
|
|||||||
def save_results_to_file(results, op_path: str) -> None:
|
def save_results_to_file(results, op_path: str) -> None:
|
||||||
"""Save results to file.
|
"""Save results to file.
|
||||||
|
|
||||||
|
Will create the output directory if it does not exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
results (dict): Results.
|
results (dict): Results.
|
||||||
op_path (str): Output path.
|
op_path (str): Output path.
|
||||||
@ -476,7 +473,7 @@ def _process_spectrogram(
|
|||||||
samplerate: int,
|
samplerate: int,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
) -> Tuple[PredictionResults, List[np.ndarray]]:
|
||||||
# evaluate model
|
# evaluate model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(spec)
|
outputs = model(spec)
|
||||||
@ -493,7 +490,6 @@ def _process_spectrogram(
|
|||||||
"resize_factor": config["resize_factor"],
|
"resize_factor": config["resize_factor"],
|
||||||
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
||||||
"detection_threshold": config["detection_threshold"],
|
"detection_threshold": config["detection_threshold"],
|
||||||
"max_scale_spec": config["max_scale_spec"],
|
|
||||||
},
|
},
|
||||||
np.array([float(samplerate)]),
|
np.array([float(samplerate)]),
|
||||||
)
|
)
|
||||||
@ -561,7 +557,7 @@ def _process_audio_array(
|
|||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
) -> Tuple[PredictionResults, List[np.ndarray], torch.Tensor]:
|
||||||
# load audio file and compute spectrogram
|
# load audio file and compute spectrogram
|
||||||
_, spec, _ = compute_spectrogram(
|
_, spec, _ = compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
@ -712,17 +708,19 @@ def process_file(
|
|||||||
predictions.append(pred_nms)
|
predictions.append(pred_nms)
|
||||||
|
|
||||||
# extract features - if there are any calls detected
|
# extract features - if there are any calls detected
|
||||||
if pred_nms["det_probs"].shape[0] > 0:
|
if pred_nms["det_probs"].shape[0] == 0:
|
||||||
if config["spec_features"]:
|
continue
|
||||||
spec_feats.append(feats.get_feats(spec_np, pred_nms, config))
|
|
||||||
|
|
||||||
if config["cnn_features"]:
|
if config["spec_features"]:
|
||||||
cnn_feats.append(features[0])
|
spec_feats.append(feats.get_feats(spec_np, pred_nms, config))
|
||||||
|
|
||||||
if config["spec_slices"]:
|
if config["cnn_features"]:
|
||||||
spec_slices.extend(
|
cnn_feats.append(features[0])
|
||||||
feats.extract_spec_slices(spec_np, pred_nms, config)
|
|
||||||
)
|
if config["spec_slices"]:
|
||||||
|
spec_slices.extend(
|
||||||
|
feats.extract_spec_slices(spec_np, pred_nms, config)
|
||||||
|
)
|
||||||
|
|
||||||
# Merge results from chunks
|
# Merge results from chunks
|
||||||
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
|
||||||
|
@ -15,6 +15,7 @@ dependencies = [
|
|||||||
"torch",
|
"torch",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
"click",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -56,3 +57,8 @@ module = [
|
|||||||
"pandas",
|
"pandas",
|
||||||
]
|
]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[tool.pylsp-mypy]
|
||||||
|
enabled = false
|
||||||
|
live_mode = true
|
||||||
|
strict = true
|
||||||
|
@ -7,3 +7,4 @@ scipy==1.9.3
|
|||||||
torch==1.13.0
|
torch==1.13.0
|
||||||
torchaudio==0.13.0
|
torchaudio==0.13.0
|
||||||
torchvision==0.14.0
|
torchvision==0.14.0
|
||||||
|
click
|
||||||
|
@ -120,18 +120,13 @@ def test_process_file_with_model():
|
|||||||
assert isinstance(predictions, dict)
|
assert isinstance(predictions, dict)
|
||||||
|
|
||||||
assert "pred_dict" in predictions
|
assert "pred_dict" in predictions
|
||||||
assert "spec_feats" in predictions
|
|
||||||
assert "spec_feat_names" in predictions
|
|
||||||
assert "cnn_feats" in predictions
|
|
||||||
assert "cnn_feat_names" in predictions
|
|
||||||
assert "spec_slices" in predictions
|
|
||||||
|
|
||||||
# By default will not return spectrogram features
|
# By default will not return other features
|
||||||
assert predictions["spec_feats"] is None
|
assert "spec_feats" not in predictions
|
||||||
assert predictions["spec_feat_names"] is None
|
assert "spec_feat_names" not in predictions
|
||||||
assert predictions["cnn_feats"] is None
|
assert "cnn_feats" not in predictions
|
||||||
assert predictions["cnn_feat_names"] is None
|
assert "cnn_feat_names" not in predictions
|
||||||
assert predictions["spec_slices"] is None
|
assert "spec_slices" not in predictions
|
||||||
|
|
||||||
# Check that predictions are returned
|
# Check that predictions are returned
|
||||||
assert isinstance(predictions["pred_dict"], dict)
|
assert isinstance(predictions["pred_dict"], dict)
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
"""Test the command line interface."""
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from bat_detect.cli import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_base_command():
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_command_help():
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["detect", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_command_on_test_audio(tmp_path):
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
|
||||||
|
# Remove results dir if it exists
|
||||||
|
if results_dir.exists():
|
||||||
|
results_dir.rmdir()
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
"example_data/audio",
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert results_dir.exists()
|
||||||
|
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||||
|
assert len(list(results_dir.glob("*.json"))) == 3
|
Loading…
Reference in New Issue
Block a user