mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fixing small errors in tests
This commit is contained in:
parent
ece1a2073d
commit
257e1e01bf
@ -7,7 +7,20 @@ import pytest
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.targets import call_type
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
TermRegistry,
|
||||||
|
build_targets,
|
||||||
|
call_type,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
||||||
|
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
||||||
|
from batdetect2.targets.terms import TagInfo
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
|
from batdetect2.train.types import ClipLabeller
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -293,3 +306,80 @@ def create_annotation_project():
|
|||||||
)
|
)
|
||||||
|
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_term_registry() -> TermRegistry:
|
||||||
|
"""Fixture for a sample TermRegistry."""
|
||||||
|
registry = TermRegistry()
|
||||||
|
registry.add_custom_term("class")
|
||||||
|
registry.add_custom_term("order")
|
||||||
|
registry.add_custom_term("species")
|
||||||
|
registry.add_custom_term("call_type")
|
||||||
|
registry.add_custom_term("quality")
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_preprocessor() -> PreprocessorProtocol:
|
||||||
|
return build_preprocessor()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bat_tag() -> TagInfo:
|
||||||
|
return TagInfo(key="class", value="bat")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def noise_tag() -> TagInfo:
|
||||||
|
return TagInfo(key="class", value="noise")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def myomyo_tag() -> TagInfo:
|
||||||
|
return TagInfo(key="species", value="Myotis myotis")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pippip_tag() -> TagInfo:
|
||||||
|
return TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_target_config(
|
||||||
|
sample_term_registry: TermRegistry,
|
||||||
|
bat_tag: TagInfo,
|
||||||
|
noise_tag: TagInfo,
|
||||||
|
myomyo_tag: TagInfo,
|
||||||
|
pippip_tag: TagInfo,
|
||||||
|
) -> TargetConfig:
|
||||||
|
return TargetConfig(
|
||||||
|
filtering=FilterConfig(
|
||||||
|
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
|
||||||
|
),
|
||||||
|
classes=ClassesConfig(
|
||||||
|
classes=[
|
||||||
|
TargetClass(name="pippip", tags=[pippip_tag]),
|
||||||
|
TargetClass(name="myomyo", tags=[myomyo_tag]),
|
||||||
|
],
|
||||||
|
generic_class=[bat_tag],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_targets(
|
||||||
|
sample_target_config: TargetConfig,
|
||||||
|
sample_term_registry: TermRegistry,
|
||||||
|
) -> TargetProtocol:
|
||||||
|
return build_targets(
|
||||||
|
sample_target_config,
|
||||||
|
term_registry=sample_term_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_labeller(
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
) -> ClipLabeller:
|
||||||
|
return build_clip_labeler(sample_targets)
|
||||||
|
@ -13,9 +13,9 @@ from batdetect2.targets.classes import (
|
|||||||
_get_default_class_name,
|
_get_default_class_name,
|
||||||
_get_default_classes,
|
_get_default_classes,
|
||||||
_is_target_class,
|
_is_target_class,
|
||||||
|
build_generic_class_tags,
|
||||||
build_sound_event_decoder,
|
build_sound_event_decoder,
|
||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
build_generic_class_tags,
|
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
load_classes_config,
|
load_classes_config,
|
||||||
load_decoder_from_config,
|
load_decoder_from_config,
|
||||||
|
@ -7,15 +7,14 @@ from soundevent import data
|
|||||||
from batdetect2.targets.filtering import (
|
from batdetect2.targets.filtering import (
|
||||||
FilterConfig,
|
FilterConfig,
|
||||||
FilterRule,
|
FilterRule,
|
||||||
build_sound_event_filter,
|
|
||||||
build_filter_from_rule,
|
build_filter_from_rule,
|
||||||
|
build_sound_event_filter,
|
||||||
contains_tags,
|
contains_tags,
|
||||||
does_not_have_tags,
|
does_not_have_tags,
|
||||||
equal_tags,
|
equal_tags,
|
||||||
has_any_tag,
|
has_any_tag,
|
||||||
load_filter_config,
|
load_filter_config,
|
||||||
load_filter_from_config,
|
load_filter_from_config,
|
||||||
merge_filters,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import TagInfo, generic_class
|
from batdetect2.targets.terms import TagInfo, generic_class
|
||||||
|
|
||||||
|
@ -9,11 +9,13 @@ from batdetect2.targets import (
|
|||||||
ReplaceRule,
|
ReplaceRule,
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TransformConfig,
|
TransformConfig,
|
||||||
build_transform_from_rule,
|
|
||||||
build_transformation_from_config,
|
build_transformation_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import TermRegistry
|
from batdetect2.targets.terms import TermRegistry
|
||||||
from batdetect2.targets.transform import DerivationRegistry
|
from batdetect2.targets.transform import (
|
||||||
|
DerivationRegistry,
|
||||||
|
build_transform_from_rule,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -4,7 +4,11 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||||
|
from batdetect2.targets.rois import ROIConfig
|
||||||
|
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
from tests.test_targets.test_transform import term_registry
|
||||||
|
|
||||||
recording = data.Recording(
|
recording = data.Recording(
|
||||||
samplerate=256_000,
|
samplerate=256_000,
|
||||||
@ -22,7 +26,9 @@ clip = data.Clip(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_generated_heatmaps_have_correct_dimensions():
|
def test_generated_heatmaps_have_correct_dimensions(
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
):
|
||||||
spec = xr.DataArray(
|
spec = xr.DataArray(
|
||||||
data=np.random.rand(100, 100),
|
data=np.random.rand(100, 100),
|
||||||
dims=["time", "frequency"],
|
dims=["time", "frequency"],
|
||||||
@ -49,8 +55,7 @@ def test_generated_heatmaps_have_correct_dimensions():
|
|||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation.sound_events,
|
clip_annotation.sound_events,
|
||||||
spec,
|
spec,
|
||||||
class_names=["bat", "cat"],
|
targets=sample_targets,
|
||||||
encoder=lambda _: "bat",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(detection_heatmap, xr.DataArray)
|
assert isinstance(detection_heatmap, xr.DataArray)
|
||||||
@ -60,7 +65,10 @@ def test_generated_heatmaps_have_correct_dimensions():
|
|||||||
assert isinstance(class_heatmap, xr.DataArray)
|
assert isinstance(class_heatmap, xr.DataArray)
|
||||||
assert class_heatmap.shape == (2, 100, 100)
|
assert class_heatmap.shape == (2, 100, 100)
|
||||||
assert class_heatmap.dims == ("category", "time", "frequency")
|
assert class_heatmap.dims == ("category", "time", "frequency")
|
||||||
assert class_heatmap.coords["category"].values.tolist() == ["bat", "cat"]
|
assert class_heatmap.coords["category"].values.tolist() == [
|
||||||
|
"pippip",
|
||||||
|
"myomyo",
|
||||||
|
]
|
||||||
|
|
||||||
assert isinstance(size_heatmap, xr.DataArray)
|
assert isinstance(size_heatmap, xr.DataArray)
|
||||||
assert size_heatmap.shape == (2, 100, 100)
|
assert size_heatmap.shape == (2, 100, 100)
|
||||||
@ -71,7 +79,22 @@ def test_generated_heatmaps_have_correct_dimensions():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_generated_heatmap_are_non_zero_at_correct_positions():
|
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||||
|
sample_target_config: TargetConfig,
|
||||||
|
sample_term_registry: TermRegistry,
|
||||||
|
pippip_tag: TagInfo,
|
||||||
|
):
|
||||||
|
config = sample_target_config.model_copy(
|
||||||
|
update=dict(
|
||||||
|
roi=ROIConfig(
|
||||||
|
time_scale=1,
|
||||||
|
frequency_scale=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = build_targets(config, term_registry=sample_term_registry)
|
||||||
|
|
||||||
spec = xr.DataArray(
|
spec = xr.DataArray(
|
||||||
data=np.random.rand(100, 100),
|
data=np.random.rand(100, 100),
|
||||||
dims=["time", "frequency"],
|
dims=["time", "frequency"],
|
||||||
@ -91,6 +114,12 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
|
|||||||
coordinates=[10, 10, 20, 20],
|
coordinates=[10, 10, 20, 20],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=sample_term_registry[pippip_tag.key],
|
||||||
|
value=pippip_tag.value,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -98,13 +127,10 @@ def test_generated_heatmap_are_non_zero_at_correct_positions():
|
|||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation.sound_events,
|
clip_annotation.sound_events,
|
||||||
spec,
|
spec,
|
||||||
class_names=["bat", "cat"],
|
targets=targets,
|
||||||
encoder=lambda _: "bat",
|
|
||||||
time_scale=1,
|
|
||||||
frequency_scale=1,
|
|
||||||
)
|
)
|
||||||
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
|
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
|
||||||
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
|
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
|
||||||
assert class_heatmap.sel(time=10, frequency=10, category="bat") == 1.0
|
assert class_heatmap.sel(time=10, frequency=10, category="pippip") == 1.0
|
||||||
assert class_heatmap.sel(time=10, frequency=10, category="cat") == 0.0
|
assert class_heatmap.sel(time=10, frequency=10, category="myomyo") == 0.0
|
||||||
assert detection_heatmap.sel(time=10, frequency=10) == 1.0
|
assert detection_heatmap.sel(time=10, frequency=10) == 1.0
|
||||||
|
Loading…
Reference in New Issue
Block a user