Fix tests

This commit is contained in:
mbsantiago 2025-04-30 22:51:49 +01:00
parent 9c8b8fb200
commit 2913fa59a4
6 changed files with 109 additions and 43 deletions

View File

@ -7,6 +7,8 @@ import pytest
import soundfile as sf
from soundevent import data, terms
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import (
@ -383,3 +385,27 @@ def sample_labeller(
sample_targets: TargetProtocol,
) -> ClipLabeller:
return build_clip_labeler(sample_targets)
@pytest.fixture
def example_dataset(example_data_dir: Path) -> DatasetConfig:
return DatasetConfig(
name="test dataset",
description="test dataset",
sources=[
BatDetect2FilesAnnotations(
name="example annotations",
audio_dir=example_data_dir / "audio",
annotations_dir=example_data_dir / "anns",
)
],
)
@pytest.fixture
def example_annotations(
example_dataset: DatasetConfig,
) -> List[data.ClipAnnotation]:
annotations = load_dataset(example_dataset)
assert len(annotations) == 3
return annotations

View File

@ -254,11 +254,11 @@ class TestLoadBatDetect2Files:
assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1
assert clip_ann.notes[0].message == "Standard notes."
clip_tag = data.find_tag(clip_ann.tags, "class")
clip_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_tag is not None
assert clip_tag.value == "Myotis"
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "class")
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
assert recording_tag is not None
assert recording_tag.value == "Myotis"
@ -271,15 +271,15 @@ class TestLoadBatDetect2Files:
40000,
]
se_class_tag = data.find_tag(se_ann.tags, "class")
se_class_tag = data.find_tag(se_ann.tags, "Class")
assert se_class_tag is not None
assert se_class_tag.value == "Myotis"
se_event_tag = data.find_tag(se_ann.tags, "event")
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
assert se_event_tag is not None
assert se_event_tag.value == "Echolocation"
se_individual_tag = data.find_tag(se_ann.tags, "individual")
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
assert se_individual_tag is not None
assert se_individual_tag.value == "0"
@ -439,7 +439,7 @@ class TestLoadBatDetect2Merged:
assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1
clip_class_tag = data.find_tag(clip_ann.tags, "class")
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_class_tag is not None
assert clip_class_tag.value == "Myotis"

View File

