mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added postprocess function to API
This commit is contained in:
parent
b0d9576a24
commit
acf01f4970
@ -15,6 +15,7 @@ from bat_detect.detector.parameters import (
|
||||
from bat_detect.types import (
|
||||
Annotation,
|
||||
DetectionModel,
|
||||
ModelOutput,
|
||||
ProcessingConfiguration,
|
||||
SpectrogramParameters,
|
||||
)
|
||||
@ -24,16 +25,17 @@ from bat_detect.utils.detector_utils import list_audio_files, load_model
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"load_audio",
|
||||
"list_audio_files",
|
||||
"config",
|
||||
"generate_spectrogram",
|
||||
"get_config",
|
||||
"list_audio_files",
|
||||
"load_audio",
|
||||
"load_model",
|
||||
"model",
|
||||
"postprocess",
|
||||
"process_audio",
|
||||
"process_file",
|
||||
"process_spectrogram",
|
||||
"process_audio",
|
||||
"model",
|
||||
"config",
|
||||
]
|
||||
|
||||
|
||||
@ -248,6 +250,48 @@ def process_audio(
|
||||
)
|
||||
|
||||
|
||||
def postprocess(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Postprocess model outputs.
|
||||
|
||||
Convert model tensor outputs to predicted bounding boxes and
|
||||
extracted features.
|
||||
|
||||
Will run non-maximum suppression and remove overlapping annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outputs : ModelOutput
|
||||
Model raw outputs.
|
||||
samp_rate : int, Optional
|
||||
Sample rate of the audio from which the spectrogram was generated.
|
||||
Defaults to 256000 which is the target sample rate of the default
|
||||
model. Only change if you generated outputs from a spectrogram with
|
||||
sample rate.
|
||||
config : Optional[ProcessingConfiguration], Optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : List[Annotation]
|
||||
List of predicted annotations.
|
||||
features: np.ndarray
|
||||
An array of extracted features for each annotation. The shape of the
|
||||
array is (n_annotations, n_features).
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
return du.postprocess_model_outputs(
|
||||
outputs,
|
||||
samp_rate,
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
model: DetectionModel = MODEL
|
||||
"""Base detection model."""
|
||||
|
||||
|
@ -134,12 +134,12 @@ def run_nms(
|
||||
y_pos[num_detection, valid_inds],
|
||||
x_pos[num_detection, valid_inds],
|
||||
].transpose(0, 1)
|
||||
feat = feat.cpu().numpy().astype(np.float32)
|
||||
feat = feat.detach().numpy().astype(np.float32)
|
||||
feats.append(feat)
|
||||
|
||||
# convert to numpy
|
||||
for key, value in pred.items():
|
||||
pred[key] = value.cpu().numpy().astype(np.float32)
|
||||
pred[key] = value.detach().numpy().astype(np.float32)
|
||||
|
||||
preds.append(pred) # type: ignore
|
||||
|
||||
|
@ -278,7 +278,23 @@ class ProcessingConfiguration(TypedDict):
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Output of the detection model."""
|
||||
"""Output of the detection model.
|
||||
|
||||
Each of the tensors has a shape of
|
||||
|
||||
`(batch_size, num_channels,spec_height, spec_width)`.
|
||||
|
||||
Where `spec_height` and `spec_width` are the height and width of the
|
||||
input spectrograms.
|
||||
|
||||
They contain localised information of:
|
||||
|
||||
1. The probability of a bounding box detection at the given location.
|
||||
2. The predicted size of the bounding box at the given location.
|
||||
3. The probabilities of each class at the given location.
|
||||
4. Same as 3. but before softmax.
|
||||
5. Features used to make the predictions at the given location.
|
||||
"""
|
||||
|
||||
pred_det: torch.Tensor
|
||||
"""Tensor with predict detection probabilities."""
|
||||
@ -330,7 +346,7 @@ class PredictionResults(TypedDict):
|
||||
high_freqs: np.ndarray
|
||||
"""High frequencies of the detections in Hz."""
|
||||
|
||||
class_probs: Optional[np.ndarray]
|
||||
class_probs: np.ndarray
|
||||
"""Class probabilities."""
|
||||
|
||||
|
||||
|
@ -16,6 +16,7 @@ from bat_detect.types import (
|
||||
Annotation,
|
||||
DetectionModel,
|
||||
FileAnnotations,
|
||||
ModelOutput,
|
||||
ModelParameters,
|
||||
PredictionResults,
|
||||
ProcessingConfiguration,
|
||||
@ -148,7 +149,7 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
|
||||
|
||||
def get_annotations_from_preds(
|
||||
predictions,
|
||||
predictions: PredictionResults,
|
||||
class_names: List[str],
|
||||
) -> List[Annotation]:
|
||||
"""Get list of annotations from predictions."""
|
||||
@ -194,7 +195,7 @@ def format_single_result(
|
||||
file_id: str,
|
||||
time_exp: float,
|
||||
duration: float,
|
||||
predictions,
|
||||
predictions: PredictionResults,
|
||||
class_names: List[str],
|
||||
) -> FileAnnotations:
|
||||
"""Format results into the format expected by the annotation tool.
|
||||
@ -506,6 +507,44 @@ def _process_spectrogram(
|
||||
return pred_nms, features
|
||||
|
||||
|
||||
def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int,
|
||||
config: ProcessingConfiguration,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
# run non-max suppression
|
||||
pred_nms_list, features = pp.run_nms(
|
||||
outputs,
|
||||
{
|
||||
"nms_kernel_size": config["nms_kernel_size"],
|
||||
"max_freq": config["max_freq"],
|
||||
"min_freq": config["min_freq"],
|
||||
"fft_win_length": config["fft_win_length"],
|
||||
"fft_overlap": config["fft_overlap"],
|
||||
"resize_factor": config["resize_factor"],
|
||||
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
||||
"detection_threshold": config["detection_threshold"],
|
||||
},
|
||||
np.array([float(samp_rate)]),
|
||||
)
|
||||
|
||||
pred_nms = pred_nms_list[0]
|
||||
|
||||
# if we have a background class
|
||||
class_probs = pred_nms.get("class_probs")
|
||||
if (class_probs is not None) and (
|
||||
class_probs.shape[0] > len(config["class_names"])
|
||||
):
|
||||
pred_nms["class_probs"] = class_probs[:-1, :]
|
||||
|
||||
annotations = get_annotations_from_preds(
|
||||
pred_nms,
|
||||
config["class_names"],
|
||||
)
|
||||
|
||||
return annotations, features[0]
|
||||
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: int,
|
||||
|
@ -224,3 +224,32 @@ def test_process_audio_with_default_model():
|
||||
assert spec is not None
|
||||
assert isinstance(spec, torch.Tensor)
|
||||
assert spec.shape == (1, 1, 128, 512)
|
||||
|
||||
|
||||
def test_postprocess_model_outputs():
|
||||
"""Test postprocessing model outputs."""
|
||||
# Load model outputs
|
||||
audio = api.load_audio(TEST_DATA[1])
|
||||
spec = api.generate_spectrogram(audio)
|
||||
model_outputs = api.model(spec)
|
||||
|
||||
# Postprocess outputs
|
||||
predictions, features = api.postprocess(model_outputs)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
assert len(predictions) > 0
|
||||
sample_pred = predictions[0]
|
||||
assert isinstance(sample_pred, dict)
|
||||
assert "class" in sample_pred
|
||||
assert "class_prob" in sample_pred
|
||||
assert "det_prob" in sample_pred
|
||||
assert "start_time" in sample_pred
|
||||
assert "end_time" in sample_pred
|
||||
assert "low_freq" in sample_pred
|
||||
assert "high_freq" in sample_pred
|
||||
|
||||
assert features is not None
|
||||
assert isinstance(features, np.ndarray)
|
||||
assert features.shape[0] == len(predictions)
|
||||
assert features.shape[1] == 32
|
||||
|
Loading…
Reference in New Issue
Block a user