Update tests after incorporating term registry from soundevent

This commit is contained in:
mbsantiago 2025-08-12 18:44:18 +01:00
parent 51d0a49da9
commit 59aaf07af5
5 changed files with 172 additions and 236 deletions

View File

@ -254,11 +254,13 @@ class TestLoadBatDetect2Files:
assert clip_ann.clip.recording.duration == 5.0 assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1 assert len(clip_ann.sound_events) == 1
assert clip_ann.notes[0].message == "Standard notes." assert clip_ann.notes[0].message == "Standard notes."
clip_tag = data.find_tag(clip_ann.tags, "Class") clip_tag = data.find_tag(clip_ann.tags, term_label="Class")
assert clip_tag is not None assert clip_tag is not None
assert clip_tag.value == "Myotis" assert clip_tag.value == "Myotis"
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class") recording_tag = data.find_tag(
clip_ann.clip.recording.tags, term_label="Class"
)
assert recording_tag is not None assert recording_tag is not None
assert recording_tag.value == "Myotis" assert recording_tag.value == "Myotis"
@ -271,15 +273,15 @@ class TestLoadBatDetect2Files:
40000, 40000,
] ]
se_class_tag = data.find_tag(se_ann.tags, "Class") se_class_tag = data.find_tag(se_ann.tags, term_label="Class")
assert se_class_tag is not None assert se_class_tag is not None
assert se_class_tag.value == "Myotis" assert se_class_tag.value == "Myotis"
se_event_tag = data.find_tag(se_ann.tags, "Call Type") se_event_tag = data.find_tag(se_ann.tags, term_label="Call Type")
assert se_event_tag is not None assert se_event_tag is not None
assert se_event_tag.value == "Echolocation" assert se_event_tag.value == "Echolocation"
se_individual_tag = data.find_tag(se_ann.tags, "Individual") se_individual_tag = data.find_tag(se_ann.tags, term_label="Individual")
assert se_individual_tag is not None assert se_individual_tag is not None
assert se_individual_tag.value == "0" assert se_individual_tag.value == "0"
@ -439,7 +441,7 @@ class TestLoadBatDetect2Merged:
assert clip_ann.clip.recording.duration == 5.0 assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1 assert len(clip_ann.sound_events) == 1
clip_class_tag = data.find_tag(clip_ann.tags, "Class") clip_class_tag = data.find_tag(clip_ann.tags, term_label="Class")
assert clip_class_tag is not None assert clip_class_tag is not None
assert clip_class_tag.value == "Myotis" assert clip_class_tag.value == "Myotis"

