mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fixing small errors in tests
This commit is contained in:
parent
ece1a2073d
commit
257e1e01bf
@ -7,7 +7,20 @@ import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import call_type
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
TermRegistry,
|
||||
build_targets,
|
||||
call_type,
|
||||
)
|
||||
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
||||
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -293,3 +306,80 @@ def create_annotation_project():
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_term_registry() -> TermRegistry:
|
||||
"""Fixture for a sample TermRegistry."""
|
||||
registry = TermRegistry()
|
||||
registry.add_custom_term("class")
|
||||
registry.add_custom_term("order")
|
||||
registry.add_custom_term("species")
|
||||
registry.add_custom_term("call_type")
|
||||
registry.add_custom_term("quality")
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_preprocessor() -> PreprocessorProtocol:
|
||||
return build_preprocessor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bat_tag() -> TagInfo:
|
||||
return TagInfo(key="class", value="bat")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_tag() -> TagInfo:
|
||||
return TagInfo(key="class", value="noise")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def myomyo_tag() -> TagInfo:
|
||||
return TagInfo(key="species", value="Myotis myotis")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pippip_tag() -> TagInfo:
|
||||
return TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_target_config(
|
||||
sample_term_registry: TermRegistry,
|
||||
bat_tag: TagInfo,
|
||||
noise_tag: TagInfo,
|
||||
myomyo_tag: TagInfo,
|
||||
pippip_tag: TagInfo,
|
||||
) -> TargetConfig:
|
||||
return TargetConfig(
|
||||
filtering=FilterConfig(
|
||||
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
|
||||
),
|
||||
classes=ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(name="pippip", tags=[pippip_tag]),
|
||||
TargetClass(name="myomyo", tags=[myomyo_tag]),
|
||||
],
|
||||
generic_class=[bat_tag],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_targets(
|
||||
sample_target_config: TargetConfig,
|
||||
sample_term_registry: TermRegistry,
|
||||
) -> TargetProtocol:
|
||||
return build_targets(
|
||||
sample_target_config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_labeller(
|
||||
sample_targets: TargetProtocol,
|
||||
) -> ClipLabeller:
|
||||
return build_clip_labeler(sample_targets)
|
||||
|
@ -13,9 +13,9 @@ from batdetect2.targets.classes import (
|
||||
_get_default_class_name,
|
||||
_get_default_classes,
|
||||
_is_target_class,
|
||||
build_generic_class_tags,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
build_generic_class_tags,
|
||||
get_class_names_from_config,
|
||||
load_classes_config,
|
||||
load_decoder_from_config,
|
||||
|
@ -7,15 +7,14 @@ from soundevent import data
|
||||
from batdetect2.targets.filtering import (
|
||||
FilterConfig,
|
||||
FilterRule,
|
||||
build_sound_event_filter,
|
||||
build_filter_from_rule,
|
||||
build_sound_event_filter,
|
||||
contains_tags,
|
||||
does_not_have_tags,
|
||||
equal_tags,
|
||||
has_any_tag,
|
||||
load_filter_config,
|
||||
load_filter_from_config,
|
||||
merge_filters,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo, generic_class
|
||||
|
||||
|
@ -9,11 +9,13 @@ from batdetect2.targets import (
|
||||
ReplaceRule,
|
||||
TagInfo,
|
||||
TransformConfig,
|
||||
build_transform_from_rule,
|
||||
build_transformation_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TermRegistry
|
||||
from batdetect2.targets.transform import DerivationRegistry
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
build_transform_from_rule,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -4,7 +4,11 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.targets.rois import ROIConfig
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
from tests.test_targets.test_transform import term_registry
|
||||
|
||||
recording = data.Recording(
|
||||
samplerate=256_000,
|
||||
@ -22,7 +26,9 @@ clip = data.Clip(
|
||||
)
|
||||
|
||||
|
||||
def test_generated_heatmaps_have_correct_dimensions():
|
||||
def test_generated_heatmaps_have_correct_dimensions(
|
||||
sample_targets: TargetProtocol,
|
||||
):
|
||||
spec = xr.DataArray(
|
||||
data=np.random.rand(100, 100),
|
||||
dims=["time", "frequency"],
|
||||
@ -49,8 +55,7 @@ def test_generated_heatmaps_have_correct_dimensions():
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation.sound_events,
|
||||
spec,
|
||||
class_names=["bat", "cat"],
|
||||
encoder=lambda _: "bat",
|
||||
targets=sample_targets,
|
||||
)
|
||||
|
||||
assert isinstance(detection_heatmap, xr.DataArray)
|
||||
@ -60,7 +65,10 @@ def test_generated_heatmaps_have_correct_dimensions():
|
||||
assert isinstance(class_heatmap, xr.DataArray)
|
||||
assert class_heatmap.shape == (2, 100, 100)
|
||||
assert class_heatmap.dims == ("category", "time", "frequency")
|
||||
assert class_heatmap.coords["category"].values.tolist() == ["bat", "cat"]
|
||||
assert class_heatmap.coords["category"].values.tolist() == [
|
||||
"pippip",
|
||||
"myomyo",
|
||||
]
|
||||
|
||||
assert isinstance(size_heatmap, xr.DataArray)
|
||||
assert size_heatmap.shape == (2, 100, 100)
|
||||
@ -71,7 +79,22 @@ def test_generated_heatmaps_have_correct_dimensions():
|
||||
]
|
||||
|
||||
|
||||
def test_generated_heatmap_are_non_zero_at_correct_positions():
|
||||
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
sample_target_config: TargetConfig,
|
||||
sample_term_registry: TermRegistry,
|
||||
pippip_tag: TagInfo,
|
||||
):
|
||||
config = sample_target_config.model_copy(
|
||||
update=dict(
|
||||
roi=ROIConfig(
|
||||
time_scale=1,
|
||||
frequency_scale=1,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
|
||||
spec = xr.DataArray(
|
||||
data=np.random.rand(100, 100),
|
||||
dims=["time", "frequency"],
|
||||
@ -91,6 +114,12 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
|
||||
coordinates=[10, 10, 20, 20],
|
||||
),
|
||||
),
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=sample_term_registry[pippip_tag.key],
|
||||
value=pippip_tag.value,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -98,13 +127,10 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation.sound_events,
|
||||
spec,
|
||||
class_names=["bat", "cat"],
|
||||
encoder=lambda _: "bat",
|
||||
time_scale=1,
|
||||
frequency_scale=1,
|
||||
targets=targets,
|
||||
)
|
||||
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
|
||||
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
|
||||
assert class_heatmap.sel(time=10, frequency=10, category="bat") == 1.0
|
||||
assert class_heatmap.sel(time=10, frequency=10, category="cat") == 0.0
|
||||
assert class_heatmap.sel(time=10, frequency=10, category="pippip") == 1.0
|
||||
assert class_heatmap.sel(time=10, frequency=10, category="myomyo") == 0.0
|
||||
assert detection_heatmap.sel(time=10, frequency=10) == 1.0
|
||||
|
Loading…
Reference in New Issue
Block a user