mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Fix tests
This commit is contained in:
parent
7336638fa9
commit
202c6cbab0
@ -120,7 +120,7 @@ def get_sound_event_tags(
|
||||
if annotation.event:
|
||||
tags.append(data.Tag(key=event_key, value=annotation.event))
|
||||
|
||||
if annotation.individual:
|
||||
if annotation.individual is not None:
|
||||
tags.append(
|
||||
data.Tag(key=individual_key, value=str(annotation.individual))
|
||||
)
|
||||
|
||||
@ -331,6 +331,7 @@ _scalers = {
|
||||
|
||||
class ScaleAmplitude(torch.nn.Module):
|
||||
def __init__(self, scale: Literal["power", "db"]):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.scaler = _scalers[scale]()
|
||||
|
||||
|
||||
166
tests/test_data/test_predictions/test_parquet.py
Normal file
166
tests/test_data/test_predictions/test_parquet.py
Normal file
@ -0,0 +1,166 @@
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.predictions import ParquetOutputConfig, build_output_formatter
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_formatter(sample_targets: TargetProtocol):
|
||||
return build_output_formatter(
|
||||
config=ParquetOutputConfig(),
|
||||
targets=sample_targets,
|
||||
)
|
||||
|
||||
|
||||
def test_roundtrip(
|
||||
sample_formatter,
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
tmp_path: Path,
|
||||
):
|
||||
detections = [
|
||||
RawPrediction(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=list(np.random.uniform(size=[4]))
|
||||
),
|
||||
detection_score=0.5,
|
||||
class_scores=np.random.uniform(
|
||||
size=len(sample_targets.class_names)
|
||||
),
|
||||
features=np.random.uniform(size=32),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||
|
||||
path = tmp_path / "predictions.parquet"
|
||||
|
||||
sample_formatter.save(predictions=[prediction], path=path)
|
||||
|
||||
assert path.exists()
|
||||
|
||||
recovered = sample_formatter.load(path=path)
|
||||
|
||||
assert len(recovered) == 1
|
||||
assert recovered[0].clip == prediction.clip
|
||||
|
||||
for recovered_prediction, detection in zip(
|
||||
recovered[0].predictions, detections
|
||||
):
|
||||
assert (
|
||||
recovered_prediction.detection_score == detection.detection_score
|
||||
)
|
||||
# Note: floating point comparison might need tolerance, but parquet should preserve float64
|
||||
assert np.allclose(
|
||||
recovered_prediction.class_scores, detection.class_scores
|
||||
)
|
||||
assert np.allclose(recovered_prediction.features, detection.features)
|
||||
assert recovered_prediction.geometry == detection.geometry
|
||||
|
||||
|
||||
def test_multiple_clips(
|
||||
sample_formatter,
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
tmp_path: Path,
|
||||
):
|
||||
# Create a second clip
|
||||
clip2 = clip.model_copy(update={"uuid": uuid4()})
|
||||
|
||||
detections1 = [
|
||||
RawPrediction(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=list(np.random.uniform(size=[4]))
|
||||
),
|
||||
detection_score=0.8,
|
||||
class_scores=np.random.uniform(
|
||||
size=len(sample_targets.class_names)
|
||||
),
|
||||
features=np.random.uniform(size=32),
|
||||
)
|
||||
]
|
||||
|
||||
detections2 = [
|
||||
RawPrediction(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=list(np.random.uniform(size=[4]))
|
||||
),
|
||||
detection_score=0.9,
|
||||
class_scores=np.random.uniform(
|
||||
size=len(sample_targets.class_names)
|
||||
),
|
||||
features=np.random.uniform(size=32),
|
||||
)
|
||||
]
|
||||
|
||||
predictions = [
|
||||
BatDetect2Prediction(clip=clip, predictions=detections1),
|
||||
BatDetect2Prediction(clip=clip2, predictions=detections2),
|
||||
]
|
||||
|
||||
path = tmp_path / "multi_predictions.parquet"
|
||||
sample_formatter.save(predictions=predictions, path=path)
|
||||
|
||||
recovered = sample_formatter.load(path=path)
|
||||
|
||||
assert len(recovered) == 2
|
||||
# Order might not be preserved if we don't sort, but implementation appends so it should be
|
||||
# However, let's sort by clip uuid to be safe if needed, or just check existence
|
||||
|
||||
recovered_uuids = {p.clip.uuid for p in recovered}
|
||||
expected_uuids = {clip.uuid, clip2.uuid}
|
||||
assert recovered_uuids == expected_uuids
|
||||
|
||||
|
||||
def test_complex_geometry(
|
||||
sample_formatter,
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
tmp_path: Path,
|
||||
):
|
||||
# Create a polygon geometry
|
||||
polygon = data.Polygon(
|
||||
coordinates=[[
|
||||
[0.0, 10000.0],
|
||||
[0.1, 20000.0],
|
||||
[0.2, 10000.0],
|
||||
[0.0, 10000.0],
|
||||
]]
|
||||
)
|
||||
|
||||
detections = [
|
||||
RawPrediction(
|
||||
geometry=polygon,
|
||||
detection_score=0.95,
|
||||
class_scores=np.random.uniform(
|
||||
size=len(sample_targets.class_names)
|
||||
),
|
||||
features=np.random.uniform(size=32),
|
||||
)
|
||||
]
|
||||
|
||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||
|
||||
path = tmp_path / "complex_geometry.parquet"
|
||||
sample_formatter.save(predictions=[prediction], path=path)
|
||||
|
||||
recovered = sample_formatter.load(path=path)
|
||||
|
||||
assert len(recovered) == 1
|
||||
assert len(recovered[0].predictions) == 1
|
||||
|
||||
recovered_pred = recovered[0].predictions[0]
|
||||
|
||||
# Check if geometry is recovered correctly as a Polygon
|
||||
assert isinstance(recovered_pred.geometry, data.Polygon)
|
||||
assert recovered_pred.geometry == polygon
|
||||
@ -1,6 +1,5 @@
|
||||
"""Test suite for model functions."""
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
@ -12,12 +11,6 @@ from batdetect2 import api
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
|
||||
def test_can_import_model_without_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
api.load_model()
|
||||
|
||||
|
||||
@settings(deadline=None, max_examples=5)
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_can_import_model_without_pickle(duration: float):
|
||||
|
||||
@ -40,12 +40,14 @@ def dummy_targets() -> TargetProtocol:
|
||||
|
||||
dimension_names = ["width", "height"]
|
||||
|
||||
generic_class_tags = [
|
||||
detection_class_tags = [
|
||||
data.Tag(
|
||||
term=data.term_from_key(key="detector"), value="batdetect2"
|
||||
)
|
||||
]
|
||||
|
||||
detection_class_name = "bat"
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation):
|
||||
return True
|
||||
|
||||
@ -80,7 +82,8 @@ def dummy_targets() -> TargetProtocol:
|
||||
]
|
||||
)
|
||||
|
||||
return DummyTargets()
|
||||
t: TargetProtocol = DummyTargets()
|
||||
return t
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -278,9 +281,9 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_basic(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_targets,
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
"""Test basic conversion, default threshold, multi-label."""
|
||||
|
||||
@ -308,7 +311,7 @@ def test_convert_raw_to_sound_event_basic(
|
||||
)
|
||||
assert feat_dict["batdetect2:f0"] == 7.0
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
generic_tags = dummy_targets.detection_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("category", "noise", 0.85),
|
||||
@ -321,7 +324,9 @@ def test_convert_raw_to_sound_event_basic(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_thresholding(
|
||||
sample_raw_predictions, sample_recording, dummy_targets
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
"""Test effect of classification threshold."""
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
@ -335,7 +340,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
generic_tags = dummy_targets.detection_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("category", "noise", 0.85),
|
||||
@ -347,9 +352,9 @@ def test_convert_raw_to_sound_event_thresholding(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_no_threshold(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_targets,
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
"""Test when classification_threshold is None."""
|
||||
raw_pred = sample_raw_predictions[2]
|
||||
@ -362,7 +367,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
generic_tags = dummy_targets.detection_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
("dwc:scientificName", "Myotis", 0.05),
|
||||
@ -375,9 +380,9 @@ def test_convert_raw_to_sound_event_no_threshold(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_top_class(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_targets,
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
"""Test top_class_only=True behavior."""
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
@ -390,7 +395,7 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
top_class_only=True,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
generic_tags = dummy_targets.detection_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("category", "noise", 0.85),
|
||||
@ -402,9 +407,9 @@ def test_convert_raw_to_sound_event_top_class(
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_targets,
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_recording: data.Recording,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
"""Test when all class scores are below the default threshold."""
|
||||
raw_pred = sample_raw_predictions[2]
|
||||
@ -417,7 +422,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
generic_tags = dummy_targets.detection_class_tags
|
||||
expected_tags = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
}
|
||||
@ -428,9 +433,9 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_basic(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
dummy_targets,
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_clip: data.Clip,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
"""Test converting a list of RawPredictions to a ClipPrediction."""
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
@ -459,7 +464,7 @@ def test_convert_raw_list_to_clip_basic(
|
||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[2].tags
|
||||
}
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
generic_tags = dummy_targets.detection_class_tags
|
||||
expected_tags3 = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||
}
|
||||
@ -480,9 +485,9 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_passes_args(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
dummy_targets,
|
||||
sample_raw_predictions: List[RawPrediction],
|
||||
sample_clip: data.Clip,
|
||||
dummy_targets: TargetProtocol,
|
||||
):
|
||||
"""Test that arguments like top_class_only are passed through."""
|
||||
|
||||
@ -500,7 +505,7 @@ def test_convert_raw_list_to_clip_passes_args(
|
||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[0].tags
|
||||
}
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
generic_tags = dummy_targets.detection_class_tags
|
||||
expected_tags1 = {
|
||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||
("category", "noise", 0.85),
|
||||
@ -508,10 +513,10 @@ def test_convert_raw_list_to_clip_passes_args(
|
||||
assert se_pred1_tags == expected_tags1
|
||||
|
||||
|
||||
def test_get_generic_tags_basic(dummy_targets):
|
||||
def test_get_generic_tags_basic(dummy_targets: TargetProtocol):
|
||||
"""Test creation of generic tags with score."""
|
||||
detection_score = 0.75
|
||||
generic_tags = dummy_targets.generic_class_tags
|
||||
generic_tags = dummy_targets.detection_class_tags
|
||||
predicted_tags = get_generic_tags(
|
||||
detection_score=detection_score, generic_class_tags=generic_tags
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess import audio
|
||||
from batdetect2.audio import AudioConfig
|
||||
|
||||
|
||||
def create_dummy_wave(
|
||||
@ -56,5 +56,5 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_audio_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig()
|
||||
def default_audio_config() -> AudioConfig:
|
||||
return AudioConfig()
|
||||
|
||||
@ -3,15 +3,14 @@ import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
build_preprocessor,
|
||||
)
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectrogramConfig,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
DEFAULT_ANCHOR,
|
||||
@ -457,7 +456,7 @@ def test_peak_energy_bbox_mapper_encode(generate_whistle):
|
||||
|
||||
# Instantiate the mapper with a preprocessor
|
||||
preprocessor = build_preprocessor(
|
||||
PreprocessingConfig.model_validate({"spectrogram": {"transforms": []}})
|
||||
PreprocessingConfig(spectrogram_transforms=[])
|
||||
)
|
||||
mapper = PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
@ -553,7 +552,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
||||
|
||||
# Instantiate the mapper.
|
||||
preprocessor = build_preprocessor(
|
||||
PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[]))
|
||||
PreprocessingConfig(spectrogram_transforms=[])
|
||||
)
|
||||
audio_loader = build_audio_loader()
|
||||
mapper = PeakEnergyBBoxMapper(
|
||||
@ -596,12 +595,10 @@ def test_build_roi_mapper_for_anchor_bbox():
|
||||
def test_build_roi_mapper_for_peak_energy_bbox():
|
||||
# Given
|
||||
preproc_config = PreprocessingConfig(
|
||||
spectrogram=SpectrogramConfig(
|
||||
transforms=[
|
||||
spectrogram_transforms=[
|
||||
ScaleAmplitudeConfig(scale="db"),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
),
|
||||
)
|
||||
config = PeakEnergyBBoxMapperConfig(
|
||||
loading_buffer=0.99,
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from batdetect2.configs import load_config
|
||||
from batdetect2.train import FullTrainingConfig
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import load_config
|
||||
|
||||
|
||||
def test_example_config_is_valid(example_data_dir):
|
||||
conf = load_config(
|
||||
example_data_dir / "config.yaml",
|
||||
schema=FullTrainingConfig,
|
||||
schema=BatDetect2Config,
|
||||
)
|
||||
assert isinstance(conf, FullTrainingConfig)
|
||||
assert isinstance(conf, BatDetect2Config)
|
||||
|
||||
@ -40,7 +40,7 @@ def test_can_save_checkpoint(
|
||||
|
||||
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
||||
|
||||
output1 = module(spec1.unsqueeze(0))
|
||||
output2 = recovered(spec2.unsqueeze(0))
|
||||
output1 = module.model(wav.unsqueeze(0))
|
||||
output2 = recovered.model(wav.unsqueeze(0))
|
||||
|
||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from batdetect2.utils.arrays import adjust_width, extend_width
|
||||
from batdetect2.core.arrays import adjust_width, extend_width
|
||||
|
||||
|
||||
def test_extend_width():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user