mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added postprocess function to API
This commit is contained in:
parent
b0d9576a24
commit
acf01f4970
@ -15,6 +15,7 @@ from bat_detect.detector.parameters import (
|
|||||||
from bat_detect.types import (
|
from bat_detect.types import (
|
||||||
Annotation,
|
Annotation,
|
||||||
DetectionModel,
|
DetectionModel,
|
||||||
|
ModelOutput,
|
||||||
ProcessingConfiguration,
|
ProcessingConfiguration,
|
||||||
SpectrogramParameters,
|
SpectrogramParameters,
|
||||||
)
|
)
|
||||||
@ -24,16 +25,17 @@ from bat_detect.utils.detector_utils import list_audio_files, load_model
|
|||||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"config",
|
||||||
"load_audio",
|
|
||||||
"list_audio_files",
|
|
||||||
"generate_spectrogram",
|
"generate_spectrogram",
|
||||||
"get_config",
|
"get_config",
|
||||||
|
"list_audio_files",
|
||||||
|
"load_audio",
|
||||||
|
"load_model",
|
||||||
|
"model",
|
||||||
|
"postprocess",
|
||||||
|
"process_audio",
|
||||||
"process_file",
|
"process_file",
|
||||||
"process_spectrogram",
|
"process_spectrogram",
|
||||||
"process_audio",
|
|
||||||
"model",
|
|
||||||
"config",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -248,6 +250,48 @@ def process_audio(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess(
|
||||||
|
outputs: ModelOutput,
|
||||||
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
|
) -> Tuple[List[Annotation], np.ndarray]:
|
||||||
|
"""Postprocess model outputs.
|
||||||
|
|
||||||
|
Convert model tensor outputs to predicted bounding boxes and
|
||||||
|
extracted features.
|
||||||
|
|
||||||
|
Will run non-maximum suppression and remove overlapping annotations.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
outputs : ModelOutput
|
||||||
|
Model raw outputs.
|
||||||
|
samp_rate : int, Optional
|
||||||
|
Sample rate of the audio from which the spectrogram was generated.
|
||||||
|
Defaults to 256000 which is the target sample rate of the default
|
||||||
|
model. Only change if you generated outputs from a spectrogram with
|
||||||
|
sample rate.
|
||||||
|
config : Optional[ProcessingConfiguration], Optional
|
||||||
|
Processing configuration, by default None (uses default parameters).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
annotations : List[Annotation]
|
||||||
|
List of predicted annotations.
|
||||||
|
features: np.ndarray
|
||||||
|
An array of extracted features for each annotation. The shape of the
|
||||||
|
array is (n_annotations, n_features).
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = CONFIG
|
||||||
|
|
||||||
|
return du.postprocess_model_outputs(
|
||||||
|
outputs,
|
||||||
|
samp_rate,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
model: DetectionModel = MODEL
|
model: DetectionModel = MODEL
|
||||||
"""Base detection model."""
|
"""Base detection model."""
|
||||||
|
|
||||||
|
@ -134,12 +134,12 @@ def run_nms(
|
|||||||
y_pos[num_detection, valid_inds],
|
y_pos[num_detection, valid_inds],
|
||||||
x_pos[num_detection, valid_inds],
|
x_pos[num_detection, valid_inds],
|
||||||
].transpose(0, 1)
|
].transpose(0, 1)
|
||||||
feat = feat.cpu().numpy().astype(np.float32)
|
feat = feat.detach().numpy().astype(np.float32)
|
||||||
feats.append(feat)
|
feats.append(feat)
|
||||||
|
|
||||||
# convert to numpy
|
# convert to numpy
|
||||||
for key, value in pred.items():
|
for key, value in pred.items():
|
||||||
pred[key] = value.cpu().numpy().astype(np.float32)
|
pred[key] = value.detach().numpy().astype(np.float32)
|
||||||
|
|
||||||
preds.append(pred) # type: ignore
|
preds.append(pred) # type: ignore
|
||||||
|
|
||||||
|
@ -278,7 +278,23 @@ class ProcessingConfiguration(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class ModelOutput(NamedTuple):
|
class ModelOutput(NamedTuple):
|
||||||
"""Output of the detection model."""
|
"""Output of the detection model.
|
||||||
|
|
||||||
|
Each of the tensors has a shape of
|
||||||
|
|
||||||
|
`(batch_size, num_channels,spec_height, spec_width)`.
|
||||||
|
|
||||||
|
Where `spec_height` and `spec_width` are the height and width of the
|
||||||
|
input spectrograms.
|
||||||
|
|
||||||
|
They contain localised information of:
|
||||||
|
|
||||||
|
1. The probability of a bounding box detection at the given location.
|
||||||
|
2. The predicted size of the bounding box at the given location.
|
||||||
|
3. The probabilities of each class at the given location.
|
||||||
|
4. Same as 3. but before softmax.
|
||||||
|
5. Features used to make the predictions at the given location.
|
||||||
|
"""
|
||||||
|
|
||||||
pred_det: torch.Tensor
|
pred_det: torch.Tensor
|
||||||
"""Tensor with predict detection probabilities."""
|
"""Tensor with predict detection probabilities."""
|
||||||
@ -330,7 +346,7 @@ class PredictionResults(TypedDict):
|
|||||||
high_freqs: np.ndarray
|
high_freqs: np.ndarray
|
||||||
"""High frequencies of the detections in Hz."""
|
"""High frequencies of the detections in Hz."""
|
||||||
|
|
||||||
class_probs: Optional[np.ndarray]
|
class_probs: np.ndarray
|
||||||
"""Class probabilities."""
|
"""Class probabilities."""
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from bat_detect.types import (
|
|||||||
Annotation,
|
Annotation,
|
||||||
DetectionModel,
|
DetectionModel,
|
||||||
FileAnnotations,
|
FileAnnotations,
|
||||||
|
ModelOutput,
|
||||||
ModelParameters,
|
ModelParameters,
|
||||||
PredictionResults,
|
PredictionResults,
|
||||||
ProcessingConfiguration,
|
ProcessingConfiguration,
|
||||||
@ -148,7 +149,7 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
|||||||
|
|
||||||
|
|
||||||
def get_annotations_from_preds(
|
def get_annotations_from_preds(
|
||||||
predictions,
|
predictions: PredictionResults,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
) -> List[Annotation]:
|
) -> List[Annotation]:
|
||||||
"""Get list of annotations from predictions."""
|
"""Get list of annotations from predictions."""
|
||||||
@ -194,7 +195,7 @@ def format_single_result(
|
|||||||
file_id: str,
|
file_id: str,
|
||||||
time_exp: float,
|
time_exp: float,
|
||||||
duration: float,
|
duration: float,
|
||||||
predictions,
|
predictions: PredictionResults,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
) -> FileAnnotations:
|
) -> FileAnnotations:
|
||||||
"""Format results into the format expected by the annotation tool.
|
"""Format results into the format expected by the annotation tool.
|
||||||
@ -506,6 +507,44 @@ def _process_spectrogram(
|
|||||||
return pred_nms, features
|
return pred_nms, features
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_model_outputs(
|
||||||
|
outputs: ModelOutput,
|
||||||
|
samp_rate: int,
|
||||||
|
config: ProcessingConfiguration,
|
||||||
|
) -> Tuple[List[Annotation], np.ndarray]:
|
||||||
|
# run non-max suppression
|
||||||
|
pred_nms_list, features = pp.run_nms(
|
||||||
|
outputs,
|
||||||
|
{
|
||||||
|
"nms_kernel_size": config["nms_kernel_size"],
|
||||||
|
"max_freq": config["max_freq"],
|
||||||
|
"min_freq": config["min_freq"],
|
||||||
|
"fft_win_length": config["fft_win_length"],
|
||||||
|
"fft_overlap": config["fft_overlap"],
|
||||||
|
"resize_factor": config["resize_factor"],
|
||||||
|
"nms_top_k_per_sec": config["nms_top_k_per_sec"],
|
||||||
|
"detection_threshold": config["detection_threshold"],
|
||||||
|
},
|
||||||
|
np.array([float(samp_rate)]),
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_nms = pred_nms_list[0]
|
||||||
|
|
||||||
|
# if we have a background class
|
||||||
|
class_probs = pred_nms.get("class_probs")
|
||||||
|
if (class_probs is not None) and (
|
||||||
|
class_probs.shape[0] > len(config["class_names"])
|
||||||
|
):
|
||||||
|
pred_nms["class_probs"] = class_probs[:-1, :]
|
||||||
|
|
||||||
|
annotations = get_annotations_from_preds(
|
||||||
|
pred_nms,
|
||||||
|
config["class_names"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return annotations, features[0]
|
||||||
|
|
||||||
|
|
||||||
def process_spectrogram(
|
def process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
|
@ -224,3 +224,32 @@ def test_process_audio_with_default_model():
|
|||||||
assert spec is not None
|
assert spec is not None
|
||||||
assert isinstance(spec, torch.Tensor)
|
assert isinstance(spec, torch.Tensor)
|
||||||
assert spec.shape == (1, 1, 128, 512)
|
assert spec.shape == (1, 1, 128, 512)
|
||||||
|
|
||||||
|
|
||||||
|
def test_postprocess_model_outputs():
|
||||||
|
"""Test postprocessing model outputs."""
|
||||||
|
# Load model outputs
|
||||||
|
audio = api.load_audio(TEST_DATA[1])
|
||||||
|
spec = api.generate_spectrogram(audio)
|
||||||
|
model_outputs = api.model(spec)
|
||||||
|
|
||||||
|
# Postprocess outputs
|
||||||
|
predictions, features = api.postprocess(model_outputs)
|
||||||
|
|
||||||
|
assert predictions is not None
|
||||||
|
assert isinstance(predictions, list)
|
||||||
|
assert len(predictions) > 0
|
||||||
|
sample_pred = predictions[0]
|
||||||
|
assert isinstance(sample_pred, dict)
|
||||||
|
assert "class" in sample_pred
|
||||||
|
assert "class_prob" in sample_pred
|
||||||
|
assert "det_prob" in sample_pred
|
||||||
|
assert "start_time" in sample_pred
|
||||||
|
assert "end_time" in sample_pred
|
||||||
|
assert "low_freq" in sample_pred
|
||||||
|
assert "high_freq" in sample_pred
|
||||||
|
|
||||||
|
assert features is not None
|
||||||
|
assert isinstance(features, np.ndarray)
|
||||||
|
assert features.shape[0] == len(predictions)
|
||||||
|
assert features.shape[1] == 32
|
||||||
|
Loading…
Reference in New Issue
Block a user