Fix tests

This commit is contained in:
mbsantiago 2025-04-30 22:51:49 +01:00
parent 9c8b8fb200
commit 2913fa59a4
6 changed files with 109 additions and 43 deletions

View File

@ -7,6 +7,8 @@ import pytest
import soundfile as sf import soundfile as sf
from soundevent import data, terms from soundevent import data, terms
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import ( from batdetect2.targets import (
@ -383,3 +385,27 @@ def sample_labeller(
sample_targets: TargetProtocol, sample_targets: TargetProtocol,
) -> ClipLabeller: ) -> ClipLabeller:
return build_clip_labeler(sample_targets) return build_clip_labeler(sample_targets)
@pytest.fixture
def example_dataset(example_data_dir: Path) -> DatasetConfig:
return DatasetConfig(
name="test dataset",
description="test dataset",
sources=[
BatDetect2FilesAnnotations(
name="example annotations",
audio_dir=example_data_dir / "audio",
annotations_dir=example_data_dir / "anns",
)
],
)
@pytest.fixture
def example_annotations(
example_dataset: DatasetConfig,
) -> List[data.ClipAnnotation]:
annotations = load_dataset(example_dataset)
assert len(annotations) == 3
return annotations

View File

@ -254,11 +254,11 @@ class TestLoadBatDetect2Files:
assert clip_ann.clip.recording.duration == 5.0 assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1 assert len(clip_ann.sound_events) == 1
assert clip_ann.notes[0].message == "Standard notes." assert clip_ann.notes[0].message == "Standard notes."
clip_tag = data.find_tag(clip_ann.tags, "class") clip_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_tag is not None assert clip_tag is not None
assert clip_tag.value == "Myotis" assert clip_tag.value == "Myotis"
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "class") recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
assert recording_tag is not None assert recording_tag is not None
assert recording_tag.value == "Myotis" assert recording_tag.value == "Myotis"
@ -271,15 +271,15 @@ class TestLoadBatDetect2Files:
40000, 40000,
] ]
se_class_tag = data.find_tag(se_ann.tags, "class") se_class_tag = data.find_tag(se_ann.tags, "Class")
assert se_class_tag is not None assert se_class_tag is not None
assert se_class_tag.value == "Myotis" assert se_class_tag.value == "Myotis"
se_event_tag = data.find_tag(se_ann.tags, "event") se_event_tag = data.find_tag(se_ann.tags, "Call Type")
assert se_event_tag is not None assert se_event_tag is not None
assert se_event_tag.value == "Echolocation" assert se_event_tag.value == "Echolocation"
se_individual_tag = data.find_tag(se_ann.tags, "individual") se_individual_tag = data.find_tag(se_ann.tags, "Individual")
assert se_individual_tag is not None assert se_individual_tag is not None
assert se_individual_tag.value == "0" assert se_individual_tag.value == "0"
@ -439,7 +439,7 @@ class TestLoadBatDetect2Merged:
assert clip_ann.clip.recording.duration == 5.0 assert clip_ann.clip.recording.duration == 5.0
assert len(clip_ann.sound_events) == 1 assert len(clip_ann.sound_events) == 1
clip_class_tag = data.find_tag(clip_ann.tags, "class") clip_class_tag = data.find_tag(clip_ann.tags, "Class")
assert clip_class_tag is not None assert clip_class_tag is not None
assert clip_class_tag.value == "Myotis" assert clip_class_tag.value == "Myotis"

View File

