Checked that returned features is a 2d numpy array

This commit is contained in:
Santiago Martinez 2023-04-07 14:46:38 -06:00
parent 29074c689e
commit 0a22f1798e
3 changed files with 57 additions and 30 deletions

View File

@ -115,6 +115,7 @@ from batdetect2.types import (
DetectionModel, DetectionModel,
ModelOutput, ModelOutput,
ProcessingConfiguration, ProcessingConfiguration,
RunResults,
SpectrogramParameters, SpectrogramParameters,
) )
from batdetect2.utils.detector_utils import list_audio_files, load_model from batdetect2.utils.detector_utils import list_audio_files, load_model
@ -134,6 +135,7 @@ __all__ = [
"process_audio", "process_audio",
"process_file", "process_file",
"process_spectrogram", "process_spectrogram",
"print_summary",
] ]
@ -150,11 +152,11 @@ def get_config(**kwargs) -> ProcessingConfiguration:
Can be used to override default parameters by passing keyword arguments. Can be used to override default parameters by passing keyword arguments.
""" """
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore return {**DEFAULT_PROCESSING_CONFIGURATIONS, **PARAMS, **kwargs} # type: ignore
# Default processing configuration # Default processing configuration
CONFIG = get_config(**PARAMS) CONFIG = get_config()
def load_audio( def load_audio(
@ -270,7 +272,7 @@ def process_spectrogram(
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], List[np.ndarray]]: ) -> Tuple[List[Annotation], np.ndarray]:
"""Process spectrogram with model. """Process spectrogram with model.
Parameters Parameters
@ -289,7 +291,11 @@ def process_spectrogram(
Returns Returns
------- -------
DetectionResult detections : List[Annotation]
List of detections.
features: np.ndarray
An array of features. The array has shape (n_detections, n_features)
where each row is a feature vector for a detection.
""" """
if config is None: if config is None:
config = CONFIG config = CONFIG
@ -308,7 +314,7 @@ def process_audio(
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: ) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model. """Process audio array with model.
Parameters Parameters
@ -329,10 +335,9 @@ def process_audio(
------- -------
annotations : List[Annotation] annotations : List[Annotation]
List of predicted annotations. List of predicted annotations.
features: np.ndarray
features: List[np.ndarray] An array of features. The array has shape (n_detections, n_features)
List of extracted features for each annotation. where each row is a feature vector for a detection.
spec : torch.Tensor spec : torch.Tensor
Spectrogram of the audio used for prediction. Spectrogram of the audio used for prediction.
""" """
@ -395,3 +400,26 @@ model: DetectionModel = MODEL
config: ProcessingConfiguration = CONFIG config: ProcessingConfiguration = CONFIG
"""Default processing configuration.""" """Default processing configuration."""
def print_summary(results: RunResults) -> None:
"""Print summary of results.
Parameters
----------
results : DetectionResult
Detection result.
"""
print("Results for " + results["pred_dict"]["id"])
print("{} calls detected\n".format(len(results["pred_dict"]["annotation"])))
print("time\tprob\tlfreq\tspecies_name")
for ann in results["pred_dict"]["annotation"]:
print(
"{}\t{}\t{}\t{}".format(
ann["start_time"],
ann["class_prob"],
ann["low_freq"],
ann["class"],
)
)

View File

@ -295,7 +295,6 @@ def save_results_to_file(results, op_path: str) -> None:
op_path (str): Output path. op_path (str): Output path.
""" """
# make directory if it does not exist # make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)): if not os.path.isdir(os.path.dirname(op_path)):
os.makedirs(os.path.dirname(op_path)) os.makedirs(os.path.dirname(op_path))
@ -474,7 +473,7 @@ def _process_spectrogram(
samplerate: int, samplerate: int,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Tuple[PredictionResults, List[np.ndarray]]: ) -> Tuple[PredictionResults, np.ndarray]:
# evaluate model # evaluate model
with torch.no_grad(): with torch.no_grad():
outputs = model(spec) outputs = model(spec)
@ -504,7 +503,7 @@ def _process_spectrogram(
): ):
pred_nms["class_probs"] = class_probs[:-1, :] pred_nms["class_probs"] = class_probs[:-1, :]
return pred_nms, features return pred_nms, np.concatenate(features, axis=0)
def postprocess_model_outputs( def postprocess_model_outputs(
@ -550,7 +549,7 @@ def process_spectrogram(
samplerate: int, samplerate: int,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> Tuple[List[Annotation], List[np.ndarray]]: ) -> Tuple[List[Annotation], np.ndarray]:
"""Process a spectrogram with detection model. """Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model. Will run non-maximum suppression on the output of the model.
@ -569,10 +568,11 @@ def process_spectrogram(
Returns Returns
------- -------
annotations : List[Annotation] detections: List[Annotation]
List of annotations predicted by the model. List of detections predicted by the model.
features : List[np.ndarray] features : np.ndarray
List of CNN features associated with each annotation. An array of CNN features associated with each annotation.
The array is of shape (num_detections, num_features).
Is empty if `config["cnn_features"]` is False. Is empty if `config["cnn_features"]` is False.
""" """
pred_nms, features = _process_spectrogram( pred_nms, features = _process_spectrogram(
@ -582,12 +582,12 @@ def process_spectrogram(
config, config,
) )
annotations = get_annotations_from_preds( detections = get_annotations_from_preds(
pred_nms, pred_nms,
config["class_names"], config["class_names"],
) )
return annotations, features return detections, features
def _process_audio_array( def _process_audio_array(
@ -596,7 +596,7 @@ def _process_audio_array(
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> Tuple[PredictionResults, List[np.ndarray], torch.Tensor]: ) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
# load audio file and compute spectrogram # load audio file and compute spectrogram
_, spec, _ = compute_spectrogram( _, spec, _ = compute_spectrogram(
audio, audio,
@ -634,7 +634,7 @@ def process_audio_array(
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: ) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
"""Process a single audio array with detection model. """Process a single audio array with detection model.
Parameters Parameters
@ -656,10 +656,9 @@ def process_audio_array(
------- -------
annotations : List[Annotation] annotations : List[Annotation]
List of annotations predicted by the model. List of annotations predicted by the model.
features : np.ndarray
features : List[np.ndarray] Array of CNN features associated with each annotation.
List of CNN features associated with each annotation. The array is of shape (num_detections, num_features).
spec : torch.Tensor spec : torch.Tensor
Spectrogram of the audio used as input. Spectrogram of the audio used as input.

View File

@ -81,7 +81,7 @@ def test_get_default_config():
assert config["denoise_spec_avg"] is True assert config["denoise_spec_avg"] is True
assert config["max_scale_spec"] is False assert config["max_scale_spec"] is False
assert config["scale_raw_audio"] is False assert config["scale_raw_audio"] is False
assert len(config["class_names"]) == 0 assert len(config["class_names"]) == 17
assert config["detection_threshold"] == 0.01 assert config["detection_threshold"] == 0.01
assert config["time_expansion"] == 1 assert config["time_expansion"] == 1
assert config["top_n"] == 3 assert config["top_n"] == 3
@ -193,8 +193,8 @@ def test_process_spectrogram_with_default_model():
assert "high_freq" in sample_pred assert "high_freq" in sample_pred
assert features is not None assert features is not None
assert isinstance(features, list) assert isinstance(features, np.ndarray)
assert len(features) == 1 assert len(features) == len(predictions)
def test_process_audio_with_default_model(): def test_process_audio_with_default_model():
@ -216,8 +216,8 @@ def test_process_audio_with_default_model():
assert "high_freq" in sample_pred assert "high_freq" in sample_pred
assert features is not None assert features is not None
assert isinstance(features, list) assert isinstance(features, np.ndarray)
assert len(features) == 1 assert len(features) == len(predictions)
assert spec is not None assert spec is not None
assert isinstance(spec, torch.Tensor) assert isinstance(spec, torch.Tensor)