Fixing small errors in tests

This commit is contained in:
mbsantiago 2025-04-22 08:51:21 +01:00
parent ece1a2073d
commit 257e1e01bf
5 changed files with 134 additions and 17 deletions

View File

@ -7,7 +7,20 @@ import pytest
import soundfile as sf
from soundevent import data, terms
from batdetect2.targets import call_type
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import (
TargetConfig,
TermRegistry,
build_targets,
call_type,
)
from batdetect2.targets.classes import ClassesConfig, TargetClass
from batdetect2.targets.filtering import FilterConfig, FilterRule
from batdetect2.targets.terms import TagInfo
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.types import ClipLabeller
@pytest.fixture
@ -293,3 +306,80 @@ def create_annotation_project():
)
return factory
@pytest.fixture
def sample_term_registry() -> TermRegistry:
"""Fixture for a sample TermRegistry."""
registry = TermRegistry()
registry.add_custom_term("class")
registry.add_custom_term("order")
registry.add_custom_term("species")
registry.add_custom_term("call_type")
registry.add_custom_term("quality")
return registry
@pytest.fixture
def sample_preprocessor() -> PreprocessorProtocol:
return build_preprocessor()
@pytest.fixture
def bat_tag() -> TagInfo:
return TagInfo(key="class", value="bat")
@pytest.fixture
def noise_tag() -> TagInfo:
return TagInfo(key="class", value="noise")
@pytest.fixture
def myomyo_tag() -> TagInfo:
return TagInfo(key="species", value="Myotis myotis")
@pytest.fixture
def pippip_tag() -> TagInfo:
return TagInfo(key="species", value="Pipistrellus pipistrellus")
@pytest.fixture
def sample_target_config(
sample_term_registry: TermRegistry,
bat_tag: TagInfo,
noise_tag: TagInfo,
myomyo_tag: TagInfo,
pippip_tag: TagInfo,
) -> TargetConfig:
return TargetConfig(
filtering=FilterConfig(
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
),
classes=ClassesConfig(
classes=[
TargetClass(name="pippip", tags=[pippip_tag]),
TargetClass(name="myomyo", tags=[myomyo_tag]),
],
generic_class=[bat_tag],
),
)
@pytest.fixture
def sample_targets(
sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
) -> TargetProtocol:
return build_targets(
sample_target_config,
term_registry=sample_term_registry,
)
@pytest.fixture
def sample_labeller(
sample_targets: TargetProtocol,
) -> ClipLabeller:
return build_clip_labeler(sample_targets)

View File

@ -13,9 +13,9 @@ from batdetect2.targets.classes import (
_get_default_class_name,
_get_default_classes,
_is_target_class,
build_generic_class_tags,
build_sound_event_decoder,
build_sound_event_encoder,
build_generic_class_tags,
get_class_names_from_config,
load_classes_config,
load_decoder_from_config,

View File

@ -7,15 +7,14 @@ from soundevent import data
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
build_sound_event_filter,
build_filter_from_rule,
build_sound_event_filter,
contains_tags,
does_not_have_tags,
equal_tags,
has_any_tag,
load_filter_config,
load_filter_from_config,
merge_filters,
)
from batdetect2.targets.terms import TagInfo, generic_class

View File

@ -9,11 +9,13 @@ from batdetect2.targets import (
ReplaceRule,
TagInfo,
TransformConfig,
build_transform_from_rule,
build_transformation_from_config,
)
from batdetect2.targets.terms import TermRegistry
from batdetect2.targets.transform import DerivationRegistry
from batdetect2.targets.transform import (
DerivationRegistry,
build_transform_from_rule,
)
@pytest.fixture

View File

@ -4,7 +4,11 @@ import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.targets.rois import ROIConfig
from batdetect2.targets.terms import TagInfo, TermRegistry
from batdetect2.train.labels import generate_heatmaps
from tests.test_targets.test_transform import term_registry
recording = data.Recording(
samplerate=256_000,
@ -22,7 +26,9 @@ clip = data.Clip(
)
def test_generated_heatmaps_have_correct_dimensions():
def test_generated_heatmaps_have_correct_dimensions(
sample_targets: TargetProtocol,
):
spec = xr.DataArray(
data=np.random.rand(100, 100),
dims=["time", "frequency"],
@ -49,8 +55,7 @@ def test_generated_heatmaps_have_correct_dimensions():
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation.sound_events,
spec,
class_names=["bat", "cat"],
encoder=lambda _: "bat",
targets=sample_targets,
)
assert isinstance(detection_heatmap, xr.DataArray)
@ -60,7 +65,10 @@ def test_generated_heatmaps_have_correct_dimensions():
assert isinstance(class_heatmap, xr.DataArray)
assert class_heatmap.shape == (2, 100, 100)
assert class_heatmap.dims == ("category", "time", "frequency")
assert class_heatmap.coords["category"].values.tolist() == ["bat", "cat"]
assert class_heatmap.coords["category"].values.tolist() == [
"pippip",
"myomyo",
]
assert isinstance(size_heatmap, xr.DataArray)
assert size_heatmap.shape == (2, 100, 100)
@ -71,7 +79,22 @@ def test_generated_heatmaps_have_correct_dimensions():
]
def test_generated_heatmap_are_non_zero_at_correct_positions():
def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
pippip_tag: TagInfo,
):
config = sample_target_config.model_copy(
update=dict(
roi=ROIConfig(
time_scale=1,
frequency_scale=1,
)
)
)
targets = build_targets(config, term_registry=sample_term_registry)
spec = xr.DataArray(
data=np.random.rand(100, 100),
dims=["time", "frequency"],
@ -91,6 +114,12 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
coordinates=[10, 10, 20, 20],
),
),
tags=[
data.Tag(
term=sample_term_registry[pippip_tag.key],
value=pippip_tag.value,
)
],
)
],
)
@ -98,13 +127,10 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation.sound_events,
spec,
class_names=["bat", "cat"],
encoder=lambda _: "bat",
time_scale=1,
frequency_scale=1,
targets=targets,
)
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
assert class_heatmap.sel(time=10, frequency=10, category="bat") == 1.0
assert class_heatmap.sel(time=10, frequency=10, category="cat") == 0.0
assert class_heatmap.sel(time=10, frequency=10, category="pippip") == 1.0
assert class_heatmap.sel(time=10, frequency=10, category="myomyo") == 0.0
assert detection_heatmap.sel(time=10, frequency=10) == 1.0