Checked that returned features is a 2d numpy array

This commit is contained in:
Santiago Martinez 2023-04-07 14:46:38 -06:00
parent 29074c689e
commit 0a22f1798e
3 changed files with 57 additions and 30 deletions

View File

@ -115,6 +115,7 @@ from batdetect2.types import (
DetectionModel,
ModelOutput,
ProcessingConfiguration,
RunResults,
SpectrogramParameters,
)
from batdetect2.utils.detector_utils import list_audio_files, load_model
@ -134,6 +135,7 @@ __all__ = [
"process_audio",
"process_file",
"process_spectrogram",
"print_summary",
]
@ -150,11 +152,11 @@ def get_config(**kwargs) -> ProcessingConfiguration:
Can be used to override default parameters by passing keyword arguments.
"""
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **PARAMS, **kwargs} # type: ignore
# Default processing configuration
CONFIG = get_config(**PARAMS)
CONFIG = get_config()
def load_audio(
@ -270,7 +272,7 @@ def process_spectrogram(
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], List[np.ndarray]]:
) -> Tuple[List[Annotation], np.ndarray]:
"""Process spectrogram with model.
Parameters
@ -289,7 +291,11 @@ def process_spectrogram(
Returns
-------
DetectionResult
detections : List[Annotation]
List of detections.
features: np.ndarray
An array of features. The array has shape (n_detections, n_features)
where each row is a feature vector for a detection.
"""
if config is None:
config = CONFIG
@ -308,7 +314,7 @@ def process_audio(
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model.
Parameters
@ -329,10 +335,9 @@ def process_audio(
-------
annotations : List[Annotation]
List of predicted annotations.
features: List[np.ndarray]
List of extracted features for each annotation.
features: np.ndarray
An array of features. The array has shape (n_detections, n_features)
where each row is a feature vector for a detection.
spec : torch.Tensor
Spectrogram of the audio used for prediction.
"""
@ -395,3 +400,26 @@ model: DetectionModel = MODEL
config: ProcessingConfiguration = CONFIG
"""Default processing configuration."""
def print_summary(results: RunResults) -> None:
"""Print summary of results.
Parameters
----------
results : DetectionResult
Detection result.
"""
print("Results for " + results["pred_dict"]["id"])
print("{} calls detected\n".format(len(results["pred_dict"]["annotation"])))
print("time\tprob\tlfreq\tspecies_name")
for ann in results["pred_dict"]["annotation"]:
print(
"{}\t{}\t{}\t{}".format(
ann["start_time"],
ann["class_prob"],
ann["low_freq"],
ann["class"],
)
)

View File

@ -295,7 +295,6 @@ def save_results_to_file(results, op_path: str) -> None:
op_path (str): Output path.
"""
# make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)):
os.makedirs(os.path.dirname(op_path))
@ -474,7 +473,7 @@ def _process_spectrogram(
samplerate: int,
model: DetectionModel,
config: ProcessingConfiguration,
) -> Tuple[PredictionResults, List[np.ndarray]]:
) -> Tuple[PredictionResults, np.ndarray]:
# evaluate model
with torch.no_grad():
outputs = model(spec)
@ -504,7 +503,7 @@ def _process_spectrogram(
):
pred_nms["class_probs"] = class_probs[:-1, :]
return pred_nms, features
return pred_nms, np.concatenate(features, axis=0)
def postprocess_model_outputs(
@ -550,7 +549,7 @@ def process_spectrogram(
samplerate: int,
model: DetectionModel,
config: ProcessingConfiguration,
) -> Tuple[List[Annotation], List[np.ndarray]]:
) -> Tuple[List[Annotation], np.ndarray]:
"""Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model.
@ -569,10 +568,11 @@ def process_spectrogram(
Returns
-------
annotations : List[Annotation]
List of annotations predicted by the model.
features : List[np.ndarray]
List of CNN features associated with each annotation.
detections: List[Annotation]
List of detections predicted by the model.
features : np.ndarray
An array of CNN features associated with each annotation.
The array is of shape (num_detections, num_features).
Is empty if `config["cnn_features"]` is False.
"""
pred_nms, features = _process_spectrogram(
@ -582,12 +582,12 @@ def process_spectrogram(
config,
)
annotations = get_annotations_from_preds(
detections = get_annotations_from_preds(
pred_nms,
config["class_names"],
)
return annotations, features
return detections, features
def _process_audio_array(
@ -596,7 +596,7 @@ def _process_audio_array(
model: DetectionModel,
config: ProcessingConfiguration,
device: torch.device,
) -> Tuple[PredictionResults, List[np.ndarray], torch.Tensor]:
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
# load audio file and compute spectrogram
_, spec, _ = compute_spectrogram(
audio,
@ -634,7 +634,7 @@ def process_audio_array(
model: DetectionModel,
config: ProcessingConfiguration,
device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
"""Process a single audio array with detection model.
Parameters
@ -656,10 +656,9 @@ def process_audio_array(
-------
annotations : List[Annotation]
List of annotations predicted by the model.
features : List[np.ndarray]
List of CNN features associated with each annotation.
features : np.ndarray
Array of CNN features associated with each annotation.
The array is of shape (num_detections, num_features).
spec : torch.Tensor
Spectrogram of the audio used as input.

View File

@ -81,7 +81,7 @@ def test_get_default_config():
assert config["denoise_spec_avg"] is True
assert config["max_scale_spec"] is False
assert config["scale_raw_audio"] is False
assert len(config["class_names"]) == 0
assert len(config["class_names"]) == 17
assert config["detection_threshold"] == 0.01
assert config["time_expansion"] == 1
assert config["top_n"] == 3
@ -193,8 +193,8 @@ def test_process_spectrogram_with_default_model():
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
assert len(features) == 1
assert isinstance(features, np.ndarray)
assert len(features) == len(predictions)
def test_process_audio_with_default_model():
@ -216,8 +216,8 @@ def test_process_audio_with_default_model():
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, list)
assert len(features) == 1
assert isinstance(features, np.ndarray)
assert len(features) == len(predictions)
assert spec is not None
assert isinstance(spec, torch.Tensor)