@ -98,7 +98,7 @@ def sample_detection_dataset() -> xr.Dataset:
expected_freqs = np.array([300, 200])
detection_coords = {
"time": ("detection", expected_times),
"freq": ("detection", expected_freqs),
"frequency": ("detection", expected_freqs),
}
scores_data = np.array([0.9, 0.8], dtype=np.float64)
@ -106,7 +106,7 @@ def sample_detection_dataset() -> xr.Dataset:
scores_data,
coords=detection_coords,
dims=["detection"],
name="scores",
name="score",
)
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
@ -183,7 +183,7 @@ def empty_detection_dataset() -> xr.Dataset:
)
return xr.Dataset(
{
"scores": scores,
"score": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
@ -206,10 +206,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
)
pred1 = RawPrediction(
detection_score=0.9,
start_time=20 - 7 / 2,
end_time=20 + 7 / 2,
low_freq=300 - 16 / 2,
high_freq=300 + 16 / 2,
geometry=data.BoundingBox(
coordinates=[
20 - 7 / 2,
300 - 16 / 2,
20 + 7 / 2,
300 + 16 / 2,
]
),
class_scores=pred1_classes,
features=pred1_features,
)
@ -224,10 +228,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
)
pred2 = RawPrediction(
detection_score=0.8,
start_time=10 - 3 / 2,
end_time=10 + 3 / 2,
low_freq=200 - 12 / 2,
high_freq=200 + 12 / 2,
geometry=data.BoundingBox(
coordinates=[
10 - 3 / 2,
200 - 12 / 2,
10 + 3 / 2,
200 + 12 / 2,
]
),
class_scores=pred2_classes,
features=pred2_features,
)
@ -242,10 +250,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
)
pred3 = RawPrediction(
detection_score=0.15,
start_time=5.0,
end_time=6.0,
low_freq=50.0,
high_freq=60.0,
geometry=data.BoundingBox(
coordinates=[
5.0,
50.0,
6.0,
60.0,
]
),
class_scores=pred3_classes,
features=pred3_features,
)
@ -267,10 +279,12 @@ def test_convert_xr_dataset_basic(
assert isinstance(pred1, RawPrediction)
assert pred1.detection_score == 0.9
assert pred1.start_time == 20 - 7 / 2
assert pred1.end_time == 20 + 7 / 2
assert pred1.low_freq == 300 - 16 / 2
assert pred1.high_freq == 300 + 16 / 2
assert pred1.geometry.coordinates == [
20 - 7 / 2,
300 - 16 / 2,
20 + 7 / 2,
300 + 16 / 2,
]
xr.testing.assert_allclose(
pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0),
@ -283,10 +297,12 @@ def test_convert_xr_dataset_basic(
assert isinstance(pred2, RawPrediction)
assert pred2.detection_score == 0.8
assert pred2.start_time == 10 - 3 / 2
assert pred2.end_time == 10 + 3 / 2
assert pred2.low_freq == 200 - 12 / 2
assert pred2.high_freq == 200 + 12 / 2
assert pred2.geometry.coordinates == [
10 - 3 / 2,
200 - 12 / 2,
10 + 3 / 2,
200 + 12 / 2,
]
xr.testing.assert_allclose(
pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1),
@ -331,15 +347,7 @@ def test_convert_raw_to_sound_event_basic(
assert isinstance(se, data.SoundEvent)
assert se.recording == sample_recording
assert isinstance(se.geometry, data.BoundingBox)
np.testing.assert_allclose(
se.geometry.coordinates,
[
raw_pred.start_time,
raw_pred.low_freq,
raw_pred.end_time,
raw_pred.high_freq,
],
)
assert se.geometry == raw_pred.geometry
assert len(se.features) == len(raw_pred.features)
feat_dict = {f.term.name: f.value for f in se.features}

View File

@ -1,6 +1,6 @@
import math
from pathlib import Path
from typing import Callable
from typing import Callable, Union
import numpy as np
import pytest
@ -307,7 +307,7 @@ def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
def test_resize_spectrogram(
sample_spec: xr.DataArray,
height: int,
resize_factor: float | None,
resize_factor: Union[float, None],
expected_freq_size: int,
expected_time_factor: float,
):

View File

@ -4,6 +4,7 @@ from typing import Callable, List, Set
import pytest
from soundevent import data
from batdetect2.targets import build_targets
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
@ -176,3 +177,34 @@ rules:
filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation) is False
def test_default_filtering_over_example_dataset(
example_annotations: List[data.ClipAnnotation],
):
targets = build_targets()
clip1 = example_annotations[0]
clip2 = example_annotations[1]
clip3 = example_annotations[2]
assert (
sum(
[targets.filter(sound_event) for sound_event in clip1.sound_events]
)
== 9
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip2.sound_events]
)
== 15
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip3.sound_events]
)
== 20
)

View File

@ -9,8 +9,8 @@ from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.augmentations import (
add_echo,
mix_examples,
select_subclip,
)
from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example
from batdetect2.train.types import ClipLabeller
@ -121,7 +121,7 @@ def test_selected_random_subclip_has_the_correct_width(
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
subclip = select_subclip(original, width=100)
subclip = select_subclip(original, start=0, span=100)
assert subclip["spectrogram"].shape[1] == 100
@ -142,7 +142,7 @@ def test_add_echo_after_subclip(
assert original.sizes["time"] > 512
subclip = select_subclip(original, width=512)
subclip = select_subclip(original, start=0, span=512)
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
assert with_echo.sizes["time"] == 512