mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fix tests
This commit is contained in:
parent
9c8b8fb200
commit
2913fa59a4
@ -7,6 +7,8 @@ import pytest
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
|
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
@ -383,3 +385,27 @@ def sample_labeller(
|
|||||||
sample_targets: TargetProtocol,
|
sample_targets: TargetProtocol,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
return build_clip_labeler(sample_targets)
|
return build_clip_labeler(sample_targets)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_dataset(example_data_dir: Path) -> DatasetConfig:
|
||||||
|
return DatasetConfig(
|
||||||
|
name="test dataset",
|
||||||
|
description="test dataset",
|
||||||
|
sources=[
|
||||||
|
BatDetect2FilesAnnotations(
|
||||||
|
name="example annotations",
|
||||||
|
audio_dir=example_data_dir / "audio",
|
||||||
|
annotations_dir=example_data_dir / "anns",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_annotations(
|
||||||
|
example_dataset: DatasetConfig,
|
||||||
|
) -> List[data.ClipAnnotation]:
|
||||||
|
annotations = load_dataset(example_dataset)
|
||||||
|
assert len(annotations) == 3
|
||||||
|
return annotations
|
||||||
|
@ -254,11 +254,11 @@ class TestLoadBatDetect2Files:
|
|||||||
assert clip_ann.clip.recording.duration == 5.0
|
assert clip_ann.clip.recording.duration == 5.0
|
||||||
assert len(clip_ann.sound_events) == 1
|
assert len(clip_ann.sound_events) == 1
|
||||||
assert clip_ann.notes[0].message == "Standard notes."
|
assert clip_ann.notes[0].message == "Standard notes."
|
||||||
clip_tag = data.find_tag(clip_ann.tags, "class")
|
clip_tag = data.find_tag(clip_ann.tags, "Class")
|
||||||
assert clip_tag is not None
|
assert clip_tag is not None
|
||||||
assert clip_tag.value == "Myotis"
|
assert clip_tag.value == "Myotis"
|
||||||
|
|
||||||
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "class")
|
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
|
||||||
assert recording_tag is not None
|
assert recording_tag is not None
|
||||||
assert recording_tag.value == "Myotis"
|
assert recording_tag.value == "Myotis"
|
||||||
|
|
||||||
@ -271,15 +271,15 @@ class TestLoadBatDetect2Files:
|
|||||||
40000,
|
40000,
|
||||||
]
|
]
|
||||||
|
|
||||||
se_class_tag = data.find_tag(se_ann.tags, "class")
|
se_class_tag = data.find_tag(se_ann.tags, "Class")
|
||||||
assert se_class_tag is not None
|
assert se_class_tag is not None
|
||||||
assert se_class_tag.value == "Myotis"
|
assert se_class_tag.value == "Myotis"
|
||||||
|
|
||||||
se_event_tag = data.find_tag(se_ann.tags, "event")
|
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
|
||||||
assert se_event_tag is not None
|
assert se_event_tag is not None
|
||||||
assert se_event_tag.value == "Echolocation"
|
assert se_event_tag.value == "Echolocation"
|
||||||
|
|
||||||
se_individual_tag = data.find_tag(se_ann.tags, "individual")
|
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
|
||||||
assert se_individual_tag is not None
|
assert se_individual_tag is not None
|
||||||
assert se_individual_tag.value == "0"
|
assert se_individual_tag.value == "0"
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ class TestLoadBatDetect2Merged:
|
|||||||
assert clip_ann.clip.recording.duration == 5.0
|
assert clip_ann.clip.recording.duration == 5.0
|
||||||
assert len(clip_ann.sound_events) == 1
|
assert len(clip_ann.sound_events) == 1
|
||||||
|
|
||||||
clip_class_tag = data.find_tag(clip_ann.tags, "class")
|
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
|
||||||
assert clip_class_tag is not None
|
assert clip_class_tag is not None
|
||||||
assert clip_class_tag.value == "Myotis"
|
assert clip_class_tag.value == "Myotis"
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ def sample_detection_dataset() -> xr.Dataset:
|
|||||||
expected_freqs = np.array([300, 200])
|
expected_freqs = np.array([300, 200])
|
||||||
detection_coords = {
|
detection_coords = {
|
||||||
"time": ("detection", expected_times),
|
"time": ("detection", expected_times),
|
||||||
"freq": ("detection", expected_freqs),
|
"frequency": ("detection", expected_freqs),
|
||||||
}
|
}
|
||||||
|
|
||||||
scores_data = np.array([0.9, 0.8], dtype=np.float64)
|
scores_data = np.array([0.9, 0.8], dtype=np.float64)
|
||||||
@ -106,7 +106,7 @@ def sample_detection_dataset() -> xr.Dataset:
|
|||||||
scores_data,
|
scores_data,
|
||||||
coords=detection_coords,
|
coords=detection_coords,
|
||||||
dims=["detection"],
|
dims=["detection"],
|
||||||
name="scores",
|
name="score",
|
||||||
)
|
)
|
||||||
|
|
||||||
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
|
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
|
||||||
@ -183,7 +183,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
|||||||
)
|
)
|
||||||
return xr.Dataset(
|
return xr.Dataset(
|
||||||
{
|
{
|
||||||
"scores": scores,
|
"score": scores,
|
||||||
"dimensions": dimensions,
|
"dimensions": dimensions,
|
||||||
"classes": classes,
|
"classes": classes,
|
||||||
"features": features,
|
"features": features,
|
||||||
@ -206,10 +206,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
)
|
)
|
||||||
pred1 = RawPrediction(
|
pred1 = RawPrediction(
|
||||||
detection_score=0.9,
|
detection_score=0.9,
|
||||||
start_time=20 - 7 / 2,
|
geometry=data.BoundingBox(
|
||||||
end_time=20 + 7 / 2,
|
coordinates=[
|
||||||
low_freq=300 - 16 / 2,
|
20 - 7 / 2,
|
||||||
high_freq=300 + 16 / 2,
|
300 - 16 / 2,
|
||||||
|
20 + 7 / 2,
|
||||||
|
300 + 16 / 2,
|
||||||
|
]
|
||||||
|
),
|
||||||
class_scores=pred1_classes,
|
class_scores=pred1_classes,
|
||||||
features=pred1_features,
|
features=pred1_features,
|
||||||
)
|
)
|
||||||
@ -224,10 +228,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
)
|
)
|
||||||
pred2 = RawPrediction(
|
pred2 = RawPrediction(
|
||||||
detection_score=0.8,
|
detection_score=0.8,
|
||||||
start_time=10 - 3 / 2,
|
geometry=data.BoundingBox(
|
||||||
end_time=10 + 3 / 2,
|
coordinates=[
|
||||||
low_freq=200 - 12 / 2,
|
10 - 3 / 2,
|
||||||
high_freq=200 + 12 / 2,
|
200 - 12 / 2,
|
||||||
|
10 + 3 / 2,
|
||||||
|
200 + 12 / 2,
|
||||||
|
]
|
||||||
|
),
|
||||||
class_scores=pred2_classes,
|
class_scores=pred2_classes,
|
||||||
features=pred2_features,
|
features=pred2_features,
|
||||||
)
|
)
|
||||||
@ -242,10 +250,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
)
|
)
|
||||||
pred3 = RawPrediction(
|
pred3 = RawPrediction(
|
||||||
detection_score=0.15,
|
detection_score=0.15,
|
||||||
start_time=5.0,
|
geometry=data.BoundingBox(
|
||||||
end_time=6.0,
|
coordinates=[
|
||||||
low_freq=50.0,
|
5.0,
|
||||||
high_freq=60.0,
|
50.0,
|
||||||
|
6.0,
|
||||||
|
60.0,
|
||||||
|
]
|
||||||
|
),
|
||||||
class_scores=pred3_classes,
|
class_scores=pred3_classes,
|
||||||
features=pred3_features,
|
features=pred3_features,
|
||||||
)
|
)
|
||||||
@ -267,10 +279,12 @@ def test_convert_xr_dataset_basic(
|
|||||||
assert isinstance(pred1, RawPrediction)
|
assert isinstance(pred1, RawPrediction)
|
||||||
assert pred1.detection_score == 0.9
|
assert pred1.detection_score == 0.9
|
||||||
|
|
||||||
assert pred1.start_time == 20 - 7 / 2
|
assert pred1.geometry.coordinates == [
|
||||||
assert pred1.end_time == 20 + 7 / 2
|
20 - 7 / 2,
|
||||||
assert pred1.low_freq == 300 - 16 / 2
|
300 - 16 / 2,
|
||||||
assert pred1.high_freq == 300 + 16 / 2
|
20 + 7 / 2,
|
||||||
|
300 + 16 / 2,
|
||||||
|
]
|
||||||
xr.testing.assert_allclose(
|
xr.testing.assert_allclose(
|
||||||
pred1.class_scores,
|
pred1.class_scores,
|
||||||
sample_detection_dataset["classes"].sel(detection=0),
|
sample_detection_dataset["classes"].sel(detection=0),
|
||||||
@ -283,10 +297,12 @@ def test_convert_xr_dataset_basic(
|
|||||||
assert isinstance(pred2, RawPrediction)
|
assert isinstance(pred2, RawPrediction)
|
||||||
assert pred2.detection_score == 0.8
|
assert pred2.detection_score == 0.8
|
||||||
|
|
||||||
assert pred2.start_time == 10 - 3 / 2
|
assert pred2.geometry.coordinates == [
|
||||||
assert pred2.end_time == 10 + 3 / 2
|
10 - 3 / 2,
|
||||||
assert pred2.low_freq == 200 - 12 / 2
|
200 - 12 / 2,
|
||||||
assert pred2.high_freq == 200 + 12 / 2
|
10 + 3 / 2,
|
||||||
|
200 + 12 / 2,
|
||||||
|
]
|
||||||
xr.testing.assert_allclose(
|
xr.testing.assert_allclose(
|
||||||
pred2.class_scores,
|
pred2.class_scores,
|
||||||
sample_detection_dataset["classes"].sel(detection=1),
|
sample_detection_dataset["classes"].sel(detection=1),
|
||||||
@ -331,15 +347,7 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
assert isinstance(se, data.SoundEvent)
|
assert isinstance(se, data.SoundEvent)
|
||||||
assert se.recording == sample_recording
|
assert se.recording == sample_recording
|
||||||
assert isinstance(se.geometry, data.BoundingBox)
|
assert isinstance(se.geometry, data.BoundingBox)
|
||||||
np.testing.assert_allclose(
|
assert se.geometry == raw_pred.geometry
|
||||||
se.geometry.coordinates,
|
|
||||||
[
|
|
||||||
raw_pred.start_time,
|
|
||||||
raw_pred.low_freq,
|
|
||||||
raw_pred.end_time,
|
|
||||||
raw_pred.high_freq,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert len(se.features) == len(raw_pred.features)
|
assert len(se.features) == len(raw_pred.features)
|
||||||
|
|
||||||
feat_dict = {f.term.name: f.value for f in se.features}
|
feat_dict = {f.term.name: f.value for f in se.features}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -307,7 +307,7 @@ def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
|
|||||||
def test_resize_spectrogram(
|
def test_resize_spectrogram(
|
||||||
sample_spec: xr.DataArray,
|
sample_spec: xr.DataArray,
|
||||||
height: int,
|
height: int,
|
||||||
resize_factor: float | None,
|
resize_factor: Union[float, None],
|
||||||
expected_freq_size: int,
|
expected_freq_size: int,
|
||||||
expected_time_factor: float,
|
expected_time_factor: float,
|
||||||
):
|
):
|
||||||
|
@ -4,6 +4,7 @@ from typing import Callable, List, Set
|
|||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.targets.filtering import (
|
from batdetect2.targets.filtering import (
|
||||||
FilterConfig,
|
FilterConfig,
|
||||||
FilterRule,
|
FilterRule,
|
||||||
@ -176,3 +177,34 @@ rules:
|
|||||||
filter_result = load_filter_from_config(test_config_path)
|
filter_result = load_filter_from_config(test_config_path)
|
||||||
annotation = create_annotation(["tag1", "tag3"])
|
annotation = create_annotation(["tag1", "tag3"])
|
||||||
assert filter_result(annotation) is False
|
assert filter_result(annotation) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_filtering_over_example_dataset(
|
||||||
|
example_annotations: List[data.ClipAnnotation],
|
||||||
|
):
|
||||||
|
targets = build_targets()
|
||||||
|
|
||||||
|
clip1 = example_annotations[0]
|
||||||
|
clip2 = example_annotations[1]
|
||||||
|
clip3 = example_annotations[2]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(
|
||||||
|
[targets.filter(sound_event) for sound_event in clip1.sound_events]
|
||||||
|
)
|
||||||
|
== 9
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(
|
||||||
|
[targets.filter(sound_event) for sound_event in clip2.sound_events]
|
||||||
|
)
|
||||||
|
== 15
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(
|
||||||
|
[targets.filter(sound_event) for sound_event in clip3.sound_events]
|
||||||
|
)
|
||||||
|
== 20
|
||||||
|
)
|
||||||
|
@ -9,8 +9,8 @@ from batdetect2.preprocess.types import PreprocessorProtocol
|
|||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
add_echo,
|
add_echo,
|
||||||
mix_examples,
|
mix_examples,
|
||||||
select_subclip,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.clips import select_subclip
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.train.types import ClipLabeller
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
subclip = select_subclip(original, width=100)
|
subclip = select_subclip(original, start=0, span=100)
|
||||||
|
|
||||||
assert subclip["spectrogram"].shape[1] == 100
|
assert subclip["spectrogram"].shape[1] == 100
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ def test_add_echo_after_subclip(
|
|||||||
|
|
||||||
assert original.sizes["time"] > 512
|
assert original.sizes["time"] > 512
|
||||||
|
|
||||||
subclip = select_subclip(original, width=512)
|
subclip = select_subclip(original, start=0, span=512)
|
||||||
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
|
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
|
||||||
|
|
||||||
assert with_echo.sizes["time"] == 512
|
assert with_echo.sizes["time"] == 512
|
||||||
|
Loading…
Reference in New Issue
Block a user