From 59aaf07af5a017d88cea9255bc84f1e5a32b2738 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 12 Aug 2025 18:44:18 +0100 Subject: [PATCH] Update tests after incorporating term registry from soundevent --- .../test_annotations/test_batdetect2.py | 14 +- tests/test_postprocessing/test_decoding.py | 231 +++++++++--------- tests/test_targets/test_classes.py | 118 +++------ tests/test_train/test_labels.py | 10 +- tests/test_train/test_preprocessing.py | 35 ++- 5 files changed, 172 insertions(+), 236 deletions(-) diff --git a/tests/test_data/test_annotations/test_batdetect2.py b/tests/test_data/test_annotations/test_batdetect2.py index 5274bc0..83ef326 100644 --- a/tests/test_data/test_annotations/test_batdetect2.py +++ b/tests/test_data/test_annotations/test_batdetect2.py @@ -254,11 +254,13 @@ class TestLoadBatDetect2Files: assert clip_ann.clip.recording.duration == 5.0 assert len(clip_ann.sound_events) == 1 assert clip_ann.notes[0].message == "Standard notes." - clip_tag = data.find_tag(clip_ann.tags, "Class") + clip_tag = data.find_tag(clip_ann.tags, term_label="Class") assert clip_tag is not None assert clip_tag.value == "Myotis" - recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class") + recording_tag = data.find_tag( + clip_ann.clip.recording.tags, term_label="Class" + ) assert recording_tag is not None assert recording_tag.value == "Myotis" @@ -271,15 +273,15 @@ class TestLoadBatDetect2Files: 40000, ] - se_class_tag = data.find_tag(se_ann.tags, "Class") + se_class_tag = data.find_tag(se_ann.tags, term_label="Class") assert se_class_tag is not None assert se_class_tag.value == "Myotis" - se_event_tag = data.find_tag(se_ann.tags, "Call Type") + se_event_tag = data.find_tag(se_ann.tags, term_label="Call Type") assert se_event_tag is not None assert se_event_tag.value == "Echolocation" - se_individual_tag = data.find_tag(se_ann.tags, "Individual") + se_individual_tag = data.find_tag(se_ann.tags, term_label="Individual") assert se_individual_tag is not None assert se_individual_tag.value == "0" @@ -439,7 +441,7 @@ class TestLoadBatDetect2Merged: assert clip_ann.clip.recording.duration == 5.0 assert len(clip_ann.sound_events) == 1 - clip_class_tag = data.find_tag(clip_ann.tags, "Class") + clip_class_tag = data.find_tag(clip_ann.tags, term_label="Class") assert clip_class_tag is not None assert clip_class_tag.value == "Myotis" diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index 9d06919..4dbb431 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional import numpy as np import pytest @@ -16,35 +16,11 @@ from batdetect2.postprocess.decoding import ( get_prediction_features, ) from batdetect2.postprocess.types import RawPrediction +from batdetect2.targets.types import TargetProtocol @pytest.fixture -def dummy_geometry_builder(): - """A simple GeometryBuilder that creates a BBox around the point.""" - - def _builder( - position: Tuple[float, float], - dimensions: xr.DataArray, - class_name: Optional[str] = None, - ) -> data.BoundingBox: - time, freq = position - width = dimensions.sel(dimension="width").item() - height = dimensions.sel(dimension="height").item() - return data.BoundingBox( - coordinates=[ - time - width / 2, - freq - height / 2, - time + width / 2, - freq + height / 2, - ] - ) - - return _builder - - -@pytest.fixture -def dummy_sound_event_decoder(): - """A simple SoundEventDecoder mapping names to tags.""" +def dummy_targets() -> TargetProtocol: tag_map = { "bat": [ data.Tag(term=data.term_from_key(key="species"), value="Myotis") @@ -57,18 +33,56 @@ def dummy_sound_event_decoder(): ], } - def _decoder(class_name: str) -> List[data.Tag]: - return tag_map.get(class_name.lower(), []) + class DummyTargets(TargetProtocol): + class_names = [ + "bat", + "noise", + "unknown", + ] - return _decoder + dimension_names = ["width", "height"] + generic_class_tags = [ + data.Tag( + term=data.term_from_key(key="detector"), value="batdetect2" + ) + ] -@pytest.fixture -def generic_tags() -> List[data.Tag]: - """Sample generic tags.""" - return [ - data.Tag(term=data.term_from_key(key="detector"), value="batdetect2") - ] + def filter(self, sound_event: data.SoundEventAnnotation): + return True + + def transform(self, sound_event: data.SoundEventAnnotation): + return sound_event + + def encode_class( + self, sound_event: data.SoundEventAnnotation + ) -> Optional[str]: + return "bat" + + def decode_class(self, class_label: str) -> List[data.Tag]: + return tag_map.get(class_label.lower(), []) + + def encode_roi(self, sound_event: data.SoundEventAnnotation): + return np.array([0.0, 0.0]), np.array([0.0, 0.0]) + + def decode_roi( + self, + position, + size: np.ndarray, + class_name: Optional[str] = None, + ): + time, freq = position + width, height = size + return data.BoundingBox( + coordinates=[ + time - width / 2, + freq - height / 2, + time + width / 2, + freq + height / 2, + ] + ) + + return DummyTargets() @pytest.fixture @@ -156,7 +170,7 @@ def empty_detection_dataset() -> xr.Dataset: """Creates an empty detection dataset with correct structure.""" detection_coords = { "time": ("detection", np.array([], dtype=np.float64)), - "freq": ("detection", np.array([], dtype=np.float64)), + "frequency": ("detection", np.array([], dtype=np.float64)), } scores = xr.DataArray( np.array([], dtype=np.float64), @@ -184,7 +198,7 @@ def empty_detection_dataset() -> xr.Dataset: ) return xr.Dataset( { - "score": scores, + "scores": scores, "dimensions": dimensions, "classes": classes, "features": features, @@ -215,8 +229,8 @@ def sample_raw_predictions() -> List[RawPrediction]: 300 + 16 / 2, ] ), - class_scores=pred1_classes, - features=pred1_features, + class_scores=pred1_classes.values, + features=pred1_features.values, ) pred2_classes = xr.DataArray( @@ -237,8 +251,8 @@ def sample_raw_predictions() -> List[RawPrediction]: 200 + 12 / 2, ] ), - class_scores=pred2_classes, - features=pred2_features, + class_scores=pred2_classes.values, + features=pred2_features.values, ) pred3_classes = xr.DataArray( @@ -259,18 +273,17 @@ def sample_raw_predictions() -> List[RawPrediction]: 60.0, ] ), - class_scores=pred3_classes, - features=pred3_features, + class_scores=pred3_classes.values, + features=pred3_features.values, ) return [pred1, pred2, pred3] -def test_convert_xr_dataset_basic( - sample_detection_dataset, dummy_geometry_builder -): +def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets): """Test basic conversion of a dataset to RawPrediction list.""" raw_predictions = convert_xr_dataset_to_raw_prediction( - sample_detection_dataset, dummy_geometry_builder + sample_detection_dataset, + dummy_targets.decode_roi, ) assert isinstance(raw_predictions, list) @@ -286,11 +299,11 @@ def test_convert_xr_dataset_basic( 20 + 7 / 2, 300 + 16 / 2, ] - xr.testing.assert_allclose( + np.testing.assert_allclose( pred1.class_scores, sample_detection_dataset["classes"].sel(detection=0), ) - xr.testing.assert_allclose( + np.testing.assert_allclose( pred1.features, sample_detection_dataset["features"].sel(detection=0) ) @@ -304,21 +317,20 @@ def test_convert_xr_dataset_basic( 10 + 3 / 2, 200 + 12 / 2, ] - xr.testing.assert_allclose( + np.testing.assert_allclose( pred2.class_scores, sample_detection_dataset["classes"].sel(detection=1), ) - xr.testing.assert_allclose( + np.testing.assert_allclose( pred2.features, sample_detection_dataset["features"].sel(detection=1) ) -def test_convert_xr_dataset_empty( - empty_detection_dataset, dummy_geometry_builder -): +def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets): """Test conversion of an empty dataset.""" raw_predictions = convert_xr_dataset_to_raw_prediction( - empty_detection_dataset, dummy_geometry_builder + empty_detection_dataset, + dummy_targets.decode_roi, ) assert isinstance(raw_predictions, list) assert len(raw_predictions) == 0 @@ -327,8 +339,7 @@ def test_convert_xr_dataset_empty( def test_convert_raw_to_sound_event_basic( sample_raw_predictions, sample_recording, - dummy_sound_event_decoder, - generic_tags, + dummy_targets, ): """Test basic conversion, default threshold, multi-label.""" @@ -337,8 +348,7 @@ def test_convert_raw_to_sound_event_basic( se_pred = convert_raw_prediction_to_sound_event_prediction( raw_prediction=raw_pred, recording=sample_recording, - sound_event_decoder=dummy_sound_event_decoder, - generic_class_tags=generic_tags, + targets=dummy_targets, ) assert isinstance(se_pred, data.SoundEventPrediction) @@ -357,10 +367,11 @@ def test_convert_raw_to_sound_event_basic( ) assert feat_dict["batdetect2:f0"] == 7.0 + generic_tags = dummy_targets.generic_class_tags expected_tags = { (generic_tags[0].term.name, generic_tags[0].value, 0.9), - ("soundevent:category", "noise", 0.85), - ("soundevent:species", "Myotis", 0.43), + ("category", "noise", 0.85), + ("dwc:scientificName", "Myotis", 0.43), } actual_tags = { (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags @@ -369,10 +380,7 @@ def test_convert_raw_to_sound_event_basic( def test_convert_raw_to_sound_event_thresholding( - sample_raw_predictions, - sample_recording, - dummy_sound_event_decoder, - generic_tags, + sample_raw_predictions, sample_recording, dummy_targets ): """Test effect of classification threshold.""" raw_pred = sample_raw_predictions[0] @@ -381,15 +389,15 @@ def test_convert_raw_to_sound_event_thresholding( se_pred = convert_raw_prediction_to_sound_event_prediction( raw_prediction=raw_pred, recording=sample_recording, - sound_event_decoder=dummy_sound_event_decoder, - generic_class_tags=generic_tags, + targets=dummy_targets, classification_threshold=high_threshold, top_class_only=False, ) + generic_tags = dummy_targets.generic_class_tags expected_tags = { (generic_tags[0].term.name, generic_tags[0].value, 0.9), - ("soundevent:category", "noise", 0.85), + ("category", "noise", 0.85), } actual_tags = { (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags @@ -400,8 +408,7 @@ def test_convert_raw_to_sound_event_thresholding( def test_convert_raw_to_sound_event_no_threshold( sample_raw_predictions, sample_recording, - dummy_sound_event_decoder, - generic_tags, + dummy_targets, ): """Test when classification_threshold is None.""" raw_pred = sample_raw_predictions[2] @@ -409,16 +416,16 @@ def test_convert_raw_to_sound_event_no_threshold( se_pred = convert_raw_prediction_to_sound_event_prediction( raw_prediction=raw_pred, recording=sample_recording, - sound_event_decoder=dummy_sound_event_decoder, - generic_class_tags=generic_tags, + targets=dummy_targets, classification_threshold=None, top_class_only=False, ) + generic_tags = dummy_targets.generic_class_tags expected_tags = { (generic_tags[0].term.name, generic_tags[0].value, 0.15), - ("soundevent:species", "Myotis", 0.05), - ("soundevent:category", "noise", 0.02), + ("dwc:scientificName", "Myotis", 0.05), + ("category", "noise", 0.02), } actual_tags = { (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags @@ -429,8 +436,7 @@ def test_convert_raw_to_sound_event_no_threshold( def test_convert_raw_to_sound_event_top_class( sample_raw_predictions, sample_recording, - dummy_sound_event_decoder, - generic_tags, + dummy_targets, ): """Test top_class_only=True behavior.""" raw_pred = sample_raw_predictions[0] @@ -438,15 +444,15 @@ def test_convert_raw_to_sound_event_top_class( se_pred = convert_raw_prediction_to_sound_event_prediction( raw_prediction=raw_pred, recording=sample_recording, - sound_event_decoder=dummy_sound_event_decoder, - generic_class_tags=generic_tags, + targets=dummy_targets, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only=True, ) + generic_tags = dummy_targets.generic_class_tags expected_tags = { (generic_tags[0].term.name, generic_tags[0].value, 0.9), - ("soundevent:category", "noise", 0.85), + ("category", "noise", 0.85), } actual_tags = { (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags @@ -457,8 +463,7 @@ def test_convert_raw_to_sound_event_top_class( def test_convert_raw_to_sound_event_all_below_threshold( sample_raw_predictions, sample_recording, - dummy_sound_event_decoder, - generic_tags, + dummy_targets, ): """Test when all class scores are below the default threshold.""" raw_pred = sample_raw_predictions[2] @@ -466,12 +471,12 @@ def test_convert_raw_to_sound_event_all_below_threshold( se_pred = convert_raw_prediction_to_sound_event_prediction( raw_prediction=raw_pred, recording=sample_recording, - sound_event_decoder=dummy_sound_event_decoder, - generic_class_tags=generic_tags, + targets=dummy_targets, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only=False, ) + generic_tags = dummy_targets.generic_class_tags expected_tags = { (generic_tags[0].term.name, generic_tags[0].value, 0.15), } @@ -484,15 +489,13 @@ def test_convert_raw_to_sound_event_all_below_threshold( def test_convert_raw_list_to_clip_basic( sample_raw_predictions, sample_clip, - dummy_sound_event_decoder, - generic_tags, + dummy_targets, ): """Test converting a list of RawPredictions to a ClipPrediction.""" clip_pred = convert_raw_predictions_to_clip_prediction( raw_predictions=sample_raw_predictions, clip=sample_clip, - sound_event_decoder=dummy_sound_event_decoder, - generic_class_tags=generic_tags, + targets=dummy_targets, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only=False, ) @@ -515,23 +518,19 @@ def test_convert_raw_list_to_clip_basic( (pt.tag.term.name, pt.tag.value, pt.score) for pt in clip_pred.sound_events[2].tags } + generic_tags = dummy_targets.generic_class_tags expected_tags3 = { (generic_tags[0].term.name, generic_tags[0].value, 0.15), } assert se_pred3_tags == expected_tags3 -def test_convert_raw_list_to_clip_empty( - sample_clip, - dummy_sound_event_decoder, - generic_tags, -): +def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets): """Test converting an empty list of RawPredictions.""" clip_pred = convert_raw_predictions_to_clip_prediction( raw_predictions=[], clip=sample_clip, - sound_event_decoder=dummy_sound_event_decoder, - generic_class_tags=generic_tags, + targets=dummy_targets, ) assert isinstance(clip_pred, data.ClipPrediction) @@ -542,16 +541,14 @@ def test_convert_raw_list_to_clip_empty( def test_convert_raw_list_to_clip_passes_args( sample_raw_predictions, sample_clip, - dummy_sound_event_decoder, - generic_tags, + dummy_targets, ): """Test that arguments like top_class_only are passed through.""" clip_pred = convert_raw_predictions_to_clip_prediction( raw_predictions=sample_raw_predictions, clip=sample_clip, - sound_event_decoder=dummy_sound_event_decoder, - generic_class_tags=generic_tags, + targets=dummy_targets, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only=True, ) @@ -562,16 +559,18 @@ def test_convert_raw_list_to_clip_passes_args( (pt.tag.term.name, pt.tag.value, pt.score) for pt in clip_pred.sound_events[0].tags } + generic_tags = dummy_targets.generic_class_tags expected_tags1 = { (generic_tags[0].term.name, generic_tags[0].value, 0.9), - ("soundevent:category", "noise", 0.85), + ("category", "noise", 0.85), } assert se_pred1_tags == expected_tags1 -def test_get_generic_tags_basic(generic_tags): +def test_get_generic_tags_basic(dummy_targets): """Test creation of generic tags with score.""" detection_score = 0.75 + generic_tags = dummy_targets.generic_class_tags predicted_tags = get_generic_tags( detection_score=detection_score, generic_class_tags=generic_tags ) @@ -589,17 +588,19 @@ def test_get_prediction_features_basic(): coords={"feature": ["feat1", "feat2", "feat3"]}, dims=["feature"], ) - features = get_prediction_features(feature_data) + features = get_prediction_features(feature_data.values) assert len(features) == 3 for feature, feat_name, feat_value in zip( - features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3] + features, + ["f0", "f1", "f2"], + [1.1, 2.2, 3.3], ): assert isinstance(feature, data.Feature) assert feature.term.name == f"batdetect2:{feat_name}" assert feature.value == feat_value -def test_get_class_tags_basic(dummy_sound_event_decoder): +def test_get_class_tags_basic(dummy_targets): """Test creation of class tags based on scores and decoder.""" class_scores = xr.DataArray( [0.6, 0.2, 0.9], @@ -607,8 +608,8 @@ def test_get_class_tags_basic(dummy_sound_event_decoder): dims=["category"], ) predicted_tags = get_class_tags( - class_scores=class_scores, - sound_event_decoder=dummy_sound_event_decoder, + class_scores=class_scores.values, + targets=dummy_targets, ) assert len(predicted_tags) == 3 tag_values = [pt.tag.value for pt in predicted_tags] @@ -622,7 +623,7 @@ def test_get_class_tags_basic(dummy_sound_event_decoder): assert 0.9 in tag_scores -def test_get_class_tags_thresholding(dummy_sound_event_decoder): +def test_get_class_tags_thresholding(dummy_targets): """Test class tag creation with a threshold.""" class_scores = xr.DataArray( [0.6, 0.2, 0.9], @@ -631,8 +632,8 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder): ) threshold = 0.5 predicted_tags = get_class_tags( - class_scores=class_scores, - sound_event_decoder=dummy_sound_event_decoder, + class_scores=class_scores.values, + targets=dummy_targets, threshold=threshold, ) @@ -643,7 +644,7 @@ def test_get_class_tags_thresholding(dummy_sound_event_decoder): assert "uncertain" in tag_values -def test_get_class_tags_top_class_only(dummy_sound_event_decoder): +def test_get_class_tags_top_class_only(dummy_targets): """Test class tag creation with top_class_only.""" class_scores = xr.DataArray( [0.6, 0.2, 0.9], @@ -651,8 +652,8 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder): dims=["category"], ) predicted_tags = get_class_tags( - class_scores=class_scores, - sound_event_decoder=dummy_sound_event_decoder, + class_scores=class_scores.values, + targets=dummy_targets, top_class_only=True, ) @@ -661,11 +662,11 @@ def test_get_class_tags_top_class_only(dummy_sound_event_decoder): assert predicted_tags[0].score == 0.9 -def test_get_class_tags_empty(dummy_sound_event_decoder): +def test_get_class_tags_empty(dummy_targets): """Test with empty class scores.""" class_scores = xr.DataArray([], coords={"category": []}, dims=["category"]) predicted_tags = get_class_tags( - class_scores=class_scores, - sound_event_decoder=dummy_sound_event_decoder, + class_scores=class_scores.values, + targets=dummy_targets, ) assert len(predicted_tags) == 0 diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py index fc4c155..587b502 100644 --- a/tests/test_targets/test_classes.py +++ b/tests/test_targets/test_classes.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from pydantic import ValidationError from soundevent import data +from soundevent.terms import get_term from batdetect2.targets.classes import ( DEFAULT_SPECIES_LIST, @@ -21,26 +22,19 @@ from batdetect2.targets.classes import ( load_decoder_from_config, load_encoder_from_config, ) -from batdetect2.targets.terms import TagInfo, TermRegistry +from batdetect2.targets.terms import TagInfo @pytest.fixture def sample_annotation( sound_event: data.SoundEvent, - sample_term_registry: TermRegistry, ) -> data.SoundEventAnnotation: """Fixture for a sample SoundEventAnnotation.""" return data.SoundEventAnnotation( sound_event=sound_event, tags=[ - data.Tag( - term=sample_term_registry.get_term("species"), - value="Pipistrellus pipistrellus", - ), - data.Tag( - term=sample_term_registry.get_term("quality"), - value="Good", - ), + data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore + data.Tag(key="quality", value="Good"), # type: ignore ], ) @@ -136,59 +130,33 @@ def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]): def test_is_target_class_match_all( sample_annotation: data.SoundEventAnnotation, - sample_term_registry: TermRegistry, ): tags = { - data.Tag( - term=sample_term_registry["species"], - value="Pipistrellus pipistrellus", - ), - data.Tag(term=sample_term_registry["quality"], value="Good"), + data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore + data.Tag(key="quality", value="Good"), # type: ignore } assert is_target_class(sample_annotation, tags, match_all=True) is True - tags = { - data.Tag( - term=sample_term_registry["species"], - value="Pipistrellus pipistrellus", - ) - } + tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore assert is_target_class(sample_annotation, tags, match_all=True) is True - tags = { - data.Tag( - term=sample_term_registry["species"], value="Myotis daubentonii" - ) - } + tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore assert is_target_class(sample_annotation, tags, match_all=True) is False def test_is_target_class_match_any( sample_annotation: data.SoundEventAnnotation, - sample_term_registry: TermRegistry, ): tags = { - data.Tag( - term=sample_term_registry["species"], - value="Pipistrellus pipistrellus", - ), - data.Tag(term=sample_term_registry["quality"], value="Good"), + data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore + data.Tag(key="quality", value="Good"), # type: ignore } assert is_target_class(sample_annotation, tags, match_all=False) is True - tags = { - data.Tag( - term=sample_term_registry["species"], - value="Pipistrellus pipistrellus", - ) - } + tags = {data.Tag(key="species", value="Pipistrellus pipistrellus")} # type: ignore assert is_target_class(sample_annotation, tags, match_all=False) is True - tags = { - data.Tag( - term=sample_term_registry["species"], value="Myotis daubentonii" - ) - } + tags = {data.Tag(key="species", value="Myotis daubentonii")} # type: ignore assert is_target_class(sample_annotation, tags, match_all=False) is False @@ -208,7 +176,6 @@ def test_get_class_names_from_config(): def test_build_encoder_from_config( sample_annotation: data.SoundEventAnnotation, - sample_term_registry: TermRegistry, ): config = ClassesConfig( classes=[ @@ -220,25 +187,18 @@ def test_build_encoder_from_config( ) ] ) - encoder = build_sound_event_encoder( - config, - term_registry=sample_term_registry, - ) + encoder = build_sound_event_encoder(config) result = encoder(sample_annotation) assert result == "pippip" config = ClassesConfig(classes=[]) - encoder = build_sound_event_encoder( - config, - term_registry=sample_term_registry, - ) + encoder = build_sound_event_encoder(config) result = encoder(sample_annotation) assert result is None def test_load_encoder_from_config_valid( sample_annotation: data.SoundEventAnnotation, - sample_term_registry: TermRegistry, create_temp_yaml: Callable[[str], Path], ): yaml_content = """ @@ -249,10 +209,7 @@ def test_load_encoder_from_config_valid( value: Pipistrellus pipistrellus """ temp_yaml_path = create_temp_yaml(yaml_content) - encoder = load_encoder_from_config( - temp_yaml_path, - term_registry=sample_term_registry, - ) + encoder = load_encoder_from_config(temp_yaml_path) # We cannot directly compare the function, so we test it. result = encoder(sample_annotation) # type: ignore assert result == "pippip" @@ -260,7 +217,6 @@ def test_load_encoder_from_config_valid( def test_load_encoder_from_config_invalid( create_temp_yaml: Callable[[str], Path], - sample_term_registry: TermRegistry, ): yaml_content = """ classes: @@ -271,10 +227,7 @@ def test_load_encoder_from_config_invalid( """ temp_yaml_path = create_temp_yaml(yaml_content) with pytest.raises(KeyError): - load_encoder_from_config( - temp_yaml_path, - term_registry=sample_term_registry, - ) + load_encoder_from_config(temp_yaml_path) def test_get_default_class_name(): @@ -291,7 +244,7 @@ def test_get_default_classes(): assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0] -def test_build_decoder_from_config(sample_term_registry: TermRegistry): +def test_build_decoder_from_config(): config = ClassesConfig( classes=[ TargetClass( @@ -304,12 +257,10 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry): ], generic_class=[TagInfo(key="order", value="Chiroptera")], ) - decoder = build_sound_event_decoder( - config, term_registry=sample_term_registry - ) + decoder = build_sound_event_decoder(config) tags = decoder("pippip") assert len(tags) == 1 - assert tags[0].term == sample_term_registry["call_type"] + assert tags[0].term == get_term("event") assert tags[0].value == "Echolocation" # Test when output_tags is None, should fall back to tags @@ -324,32 +275,25 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry): ], generic_class=[TagInfo(key="order", value="Chiroptera")], ) - decoder = build_sound_event_decoder( - config, term_registry=sample_term_registry - ) + decoder = build_sound_event_decoder(config) tags = decoder("pippip") assert len(tags) == 1 - assert tags[0].term == sample_term_registry["species"] + assert tags[0].term == get_term("species") assert tags[0].value == "Pipistrellus pipistrellus" # Test raise_on_unmapped=True - decoder = build_sound_event_decoder( - config, term_registry=sample_term_registry, raise_on_unmapped=True - ) + decoder = build_sound_event_decoder(config, raise_on_unmapped=True) with pytest.raises(ValueError): decoder("unknown_class") # Test raise_on_unmapped=False - decoder = build_sound_event_decoder( - config, term_registry=sample_term_registry, raise_on_unmapped=False - ) + decoder = build_sound_event_decoder(config, raise_on_unmapped=False) tags = decoder("unknown_class") assert len(tags) == 0 def test_load_decoder_from_config_valid( create_temp_yaml: Callable[[str], Path], - sample_term_registry: TermRegistry, ): yaml_content = """ classes: @@ -366,17 +310,15 @@ def test_load_decoder_from_config_valid( """ temp_yaml_path = create_temp_yaml(yaml_content) decoder = load_decoder_from_config( - temp_yaml_path, term_registry=sample_term_registry + temp_yaml_path, ) tags = decoder("pippip") assert len(tags) == 1 - assert tags[0].term == sample_term_registry["call_type"] + assert tags[0].term == get_term("call_type") assert tags[0].value == "Echolocation" -def test_build_generic_class_tags_from_config( - sample_term_registry: TermRegistry, -): +def test_build_generic_class_tags_from_config(): config = ClassesConfig( classes=[ TargetClass( @@ -391,11 +333,9 @@ def test_build_generic_class_tags_from_config( TagInfo(key="call_type", value="Echolocation"), ], ) - generic_tags = build_generic_class_tags( - config, term_registry=sample_term_registry - ) + generic_tags = build_generic_class_tags(config) assert len(generic_tags) == 2 - assert generic_tags[0].term == sample_term_registry["order"] + assert generic_tags[0].term == get_term("order") assert generic_tags[0].value == "Chiroptera" - assert generic_tags[1].term == sample_term_registry["call_type"] + assert generic_tags[1].term == get_term("call_type") assert generic_tags[1].value == "Echolocation" diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 6e4e23c..17e21bd 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -80,7 +80,6 @@ def test_generated_heatmaps_have_correct_dimensions( def test_generated_heatmap_are_non_zero_at_correct_positions( sample_target_config: TargetConfig, - sample_term_registry: TermRegistry, pippip_tag: TagInfo, ): config = sample_target_config.model_copy( @@ -92,7 +91,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( ) ) - targets = build_targets(config, term_registry=sample_term_registry) + targets = build_targets(config) spec = xr.DataArray( data=np.random.rand(100, 100), @@ -113,12 +112,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( coordinates=[10, 10, 20, 20], ), ), - tags=[ - data.Tag( - term=sample_term_registry[pippip_tag.key], - value=pippip_tag.value, - ) - ], + tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore ) ], ) diff --git a/tests/test_train/test_preprocessing.py b/tests/test_train/test_preprocessing.py index 8727ee9..ee97fd2 100644 --- a/tests/test_train/test_preprocessing.py +++ b/tests/test_train/test_preprocessing.py @@ -2,12 +2,12 @@ import pytest import torch import xarray as xr from soundevent import data +from soundevent.terms import get_term from batdetect2.models.types import ModelOutput from batdetect2.postprocess import build_postprocessor, load_postprocess_config from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.targets import build_targets, load_target_config -from batdetect2.targets.terms import get_term_from_key from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.preprocess import generate_train_example @@ -15,7 +15,6 @@ from batdetect2.train.preprocess import generate_train_example @pytest.fixture def build_from_config( create_temp_yaml, - sample_term_registry, ): def build(yaml_content): config_path = create_temp_yaml(yaml_content) @@ -31,9 +30,7 @@ def build_from_config( field="postprocessing", ) - targets = build_targets( - targets_config, term_registry=sample_term_registry - ) + targets = build_targets(targets_config) preprocessor = build_preprocessor(preprocessing_config) labeller = build_clip_labeler( targets=targets, @@ -54,7 +51,6 @@ def build_from_config( # TODO: better name def test_generated_train_example_has_expected_outputs( build_from_config, - sample_term_registry, recording, ): yaml_content = """ @@ -78,10 +74,11 @@ def test_generated_train_example_has_expected_outputs( _, preprocessor, labeller, _ = build_from_config(yaml_content) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) - species = get_term_from_key("species", term_registry=sample_term_registry) se1 = data.SoundEventAnnotation( sound_event=data.SoundEvent(recording=recording, geometry=geometry), - tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], + tags=[ + data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore + ], ) clip_annotation = data.ClipAnnotation( clip=data.Clip(start_time=0, end_time=0.5, recording=recording), @@ -108,7 +105,6 @@ def test_generated_train_example_has_expected_outputs( def test_encoding_decoding_roundtrip_recovers_object( build_from_config, - sample_term_registry, recording, ): yaml_content = """ @@ -131,10 +127,11 @@ def test_encoding_decoding_roundtrip_recovers_object( _, preprocessor, labeller, postprocessor = build_from_config(yaml_content) geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) - species = get_term_from_key("species", term_registry=sample_term_registry) se1 = data.SoundEventAnnotation( sound_event=data.SoundEvent(recording=recording, geometry=geometry), - tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], + tags=[ + data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore + ], ) clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) @@ -171,14 +168,16 @@ def test_encoding_decoding_roundtrip_recovers_object( assert len(recovered.tags) == 2 predicted_species_tag = next( - iter(t for t in recovered.tags if t.tag.term == species), None + iter(t for t in recovered.tags if t.tag.term == get_term("species")), + None, ) assert predicted_species_tag is not None assert predicted_species_tag.score == 1 assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus" predicted_order_tag = next( - iter(t for t in recovered.tags if t.tag.term.label == "order"), None + iter(t for t in recovered.tags if t.tag.term == get_term("order")), + None, ) assert predicted_order_tag is not None assert predicted_order_tag.score == 1 @@ -187,7 +186,6 @@ def test_encoding_decoding_roundtrip_recovers_object( def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( build_from_config, - sample_term_registry, recording, ): yaml_content = """ @@ -217,10 +215,9 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( _, preprocessor, labeller, postprocessor = build_from_config(yaml_content) geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) - species = get_term_from_key("species", term_registry=sample_term_registry) se1 = data.SoundEventAnnotation( sound_event=data.SoundEvent(recording=recording, geometry=geometry), - tags=[data.Tag(term=species, value="Myotis myotis")], + tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore ) clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) @@ -257,14 +254,16 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( assert len(recovered.tags) == 2 predicted_species_tag = next( - iter(t for t in recovered.tags if t.tag.term == species), None + iter(t for t in recovered.tags if t.tag.term == get_term("species")), + None, ) assert predicted_species_tag is not None assert predicted_species_tag.score == 1 assert predicted_species_tag.tag.value == "Myotis myotis" predicted_order_tag = next( - iter(t for t in recovered.tags if t.tag.term.label == "order"), None + iter(t for t in recovered.tags if t.tag.term == get_term("order")), + None, ) assert predicted_order_tag is not None assert predicted_order_tag.score == 1