diff --git a/tests/conftest.py b/tests/conftest.py index 791950c..556540f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,8 @@ import pytest import soundfile as sf from soundevent import data, terms +from batdetect2.data import DatasetConfig, load_dataset +from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import ( @@ -383,3 +385,27 @@ def sample_labeller( sample_targets: TargetProtocol, ) -> ClipLabeller: return build_clip_labeler(sample_targets) + + +@pytest.fixture +def example_dataset(example_data_dir: Path) -> DatasetConfig: + return DatasetConfig( + name="test dataset", + description="test dataset", + sources=[ + BatDetect2FilesAnnotations( + name="example annotations", + audio_dir=example_data_dir / "audio", + annotations_dir=example_data_dir / "anns", + ) + ], + ) + + +@pytest.fixture +def example_annotations( + example_dataset: DatasetConfig, +) -> List[data.ClipAnnotation]: + annotations = load_dataset(example_dataset) + assert len(annotations) == 3 + return annotations diff --git a/tests/test_data/test_annotations/test_batdetect2.py b/tests/test_data/test_annotations/test_batdetect2.py index 8e4beb9..5274bc0 100644 --- a/tests/test_data/test_annotations/test_batdetect2.py +++ b/tests/test_data/test_annotations/test_batdetect2.py @@ -254,11 +254,11 @@ class TestLoadBatDetect2Files: assert clip_ann.clip.recording.duration == 5.0 assert len(clip_ann.sound_events) == 1 assert clip_ann.notes[0].message == "Standard notes." - clip_tag = data.find_tag(clip_ann.tags, "class") + clip_tag = data.find_tag(clip_ann.tags, "Class") assert clip_tag is not None assert clip_tag.value == "Myotis" - recording_tag = data.find_tag(clip_ann.clip.recording.tags, "class") + recording_tag = data.find_tag(clip_ann.clip.recording.tags, "Class") assert recording_tag is not None assert recording_tag.value == "Myotis" @@ -271,15 +271,15 @@ class TestLoadBatDetect2Files: 40000, ] - se_class_tag = data.find_tag(se_ann.tags, "class") + se_class_tag = data.find_tag(se_ann.tags, "Class") assert se_class_tag is not None assert se_class_tag.value == "Myotis" - se_event_tag = data.find_tag(se_ann.tags, "event") + se_event_tag = data.find_tag(se_ann.tags, "Call Type") assert se_event_tag is not None assert se_event_tag.value == "Echolocation" - se_individual_tag = data.find_tag(se_ann.tags, "individual") + se_individual_tag = data.find_tag(se_ann.tags, "Individual") assert se_individual_tag is not None assert se_individual_tag.value == "0" @@ -439,7 +439,7 @@ class TestLoadBatDetect2Merged: assert clip_ann.clip.recording.duration == 5.0 assert len(clip_ann.sound_events) == 1 - clip_class_tag = data.find_tag(clip_ann.tags, "class") + clip_class_tag = data.find_tag(clip_ann.tags, "Class") assert clip_class_tag is not None assert clip_class_tag.value == "Myotis" diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index 4580772..151146a 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -98,7 +98,7 @@ def sample_detection_dataset() -> xr.Dataset: expected_freqs = np.array([300, 200]) detection_coords = { "time": ("detection", expected_times), - "freq": ("detection", expected_freqs), + "frequency": ("detection", expected_freqs), } scores_data = np.array([0.9, 0.8], dtype=np.float64) @@ -106,7 +106,7 @@ def sample_detection_dataset() -> xr.Dataset: scores_data, coords=detection_coords, dims=["detection"], - name="scores", + name="score", ) dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32) @@ -183,7 +183,7 @@ def empty_detection_dataset() -> xr.Dataset: ) return xr.Dataset( { - "scores": scores, + "score": scores, "dimensions": dimensions, "classes": classes, "features": features, @@ -206,10 +206,14 @@ def sample_raw_predictions() -> List[RawPrediction]: ) pred1 = RawPrediction( detection_score=0.9, - start_time=20 - 7 / 2, - end_time=20 + 7 / 2, - low_freq=300 - 16 / 2, - high_freq=300 + 16 / 2, + geometry=data.BoundingBox( + coordinates=[ + 20 - 7 / 2, + 300 - 16 / 2, + 20 + 7 / 2, + 300 + 16 / 2, + ] + ), class_scores=pred1_classes, features=pred1_features, ) @@ -224,10 +228,14 @@ def sample_raw_predictions() -> List[RawPrediction]: ) pred2 = RawPrediction( detection_score=0.8, - start_time=10 - 3 / 2, - end_time=10 + 3 / 2, - low_freq=200 - 12 / 2, - high_freq=200 + 12 / 2, + geometry=data.BoundingBox( + coordinates=[ + 10 - 3 / 2, + 200 - 12 / 2, + 10 + 3 / 2, + 200 + 12 / 2, + ] + ), class_scores=pred2_classes, features=pred2_features, ) @@ -242,10 +250,14 @@ def sample_raw_predictions() -> List[RawPrediction]: ) pred3 = RawPrediction( detection_score=0.15, - start_time=5.0, - end_time=6.0, - low_freq=50.0, - high_freq=60.0, + geometry=data.BoundingBox( + coordinates=[ + 5.0, + 50.0, + 6.0, + 60.0, + ] + ), class_scores=pred3_classes, features=pred3_features, ) @@ -267,10 +279,12 @@ def test_convert_xr_dataset_basic( assert isinstance(pred1, RawPrediction) assert pred1.detection_score == 0.9 - assert pred1.start_time == 20 - 7 / 2 - assert pred1.end_time == 20 + 7 / 2 - assert pred1.low_freq == 300 - 16 / 2 - assert pred1.high_freq == 300 + 16 / 2 + assert pred1.geometry.coordinates == [ + 20 - 7 / 2, + 300 - 16 / 2, + 20 + 7 / 2, + 300 + 16 / 2, + ] xr.testing.assert_allclose( pred1.class_scores, sample_detection_dataset["classes"].sel(detection=0), @@ -283,10 +297,12 @@ def test_convert_xr_dataset_basic( assert isinstance(pred2, RawPrediction) assert pred2.detection_score == 0.8 - assert pred2.start_time == 10 - 3 / 2 - assert pred2.end_time == 10 + 3 / 2 - assert pred2.low_freq == 200 - 12 / 2 - assert pred2.high_freq == 200 + 12 / 2 + assert pred2.geometry.coordinates == [ + 10 - 3 / 2, + 200 - 12 / 2, + 10 + 3 / 2, + 200 + 12 / 2, + ] xr.testing.assert_allclose( pred2.class_scores, sample_detection_dataset["classes"].sel(detection=1), @@ -331,15 +347,7 @@ def test_convert_raw_to_sound_event_basic( assert isinstance(se, data.SoundEvent) assert se.recording == sample_recording assert isinstance(se.geometry, data.BoundingBox) - np.testing.assert_allclose( - se.geometry.coordinates, - [ - raw_pred.start_time, - raw_pred.low_freq, - raw_pred.end_time, - raw_pred.high_freq, - ], - ) + assert se.geometry == raw_pred.geometry assert len(se.features) == len(raw_pred.features) feat_dict = {f.term.name: f.value for f in se.features} diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py index b25dbc9..de8ca78 100644 --- a/tests/test_preprocessing/test_spectrogram.py +++ b/tests/test_preprocessing/test_spectrogram.py @@ -1,6 +1,6 @@ import math from pathlib import Path -from typing import Callable +from typing import Callable, Union import numpy as np import pytest @@ -307,7 +307,7 @@ def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray): def test_resize_spectrogram( sample_spec: xr.DataArray, height: int, - resize_factor: float | None, + resize_factor: Union[float, None], expected_freq_size: int, expected_time_factor: float, ): diff --git a/tests/test_targets/test_filtering.py b/tests/test_targets/test_filtering.py index 069f42c..3036143 100644 --- a/tests/test_targets/test_filtering.py +++ b/tests/test_targets/test_filtering.py @@ -4,6 +4,7 @@ from typing import Callable, List, Set import pytest from soundevent import data +from batdetect2.targets import build_targets from batdetect2.targets.filtering import ( FilterConfig, FilterRule, @@ -176,3 +177,34 @@ rules: filter_result = load_filter_from_config(test_config_path) annotation = create_annotation(["tag1", "tag3"]) assert filter_result(annotation) is False + + +def test_default_filtering_over_example_dataset( + example_annotations: List[data.ClipAnnotation], +): + targets = build_targets() + + clip1 = example_annotations[0] + clip2 = example_annotations[1] + clip3 = example_annotations[2] + + assert ( + sum( + [targets.filter(sound_event) for sound_event in clip1.sound_events] + ) + == 9 + ) + + assert ( + sum( + [targets.filter(sound_event) for sound_event in clip2.sound_events] + ) + == 15 + ) + + assert ( + sum( + [targets.filter(sound_event) for sound_event in clip3.sound_events] + ) + == 20 + ) diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index 5560d85..78a6251 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -9,8 +9,8 @@ from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.train.augmentations import ( add_echo, mix_examples, - select_subclip, ) +from batdetect2.train.clips import select_subclip from batdetect2.train.preprocess import generate_train_example from batdetect2.train.types import ClipLabeller @@ -121,7 +121,7 @@ def test_selected_random_subclip_has_the_correct_width( preprocessor=sample_preprocessor, labeller=sample_labeller, ) - subclip = select_subclip(original, width=100) + subclip = select_subclip(original, start=0, span=100) assert subclip["spectrogram"].shape[1] == 100 @@ -142,7 +142,7 @@ def test_add_echo_after_subclip( assert original.sizes["time"] > 512 - subclip = select_subclip(original, width=512) + subclip = select_subclip(original, start=0, span=512) with_echo = add_echo(subclip, preprocessor=sample_preprocessor) assert with_echo.sizes["time"] == 512