mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added working postprocess decoding tests
This commit is contained in:
parent
1f4454693e
commit
3abebc9c17
@ -102,10 +102,6 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
|
||||
|
||||
classes = det_info.classes
|
||||
features = det_info.features
|
||||
|
||||
detections.append(
|
||||
RawPrediction(
|
||||
detection_score=det_info.score,
|
||||
@ -113,8 +109,8 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_scores=classes,
|
||||
features=features,
|
||||
class_scores=det_info.classes,
|
||||
features=det_info.features,
|
||||
)
|
||||
)
|
||||
|
||||
@ -256,7 +252,79 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction.high_freq,
|
||||
]
|
||||
),
|
||||
features=[
|
||||
features=get_prediction_features(raw_prediction.features),
|
||||
)
|
||||
|
||||
tags = [
|
||||
*get_generic_tags(
|
||||
raw_prediction.detection_score,
|
||||
generic_class_tags=generic_class_tags,
|
||||
),
|
||||
*get_class_tags(
|
||||
raw_prediction.class_scores,
|
||||
sound_event_decoder,
|
||||
top_class_only=top_class_only,
|
||||
threshold=classification_threshold,
|
||||
),
|
||||
]
|
||||
|
||||
return data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=raw_prediction.detection_score,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def get_generic_tags(
|
||||
detection_score: float,
|
||||
generic_class_tags: List[data.Tag],
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Create PredictedTag objects for the generic category.
|
||||
|
||||
Takes the base list of generic tags and assigns the overall detection
|
||||
score to each one, wrapping them in `PredictedTag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_score : float
|
||||
The overall confidence score of the detection event.
|
||||
generic_class_tags : List[data.Tag]
|
||||
The list of base `soundevent.data.Tag` objects that define the
|
||||
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.PredictedTag]
|
||||
A list of `PredictedTag` objects for the generic category, each
|
||||
assigned the `detection_score`.
|
||||
"""
|
||||
return [
|
||||
data.PredictedTag(tag=tag, score=detection_score)
|
||||
for tag in generic_class_tags
|
||||
]
|
||||
|
||||
|
||||
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : xr.DataArray
|
||||
A 1D xarray DataArray containing feature values, indexed by a coordinate
|
||||
named 'feature' which holds the feature names (e.g., output of selecting
|
||||
features for one detection from `extract_detection_xr_dataset`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Feature]
|
||||
A list of `soundevent.data.Feature` objects.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- This function creates basic `Term` objects using the feature coordinate
|
||||
names with a "batdetect2:" prefix.
|
||||
"""
|
||||
return [
|
||||
data.Feature(
|
||||
term=data.Term(
|
||||
name=f"batdetect2:{feat_name}",
|
||||
@ -265,24 +333,49 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
for feat_name, value in _iterate_over_array(
|
||||
raw_prediction.features
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
tags = [
|
||||
data.PredictedTag(tag=tag, score=raw_prediction.detection_score)
|
||||
for tag in generic_class_tags
|
||||
for feat_name, value in _iterate_over_array(features)
|
||||
]
|
||||
|
||||
class_scores = raw_prediction.class_scores
|
||||
|
||||
if classification_threshold is not None:
|
||||
class_scores = class_scores.where(
|
||||
class_scores > classification_threshold,
|
||||
drop=True,
|
||||
)
|
||||
def get_class_tags(
|
||||
class_scores: xr.DataArray,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
top_class_only: bool = False,
|
||||
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Generate specific PredictedTags based on class scores and decoder.
|
||||
|
||||
Filters class scores by the threshold, sorts remaining scores descending,
|
||||
decodes the class name(s) into base tags using the `sound_event_decoder`,
|
||||
and creates `PredictedTag` objects associating the class score. Stops after
|
||||
the first (top) class if `top_class_only` is True.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_scores : xr.DataArray
|
||||
A 1D xarray DataArray containing class probabilities/scores, indexed
|
||||
by a 'category' coordinate holding the class names.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Function to map a class name string to a list of base `data.Tag`
|
||||
objects.
|
||||
top_class_only : bool, default=False
|
||||
If True, only generate tags for the single highest-scoring class above
|
||||
the threshold.
|
||||
threshold : float, optional
|
||||
Minimum score for a class to be considered. If None, all classes are
|
||||
processed (or top-1 if `top_class_only` is True). Defaults to
|
||||
`DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.PredictedTag]
|
||||
A list of `PredictedTag` objects for the class(es) that passed the
|
||||
threshold, ordered by score if `top_class_only` is False.
|
||||
"""
|
||||
tags = []
|
||||
|
||||
if threshold is not None:
|
||||
class_scores = class_scores.where(class_scores > threshold, drop=True)
|
||||
|
||||
for class_name, score in _iterate_sorted(class_scores):
|
||||
class_tags = sound_event_decoder(class_name)
|
||||
@ -298,11 +391,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
if top_class_only:
|
||||
break
|
||||
|
||||
return data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=raw_prediction.detection_score,
|
||||
tags=tags,
|
||||
)
|
||||
return tags
|
||||
|
||||
|
||||
def _iterate_over_array(array: xr.DataArray):
|
||||
@ -314,7 +403,7 @@ def _iterate_over_array(array: xr.DataArray):
|
||||
|
||||
def _iterate_sorted(array: xr.DataArray):
|
||||
dim_name = array.dims[0]
|
||||
coords = array.coords[dim_name]
|
||||
indices = np.argsort(coords.values)
|
||||
coords = array.coords[dim_name].values
|
||||
indices = np.argsort(-array.values)
|
||||
for index in indices:
|
||||
yield str(coords[index]), coords.values[index]
|
||||
yield str(coords[index]), float(array.values[index])
|
||||
|
@ -4,21 +4,20 @@ from typing import List, Tuple
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
|
||||
# Removed dataclass import as MockRawPrediction is replaced
|
||||
from soundevent import data
|
||||
|
||||
# Import functions to test
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
get_class_tags,
|
||||
get_generic_tags,
|
||||
get_prediction_features,
|
||||
)
|
||||
from batdetect2.postprocess.types import RawPrediction
|
||||
|
||||
|
||||
# Dummy GeometryBuilder function fixture
|
||||
@pytest.fixture
|
||||
def dummy_geometry_builder():
|
||||
"""A simple GeometryBuilder that creates a BBox around the point."""
|
||||
@ -30,7 +29,6 @@ def dummy_geometry_builder():
|
||||
time, freq = position
|
||||
width = dimensions.sel(dimension="width").item()
|
||||
height = dimensions.sel(dimension="height").item()
|
||||
# Assume position is the center
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - width / 2,
|
||||
@ -43,7 +41,6 @@ def dummy_geometry_builder():
|
||||
return _builder
|
||||
|
||||
|
||||
# Dummy SoundEventDecoder function fixture
|
||||
@pytest.fixture
|
||||
def dummy_sound_event_decoder():
|
||||
"""A simple SoundEventDecoder mapping names to tags."""
|
||||
@ -94,12 +91,9 @@ def sample_clip(sample_recording) -> data.Clip:
|
||||
)
|
||||
|
||||
|
||||
# Fixture for a detection dataset (adapted from test_extraction)
|
||||
@pytest.fixture
|
||||
def sample_detection_dataset() -> xr.Dataset:
|
||||
"""Creates a sample detection dataset suitable for decoding."""
|
||||
# Based on test_extraction's corrected expectations
|
||||
# Detections: (t=20, f=300, s=0.9), (t=10, f=200, s=0.8)
|
||||
expected_times = np.array([20, 10])
|
||||
expected_freqs = np.array([300, 200])
|
||||
detection_coords = {
|
||||
@ -125,7 +119,7 @@ def sample_detection_dataset() -> xr.Dataset:
|
||||
|
||||
classes_data = np.array(
|
||||
[[0.43, 0.85], [0.24, 0.66]],
|
||||
dtype=np.float32, # Simplified values
|
||||
dtype=np.float32,
|
||||
)
|
||||
classes = xr.DataArray(
|
||||
classes_data,
|
||||
@ -198,13 +192,10 @@ def empty_detection_dataset() -> xr.Dataset:
|
||||
)
|
||||
|
||||
|
||||
# Fixture for sample RawPrediction objects (using the actual type)
|
||||
@pytest.fixture
|
||||
def sample_raw_predictions() -> List[RawPrediction]:
|
||||
"""Manually crafted RawPrediction objects using the actual type."""
|
||||
# Corresponds roughly to sample_detection_dataset after geometry building
|
||||
# Det 1: t=20, f=300, s=0.9, w=7, h=16, classes=[0.43, 0.85], feats=[7, 16, 25, 34]
|
||||
# Det 2: t=10, f=200, s=0.8, w=3, h=12, classes=[0.24, 0.66], feats=[ 3, 12, 21, 30]
|
||||
|
||||
pred1_classes = xr.DataArray(
|
||||
[0.43, 0.85], coords={"category": ["bat", "noise"]}, dims=["category"]
|
||||
)
|
||||
@ -213,12 +204,12 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred1 = RawPrediction( # Use RawPrediction directly
|
||||
pred1 = RawPrediction(
|
||||
detection_score=0.9,
|
||||
start_time=20 - 7 / 2,
|
||||
end_time=20 + 7 / 2, # 16.5, 23.5
|
||||
end_time=20 + 7 / 2,
|
||||
low_freq=300 - 16 / 2,
|
||||
high_freq=300 + 16 / 2, # 292, 308
|
||||
high_freq=300 + 16 / 2,
|
||||
class_scores=pred1_classes,
|
||||
features=pred1_features,
|
||||
)
|
||||
@ -231,25 +222,25 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred2 = RawPrediction( # Use RawPrediction directly
|
||||
pred2 = RawPrediction(
|
||||
detection_score=0.8,
|
||||
start_time=10 - 3 / 2,
|
||||
end_time=10 + 3 / 2, # 8.5, 11.5
|
||||
end_time=10 + 3 / 2,
|
||||
low_freq=200 - 12 / 2,
|
||||
high_freq=200 + 12 / 2, # 194, 206
|
||||
high_freq=200 + 12 / 2,
|
||||
class_scores=pred2_classes,
|
||||
features=pred2_features,
|
||||
)
|
||||
|
||||
pred3_classes = xr.DataArray(
|
||||
[0.05, 0.02], coords={"category": ["bat", "noise"]}, dims=["category"]
|
||||
) # Below default threshold
|
||||
)
|
||||
pred3_features = xr.DataArray(
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred3 = RawPrediction( # Use RawPrediction directly
|
||||
pred3 = RawPrediction(
|
||||
detection_score=0.15,
|
||||
start_time=5.0,
|
||||
end_time=6.0,
|
||||
@ -261,9 +252,6 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
return [pred1, pred2, pred3]
|
||||
|
||||
|
||||
# --- Tests for convert_xr_dataset_to_raw_prediction ---
|
||||
|
||||
|
||||
def test_convert_xr_dataset_basic(
|
||||
sample_detection_dataset, dummy_geometry_builder
|
||||
):
|
||||
@ -275,16 +263,14 @@ def test_convert_xr_dataset_basic(
|
||||
assert isinstance(raw_predictions, list)
|
||||
assert len(raw_predictions) == 2
|
||||
|
||||
# Check first prediction (score=0.9)
|
||||
pred1 = raw_predictions[0]
|
||||
assert isinstance(pred1, RawPrediction) # Check against the actual type
|
||||
assert pred1.detection_score == pytest.approx(0.9)
|
||||
# Check bounds derived from dummy_geometry_builder (center pos assumed)
|
||||
# t=20, f=300, w=7, h=16
|
||||
assert pred1.start_time == pytest.approx(20 - 7 / 2)
|
||||
assert pred1.end_time == pytest.approx(20 + 7 / 2)
|
||||
assert pred1.low_freq == pytest.approx(300 - 16 / 2)
|
||||
assert pred1.high_freq == pytest.approx(300 + 16 / 2)
|
||||
assert isinstance(pred1, RawPrediction)
|
||||
assert pred1.detection_score == 0.9
|
||||
|
||||
assert pred1.start_time == 20 - 7 / 2
|
||||
assert pred1.end_time == 20 + 7 / 2
|
||||
assert pred1.low_freq == 300 - 16 / 2
|
||||
assert pred1.high_freq == 300 + 16 / 2
|
||||
xr.testing.assert_allclose(
|
||||
pred1.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=0),
|
||||
@ -293,15 +279,14 @@ def test_convert_xr_dataset_basic(
|
||||
pred1.features, sample_detection_dataset["features"].sel(detection=0)
|
||||
)
|
||||
|
||||
# Check second prediction (score=0.8)
|
||||
pred2 = raw_predictions[1]
|
||||
assert isinstance(pred2, RawPrediction) # Check against the actual type
|
||||
assert pred2.detection_score == pytest.approx(0.8)
|
||||
# t=10, f=200, w=3, h=12
|
||||
assert pred2.start_time == pytest.approx(10 - 3 / 2)
|
||||
assert pred2.end_time == pytest.approx(10 + 3 / 2)
|
||||
assert pred2.low_freq == pytest.approx(200 - 12 / 2)
|
||||
assert pred2.high_freq == pytest.approx(200 + 12 / 2)
|
||||
assert isinstance(pred2, RawPrediction)
|
||||
assert pred2.detection_score == 0.8
|
||||
|
||||
assert pred2.start_time == 10 - 3 / 2
|
||||
assert pred2.end_time == 10 + 3 / 2
|
||||
assert pred2.low_freq == 200 - 12 / 2
|
||||
assert pred2.high_freq == 200 + 12 / 2
|
||||
xr.testing.assert_allclose(
|
||||
pred2.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=1),
|
||||
@ -311,9 +296,6 @@ def test_convert_xr_dataset_basic(
|
||||
)
|
||||
|
||||
|
||||
# ...(rest of the tests remain unchanged as they accessed attributes correctly)...
|
||||
|
||||
|
||||
def test_convert_xr_dataset_empty(
|
||||
empty_detection_dataset, dummy_geometry_builder
|
||||
):
|
||||
@ -325,9 +307,6 @@ def test_convert_xr_dataset_empty(
|
||||
assert len(raw_predictions) == 0
|
||||
|
||||
|
||||
# --- Tests for convert_raw_prediction_to_sound_event_prediction ---
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_basic(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
@ -335,7 +314,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
generic_tags,
|
||||
):
|
||||
"""Test basic conversion, default threshold, multi-label."""
|
||||
# score=0.9, classes=[0.43(bat), 0.85(noise)]
|
||||
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
@ -343,14 +322,11 @@ def test_convert_raw_to_sound_event_basic(
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
# classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD (0.1),
|
||||
# top_class_only=False,
|
||||
)
|
||||
|
||||
assert isinstance(se_pred, data.SoundEventPrediction)
|
||||
assert se_pred.score == pytest.approx(raw_pred.detection_score)
|
||||
assert se_pred.score == raw_pred.detection_score
|
||||
|
||||
# Check SoundEvent
|
||||
se = se_pred.sound_event
|
||||
assert isinstance(se, data.SoundEvent)
|
||||
assert se.recording == sample_recording
|
||||
@ -365,27 +341,21 @@ def test_convert_raw_to_sound_event_basic(
|
||||
],
|
||||
)
|
||||
assert len(se.features) == len(raw_pred.features)
|
||||
# Simple check for feature presence and value type
|
||||
|
||||
feat_dict = {f.term.name: f.value for f in se.features}
|
||||
assert "batdetect2:f0" in feat_dict and isinstance(
|
||||
feat_dict["batdetect2:f0"], float
|
||||
)
|
||||
assert feat_dict["batdetect2:f0"] == pytest.approx(7.0)
|
||||
assert feat_dict["batdetect2:f0"] == 7.0
|
||||
|
||||
# Check Tags
|
||||
# Expected: Generic(0.9), Noise(0.85), Bat(0.43)
|
||||
# Note: Order might depend on sortby implementation detail, compare as sets
|
||||
expected_tags = {
|
||||
# Generic Tag
|
||||
(generic_tags[0].key, generic_tags[0].value, 0.9),
|
||||
# Noise Tag (score 0.85 > 0.1)
|
||||
("category", "noise", 0.85),
|
||||
# Bat Tag (score 0.43 > 0.1)
|
||||
("species", "Myotis", 0.43),
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("soundevent:species", "Myotis", 0.43),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
}
|
||||
print("expected", expected_tags)
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
print("actual", actual_tags)
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
@ -396,9 +366,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
generic_tags,
|
||||
):
|
||||
"""Test effect of classification threshold."""
|
||||
raw_pred = sample_raw_predictions[
|
||||
0
|
||||
] # score=0.9, classes=[0.43(bat), 0.85(noise)]
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
high_threshold = 0.5
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
@ -406,16 +374,17 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=high_threshold, # Only noise should pass
|
||||
classification_threshold=high_threshold,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
# Expected: Generic(0.9), Noise(0.85) - Bat (0.43) is below threshold
|
||||
expected_tags = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
|
||||
("category", "noise", pytest.approx(0.85)),
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
}
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
@ -426,27 +395,25 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
generic_tags,
|
||||
):
|
||||
"""Test when classification_threshold is None."""
|
||||
raw_pred = sample_raw_predictions[
|
||||
2
|
||||
] # score=0.15, classes=[0.05(bat), 0.02(noise)]
|
||||
# Both classes are below default threshold, but should be included if None
|
||||
raw_pred = sample_raw_predictions[2]
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=None, # No thresholding
|
||||
classification_threshold=None,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
# Expected: Generic(0.15), Bat(0.05), Noise(0.02)
|
||||
expected_tags = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
|
||||
("species", "Myotis", pytest.approx(0.05)),
|
||||
("category", "noise", pytest.approx(0.02)),
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
("soundevent:species", "Myotis", 0.05),
|
||||
("soundevent:category", "noise", 0.02),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
}
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
@ -457,10 +424,7 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
generic_tags,
|
||||
):
|
||||
"""Test top_class_only=True behavior."""
|
||||
raw_pred = sample_raw_predictions[
|
||||
0
|
||||
] # score=0.9, classes=[0.43(bat), 0.85(noise)]
|
||||
# Highest score is noise (0.85)
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
@ -468,15 +432,16 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=True, # Only include top class (noise)
|
||||
top_class_only=True,
|
||||
)
|
||||
|
||||
# Expected: Generic(0.9), Noise(0.85)
|
||||
expected_tags = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
|
||||
("category", "noise", pytest.approx(0.85)),
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
}
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
@ -487,30 +452,26 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
generic_tags,
|
||||
):
|
||||
"""Test when all class scores are below the default threshold."""
|
||||
raw_pred = sample_raw_predictions[
|
||||
2
|
||||
] # score=0.15, classes=[0.05(bat), 0.02(noise)]
|
||||
raw_pred = sample_raw_predictions[2]
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, # 0.1
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
# Expected: Only Generic(0.15) tag, as others are below threshold
|
||||
expected_tags = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
}
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
# --- Tests for convert_raw_predictions_to_clip_prediction ---
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_basic(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
@ -531,25 +492,22 @@ def test_convert_raw_list_to_clip_basic(
|
||||
assert clip_pred.clip == sample_clip
|
||||
assert len(clip_pred.sound_events) == len(sample_raw_predictions)
|
||||
|
||||
# Check if the contained sound events seem correct (basic check)
|
||||
assert clip_pred.sound_events[0].score == pytest.approx(
|
||||
assert clip_pred.sound_events[0].score == (
|
||||
sample_raw_predictions[0].detection_score
|
||||
)
|
||||
assert clip_pred.sound_events[1].score == pytest.approx(
|
||||
assert clip_pred.sound_events[1].score == (
|
||||
sample_raw_predictions[1].detection_score
|
||||
)
|
||||
assert clip_pred.sound_events[2].score == pytest.approx(
|
||||
assert clip_pred.sound_events[2].score == (
|
||||
sample_raw_predictions[2].detection_score
|
||||
)
|
||||
|
||||
# Check if tags were generated correctly for one event (e.g., the last one)
|
||||
# Pred 3 has score 0.15, classes [0.05, 0.02]. Only generic tag expected.
|
||||
se_pred3_tags = {
|
||||
(pt.tag.key, pt.tag.value, pt.score)
|
||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[2].tags
|
||||
}
|
||||
expected_tags3 = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
}
|
||||
assert se_pred3_tags == expected_tags3
|
||||
|
||||
@ -579,26 +537,126 @@ def test_convert_raw_list_to_clip_passes_args(
|
||||
generic_tags,
|
||||
):
|
||||
"""Test that arguments like top_class_only are passed through."""
|
||||
# Use top_class_only = True
|
||||
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=sample_raw_predictions,
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=True, # <<-- Argument being tested
|
||||
top_class_only=True,
|
||||
)
|
||||
|
||||
assert len(clip_pred.sound_events) == 3
|
||||
|
||||
# Check tags for the first prediction (score=0.9, classes=[0.43(bat), 0.85(noise)])
|
||||
# With top_class_only=True, expect Generic(0.9) and Noise(0.85) only
|
||||
se_pred1_tags = {
|
||||
(pt.tag.key, pt.tag.value, pt.score)
|
||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[0].tags
|
||||
}
|
||||
expected_tags1 = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
|
||||
("category", "noise", pytest.approx(0.85)),
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
}
|
||||
assert se_pred1_tags == expected_tags1
|
||||
|
||||
|
||||
def test_get_generic_tags_basic(generic_tags):
|
||||
"""Test creation of generic tags with score."""
|
||||
detection_score = 0.75
|
||||
predicted_tags = get_generic_tags(
|
||||
detection_score=detection_score, generic_class_tags=generic_tags
|
||||
)
|
||||
assert len(predicted_tags) == len(generic_tags)
|
||||
for predicted_tag in predicted_tags:
|
||||
assert isinstance(predicted_tag, data.PredictedTag)
|
||||
assert predicted_tag.score == detection_score
|
||||
assert predicted_tag.tag in generic_tags
|
||||
|
||||
|
||||
def test_get_prediction_features_basic():
|
||||
"""Test conversion of feature DataArray to list of Features."""
|
||||
feature_data = xr.DataArray(
|
||||
[1.1, 2.2, 3.3],
|
||||
coords={"feature": ["feat1", "feat2", "feat3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
features = get_prediction_features(feature_data)
|
||||
assert len(features) == 3
|
||||
for feature, feat_name, feat_value in zip(
|
||||
features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3]
|
||||
):
|
||||
assert isinstance(feature, data.Feature)
|
||||
assert feature.term.name == f"batdetect2:{feat_name}"
|
||||
assert feature.value == feat_value
|
||||
|
||||
|
||||
def test_get_class_tags_basic(dummy_sound_event_decoder):
|
||||
"""Test creation of class tags based on scores and decoder."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
coords={"category": ["bat", "noise", "unknown"]},
|
||||
dims=["category"],
|
||||
)
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
)
|
||||
assert len(predicted_tags) == 3
|
||||
tag_values = [pt.tag.value for pt in predicted_tags]
|
||||
tag_scores = [pt.score for pt in predicted_tags]
|
||||
|
||||
assert "Myotis" in tag_values
|
||||
assert "noise" in tag_values
|
||||
assert "uncertain" in tag_values
|
||||
assert 0.6 in tag_scores
|
||||
assert 0.2 in tag_scores
|
||||
assert 0.9 in tag_scores
|
||||
|
||||
|
||||
def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
||||
"""Test class tag creation with a threshold."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
coords={"category": ["bat", "noise", "unknown"]},
|
||||
dims=["category"],
|
||||
)
|
||||
threshold = 0.5
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
assert len(predicted_tags) == 2
|
||||
tag_values = [pt.tag.value for pt in predicted_tags]
|
||||
assert "Myotis" in tag_values
|
||||
assert "noise" not in tag_values
|
||||
assert "uncertain" in tag_values
|
||||
|
||||
|
||||
def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
||||
"""Test class tag creation with top_class_only."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
coords={"category": ["bat", "noise", "unknown"]},
|
||||
dims=["category"],
|
||||
)
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
top_class_only=True,
|
||||
)
|
||||
|
||||
assert len(predicted_tags) == 1
|
||||
assert predicted_tags[0].tag.value == "uncertain"
|
||||
assert predicted_tags[0].score == 0.9
|
||||
|
||||
|
||||
def test_get_class_tags_empty(dummy_sound_event_decoder):
|
||||
"""Test with empty class scores."""
|
||||
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
)
|
||||
assert len(predicted_tags) == 0
|
||||
|
Loading…
Reference in New Issue
Block a user