diff --git a/tests/conftest.py b/tests/conftest.py index 05fdcbb..791950c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,20 @@ import pytest import soundfile as sf from soundevent import data, terms -from batdetect2.targets import call_type +from batdetect2.preprocess import build_preprocessor +from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets import ( + TargetConfig, + TermRegistry, + build_targets, + call_type, +) +from batdetect2.targets.classes import ClassesConfig, TargetClass +from batdetect2.targets.filtering import FilterConfig, FilterRule +from batdetect2.targets.terms import TagInfo +from batdetect2.targets.types import TargetProtocol +from batdetect2.train.labels import build_clip_labeler +from batdetect2.train.types import ClipLabeller @pytest.fixture @@ -293,3 +306,80 @@ def create_annotation_project(): ) return factory + + +@pytest.fixture +def sample_term_registry() -> TermRegistry: + """Fixture for a sample TermRegistry.""" + registry = TermRegistry() + registry.add_custom_term("class") + registry.add_custom_term("order") + registry.add_custom_term("species") + registry.add_custom_term("call_type") + registry.add_custom_term("quality") + return registry + + +@pytest.fixture +def sample_preprocessor() -> PreprocessorProtocol: + return build_preprocessor() + + +@pytest.fixture +def bat_tag() -> TagInfo: + return TagInfo(key="class", value="bat") + + +@pytest.fixture +def noise_tag() -> TagInfo: + return TagInfo(key="class", value="noise") + + +@pytest.fixture +def myomyo_tag() -> TagInfo: + return TagInfo(key="species", value="Myotis myotis") + + +@pytest.fixture +def pippip_tag() -> TagInfo: + return TagInfo(key="species", value="Pipistrellus pipistrellus") + + +@pytest.fixture +def sample_target_config( + sample_term_registry: TermRegistry, + bat_tag: TagInfo, + noise_tag: TagInfo, + myomyo_tag: TagInfo, + pippip_tag: TagInfo, +) -> TargetConfig: + return TargetConfig( + filtering=FilterConfig( + rules=[FilterRule(match_type="exclude", tags=[noise_tag])] + ), + classes=ClassesConfig( + classes=[ + TargetClass(name="pippip", tags=[pippip_tag]), + TargetClass(name="myomyo", tags=[myomyo_tag]), + ], + generic_class=[bat_tag], + ), + ) + + +@pytest.fixture +def sample_targets( + sample_target_config: TargetConfig, + sample_term_registry: TermRegistry, +) -> TargetProtocol: + return build_targets( + sample_target_config, + term_registry=sample_term_registry, + ) + + +@pytest.fixture +def sample_labeller( + sample_targets: TargetProtocol, +) -> ClipLabeller: + return build_clip_labeler(sample_targets) diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py index c75c4b2..bda228b 100644 --- a/tests/test_targets/test_classes.py +++ b/tests/test_targets/test_classes.py @@ -13,9 +13,9 @@ from batdetect2.targets.classes import ( _get_default_class_name, _get_default_classes, _is_target_class, + build_generic_class_tags, build_sound_event_decoder, build_sound_event_encoder, - build_generic_class_tags, get_class_names_from_config, load_classes_config, load_decoder_from_config, diff --git a/tests/test_targets/test_filtering.py b/tests/test_targets/test_filtering.py index 426266c..069f42c 100644 --- a/tests/test_targets/test_filtering.py +++ b/tests/test_targets/test_filtering.py @@ -7,15 +7,14 @@ from soundevent import data from batdetect2.targets.filtering import ( FilterConfig, FilterRule, - build_sound_event_filter, build_filter_from_rule, + build_sound_event_filter, contains_tags, does_not_have_tags, equal_tags, has_any_tag, load_filter_config, load_filter_from_config, - merge_filters, ) from batdetect2.targets.terms import TagInfo, generic_class diff --git a/tests/test_targets/test_transform.py b/tests/test_targets/test_transform.py index c55ec07..ababd8d 100644 --- a/tests/test_targets/test_transform.py +++ b/tests/test_targets/test_transform.py @@ -9,11 +9,13 @@ from batdetect2.targets import ( ReplaceRule, TagInfo, TransformConfig, - build_transform_from_rule, build_transformation_from_config, ) from batdetect2.targets.terms import TermRegistry -from batdetect2.targets.transform import DerivationRegistry +from batdetect2.targets.transform import ( + DerivationRegistry, + build_transform_from_rule, +) @pytest.fixture diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index e86aae7..9f55242 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -4,7 +4,11 @@ import numpy as np import xarray as xr from soundevent import data +from batdetect2.targets import TargetConfig, TargetProtocol, build_targets +from batdetect2.targets.rois import ROIConfig +from batdetect2.targets.terms import TagInfo, TermRegistry from batdetect2.train.labels import generate_heatmaps +from tests.test_targets.test_transform import term_registry recording = data.Recording( samplerate=256_000, @@ -22,7 +26,9 @@ clip = data.Clip( ) -def test_generated_heatmaps_have_correct_dimensions(): +def test_generated_heatmaps_have_correct_dimensions( + sample_targets: TargetProtocol, +): spec = xr.DataArray( data=np.random.rand(100, 100), dims=["time", "frequency"], @@ -49,8 +55,7 @@ def test_generated_heatmaps_have_correct_dimensions(): detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( clip_annotation.sound_events, spec, - class_names=["bat", "cat"], - encoder=lambda _: "bat", + targets=sample_targets, ) assert isinstance(detection_heatmap, xr.DataArray) @@ -60,7 +65,10 @@ def test_generated_heatmaps_have_correct_dimensions(): assert isinstance(class_heatmap, xr.DataArray) assert class_heatmap.shape == (2, 100, 100) assert class_heatmap.dims == ("category", "time", "frequency") - assert class_heatmap.coords["category"].values.tolist() == ["bat", "cat"] + assert class_heatmap.coords["category"].values.tolist() == [ + "pippip", + "myomyo", + ] assert isinstance(size_heatmap, xr.DataArray) assert size_heatmap.shape == (2, 100, 100) @@ -71,7 +79,22 @@ def test_generated_heatmaps_have_correct_dimensions(): ] -def test_generated_heatmap_are_non_zero_at_correct_positions(): +def test_generated_heatmap_are_non_zero_at_correct_positions( + sample_target_config: TargetConfig, + sample_term_registry: TermRegistry, + pippip_tag: TagInfo, +): + config = sample_target_config.model_copy( + update=dict( + roi=ROIConfig( + time_scale=1, + frequency_scale=1, + ) + ) + ) + + targets = build_targets(config, term_registry=sample_term_registry) + spec = xr.DataArray( data=np.random.rand(100, 100), dims=["time", "frequency"], @@ -91,6 +114,12 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(): coordinates=[10, 10, 20, 20], ), ), + tags=[ + data.Tag( + term=sample_term_registry[pippip_tag.key], + value=pippip_tag.value, + ) + ], ) ], ) @@ -98,13 +127,10 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(): detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( clip_annotation.sound_events, spec, - class_names=["bat", "cat"], - encoder=lambda _: "bat", - time_scale=1, - frequency_scale=1, + targets=targets, ) assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10 assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10 - assert class_heatmap.sel(time=10, frequency=10, category="bat") == 1.0 - assert class_heatmap.sel(time=10, frequency=10, category="cat") == 0.0 + assert class_heatmap.sel(time=10, frequency=10, category="pippip") == 1.0 + assert class_heatmap.sel(time=10, frequency=10, category="myomyo") == 0.0 assert detection_heatmap.sel(time=10, frequency=10) == 1.0