@ -98,7 +98,7 @@ def sample_detection_dataset() -> xr.Dataset:
expected_freqs = np.array([300, 200]) expected_freqs = np.array([300, 200])
detection_coords = { detection_coords = {
"time": ("detection", expected_times), "time": ("detection", expected_times),
"freq": ("detection", expected_freqs), "frequency": ("detection", expected_freqs),
} }
scores_data = np.array([0.9, 0.8], dtype=np.float64) scores_data = np.array([0.9, 0.8], dtype=np.float64)
@ -106,7 +106,7 @@ def sample_detection_dataset() -> xr.Dataset:
scores_data, scores_data,
coords=detection_coords, coords=detection_coords,
dims=["detection"], dims=["detection"],
name="scores", name="score",
) )
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32) dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
@ -183,7 +183,7 @@ def empty_detection_dataset() -> xr.Dataset:
) )
return xr.Dataset( return xr.Dataset(
{ {
"scores": scores, "score": scores,
"dimensions": dimensions, "dimensions": dimensions,
"classes": classes, "classes": classes,
"features": features, "features": features,
@ -206,10 +206,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
) )
pred1 = RawPrediction( pred1 = RawPrediction(
detection_score=0.9, detection_score=0.9,
start_time=20 - 7 / 2, geometry=data.BoundingBox(
end_time=20 + 7 / 2, coordinates=[
low_freq=300 - 16 / 2, 20 - 7 / 2,
high_freq=300 + 16 / 2, 300 - 16 / 2,
20 + 7 / 2,
300 + 16 / 2,
]
),
class_scores=pred1_classes, class_scores=pred1_classes,
features=pred1_features, features=pred1_features,
) )
@ -224,10 +228,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
) )
pred2 = RawPrediction( pred2 = RawPrediction(
detection_score=0.8, detection_score=0.8,
start_time=10 - 3 / 2, geometry=data.BoundingBox(
end_time=10 + 3 / 2, coordinates=[
low_freq=200 - 12 / 2, 10 - 3 / 2,
high_freq=200 + 12 / 2, 200 - 12 / 2,
10 + 3 / 2,
200 + 12 / 2,
]
),
class_scores=pred2_classes, class_scores=pred2_classes,
features=pred2_features, features=pred2_features,
) )
@ -242,10 +250,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
) )
pred3 = RawPrediction( pred3 = RawPrediction(
detection_score=0.15, detection_score=0.15,
start_time=5.0, geometry=data.BoundingBox(
end_time=6.0, coordinates=[
low_freq=50.0, 5.0,
high_freq=60.0, 50.0,
6.0,
60.0,
]
),
class_scores=pred3_classes, class_scores=pred3_classes,
features=pred3_features, features=pred3_features,
) )
@ -267,10 +279,12 @@ def test_convert_xr_dataset_basic(
assert isinstance(pred1, RawPrediction) assert isinstance(pred1, RawPrediction)
assert pred1.detection_score == 0.9 assert pred1.detection_score == 0.9
assert pred1.start_time == 20 - 7 / 2 assert pred1.geometry.coordinates == [
assert pred1.end_time == 20 + 7 / 2 20 - 7 / 2,
assert pred1.low_freq == 300 - 16 / 2 300 - 16 / 2,
assert pred1.high_freq == 300 + 16 / 2 20 + 7 / 2,
300 + 16 / 2,
]
xr.testing.assert_allclose( xr.testing.assert_allclose(
pred1.class_scores, pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0), sample_detection_dataset["classes"].sel(detection=0),
@ -283,10 +297,12 @@ def test_convert_xr_dataset_basic(
assert isinstance(pred2, RawPrediction) assert isinstance(pred2, RawPrediction)
assert pred2.detection_score == 0.8 assert pred2.detection_score == 0.8
assert pred2.start_time == 10 - 3 / 2 assert pred2.geometry.coordinates == [
assert pred2.end_time == 10 + 3 / 2 10 - 3 / 2,
assert pred2.low_freq == 200 - 12 / 2 200 - 12 / 2,
assert pred2.high_freq == 200 + 12 / 2 10 + 3 / 2,
200 + 12 / 2,
]
xr.testing.assert_allclose( xr.testing.assert_allclose(
pred2.class_scores, pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1), sample_detection_dataset["classes"].sel(detection=1),
@ -331,15 +347,7 @@ def test_convert_raw_to_sound_event_basic(
assert isinstance(se, data.SoundEvent) assert isinstance(se, data.SoundEvent)
assert se.recording == sample_recording assert se.recording == sample_recording
assert isinstance(se.geometry, data.BoundingBox) assert isinstance(se.geometry, data.BoundingBox)
np.testing.assert_allclose( assert se.geometry == raw_pred.geometry
se.geometry.coordinates,
[
raw_pred.start_time,
raw_pred.low_freq,
raw_pred.end_time,
raw_pred.high_freq,
],
)
assert len(se.features) == len(raw_pred.features) assert len(se.features) == len(raw_pred.features)
feat_dict = {f.term.name: f.value for f in se.features} feat_dict = {f.term.name: f.value for f in se.features}

View File

@ -1,6 +1,6 @@
import math import math
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, Union
import numpy as np import numpy as np
import pytest import pytest
@ -307,7 +307,7 @@ def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
def test_resize_spectrogram( def test_resize_spectrogram(
sample_spec: xr.DataArray, sample_spec: xr.DataArray,
height: int, height: int,
resize_factor: float | None, resize_factor: Union[float, None],
expected_freq_size: int, expected_freq_size: int,
expected_time_factor: float, expected_time_factor: float,
): ):

View File

@ -4,6 +4,7 @@ from typing import Callable, List, Set
import pytest import pytest
from soundevent import data from soundevent import data
from batdetect2.targets import build_targets
from batdetect2.targets.filtering import ( from batdetect2.targets.filtering import (
FilterConfig, FilterConfig,
FilterRule, FilterRule,
@ -176,3 +177,34 @@ rules:
filter_result = load_filter_from_config(test_config_path) filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"]) annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation) is False assert filter_result(annotation) is False
def test_default_filtering_over_example_dataset(
example_annotations: List[data.ClipAnnotation],
):
targets = build_targets()
clip1 = example_annotations[0]
clip2 = example_annotations[1]
clip3 = example_annotations[2]
assert (
sum(
[targets.filter(sound_event) for sound_event in clip1.sound_events]
)
== 9
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip2.sound_events]
)
== 15
)
assert (
sum(
[targets.filter(sound_event) for sound_event in clip3.sound_events]
)
== 20
)

View File

@ -9,8 +9,8 @@ from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
add_echo, add_echo,
mix_examples, mix_examples,
select_subclip,
) )
from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
from batdetect2.train.types import ClipLabeller from batdetect2.train.types import ClipLabeller
@ -121,7 +121,7 @@ def test_selected_random_subclip_has_the_correct_width(
preprocessor=sample_preprocessor, preprocessor=sample_preprocessor,
labeller=sample_labeller, labeller=sample_labeller,
) )
subclip = select_subclip(original, width=100) subclip = select_subclip(original, start=0, span=100)
assert subclip["spectrogram"].shape[1] == 100 assert subclip["spectrogram"].shape[1] == 100
@ -142,7 +142,7 @@ def test_add_echo_after_subclip(
assert original.sizes["time"] > 512 assert original.sizes["time"] > 512
subclip = select_subclip(original, width=512) subclip = select_subclip(original, start=0, span=512)
with_echo = add_echo(subclip, preprocessor=sample_preprocessor) with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
assert with_echo.sizes["time"] == 512 assert with_echo.sizes["time"] == 512