mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Update tests after incorporating term registry from soundevent
This commit is contained in:
parent
51d0a49da9
commit
59aaf07af5
@ -254,11 +254,13 @@ class TestLoadBatDetect2Files:
|
||||
assert clip_ann.clip.recording.duration == 5.0
|
||||
assert len(clip_ann.sound_events) == 1
|
||||
assert clip_ann.notes[0].message == "Standard notes."
|
||||
clip_tag = data.find_tag(clip_ann.tags, "Class")
|
||||
clip_tag = data.find_tag(clip_ann.tags, term_label="Class")
|
||||
assert clip_tag is not None
|
||||
assert clip_tag.value == "Myotis"
|
||||
|
||||
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
|
||||
recording_tag = data.find_tag(
|
||||
clip_ann.clip.recording.tags, term_label="Class"
|
||||
)
|
||||
assert recording_tag is not None
|
||||
assert recording_tag.value == "Myotis"
|
||||
|
||||
@ -271,15 +273,15 @@ class TestLoadBatDetect2Files:
|
||||
40000,
|
||||
]
|
||||
|
||||
se_class_tag = data.find_tag(se_ann.tags, "Class")
|
||||
se_class_tag = data.find_tag(se_ann.tags, term_label="Class")
|
||||
assert se_class_tag is not None
|
||||
assert se_class_tag.value == "Myotis"
|
||||
|
||||
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
|
||||
se_event_tag = data.find_tag(se_ann.tags, term_label="Call Type")
|
||||
assert se_event_tag is not None
|
||||
assert se_event_tag.value == "Echolocation"
|
||||
|
||||
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
|
||||
se_individual_tag = data.find_tag(se_ann.tags, term_label="Individual")
|
||||
assert se_individual_tag is not None
|
||||
assert se_individual_tag.value == "0"
|
||||
|
||||
@ -439,7 +441,7 @@ class TestLoadBatDetect2Merged:
|
||||
assert clip_ann.clip.recording.duration == 5.0
|
||||
assert len(clip_ann.sound_events) == 1
|
||||
|
||||
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
|
||||
clip_class_tag = data.find_tag(clip_ann.tags, term_label="Class")
|
||||
assert clip_class_tag is not None
|
||||
assert clip_class_tag.value == "Myotis"
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -16,35 +16,11 @@ from batdetect2.postprocess.decoding import (
|
||||
get_prediction_features,
|
||||
)
|
||||
from batdetect2.postprocess.types import RawPrediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_geometry_builder():
|
||||
"""A simple GeometryBuilder that creates a BBox around the point."""
|
||||
|
||||
def _builder(
|
||||
position: Tuple[float, float],
|
||||
dimensions: xr.DataArray,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.BoundingBox:
|
||||
time, freq = position
|
||||
width = dimensions.sel(dimension="width").item()
|
||||
height = dimensions.sel(dimension="height").item()
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - width / 2,
|
||||
freq - height / 2,
|
||||
time + width / 2,
|
||||
freq + height / 2,
|
||||
]
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_sound_event_decoder():
|
||||
"""A simple SoundEventDecoder mapping names to tags."""
|
||||
def dummy_targets() -> TargetProtocol:
|
||||
tag_map = {
|
||||
"bat": [
|
||||
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
|
||||
@ -57,18 +33,56 @@ def dummy_sound_event_decoder():
|
||||
],
|
||||
}
|
||||
|
||||
def _decoder(class_name: str) -> List[data.Tag]:
|
||||
return tag_map.get(class_name.lower(), [])
|
||||
class DummyTargets(TargetProtocol):
|
||||
class_names = [
|
||||
"bat",
|
||||
"noise",
|
||||
"unknown",
|
||||
]
|
||||
|
||||
return _decoder
|
||||
dimension_names = ["width", "height"]
|
||||
|
||||
generic_class_tags = [
|
||||
data.Tag(
|
||||
term=data.term_from_key(key="detector"), value="batdetect2"
|
||||
)
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def generic_tags() -> List[data.Tag]:
|
||||
"""Sample generic tags."""
|
||||
return [
|
||||
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
|
||||
]
|
||||
def filter(self, sound_event: data.SoundEventAnnotation):
|
||||
return True
|
||||
|
||||
def transform(self, sound_event: data.SoundEventAnnotation):
|
||||
return sound_event
|
||||
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
return "bat"
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
return tag_map.get(class_label.lower(), [])
|
||||
|
||||
def encode_roi(self, sound_event: data.SoundEventAnnotation):
|
||||
return np.array([0.0, 0.0]), np.array([0.0, 0.0])
|
||||
|
||||
def decode_roi(
|
||||
self,
|
||||
position,
|
||||
size: np.ndarray,
|
||||
class_name: Optional[str] = None,
|
||||
):
|
||||
time, freq = position
|
||||
width, height = size
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - width / 2,
|
||||
freq - height / 2,
|
||||
time + width / 2,
|
||||
freq + height / 2,
|
||||
]
|
||||
)
|
||||
|
||||
return DummyTargets()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -156,7 +170,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
||||
"""Creates an empty detection dataset with correct structure."""
|
||||
detection_coords = {
|
||||
"time": ("detection", np.array([], dtype=np.float64)),
|
||||
"freq": ("detection", np.array([], dtype=np.float64)),
|
||||
"frequency": ("detection", np.array([], dtype=np.float64)),
|
||||
}
|
||||
scores = xr.DataArray(
|
||||
np.array([], dtype=np.float64),
|
||||
@ -184,7 +198,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
||||
)
|
||||
return xr.Dataset(
|
||||
{
|
||||
"score": scores,
|
||||
"scores": scores,
|
||||
"dimensions": dimensions,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
@ -215,8 +229,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
300 + 16 / 2,
|
||||
]
|
||||
),
|
||||
class_scores=pred1_classes,
|
||||
features=pred1_features,
|
||||
class_scores=pred1_classes.values,
|
||||
features=pred1_features.values,
|
||||
)
|
||||
|
||||
pred2_classes = xr.DataArray(
|
||||
@ -237,8 +251,8 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
200 + 12 / 2,
|
||||
]
|
||||
),
|
||||
class_scores=pred2_classes,
|
||||
features=pred2_features,
|
||||
class_scores=pred2_classes.values,
|
||||
features=pred2_features.values,
|
||||
)
|
||||
|
||||
pred3_classes = xr.DataArray(
|
||||
@ -259,18 +273,17 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
60.0,
|
||||
]
|
||||
),
|
||||
class_scores=pred3_classes,
|
||||
features=pred3_features,
|
||||
class_scores=pred3_classes.values,
|
||||
features=pred3_features.values,
|
||||
)
|
||||
return [pred1, pred2, pred3]
|
||||
|
||||
|
||||
def test_convert_xr_dataset_basic(
|
||||
sample_detection_dataset, dummy_geometry_builder
|
||||
):
|
||||
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
|
||||
"""Test basic conversion of a dataset to RawPrediction list."""
|
||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||
sample_detection_dataset, dummy_geometry_builder
|
||||
sample_detection_dataset,
|
||||
dummy_targets.decode_roi,
|
||||
)
|
||||
|
||||
assert isinstance(raw_predictions, list)
|
||||
@ -286,11 +299,11 @@ def test_convert_xr_dataset_basic(
|
||||
20 + 7 / 2,
|
||||
300 + 16 / 2,
|
||||
]
|
||||
xr.testing.assert_allclose(
|
||||
np.testing.assert_allclose(
|
||||
pred1.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=0),
|
||||
)
|
||||
xr.testing.assert_allclose(
|
||||
np.testing.assert_allclose(
|
||||
pred1.features, sample_detection_dataset["features"].sel(detection=0)
|
||||
)
|
||||
|
||||
@ -304,21 +317,20 @@ def test_convert_xr_dataset_basic(
|
||||
10 + 3 / 2,
|
||||
200 + 12 / 2,
|
||||
]
|
||||
xr.testing.assert_allclose(
|
||||
np.testing.assert_allclose(
|
||||
pred2.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=1),
|
||||
)
|
||||
xr.testing.assert_allclose(
|
||||
np.testing.assert_allclose(
|
||||
pred2.features, sample_detection_dataset["features"].sel(detection=1)
|
||||
)
|
||||
|
||||
|
||||
def test_convert_xr_dataset_empty(
|
||||
empty_detection_dataset, dummy_geometry_builder
|
||||
):
|
||||
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
|
||||
"""Test conversion of an empty dataset."""
|
||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||
empty_detection_dataset, dummy_geometry_builder
|
||||
empty_detection_dataset,
|
||||
dummy_targets.decode_roi,
|
||||
)
|
||||
assert isinstance(raw_predictions, list)
|
||||
assert len(raw_predictions) == 0
|
||||
@ -327,8 +339,7 @@ def test_convert_xr_dataset_empty(
|
||||
def test_convert_raw_to_sound_event_basic(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test basic conversion, default threshold, multi-label."""
|
||||
|
||||
@ -337,8 +348,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
)
|
||||
|
||||
assert isinstance(se_pred, data.SoundEventPrediction)
|
||||
@ -357,10 +367,11 @@ def test_convert_raw_to_sound_event_basic(
|
||||
)
|
||||
assert feat_dict["batdetect2:f0"] == 7.0
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("soundevent:species", "Myotis", 0.43),
|
||||
("category", "noise", 0.85),
|
||||
("dwc:scientificName", "Myotis", 0.43),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
@ -369,10 +380,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_thresholding(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
sample_raw_predictions, sample_recording, dummy_targets
|
||||
):
|
||||
"""Test effect of classification threshold."""
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
@ -381,15 +389,15 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=high_threshold,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("category", "noise", 0.85),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
@ -400,8 +408,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
def test_convert_raw_to_sound_event_no_threshold(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test when classification_threshold is None."""
|
||||
raw_pred = sample_raw_predictions[2]
|
||||
@ -409,16 +416,16 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=None,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
("soundevent:species", "Myotis", 0.05),
|
||||
("soundevent:category", "noise", 0.02),
|
||||
("dwc:scientificName", "Myotis", 0.05),
|
||||
("category", "noise", 0.02),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
@ -429,8 +436,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
def test_convert_raw_to_sound_event_top_class(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test top_class_only=True behavior."""
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
@ -438,15 +444,15 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=True,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("category", "noise", 0.85),
|
||||
}
|
||||
actual_tags = {
|
||||
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
|
||||
@ -457,8 +463,7 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test when all class scores are below the default threshold."""
|
||||
raw_pred = sample_raw_predictions[2]
|
||||
@ -466,12 +471,12 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
}
|
||||
@ -484,15 +489,13 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
def test_convert_raw_list_to_clip_basic(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test converting a list of RawPredictions to a ClipPrediction."""
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=sample_raw_predictions,
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=False,
|
||||
)
|
||||
@ -515,23 +518,19 @@ def test_convert_raw_list_to_clip_basic(
|
||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[2].tags
|
||||
}
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags3 = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
}
|
||||
assert se_pred3_tags == expected_tags3
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_empty(
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
||||
"""Test converting an empty list of RawPredictions."""
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=[],
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
)
|
||||
|
||||
assert isinstance(clip_pred, data.ClipPrediction)
|
||||
@ -542,16 +541,14 @@ def test_convert_raw_list_to_clip_empty(
|
||||
def test_convert_raw_list_to_clip_passes_args(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
dummy_targets,
|
||||
):
|
||||
"""Test that arguments like top_class_only are passed through."""
|
||||
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=sample_raw_predictions,
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
targets=dummy_targets,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=True,
|
||||
)
|
||||
@ -562,16 +559,18 @@ def test_convert_raw_list_to_clip_passes_args(
|
||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[0].tags
|
||||
}
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
expected_tags1 = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("soundevent:category", "noise", 0.85),
|
||||
("category", "noise", 0.85),
|
||||
}
|
||||
assert se_pred1_tags == expected_tags1
|
||||
|
||||
|
||||
def test_get_generic_tags_basic(generic_tags):
|
||||
def test_get_generic_tags_basic(dummy_targets):
|
||||
"""Test creation of generic tags with score."""
|
||||
detection_score = 0.75
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
predicted_tags = get_generic_tags(
|
||||
detection_score=detection_score, generic_class_tags=generic_tags
|
||||
)
|
||||
@ -589,17 +588,19 @@ def test_get_prediction_features_basic():
|
||||
coords={"feature": ["feat1", "feat2", "feat3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
features = get_prediction_features(feature_data)
|
||||
features = get_prediction_features(feature_data.values)
|
||||
assert len(features) == 3
|
||||
for feature, feat_name, feat_value in zip(
|
||||
features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3]
|
||||
features,
|
||||
["f0", "f1", "f2"],
|
||||
[1.1, 2.2, 3.3],
|
||||
):
|
||||
assert isinstance(feature, data.Feature)
|
||||
assert feature.term.name == f"batdetect2:{feat_name}"
|
||||
assert feature.value == feat_value
|
||||
|
||||
|
||||
def test_get_class_tags_basic(dummy_sound_event_decoder):
|
||||
def test_get_class_tags_basic(dummy_targets):
|
||||
"""Test creation of class tags based on scores and decoder."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
@ -607,8 +608,8 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
|
||||
dims=["category"],
|
||||
)
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
class_scores=class_scores.values,
|
||||
targets=dummy_targets,
|
||||
)
|
||||
assert len(predicted_tags) == 3
|
||||
tag_values = [pt.tag.value for pt in predicted_tags]
|
||||
@ -622,7 +623,7 @@ def test_get_class_tags_basic(dummy_sound_event_decoder):
|
||||
assert 0.9 in tag_scores
|
||||
|
||||
|
||||
def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
||||
def test_get_class_tags_thresholding(dummy_targets):
|
||||
"""Test class tag creation with a threshold."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
@ -631,8 +632,8 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
||||
)
|
||||
threshold = 0.5
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
class_scores=class_scores.values,
|
||||
targets=dummy_targets,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
@ -643,7 +644,7 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder):
|
||||
assert "uncertain" in tag_values
|
||||
|
||||
|
||||
def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
||||
def test_get_class_tags_top_class_only(dummy_targets):
|
||||
"""Test class tag creation with top_class_only."""
|
||||
class_scores = xr.DataArray(
|
||||
[0.6, 0.2, 0.9],
|
||||
@ -651,8 +652,8 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
||||
dims=["category"],
|
||||
)
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
class_scores=class_scores.values,
|
||||
targets=dummy_targets,
|
||||
top_class_only=True,
|
||||
)
|
||||
|
||||
@ -661,11 +662,11 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
|
||||
assert predicted_tags[0].score == 0.9
|
||||
|
||||
|
||||
def test_get_class_tags_empty(dummy_sound_event_decoder):
|
||||
def test_get_class_tags_empty(dummy_targets):
|
||||
"""Test with empty class scores."""
|
||||
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
|
||||
predicted_tags = get_class_tags(
|
||||
class_scores=class_scores,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
class_scores=class_scores.values,
|
||||
targets=dummy_targets,
|
||||
)
|
||||
assert len(predicted_tags) == 0
|
||||
|
||||
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_SPECIES_LIST,
|
||||
@ -21,26 +22,19 @@ from batdetect2.targets.classes import (
|
||||
load_decoder_from_config,
|
||||
load_encoder_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotation(
|
||||
sound_event: data.SoundEvent,
|
||||
sample_term_registry: TermRegistry,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Fixture for a sample SoundEventAnnotation."""
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=sample_term_registry.get_term("species"),
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(
|
||||
term=sample_term_registry.get_term("quality"),
|
||||
value="Good",
|
||||
),
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
],
|
||||
)
|
||||
|
||||
@ -136,59 +130,33 @@ def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
|
||||
|
||||
def test_is_target_class_match_all(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
)
|
||||
}
|
||||
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||
)
|
||||
}
|
||||
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
||||
|
||||
|
||||
def test_is_target_class_match_any(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
data.Tag(key="quality", value="Good"), # type: ignore
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
)
|
||||
}
|
||||
tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||
)
|
||||
}
|
||||
tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
||||
|
||||
|
||||
@ -208,7 +176,6 @@ def test_get_class_names_from_config():
|
||||
|
||||
def test_build_encoder_from_config(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
@ -220,25 +187,18 @@ def test_build_encoder_from_config(
|
||||
)
|
||||
]
|
||||
)
|
||||
encoder = build_sound_event_encoder(
|
||||
config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
encoder = build_sound_event_encoder(config)
|
||||
result = encoder(sample_annotation)
|
||||
assert result == "pippip"
|
||||
|
||||
config = ClassesConfig(classes=[])
|
||||
encoder = build_sound_event_encoder(
|
||||
config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
encoder = build_sound_event_encoder(config)
|
||||
result = encoder(sample_annotation)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_encoder_from_config_valid(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
yaml_content = """
|
||||
@ -249,10 +209,7 @@ def test_load_encoder_from_config_valid(
|
||||
value: Pipistrellus pipistrellus
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
encoder = load_encoder_from_config(
|
||||
temp_yaml_path,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
encoder = load_encoder_from_config(temp_yaml_path)
|
||||
# We cannot directly compare the function, so we test it.
|
||||
result = encoder(sample_annotation) # type: ignore
|
||||
assert result == "pippip"
|
||||
@ -260,7 +217,6 @@ def test_load_encoder_from_config_valid(
|
||||
|
||||
def test_load_encoder_from_config_invalid(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
@ -271,10 +227,7 @@ def test_load_encoder_from_config_invalid(
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
with pytest.raises(KeyError):
|
||||
load_encoder_from_config(
|
||||
temp_yaml_path,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
load_encoder_from_config(temp_yaml_path)
|
||||
|
||||
|
||||
def test_get_default_class_name():
|
||||
@ -291,7 +244,7 @@ def test_get_default_classes():
|
||||
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
||||
|
||||
|
||||
def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
def test_build_decoder_from_config():
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
@ -304,12 +257,10 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
decoder = build_sound_event_decoder(config)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["call_type"]
|
||||
assert tags[0].term == get_term("event")
|
||||
assert tags[0].value == "Echolocation"
|
||||
|
||||
# Test when output_tags is None, should fall back to tags
|
||||
@ -324,32 +275,25 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
decoder = build_sound_event_decoder(config)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["species"]
|
||||
assert tags[0].term == get_term("species")
|
||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||
|
||||
# Test raise_on_unmapped=True
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry, raise_on_unmapped=True
|
||||
)
|
||||
decoder = build_sound_event_decoder(config, raise_on_unmapped=True)
|
||||
with pytest.raises(ValueError):
|
||||
decoder("unknown_class")
|
||||
|
||||
# Test raise_on_unmapped=False
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry, raise_on_unmapped=False
|
||||
)
|
||||
decoder = build_sound_event_decoder(config, raise_on_unmapped=False)
|
||||
tags = decoder("unknown_class")
|
||||
assert len(tags) == 0
|
||||
|
||||
|
||||
def test_load_decoder_from_config_valid(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
@ -366,17 +310,15 @@ def test_load_decoder_from_config_valid(
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
decoder = load_decoder_from_config(
|
||||
temp_yaml_path, term_registry=sample_term_registry
|
||||
temp_yaml_path,
|
||||
)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["call_type"]
|
||||
assert tags[0].term == get_term("call_type")
|
||||
assert tags[0].value == "Echolocation"
|
||||
|
||||
|
||||
def test_build_generic_class_tags_from_config(
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
def test_build_generic_class_tags_from_config():
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
@ -391,11 +333,9 @@ def test_build_generic_class_tags_from_config(
|
||||
TagInfo(key="call_type", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
generic_tags = build_generic_class_tags(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
generic_tags = build_generic_class_tags(config)
|
||||
assert len(generic_tags) == 2
|
||||
assert generic_tags[0].term == sample_term_registry["order"]
|
||||
assert generic_tags[0].term == get_term("order")
|
||||
assert generic_tags[0].value == "Chiroptera"
|
||||
assert generic_tags[1].term == sample_term_registry["call_type"]
|
||||
assert generic_tags[1].term == get_term("call_type")
|
||||
assert generic_tags[1].value == "Echolocation"
|
||||
|
||||
@ -80,7 +80,6 @@ def test_generated_heatmaps_have_correct_dimensions(
|
||||
|
||||
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
sample_target_config: TargetConfig,
|
||||
sample_term_registry: TermRegistry,
|
||||
pippip_tag: TagInfo,
|
||||
):
|
||||
config = sample_target_config.model_copy(
|
||||
@ -92,7 +91,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
)
|
||||
)
|
||||
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
targets = build_targets(config)
|
||||
|
||||
spec = xr.DataArray(
|
||||
data=np.random.rand(100, 100),
|
||||
@ -113,12 +112,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
coordinates=[10, 10, 20, 20],
|
||||
),
|
||||
),
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=sample_term_registry[pippip_tag.key],
|
||||
value=pippip_tag.value,
|
||||
)
|
||||
],
|
||||
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@ -2,12 +2,12 @@ import pytest
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.terms import get_term
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
|
||||
@ -15,7 +15,6 @@ from batdetect2.train.preprocess import generate_train_example
|
||||
@pytest.fixture
|
||||
def build_from_config(
|
||||
create_temp_yaml,
|
||||
sample_term_registry,
|
||||
):
|
||||
def build(yaml_content):
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
@ -31,9 +30,7 @@ def build_from_config(
|
||||
field="postprocessing",
|
||||
)
|
||||
|
||||
targets = build_targets(
|
||||
targets_config, term_registry=sample_term_registry
|
||||
)
|
||||
targets = build_targets(targets_config)
|
||||
preprocessor = build_preprocessor(preprocessing_config)
|
||||
labeller = build_clip_labeler(
|
||||
targets=targets,
|
||||
@ -54,7 +51,6 @@ def build_from_config(
|
||||
# TODO: better name
|
||||
def test_generated_train_example_has_expected_outputs(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
@ -78,10 +74,11 @@ def test_generated_train_example_has_expected_outputs(
|
||||
_, preprocessor, labeller, _ = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
tags=[
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
],
|
||||
)
|
||||
clip_annotation = data.ClipAnnotation(
|
||||
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
|
||||
@ -108,7 +105,6 @@ def test_generated_train_example_has_expected_outputs(
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
@ -131,10 +127,11 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
tags=[
|
||||
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
|
||||
],
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
@ -171,14 +168,16 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == species), None
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||
None,
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||
None,
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
@ -187,7 +186,6 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
@ -217,10 +215,9 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
@ -257,14 +254,16 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == species), None
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
|
||||
None,
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Myotis myotis"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
||||
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
|
||||
None,
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user