mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fix tests
This commit is contained in:
parent
9c8b8fb200
commit
2913fa59a4
@ -7,6 +7,8 @@ import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import (
|
||||
@ -383,3 +385,27 @@ def sample_labeller(
|
||||
sample_targets: TargetProtocol,
|
||||
) -> ClipLabeller:
|
||||
return build_clip_labeler(sample_targets)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_dataset(example_data_dir: Path) -> DatasetConfig:
|
||||
return DatasetConfig(
|
||||
name="test dataset",
|
||||
description="test dataset",
|
||||
sources=[
|
||||
BatDetect2FilesAnnotations(
|
||||
name="example annotations",
|
||||
audio_dir=example_data_dir / "audio",
|
||||
annotations_dir=example_data_dir / "anns",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_annotations(
|
||||
example_dataset: DatasetConfig,
|
||||
) -> List[data.ClipAnnotation]:
|
||||
annotations = load_dataset(example_dataset)
|
||||
assert len(annotations) == 3
|
||||
return annotations
|
||||
|
@ -254,11 +254,11 @@ class TestLoadBatDetect2Files:
|
||||
assert clip_ann.clip.recording.duration == 5.0
|
||||
assert len(clip_ann.sound_events) == 1
|
||||
assert clip_ann.notes[0].message == "Standard notes."
|
||||
clip_tag = data.find_tag(clip_ann.tags, "class")
|
||||
clip_tag = data.find_tag(clip_ann.tags, "Class")
|
||||
assert clip_tag is not None
|
||||
assert clip_tag.value == "Myotis"
|
||||
|
||||
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "class")
|
||||
recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class")
|
||||
assert recording_tag is not None
|
||||
assert recording_tag.value == "Myotis"
|
||||
|
||||
@ -271,15 +271,15 @@ class TestLoadBatDetect2Files:
|
||||
40000,
|
||||
]
|
||||
|
||||
se_class_tag = data.find_tag(se_ann.tags, "class")
|
||||
se_class_tag = data.find_tag(se_ann.tags, "Class")
|
||||
assert se_class_tag is not None
|
||||
assert se_class_tag.value == "Myotis"
|
||||
|
||||
se_event_tag = data.find_tag(se_ann.tags, "event")
|
||||
se_event_tag = data.find_tag(se_ann.tags, "Call Type")
|
||||
assert se_event_tag is not None
|
||||
assert se_event_tag.value == "Echolocation"
|
||||
|
||||
se_individual_tag = data.find_tag(se_ann.tags, "individual")
|
||||
se_individual_tag = data.find_tag(se_ann.tags, "Individual")
|
||||
assert se_individual_tag is not None
|
||||
assert se_individual_tag.value == "0"
|
||||
|
||||
@ -439,7 +439,7 @@ class TestLoadBatDetect2Merged:
|
||||
assert clip_ann.clip.recording.duration == 5.0
|
||||
assert len(clip_ann.sound_events) == 1
|
||||
|
||||
clip_class_tag = data.find_tag(clip_ann.tags, "class")
|
||||
clip_class_tag = data.find_tag(clip_ann.tags, "Class")
|
||||
assert clip_class_tag is not None
|
||||
assert clip_class_tag.value == "Myotis"
|
||||
|
||||
|
@ -98,7 +98,7 @@ def sample_detection_dataset() -> xr.Dataset:
|
||||
expected_freqs = np.array([300, 200])
|
||||
detection_coords = {
|
||||
"time": ("detection", expected_times),
|
||||
"freq": ("detection", expected_freqs),
|
||||
"frequency": ("detection", expected_freqs),
|
||||
}
|
||||
|
||||
scores_data = np.array([0.9, 0.8], dtype=np.float64)
|
||||
@ -106,7 +106,7 @@ def sample_detection_dataset() -> xr.Dataset:
|
||||
scores_data,
|
||||
coords=detection_coords,
|
||||
dims=["detection"],
|
||||
name="scores",
|
||||
name="score",
|
||||
)
|
||||
|
||||
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
|
||||
@ -183,7 +183,7 @@ def empty_detection_dataset() -> xr.Dataset:
|
||||
)
|
||||
return xr.Dataset(
|
||||
{
|
||||
"scores": scores,
|
||||
"score": scores,
|
||||
"dimensions": dimensions,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
@ -206,10 +206,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
)
|
||||
pred1 = RawPrediction(
|
||||
detection_score=0.9,
|
||||
start_time=20 - 7 / 2,
|
||||
end_time=20 + 7 / 2,
|
||||
low_freq=300 - 16 / 2,
|
||||
high_freq=300 + 16 / 2,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
20 - 7 / 2,
|
||||
300 - 16 / 2,
|
||||
20 + 7 / 2,
|
||||
300 + 16 / 2,
|
||||
]
|
||||
),
|
||||
class_scores=pred1_classes,
|
||||
features=pred1_features,
|
||||
)
|
||||
@ -224,10 +228,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
)
|
||||
pred2 = RawPrediction(
|
||||
detection_score=0.8,
|
||||
start_time=10 - 3 / 2,
|
||||
end_time=10 + 3 / 2,
|
||||
low_freq=200 - 12 / 2,
|
||||
high_freq=200 + 12 / 2,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
10 - 3 / 2,
|
||||
200 - 12 / 2,
|
||||
10 + 3 / 2,
|
||||
200 + 12 / 2,
|
||||
]
|
||||
),
|
||||
class_scores=pred2_classes,
|
||||
features=pred2_features,
|
||||
)
|
||||
@ -242,10 +250,14 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
)
|
||||
pred3 = RawPrediction(
|
||||
detection_score=0.15,
|
||||
start_time=5.0,
|
||||
end_time=6.0,
|
||||
low_freq=50.0,
|
||||
high_freq=60.0,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
5.0,
|
||||
50.0,
|
||||
6.0,
|
||||
60.0,
|
||||
]
|
||||
),
|
||||
class_scores=pred3_classes,
|
||||
features=pred3_features,
|
||||
)
|
||||
@ -267,10 +279,12 @@ def test_convert_xr_dataset_basic(
|
||||
assert isinstance(pred1, RawPrediction)
|
||||
assert pred1.detection_score == 0.9
|
||||
|
||||
assert pred1.start_time == 20 - 7 / 2
|
||||
assert pred1.end_time == 20 + 7 / 2
|
||||
assert pred1.low_freq == 300 - 16 / 2
|
||||
assert pred1.high_freq == 300 + 16 / 2
|
||||
assert pred1.geometry.coordinates == [
|
||||
20 - 7 / 2,
|
||||
300 - 16 / 2,
|
||||
20 + 7 / 2,
|
||||
300 + 16 / 2,
|
||||
]
|
||||
xr.testing.assert_allclose(
|
||||
pred1.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=0),
|
||||
@ -283,10 +297,12 @@ def test_convert_xr_dataset_basic(
|
||||
assert isinstance(pred2, RawPrediction)
|
||||
assert pred2.detection_score == 0.8
|
||||
|
||||
assert pred2.start_time == 10 - 3 / 2
|
||||
assert pred2.end_time == 10 + 3 / 2
|
||||
assert pred2.low_freq == 200 - 12 / 2
|
||||
assert pred2.high_freq == 200 + 12 / 2
|
||||
assert pred2.geometry.coordinates == [
|
||||
10 - 3 / 2,
|
||||
200 - 12 / 2,
|
||||
10 + 3 / 2,
|
||||
200 + 12 / 2,
|
||||
]
|
||||
xr.testing.assert_allclose(
|
||||
pred2.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=1),
|
||||
@ -331,15 +347,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
assert isinstance(se, data.SoundEvent)
|
||||
assert se.recording == sample_recording
|
||||
assert isinstance(se.geometry, data.BoundingBox)
|
||||
np.testing.assert_allclose(
|
||||
se.geometry.coordinates,
|
||||
[
|
||||
raw_pred.start_time,
|
||||
raw_pred.low_freq,
|
||||
raw_pred.end_time,
|
||||
raw_pred.high_freq,
|
||||
],
|
||||
)
|
||||
assert se.geometry == raw_pred.geometry
|
||||
assert len(se.features) == len(raw_pred.features)
|
||||
|
||||
feat_dict = {f.term.name: f.value for f in se.features}
|
||||
|
@ -1,6 +1,6 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -307,7 +307,7 @@ def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
|
||||
def test_resize_spectrogram(
|
||||
sample_spec: xr.DataArray,
|
||||
height: int,
|
||||
resize_factor: float | None,
|
||||
resize_factor: Union[float, None],
|
||||
expected_freq_size: int,
|
||||
expected_time_factor: float,
|
||||
):
|
||||
|
@ -4,6 +4,7 @@ from typing import Callable, List, Set
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.filtering import (
|
||||
FilterConfig,
|
||||
FilterRule,
|
||||
@ -176,3 +177,34 @@ rules:
|
||||
filter_result = load_filter_from_config(test_config_path)
|
||||
annotation = create_annotation(["tag1", "tag3"])
|
||||
assert filter_result(annotation) is False
|
||||
|
||||
|
||||
def test_default_filtering_over_example_dataset(
|
||||
example_annotations: List[data.ClipAnnotation],
|
||||
):
|
||||
targets = build_targets()
|
||||
|
||||
clip1 = example_annotations[0]
|
||||
clip2 = example_annotations[1]
|
||||
clip3 = example_annotations[2]
|
||||
|
||||
assert (
|
||||
sum(
|
||||
[targets.filter(sound_event) for sound_event in clip1.sound_events]
|
||||
)
|
||||
== 9
|
||||
)
|
||||
|
||||
assert (
|
||||
sum(
|
||||
[targets.filter(sound_event) for sound_event in clip2.sound_events]
|
||||
)
|
||||
== 15
|
||||
)
|
||||
|
||||
assert (
|
||||
sum(
|
||||
[targets.filter(sound_event) for sound_event in clip3.sound_events]
|
||||
)
|
||||
== 20
|
||||
)
|
||||
|
@ -9,8 +9,8 @@ from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.train.augmentations import (
|
||||
add_echo,
|
||||
mix_examples,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.clips import select_subclip
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
@ -121,7 +121,7 @@ def test_selected_random_subclip_has_the_correct_width(
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
subclip = select_subclip(original, width=100)
|
||||
subclip = select_subclip(original, start=0, span=100)
|
||||
|
||||
assert subclip["spectrogram"].shape[1] == 100
|
||||
|
||||
@ -142,7 +142,7 @@ def test_add_echo_after_subclip(
|
||||
|
||||
assert original.sizes["time"] > 512
|
||||
|
||||
subclip = select_subclip(original, width=512)
|
||||
subclip = select_subclip(original, start=0, span=512)
|
||||
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
|
||||
|
||||
assert with_echo.sizes["time"] == 512
|
||||
|
Loading…
Reference in New Issue
Block a user