diff --git a/bat_detect/api.py b/bat_detect/api.py index f05748d..ad1f1e1 100644 --- a/bat_detect/api.py +++ b/bat_detect/api.py @@ -15,6 +15,7 @@ from bat_detect.detector.parameters import ( from bat_detect.types import ( Annotation, DetectionModel, + ModelOutput, ProcessingConfiguration, SpectrogramParameters, ) @@ -24,16 +25,17 @@ from bat_detect.utils.detector_utils import list_audio_files, load_model warnings.filterwarnings("ignore", category=UserWarning, module="torch") __all__ = [ - "load_model", - "load_audio", - "list_audio_files", + "config", "generate_spectrogram", "get_config", + "list_audio_files", + "load_audio", + "load_model", + "model", + "postprocess", + "process_audio", "process_file", "process_spectrogram", - "process_audio", - "model", - "config", ] @@ -248,6 +250,48 @@ def process_audio( ) +def postprocess( + outputs: ModelOutput, + samp_rate: int = TARGET_SAMPLERATE_HZ, + config: Optional[ProcessingConfiguration] = None, +) -> Tuple[List[Annotation], np.ndarray]: + """Postprocess model outputs. + + Convert model tensor outputs to predicted bounding boxes and + extracted features. + + Will run non-maximum suppression and remove overlapping annotations. + + Parameters + ---------- + outputs : ModelOutput + Model raw outputs. + samp_rate : int, Optional + Sample rate of the audio from which the spectrogram was generated. + Defaults to 256000 which is the target sample rate of the default + model. Only change if you generated outputs from a spectrogram with + sample rate. + config : Optional[ProcessingConfiguration], Optional + Processing configuration, by default None (uses default parameters). + + Returns + ------- + annotations : List[Annotation] + List of predicted annotations. + features: np.ndarray + An array of extracted features for each annotation. The shape of the + array is (n_annotations, n_features). + """ + if config is None: + config = CONFIG + + return du.postprocess_model_outputs( + outputs, + samp_rate, + config, + ) + + model: DetectionModel = MODEL """Base detection model.""" diff --git a/bat_detect/detector/post_process.py b/bat_detect/detector/post_process.py index ae56f80..5aa6895 100644 --- a/bat_detect/detector/post_process.py +++ b/bat_detect/detector/post_process.py @@ -134,12 +134,12 @@ def run_nms( y_pos[num_detection, valid_inds], x_pos[num_detection, valid_inds], ].transpose(0, 1) - feat = feat.cpu().numpy().astype(np.float32) + feat = feat.detach().numpy().astype(np.float32) feats.append(feat) # convert to numpy for key, value in pred.items(): - pred[key] = value.cpu().numpy().astype(np.float32) + pred[key] = value.detach().numpy().astype(np.float32) preds.append(pred) # type: ignore diff --git a/bat_detect/types.py b/bat_detect/types.py index c4b6297..5e20c48 100644 --- a/bat_detect/types.py +++ b/bat_detect/types.py @@ -278,7 +278,23 @@ class ProcessingConfiguration(TypedDict): class ModelOutput(NamedTuple): - """Output of the detection model.""" + """Output of the detection model. + + Each of the tensors has a shape of + + `(batch_size, num_channels,spec_height, spec_width)`. + + Where `spec_height` and `spec_width` are the height and width of the + input spectrograms. + + They contain localised information of: + + 1. The probability of a bounding box detection at the given location. + 2. The predicted size of the bounding box at the given location. + 3. The probabilities of each class at the given location. + 4. Same as 3. but before softmax. + 5. Features used to make the predictions at the given location. + """ pred_det: torch.Tensor """Tensor with predict detection probabilities.""" @@ -330,7 +346,7 @@ class PredictionResults(TypedDict): high_freqs: np.ndarray """High frequencies of the detections in Hz.""" - class_probs: Optional[np.ndarray] + class_probs: np.ndarray """Class probabilities.""" diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index 4194ddf..cd71ee6 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -16,6 +16,7 @@ from bat_detect.types import ( Annotation, DetectionModel, FileAnnotations, + ModelOutput, ModelParameters, PredictionResults, ProcessingConfiguration, @@ -148,7 +149,7 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): def get_annotations_from_preds( - predictions, + predictions: PredictionResults, class_names: List[str], ) -> List[Annotation]: """Get list of annotations from predictions.""" @@ -194,7 +195,7 @@ def format_single_result( file_id: str, time_exp: float, duration: float, - predictions, + predictions: PredictionResults, class_names: List[str], ) -> FileAnnotations: """Format results into the format expected by the annotation tool. @@ -506,6 +507,44 @@ def _process_spectrogram( return pred_nms, features +def postprocess_model_outputs( + outputs: ModelOutput, + samp_rate: int, + config: ProcessingConfiguration, +) -> Tuple[List[Annotation], np.ndarray]: + # run non-max suppression + pred_nms_list, features = pp.run_nms( + outputs, + { + "nms_kernel_size": config["nms_kernel_size"], + "max_freq": config["max_freq"], + "min_freq": config["min_freq"], + "fft_win_length": config["fft_win_length"], + "fft_overlap": config["fft_overlap"], + "resize_factor": config["resize_factor"], + "nms_top_k_per_sec": config["nms_top_k_per_sec"], + "detection_threshold": config["detection_threshold"], + }, + np.array([float(samp_rate)]), + ) + + pred_nms = pred_nms_list[0] + + # if we have a background class + class_probs = pred_nms.get("class_probs") + if (class_probs is not None) and ( + class_probs.shape[0] > len(config["class_names"]) + ): + pred_nms["class_probs"] = class_probs[:-1, :] + + annotations = get_annotations_from_preds( + pred_nms, + config["class_names"], + ) + + return annotations, features[0] + + def process_spectrogram( spec: torch.Tensor, samplerate: int, diff --git a/tests/test_api.py b/tests/test_api.py index 52ba40b..927d7be 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -224,3 +224,32 @@ def test_process_audio_with_default_model(): assert spec is not None assert isinstance(spec, torch.Tensor) assert spec.shape == (1, 1, 128, 512) + + +def test_postprocess_model_outputs(): + """Test postprocessing model outputs.""" + # Load model outputs + audio = api.load_audio(TEST_DATA[1]) + spec = api.generate_spectrogram(audio) + model_outputs = api.model(spec) + + # Postprocess outputs + predictions, features = api.postprocess(model_outputs) + + assert predictions is not None + assert isinstance(predictions, list) + assert len(predictions) > 0 + sample_pred = predictions[0] + assert isinstance(sample_pred, dict) + assert "class" in sample_pred + assert "class_prob" in sample_pred + assert "det_prob" in sample_pred + assert "start_time" in sample_pred + assert "end_time" in sample_pred + assert "low_freq" in sample_pred + assert "high_freq" in sample_pred + + assert features is not None + assert isinstance(features, np.ndarray) + assert features.shape[0] == len(predictions) + assert features.shape[1] == 32