Added postprocess function to API

This commit is contained in:
Santiago Martinez 2023-02-26 20:48:52 +00:00
parent b0d9576a24
commit acf01f4970
5 changed files with 140 additions and 12 deletions

View File

@ -15,6 +15,7 @@ from bat_detect.detector.parameters import (
from bat_detect.types import (
Annotation,
DetectionModel,
ModelOutput,
ProcessingConfiguration,
SpectrogramParameters,
)
@ -24,16 +25,17 @@ from bat_detect.utils.detector_utils import list_audio_files, load_model
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
__all__ = [
"load_model",
"load_audio",
"list_audio_files",
"config",
"generate_spectrogram",
"get_config",
"list_audio_files",
"load_audio",
"load_model",
"model",
"postprocess",
"process_audio",
"process_file",
"process_spectrogram",
"process_audio",
"model",
"config",
]
@ -248,6 +250,48 @@ def process_audio(
)
def postprocess(
outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], np.ndarray]:
"""Postprocess model outputs.
Convert model tensor outputs to predicted bounding boxes and
extracted features.
Will run non-maximum suppression and remove overlapping annotations.
Parameters
----------
outputs : ModelOutput
Model raw outputs.
samp_rate : int, Optional
Sample rate of the audio from which the spectrogram was generated.
Defaults to 256000 which is the target sample rate of the default
model. Only change if you generated outputs from a spectrogram with
sample rate.
config : Optional[ProcessingConfiguration], Optional
Processing configuration, by default None (uses default parameters).
Returns
-------
annotations : List[Annotation]
List of predicted annotations.
features: np.ndarray
An array of extracted features for each annotation. The shape of the
array is (n_annotations, n_features).
"""
if config is None:
config = CONFIG
return du.postprocess_model_outputs(
outputs,
samp_rate,
config,
)
model: DetectionModel = MODEL
"""Base detection model."""

View File

@ -134,12 +134,12 @@ def run_nms(
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
].transpose(0, 1)
feat = feat.cpu().numpy().astype(np.float32)
feat = feat.detach().numpy().astype(np.float32)
feats.append(feat)
# convert to numpy
for key, value in pred.items():
pred[key] = value.cpu().numpy().astype(np.float32)
pred[key] = value.detach().numpy().astype(np.float32)
preds.append(pred) # type: ignore

View File

@ -278,7 +278,23 @@ class ProcessingConfiguration(TypedDict):
class ModelOutput(NamedTuple):
"""Output of the detection model."""
"""Output of the detection model.
Each of the tensors has a shape of
`(batch_size, num_channels,spec_height, spec_width)`.
Where `spec_height` and `spec_width` are the height and width of the
input spectrograms.
They contain localised information of:
1. The probability of a bounding box detection at the given location.
2. The predicted size of the bounding box at the given location.
3. The probabilities of each class at the given location.
4. Same as 3. but before softmax.
5. Features used to make the predictions at the given location.
"""
pred_det: torch.Tensor
"""Tensor with predict detection probabilities."""
@ -330,7 +346,7 @@ class PredictionResults(TypedDict):
high_freqs: np.ndarray
"""High frequencies of the detections in Hz."""
class_probs: Optional[np.ndarray]
class_probs: np.ndarray
"""Class probabilities."""

View File

@ -16,6 +16,7 @@ from bat_detect.types import (
Annotation,
DetectionModel,
FileAnnotations,
ModelOutput,
ModelParameters,
PredictionResults,
ProcessingConfiguration,
@ -148,7 +149,7 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
def get_annotations_from_preds(
predictions,
predictions: PredictionResults,
class_names: List[str],
) -> List[Annotation]:
"""Get list of annotations from predictions."""
@ -194,7 +195,7 @@ def format_single_result(
file_id: str,
time_exp: float,
duration: float,
predictions,
predictions: PredictionResults,
class_names: List[str],
) -> FileAnnotations:
"""Format results into the format expected by the annotation tool.
@ -506,6 +507,44 @@ def _process_spectrogram(
return pred_nms, features
def postprocess_model_outputs(
outputs: ModelOutput,
samp_rate: int,
config: ProcessingConfiguration,
) -> Tuple[List[Annotation], np.ndarray]:
# run non-max suppression
pred_nms_list, features = pp.run_nms(
outputs,
{
"nms_kernel_size": config["nms_kernel_size"],
"max_freq": config["max_freq"],
"min_freq": config["min_freq"],
"fft_win_length": config["fft_win_length"],
"fft_overlap": config["fft_overlap"],
"resize_factor": config["resize_factor"],
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
"detection_threshold": config["detection_threshold"],
},
np.array([float(samp_rate)]),
)
pred_nms = pred_nms_list[0]
# if we have a background class
class_probs = pred_nms.get("class_probs")
if (class_probs is not None) and (
class_probs.shape[0] > len(config["class_names"])
):
pred_nms["class_probs"] = class_probs[:-1, :]
annotations = get_annotations_from_preds(
pred_nms,
config["class_names"],
)
return annotations, features[0]
def process_spectrogram(
spec: torch.Tensor,
samplerate: int,

View File

@ -224,3 +224,32 @@ def test_process_audio_with_default_model():
assert spec is not None
assert isinstance(spec, torch.Tensor)
assert spec.shape == (1, 1, 128, 512)
def test_postprocess_model_outputs():
"""Test postprocessing model outputs."""
# Load model outputs
audio = api.load_audio(TEST_DATA[1])
spec = api.generate_spectrogram(audio)
model_outputs = api.model(spec)
# Postprocess outputs
predictions, features = api.postprocess(model_outputs)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, np.ndarray)
assert features.shape[0] == len(predictions)
assert features.shape[1] == 32