View File

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional
import numpy as np import numpy as np
import pytest import pytest
@ -16,35 +16,11 @@ from batdetect2.postprocess.decoding import (
get_prediction_features, get_prediction_features,
) )
from batdetect2.postprocess.types import RawPrediction from batdetect2.postprocess.types import RawPrediction
from batdetect2.targets.types import TargetProtocol
@pytest.fixture @pytest.fixture
def dummy_geometry_builder(): def dummy_targets() -> TargetProtocol:
"""A simple GeometryBuilder that creates a BBox around the point."""
def _builder(
position: Tuple[float, float],
dimensions: xr.DataArray,
class_name: Optional[str] = None,
) -> data.BoundingBox:
time, freq = position
width = dimensions.sel(dimension="width").item()
height = dimensions.sel(dimension="height").item()
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return _builder
@pytest.fixture
def dummy_sound_event_decoder():
"""A simple SoundEventDecoder mapping names to tags."""
tag_map = { tag_map = {
"bat": [ "bat": [
data.Tag(term=data.term_from_key(key="species"), value="Myotis") data.Tag(term=data.term_from_key(key="species"), value="Myotis")
@ -57,19 +33,57 @@ def dummy_sound_event_decoder():
], ],
} }
def _decoder(class_name: str) -> List[data.Tag]: class DummyTargets(TargetProtocol):
return tag_map.get(class_name.lower(), []) class_names = [
"bat",
return _decoder "noise",
"unknown",
@pytest.fixture
def generic_tags() -> List[data.Tag]:
"""Sample generic tags."""
return [
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
] ]
dimension_names = ["width", "height"]
generic_class_tags = [
data.Tag(
term=data.term_from_key(key="detector"), value="batdetect2"
)
]
def filter(self, sound_event: data.SoundEventAnnotation):
return True
def transform(self, sound_event: data.SoundEventAnnotation):
return sound_event
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
return "bat"
def decode_class(self, class_label: str) -> List[data.Tag]:
return tag_map.get(class_label.lower(), [])
def encode_roi(self, sound_event: data.SoundEventAnnotation):
return np.array([0.0, 0.0]), np.array([0.0, 0.0])
def decode_roi(
self,
position,
size: np.ndarray,
class_name: Optional[str] = None,
):
time, freq = position
width, height = size
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return DummyTargets()
@pytest.fixture @pytest.fixture
def sample_recording() -> data.Recording: def sample_recording() -> data.Recording:
@ -156,7 +170,7 @@ def empty_detection_dataset() -> xr.Dataset:
"""Creates an empty detection dataset with correct structure.""" """Creates an empty detection dataset with correct structure."""
detection_coords = { detection_coords = {
"time": ("detection", np.array([], dtype=np.float64)), "time": ("detection", np.array([], dtype=np.float64)),
"freq": ("detection", np.array([], dtype=np.float64)), "frequency": ("detection", np.array([], dtype=np.float64)),
} }
scores = xr.DataArray( scores = xr.DataArray(
np.array([], dtype=np.float64), np.array([], dtype=np.float64),
@ -184,7 +198,7 @@ def empty_detection_dataset() -> xr.Dataset:
) )
return xr.Dataset( return xr.Dataset(
{ {
"score": scores, "scores": scores,
"dimensions": dimensions, "dimensions": dimensions,
"classes": classes, "classes": classes,
"features": features, "features": features,
@ -215,8 +229,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
300 + 16 / 2, 300 + 16 / 2,
] ]
), ),
class_scores=pred1_classes, class_scores=pred1_classes.values,
features=pred1_features, features=pred1_features.values,
) )
pred2_classes = xr.DataArray( pred2_classes = xr.DataArray(
@ -237,8 +251,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
200 + 12 / 2, 200 + 12 / 2,
] ]
), ),
class_scores=pred2_classes, class_scores=pred2_classes.values,
features=pred2_features, features=pred2_features.values,
) )
pred3_classes = xr.DataArray( pred3_classes = xr.DataArray(
@ -259,18 +273,17 @@ def sample_raw_predictions() -> List[RawPrediction]:
60.0, 60.0,
] ]
), ),
class_scores=pred3_classes, class_scores=pred3_classes.values,
features=pred3_features, features=pred3_features.values,
) )
return [pred1, pred2, pred3] return [pred1, pred2, pred3]
def test_convert_xr_dataset_basic( def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
sample_detection_dataset, dummy_geometry_builder
):
"""Test basic conversion of a dataset to RawPrediction list.""" """Test basic conversion of a dataset to RawPrediction list."""
raw_predictions = convert_xr_dataset_to_raw_prediction( raw_predictions = convert_xr_dataset_to_raw_prediction(
sample_detection_dataset, dummy_geometry_builder sample_detection_dataset,
dummy_targets.decode_roi,
) )
assert isinstance(raw_predictions, list) assert isinstance(raw_predictions, list)
@ -286,11 +299,11 @@ def test_convert_xr_dataset_basic(
20 + 7 / 2, 20 + 7 / 2,
300 + 16 / 2, 300 + 16 / 2,
] ]
xr.testing.assert_allclose( np.testing.assert_allclose(
pred1.class_scores, pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0), sample_detection_dataset["classes"].sel(detection=0),
) )
xr.testing.assert_allclose( np.testing.assert_allclose(
pred1.features, sample_detection_dataset["features"].sel(detection=0) pred1.features, sample_detection_dataset["features"].sel(detection=0)
) )
@ -304,21 +317,20 @@ def test_convert_xr_dataset_basic(
10 + 3 / 2, 10 + 3 / 2,
200 + 12 / 2, 200 + 12 / 2,
] ]
xr.testing.assert_allclose( np.testing.assert_allclose(
pred2.class_scores, pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1), sample_detection_dataset["classes"].sel(detection=1),
) )
xr.testing.assert_allclose( np.testing.assert_allclose(
pred2.features, sample_detection_dataset["features"].sel(detection=1) pred2.features, sample_detection_dataset["features"].sel(detection=1)
) )
def test_convert_xr_dataset_empty( def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
empty_detection_dataset, dummy_geometry_builder
):
"""Test conversion of an empty dataset.""" """Test conversion of an empty dataset."""
raw_predictions = convert_xr_dataset_to_raw_prediction( raw_predictions = convert_xr_dataset_to_raw_prediction(
empty_detection_dataset, dummy_geometry_builder empty_detection_dataset,
dummy_targets.decode_roi,
) )
assert isinstance(raw_predictions, list) assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 0 assert len(raw_predictions) == 0
@ -327,8 +339,7 @@ def test_convert_xr_dataset_empty(
def test_convert_raw_to_sound_event_basic( def test_convert_raw_to_sound_event_basic(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,
dummy_sound_event_decoder, dummy_targets,
generic_tags,
): ):
"""Test basic conversion, default threshold, multi-label.""" """Test basic conversion, default threshold, multi-label."""
@ -337,8 +348,7 @@ def test_convert_raw_to_sound_event_basic(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
generic_class_tags=generic_tags,
) )
assert isinstance(se_pred, data.SoundEventPrediction) assert isinstance(se_pred, data.SoundEventPrediction)
@ -357,10 +367,11 @@ def test_convert_raw_to_sound_event_basic(
) )
assert feat_dict["batdetect2:f0"] == 7.0 assert feat_dict["batdetect2:f0"] == 7.0
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9), (generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85), ("category", "noise", 0.85),
("soundevent:species", "Myotis", 0.43), ("dwc:scientificName", "Myotis", 0.43),
} }
actual_tags = { actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -369,10 +380,7 @@ def test_convert_raw_to_sound_event_basic(
def test_convert_raw_to_sound_event_thresholding( def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions, sample_raw_predictions, sample_recording, dummy_targets
sample_recording,
dummy_sound_event_decoder,
generic_tags,
): ):
"""Test effect of classification threshold.""" """Test effect of classification threshold."""
raw_pred = sample_raw_predictions[0] raw_pred = sample_raw_predictions[0]
@ -381,15 +389,15 @@ def test_convert_raw_to_sound_event_thresholding(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
generic_class_tags=generic_tags,
classification_threshold=high_threshold, classification_threshold=high_threshold,
top_class_only=False, top_class_only=False,
) )
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9), (generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85), ("category", "noise", 0.85),
} }
actual_tags = { actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -400,8 +408,7 @@ def test_convert_raw_to_sound_event_thresholding(
def test_convert_raw_to_sound_event_no_threshold( def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,
dummy_sound_event_decoder, dummy_targets,
generic_tags,
): ):
"""Test when classification_threshold is None.""" """Test when classification_threshold is None."""
raw_pred = sample_raw_predictions[2] raw_pred = sample_raw_predictions[2]
@ -409,16 +416,16 @@ def test_convert_raw_to_sound_event_no_threshold(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
generic_class_tags=generic_tags,
classification_threshold=None, classification_threshold=None,
top_class_only=False, top_class_only=False,
) )
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15), (generic_tags[0].term.name, generic_tags[0].value, 0.15),
("soundevent:species", "Myotis", 0.05), ("dwc:scientificName", "Myotis", 0.05),
("soundevent:category", "noise", 0.02), ("category", "noise", 0.02),
} }
actual_tags = { actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -429,8 +436,7 @@ def test_convert_raw_to_sound_event_no_threshold(
def test_convert_raw_to_sound_event_top_class( def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,
dummy_sound_event_decoder, dummy_targets,
generic_tags,
): ):
"""Test top_class_only=True behavior.""" """Test top_class_only=True behavior."""
raw_pred = sample_raw_predictions[0] raw_pred = sample_raw_predictions[0]
@ -438,15 +444,15 @@ def test_convert_raw_to_sound_event_top_class(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True, top_class_only=True,
) )
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9), (generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85), ("category", "noise", 0.85),
} }
actual_tags = { actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
@ -457,8 +463,7 @@ def test_convert_raw_to_sound_event_top_class(
def test_convert_raw_to_sound_event_all_below_threshold( def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,
dummy_sound_event_decoder, dummy_targets,
generic_tags,
): ):
"""Test when all class scores are below the default threshold.""" """Test when all class scores are below the default threshold."""
raw_pred = sample_raw_predictions[2] raw_pred = sample_raw_predictions[2]
@ -466,12 +471,12 @@ def test_convert_raw_to_sound_event_all_below_threshold(
se_pred = convert_raw_prediction_to_sound_event_prediction( se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred, raw_prediction=raw_pred,
recording=sample_recording, recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False, top_class_only=False,
) )
generic_tags = dummy_targets.generic_class_tags
expected_tags = { expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15), (generic_tags[0].term.name, generic_tags[0].value, 0.15),
} }
@ -484,15 +489,13 @@ def test_convert_raw_to_sound_event_all_below_threshold(
def test_convert_raw_list_to_clip_basic( def test_convert_raw_list_to_clip_basic(
sample_raw_predictions, sample_raw_predictions,
sample_clip, sample_clip,
dummy_sound_event_decoder, dummy_targets,
generic_tags,
): ):
"""Test converting a list of RawPredictions to a ClipPrediction.""" """Test converting a list of RawPredictions to a ClipPrediction."""
clip_pred = convert_raw_predictions_to_clip_prediction( clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions, raw_predictions=sample_raw_predictions,
clip=sample_clip, clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False, top_class_only=False,
) )
@ -515,23 +518,19 @@ def test_convert_raw_list_to_clip_basic(
(pt.tag.term.name, pt.tag.value, pt.score) (pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[2].tags for pt in clip_pred.sound_events[2].tags
} }
generic_tags = dummy_targets.generic_class_tags
expected_tags3 = { expected_tags3 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15), (generic_tags[0].term.name, generic_tags[0].value, 0.15),
} }
assert se_pred3_tags == expected_tags3 assert se_pred3_tags == expected_tags3
def test_convert_raw_list_to_clip_empty( def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting an empty list of RawPredictions.""" """Test converting an empty list of RawPredictions."""
clip_pred = convert_raw_predictions_to_clip_prediction( clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=[], raw_predictions=[],
clip=sample_clip, clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
generic_class_tags=generic_tags,
) )
assert isinstance(clip_pred, data.ClipPrediction) assert isinstance(clip_pred, data.ClipPrediction)
@ -542,16 +541,14 @@ def test_convert_raw_list_to_clip_empty(
def test_convert_raw_list_to_clip_passes_args( def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions, sample_raw_predictions,
sample_clip, sample_clip,
dummy_sound_event_decoder, dummy_targets,
generic_tags,
): ):
"""Test that arguments like top_class_only are passed through.""" """Test that arguments like top_class_only are passed through."""
clip_pred = convert_raw_predictions_to_clip_prediction( clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions, raw_predictions=sample_raw_predictions,
clip=sample_clip, clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True, top_class_only=True,
) )
@ -562,16 +559,18 @@ def test_convert_raw_list_to_clip_passes_args(
(pt.tag.term.name, pt.tag.value, pt.score) (pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[0].tags for pt in clip_pred.sound_events[0].tags
} }
generic_tags = dummy_targets.generic_class_tags
expected_tags1 = { expected_tags1 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9), (generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85), ("category", "noise", 0.85),
} }
assert se_pred1_tags == expected_tags1 assert se_pred1_tags == expected_tags1
def test_get_generic_tags_basic(generic_tags): def test_get_generic_tags_basic(dummy_targets):
"""Test creation of generic tags with score.""" """Test creation of generic tags with score."""
detection_score = 0.75 detection_score = 0.75
generic_tags = dummy_targets.generic_class_tags
predicted_tags = get_generic_tags( predicted_tags = get_generic_tags(
detection_score=detection_score, generic_class_tags=generic_tags detection_score=detection_score, generic_class_tags=generic_tags
) )
@ -589,17 +588,19 @@ def test_get_prediction_features_basic():
coords={"feature": ["feat1", "feat2", "feat3"]}, coords={"feature": ["feat1", "feat2", "feat3"]},
dims=["feature"], dims=["feature"],
) )
features = get_prediction_features(feature_data) features = get_prediction_features(feature_data.values)
assert len(features) == 3 assert len(features) == 3
for feature, feat_name, feat_value in zip( for feature, feat_name, feat_value in zip(
features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3] features,
["f0", "f1", "f2"],
[1.1, 2.2, 3.3],
): ):
assert isinstance(feature, data.Feature) assert isinstance(feature, data.Feature)
assert feature.term.name == f"batdetect2:{feat_name}" assert feature.term.name == f"batdetect2:{feat_name}"
assert feature.value == feat_value assert feature.value == feat_value
def test_get_class_tags_basic(dummy_sound_event_decoder): def test_get_class_tags_basic(dummy_targets):
"""Test creation of class tags based on scores and decoder.""" """Test creation of class tags based on scores and decoder."""
class_scores = xr.DataArray( class_scores = xr.DataArray(
[0.6, 0.2, 0.9], [0.6, 0.2, 0.9],
@ -607,8 +608,8 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
dims=["category"], dims=["category"],
) )
predicted_tags = get_class_tags( predicted_tags = get_class_tags(
class_scores=class_scores, class_scores=class_scores.values,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
) )
assert len(predicted_tags) == 3 assert len(predicted_tags) == 3
tag_values = [pt.tag.value for pt in predicted_tags] tag_values = [pt.tag.value for pt in predicted_tags]
@ -622,7 +623,7 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
assert 0.9 in tag_scores assert 0.9 in tag_scores
def test_get_class_tags_thresholding(dummy_sound_event_decoder): def test_get_class_tags_thresholding(dummy_targets):
"""Test class tag creation with a threshold.""" """Test class tag creation with a threshold."""
class_scores = xr.DataArray( class_scores = xr.DataArray(
[0.6, 0.2, 0.9], [0.6, 0.2, 0.9],
@ -631,8 +632,8 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
) )
threshold = 0.5 threshold = 0.5
predicted_tags = get_class_tags( predicted_tags = get_class_tags(
class_scores=class_scores, class_scores=class_scores.values,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
threshold=threshold, threshold=threshold,
) )
@ -643,7 +644,7 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
assert "uncertain" in tag_values assert "uncertain" in tag_values
def test_get_class_tags_top_class_only(dummy_sound_event_decoder): def test_get_class_tags_top_class_only(dummy_targets):
"""Test class tag creation with top_class_only.""" """Test class tag creation with top_class_only."""
class_scores = xr.DataArray( class_scores = xr.DataArray(
[0.6, 0.2, 0.9], [0.6, 0.2, 0.9],
@ -651,8 +652,8 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
dims=["category"], dims=["category"],
) )
predicted_tags = get_class_tags( predicted_tags = get_class_tags(
class_scores=class_scores, class_scores=class_scores.values,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
top_class_only=True, top_class_only=True,
) )
@ -661,11 +662,11 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
assert predicted_tags[0].score == 0.9 assert predicted_tags[0].score == 0.9
def test_get_class_tags_empty(dummy_sound_event_decoder): def test_get_class_tags_empty(dummy_targets):
"""Test with empty class scores.""" """Test with empty class scores."""
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"]) class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
predicted_tags = get_class_tags( predicted_tags = get_class_tags(
class_scores=class_scores, class_scores=class_scores.values,
sound_event_decoder=dummy_sound_event_decoder, targets=dummy_targets,
) )
assert len(predicted_tags) == 0 assert len(predicted_tags) == 0

View File

@ -5,6 +5,7 @@ from uuid import uuid4
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from soundevent import data from soundevent import data
from soundevent.terms import get_term
from batdetect2.targets.classes import ( from batdetect2.targets.classes import (
DEFAULT_SPECIES_LIST, DEFAULT_SPECIES_LIST,
@ -21,26 +22,19 @@ from batdetect2.targets.classes import (
load_decoder_from_config, load_decoder_from_config,
load_encoder_from_config, load_encoder_from_config,
) )
from batdetect2.targets.terms import TagInfo, TermRegistry from batdetect2.targets.terms import TagInfo
@pytest.fixture @pytest.fixture
def sample_annotation( def sample_annotation(
sound_event: data.SoundEvent, sound_event: data.SoundEvent,
sample_term_registry: TermRegistry,
) -> data.SoundEventAnnotation: ) -> data.SoundEventAnnotation:
"""Fixture for a sample SoundEventAnnotation.""" """Fixture for a sample SoundEventAnnotation."""
return data.SoundEventAnnotation( return data.SoundEventAnnotation(
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag( data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
term=sample_term_registry.get_term("species"), data.Tag(key="quality", value="Good"), # type: ignore
value="Pipistrellus pipistrellus",
),
data.Tag(
term=sample_term_registry.get_term("quality"),
value="Good",
),
], ],
) )
@ -136,59 +130,33 @@ def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
def test_is_target_class_match_all( def test_is_target_class_match_all(
sample_annotation: data.SoundEventAnnotation, sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
): ):
tags = { tags = {
data.Tag( data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
term=sample_term_registry["species"], data.Tag(key="quality", value="Good"), # type: ignore
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
} }
assert is_target_class(sample_annotation, tags, match_all=True) is True assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = { tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
assert is_target_class(sample_annotation, tags, match_all=True) is True assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = { tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
assert is_target_class(sample_annotation, tags, match_all=True) is False assert is_target_class(sample_annotation, tags, match_all=True) is False
def test_is_target_class_match_any( def test_is_target_class_match_any(
sample_annotation: data.SoundEventAnnotation, sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
): ):
tags = { tags = {
data.Tag( data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
term=sample_term_registry["species"], data.Tag(key="quality", value="Good"), # type: ignore
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
} }
assert is_target_class(sample_annotation, tags, match_all=False) is True assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = { tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
assert is_target_class(sample_annotation, tags, match_all=False) is True assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = { tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
assert is_target_class(sample_annotation, tags, match_all=False) is False assert is_target_class(sample_annotation, tags, match_all=False) is False
@ -208,7 +176,6 @@ def test_get_class_names_from_config():
def test_build_encoder_from_config( def test_build_encoder_from_config(
sample_annotation: data.SoundEventAnnotation, sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
): ):
config = ClassesConfig( config = ClassesConfig(
classes=[ classes=[
@ -220,25 +187,18 @@ def test_build_encoder_from_config(
) )
] ]
) )
encoder = build_sound_event_encoder( encoder = build_sound_event_encoder(config)
config,
term_registry=sample_term_registry,
)
result = encoder(sample_annotation) result = encoder(sample_annotation)
assert result == "pippip" assert result == "pippip"
config = ClassesConfig(classes=[]) config = ClassesConfig(classes=[])
encoder = build_sound_event_encoder( encoder = build_sound_event_encoder(config)
config,
term_registry=sample_term_registry,
)
result = encoder(sample_annotation) result = encoder(sample_annotation)
assert result is None assert result is None
def test_load_encoder_from_config_valid( def test_load_encoder_from_config_valid(
sample_annotation: data.SoundEventAnnotation, sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
create_temp_yaml: Callable[[str], Path], create_temp_yaml: Callable[[str], Path],
): ):
yaml_content = """ yaml_content = """
@ -249,10 +209,7 @@ def test_load_encoder_from_config_valid(
value: Pipistrellus pipistrellus value: Pipistrellus pipistrellus
""" """
temp_yaml_path = create_temp_yaml(yaml_content) temp_yaml_path = create_temp_yaml(yaml_content)
encoder = load_encoder_from_config( encoder = load_encoder_from_config(temp_yaml_path)
temp_yaml_path,
term_registry=sample_term_registry,
)
# We cannot directly compare the function, so we test it. # We cannot directly compare the function, so we test it.
result = encoder(sample_annotation) # type: ignore result = encoder(sample_annotation) # type: ignore
assert result == "pippip" assert result == "pippip"
@ -260,7 +217,6 @@ def test_load_encoder_from_config_valid(
def test_load_encoder_from_config_invalid( def test_load_encoder_from_config_invalid(
create_temp_yaml: Callable[[str], Path], create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
): ):
yaml_content = """ yaml_content = """
classes: classes:
@ -271,10 +227,7 @@ def test_load_encoder_from_config_invalid(
""" """
temp_yaml_path = create_temp_yaml(yaml_content) temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(KeyError): with pytest.raises(KeyError):
load_encoder_from_config( load_encoder_from_config(temp_yaml_path)
temp_yaml_path,
term_registry=sample_term_registry,
)
def test_get_default_class_name(): def test_get_default_class_name():
@ -291,7 +244,7 @@ def test_get_default_classes():
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0] assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
def test_build_decoder_from_config(sample_term_registry: TermRegistry): def test_build_decoder_from_config():
config = ClassesConfig( config = ClassesConfig(
classes=[ classes=[
TargetClass( TargetClass(
@ -304,12 +257,10 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
], ],
generic_class=[TagInfo(key="order", value="Chiroptera")], generic_class=[TagInfo(key="order", value="Chiroptera")],
) )
decoder = build_sound_event_decoder( decoder = build_sound_event_decoder(config)
config, term_registry=sample_term_registry
)
tags = decoder("pippip") tags = decoder("pippip")
assert len(tags) == 1 assert len(tags) == 1
assert tags[0].term == sample_term_registry["call_type"] assert tags[0].term == get_term("event")
assert tags[0].value == "Echolocation" assert tags[0].value == "Echolocation"
# Test when output_tags is None, should fall back to tags # Test when output_tags is None, should fall back to tags
@ -324,32 +275,25 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
], ],
generic_class=[TagInfo(key="order", value="Chiroptera")], generic_class=[TagInfo(key="order", value="Chiroptera")],
) )
decoder = build_sound_event_decoder( decoder = build_sound_event_decoder(config)
config, term_registry=sample_term_registry
)
tags = decoder("pippip") tags = decoder("pippip")
assert len(tags) == 1 assert len(tags) == 1
assert tags[0].term == sample_term_registry["species"] assert tags[0].term == get_term("species")
assert tags[0].value == "Pipistrellus pipistrellus" assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True # Test raise_on_unmapped=True
decoder = build_sound_event_decoder( decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
config, term_registry=sample_term_registry, raise_on_unmapped=True
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
decoder("unknown_class") decoder("unknown_class")
# Test raise_on_unmapped=False # Test raise_on_unmapped=False
decoder = build_sound_event_decoder( decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
config, term_registry=sample_term_registry, raise_on_unmapped=False
)
tags = decoder("unknown_class") tags = decoder("unknown_class")
assert len(tags) == 0 assert len(tags) == 0
def test_load_decoder_from_config_valid( def test_load_decoder_from_config_valid(
create_temp_yaml: Callable[[str], Path], create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
): ):
yaml_content = """ yaml_content = """
classes: classes:
@ -366,17 +310,15 @@ def test_load_decoder_from_config_valid(
""" """
temp_yaml_path = create_temp_yaml(yaml_content) temp_yaml_path = create_temp_yaml(yaml_content)
decoder = load_decoder_from_config( decoder = load_decoder_from_config(
temp_yaml_path, term_registry=sample_term_registry temp_yaml_path,
) )
tags = decoder("pippip") tags = decoder("pippip")
assert len(tags) == 1 assert len(tags) == 1
assert tags[0].term == sample_term_registry["call_type"] assert tags[0].term == get_term("call_type")
assert tags[0].value == "Echolocation" assert tags[0].value == "Echolocation"
def test_build_generic_class_tags_from_config( def test_build_generic_class_tags_from_config():
sample_term_registry: TermRegistry,
):
config = ClassesConfig( config = ClassesConfig(
classes=[ classes=[
TargetClass( TargetClass(
@ -391,11 +333,9 @@ def test_build_generic_class_tags_from_config(
TagInfo(key="call_type", value="Echolocation"), TagInfo(key="call_type", value="Echolocation"),
], ],
) )
generic_tags = build_generic_class_tags( generic_tags = build_generic_class_tags(config)
config, term_registry=sample_term_registry
)
assert len(generic_tags) == 2 assert len(generic_tags) == 2
assert generic_tags[0].term == sample_term_registry["order"] assert generic_tags[0].term == get_term("order")
assert generic_tags[0].value == "Chiroptera" assert generic_tags[0].value == "Chiroptera"
assert generic_tags[1].term == sample_term_registry["call_type"] assert generic_tags[1].term == get_term("call_type")
assert generic_tags[1].value == "Echolocation" assert generic_tags[1].value == "Echolocation"

View File

@ -80,7 +80,6 @@ def test_generated_heatmaps_have_correct_dimensions(
def test_generated_heatmap_are_non_zero_at_correct_positions( def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig, sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
pippip_tag: TagInfo, pippip_tag: TagInfo,
): ):
config = sample_target_config.model_copy( config = sample_target_config.model_copy(
@ -92,7 +91,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
) )
) )
targets = build_targets(config, term_registry=sample_term_registry) targets = build_targets(config)
spec = xr.DataArray( spec = xr.DataArray(
data=np.random.rand(100, 100), data=np.random.rand(100, 100),
@ -113,12 +112,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
coordinates=[10, 10, 20, 20], coordinates=[10, 10, 20, 20],
), ),
), ),
tags=[ tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
data.Tag(
term=sample_term_registry[pippip_tag.key],
value=pippip_tag.value,
)
],
) )
], ],
) )

View File

@ -2,12 +2,12 @@ import pytest
import torch import torch
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from soundevent.terms import get_term
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.postprocess import build_postprocessor, load_postprocess_config from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config from batdetect2.targets import build_targets, load_target_config
from batdetect2.targets.terms import get_term_from_key
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
@ -15,7 +15,6 @@ from batdetect2.train.preprocess import generate_train_example
@pytest.fixture @pytest.fixture
def build_from_config( def build_from_config(
create_temp_yaml, create_temp_yaml,
sample_term_registry,
): ):
def build(yaml_content): def build(yaml_content):
config_path = create_temp_yaml(yaml_content) config_path = create_temp_yaml(yaml_content)
@ -31,9 +30,7 @@ def build_from_config(
field="postprocessing", field="postprocessing",
) )
targets = build_targets( targets = build_targets(targets_config)
targets_config, term_registry=sample_term_registry
)
preprocessor = build_preprocessor(preprocessing_config) preprocessor = build_preprocessor(preprocessing_config)
labeller = build_clip_labeler( labeller = build_clip_labeler(
targets=targets, targets=targets,
@ -54,7 +51,6 @@ def build_from_config(
# TODO: better name # TODO: better name
def test_generated_train_example_has_expected_outputs( def test_generated_train_example_has_expected_outputs(
build_from_config, build_from_config,
sample_term_registry,
recording, recording,
): ):
yaml_content = """ yaml_content = """
@ -78,10 +74,11 @@ def test_generated_train_example_has_expected_outputs(
_, preprocessor, labeller, _ = build_from_config(yaml_content) _, preprocessor, labeller, _ = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation( se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry), sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
) )
clip_annotation = data.ClipAnnotation( clip_annotation = data.ClipAnnotation(
clip=data.Clip(start_time=0, end_time=0.5, recording=recording), clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
@ -108,7 +105,6 @@ def test_generated_train_example_has_expected_outputs(
def test_encoding_decoding_roundtrip_recovers_object( def test_encoding_decoding_roundtrip_recovers_object(
build_from_config, build_from_config,
sample_term_registry,
recording, recording,
): ):
yaml_content = """ yaml_content = """
@ -131,10 +127,11 @@ def test_encoding_decoding_roundtrip_recovers_object(
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content) _, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation( se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry), sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
) )
clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
@ -171,14 +168,16 @@ def test_encoding_decoding_roundtrip_recovers_object(
assert len(recovered.tags) == 2 assert len(recovered.tags) == 2
predicted_species_tag = next( predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == species), None iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
) )
assert predicted_species_tag is not None assert predicted_species_tag is not None
assert predicted_species_tag.score == 1 assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus" assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
predicted_order_tag = next( predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term.label == "order"), None iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
) )
assert predicted_order_tag is not None assert predicted_order_tag is not None
assert predicted_order_tag.score == 1 assert predicted_order_tag.score == 1
@ -187,7 +186,6 @@ def test_encoding_decoding_roundtrip_recovers_object(
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
build_from_config, build_from_config,
sample_term_registry,
recording, recording,
): ):
yaml_content = """ yaml_content = """
@ -217,10 +215,9 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content) _, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation( se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry), sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Myotis myotis")], tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
) )
clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
@ -257,14 +254,16 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
assert len(recovered.tags) == 2 assert len(recovered.tags) == 2
predicted_species_tag = next( predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == species), None iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
) )
assert predicted_species_tag is not None assert predicted_species_tag is not None
assert predicted_species_tag.score == 1 assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Myotis myotis" assert predicted_species_tag.tag.value == "Myotis myotis"
predicted_order_tag = next( predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term.label == "order"), None iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
) )
assert predicted_order_tag is not None assert predicted_order_tag is not None
assert predicted_order_tag.score == 1 assert predicted_order_tag.score == 1