Compare commits

..

2 Commits

Author SHA1 Message Date
mbsantiago
a4a5a10da1 Remove torch version restrictions 2025-11-28 18:10:51 +00:00
mbsantiago
202c6cbab0 Fix tests 2025-11-28 18:10:37 +00:00
11 changed files with 222 additions and 60 deletions

View File

@ -27,8 +27,8 @@ dependencies = [
"seaborn>=0.13.2",
"soundevent[audio,geometry,plot]>=2.9.1",
"tensorboard>=2.16.2",
"torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0",
"torch>=1.13.1",
"torchaudio>=1.13.1",
"torchvision>=0.14.0",
"tqdm>=4.66.2",
]

View File

@ -120,7 +120,7 @@ def get_sound_event_tags(
if annotation.event:
tags.append(data.Tag(key=event_key, value=annotation.event))
if annotation.individual:
if annotation.individual is not None:
tags.append(
data.Tag(key=individual_key, value=str(annotation.individual))
)

View File

@ -331,6 +331,7 @@ _scalers = {
class ScaleAmplitude(torch.nn.Module):
def __init__(self, scale: Literal["power", "db"]):
super().__init__()
self.scale = scale
self.scaler = _scalers[scale]()

View File

@ -0,0 +1,166 @@
from pathlib import Path
from uuid import uuid4
import numpy as np
import pytest
from soundevent import data
from batdetect2.data.predictions import ParquetOutputConfig, build_output_formatter
from batdetect2.typing import (
BatDetect2Prediction,
RawPrediction,
TargetProtocol,
)
@pytest.fixture
def sample_formatter(sample_targets: TargetProtocol):
return build_output_formatter(
config=ParquetOutputConfig(),
targets=sample_targets,
)
def test_roundtrip(
sample_formatter,
clip: data.Clip,
sample_targets: TargetProtocol,
tmp_path: Path,
):
detections = [
RawPrediction(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
detection_score=0.5,
class_scores=np.random.uniform(
size=len(sample_targets.class_names)
),
features=np.random.uniform(size=32),
)
for _ in range(10)
]
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
path = tmp_path / "predictions.parquet"
sample_formatter.save(predictions=[prediction], path=path)
assert path.exists()
recovered = sample_formatter.load(path=path)
assert len(recovered) == 1
assert recovered[0].clip == prediction.clip
for recovered_prediction, detection in zip(
recovered[0].predictions, detections
):
assert (
recovered_prediction.detection_score == detection.detection_score
)
# Note: floating point comparison might need tolerance, but parquet should preserve float64
assert np.allclose(
recovered_prediction.class_scores, detection.class_scores
)
assert np.allclose(recovered_prediction.features, detection.features)
assert recovered_prediction.geometry == detection.geometry
def test_multiple_clips(
sample_formatter,
clip: data.Clip,
sample_targets: TargetProtocol,
tmp_path: Path,
):
# Create a second clip
clip2 = clip.model_copy(update={"uuid": uuid4()})
detections1 = [
RawPrediction(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
detection_score=0.8,
class_scores=np.random.uniform(
size=len(sample_targets.class_names)
),
features=np.random.uniform(size=32),
)
]
detections2 = [
RawPrediction(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
detection_score=0.9,
class_scores=np.random.uniform(
size=len(sample_targets.class_names)
),
features=np.random.uniform(size=32),
)
]
predictions = [
BatDetect2Prediction(clip=clip, predictions=detections1),
BatDetect2Prediction(clip=clip2, predictions=detections2),
]
path = tmp_path / "multi_predictions.parquet"
sample_formatter.save(predictions=predictions, path=path)
recovered = sample_formatter.load(path=path)
assert len(recovered) == 2
# Order might not be preserved if we don't sort, but implementation appends so it should be
# However, let's sort by clip uuid to be safe if needed, or just check existence
recovered_uuids = {p.clip.uuid for p in recovered}
expected_uuids = {clip.uuid, clip2.uuid}
assert recovered_uuids == expected_uuids
def test_complex_geometry(
sample_formatter,
clip: data.Clip,
sample_targets: TargetProtocol,
tmp_path: Path,
):
# Create a polygon geometry
polygon = data.Polygon(
coordinates=[[
[0.0, 10000.0],
[0.1, 20000.0],
[0.2, 10000.0],
[0.0, 10000.0],
]]
)
detections = [
RawPrediction(
geometry=polygon,
detection_score=0.95,
class_scores=np.random.uniform(
size=len(sample_targets.class_names)
),
features=np.random.uniform(size=32),
)
]
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
path = tmp_path / "complex_geometry.parquet"
sample_formatter.save(predictions=[prediction], path=path)
recovered = sample_formatter.load(path=path)
assert len(recovered) == 1
assert len(recovered[0].predictions) == 1
recovered_pred = recovered[0].predictions[0]
# Check if geometry is recovered correctly as a Polygon
assert isinstance(recovered_pred.geometry, data.Polygon)
assert recovered_pred.geometry == polygon

View File

@ -1,6 +1,5 @@
"""Test suite for model functions."""
import warnings
from pathlib import Path
from typing import List
@ -12,12 +11,6 @@ from batdetect2 import api
from batdetect2.detector import parameters
def test_can_import_model_without_warnings():
with warnings.catch_warnings():
warnings.simplefilter("error")
api.load_model()
@settings(deadline=None, max_examples=5)
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_import_model_without_pickle(duration: float):

View File

@ -40,12 +40,14 @@ def dummy_targets() -> TargetProtocol:
dimension_names = ["width", "height"]
generic_class_tags = [
detection_class_tags = [
data.Tag(
term=data.term_from_key(key="detector"), value="batdetect2"
)
]
detection_class_name = "bat"
def filter(self, sound_event: data.SoundEventAnnotation):
return True
@ -80,7 +82,8 @@ def dummy_targets() -> TargetProtocol:
]
)
return DummyTargets()
t: TargetProtocol = DummyTargets()
return t
@pytest.fixture
@ -278,9 +281,9 @@ def sample_raw_predictions() -> List[RawPrediction]:
def test_convert_raw_to_sound_event_basic(
sample_raw_predictions,
sample_recording,
dummy_targets,
sample_raw_predictions: List[RawPrediction],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
"""Test basic conversion, default threshold, multi-label."""
@ -308,7 +311,7 @@ def test_convert_raw_to_sound_event_basic(
)
assert feat_dict["batdetect2:f0"] == 7.0
generic_tags = dummy_targets.generic_class_tags
generic_tags = dummy_targets.detection_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85),
@ -321,7 +324,9 @@ def test_convert_raw_to_sound_event_basic(
def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions, sample_recording, dummy_targets
sample_raw_predictions: List[RawPrediction],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
"""Test effect of classification threshold."""
raw_pred = sample_raw_predictions[0]
@ -335,7 +340,7 @@ def test_convert_raw_to_sound_event_thresholding(
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
generic_tags = dummy_targets.detection_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85),
@ -347,9 +352,9 @@ def test_convert_raw_to_sound_event_thresholding(
def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions,
sample_recording,
dummy_targets,
sample_raw_predictions: List[RawPrediction],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
"""Test when classification_threshold is None."""
raw_pred = sample_raw_predictions[2]
@ -362,7 +367,7 @@ def test_convert_raw_to_sound_event_no_threshold(
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
generic_tags = dummy_targets.detection_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
("dwc:scientificName", "Myotis", 0.05),
@ -375,9 +380,9 @@ def test_convert_raw_to_sound_event_no_threshold(
def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions,
sample_recording,
dummy_targets,
sample_raw_predictions: List[RawPrediction],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
"""Test top_class_only=True behavior."""
raw_pred = sample_raw_predictions[0]
@ -390,7 +395,7 @@ def test_convert_raw_to_sound_event_top_class(
top_class_only=True,
)
generic_tags = dummy_targets.generic_class_tags
generic_tags = dummy_targets.detection_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85),
@ -402,9 +407,9 @@ def test_convert_raw_to_sound_event_top_class(
def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions,
sample_recording,
dummy_targets,
sample_raw_predictions: List[RawPrediction],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
"""Test when all class scores are below the default threshold."""
raw_pred = sample_raw_predictions[2]
@ -417,7 +422,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
top_class_only=False,
)
generic_tags = dummy_targets.generic_class_tags
generic_tags = dummy_targets.detection_class_tags
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
}
@ -428,9 +433,9 @@ def test_convert_raw_to_sound_event_all_below_threshold(
def test_convert_raw_list_to_clip_basic(
sample_raw_predictions,
sample_clip,
dummy_targets,
sample_raw_predictions: List[RawPrediction],
sample_clip: data.Clip,
dummy_targets: TargetProtocol,
):
"""Test converting a list of RawPredictions to a ClipPrediction."""
clip_pred = convert_raw_predictions_to_clip_prediction(
@ -459,7 +464,7 @@ def test_convert_raw_list_to_clip_basic(
(pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[2].tags
}
generic_tags = dummy_targets.generic_class_tags
generic_tags = dummy_targets.detection_class_tags
expected_tags3 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
}
@ -480,9 +485,9 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions,
sample_clip,
dummy_targets,
sample_raw_predictions: List[RawPrediction],
sample_clip: data.Clip,
dummy_targets: TargetProtocol,
):
"""Test that arguments like top_class_only are passed through."""
@ -500,7 +505,7 @@ def test_convert_raw_list_to_clip_passes_args(
(pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[0].tags
}
generic_tags = dummy_targets.generic_class_tags
generic_tags = dummy_targets.detection_class_tags
expected_tags1 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("category", "noise", 0.85),
@ -508,10 +513,10 @@ def test_convert_raw_list_to_clip_passes_args(
assert se_pred1_tags == expected_tags1
def test_get_generic_tags_basic(dummy_targets):
def test_get_generic_tags_basic(dummy_targets: TargetProtocol):
"""Test creation of generic tags with score."""
detection_score = 0.75
generic_tags = dummy_targets.generic_class_tags
generic_tags = dummy_targets.detection_class_tags
predicted_tags = get_generic_tags(
detection_score=detection_score, generic_class_tags=generic_tags
)

View File

@ -6,7 +6,7 @@ import pytest
import soundfile as sf
from soundevent import data
from batdetect2.preprocess import audio
from batdetect2.audio import AudioConfig
def create_dummy_wave(
@ -56,5 +56,5 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
@pytest.fixture
def default_audio_config() -> audio.AudioConfig:
return audio.AudioConfig()
def default_audio_config() -> AudioConfig:
return AudioConfig()

View File

@ -3,15 +3,14 @@ import pytest
import soundfile as sf
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.preprocess import (
PreprocessingConfig,
build_preprocessor,
)
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.preprocess.spectrogram import (
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
SpectrogramConfig,
)
from batdetect2.targets.rois import (
DEFAULT_ANCHOR,
@ -457,7 +456,7 @@ def test_peak_energy_bbox_mapper_encode(generate_whistle):
# Instantiate the mapper with a preprocessor
preprocessor = build_preprocessor(
PreprocessingConfig.model_validate({"spectrogram": {"transforms": []}})
PreprocessingConfig(spectrogram_transforms=[])
)
mapper = PeakEnergyBBoxMapper(
preprocessor=preprocessor,
@ -553,7 +552,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
# Instantiate the mapper.
preprocessor = build_preprocessor(
PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[]))
PreprocessingConfig(spectrogram_transforms=[])
)
audio_loader = build_audio_loader()
mapper = PeakEnergyBBoxMapper(
@ -596,12 +595,10 @@ def test_build_roi_mapper_for_anchor_bbox():
def test_build_roi_mapper_for_peak_energy_bbox():
# Given
preproc_config = PreprocessingConfig(
spectrogram=SpectrogramConfig(
transforms=[
spectrogram_transforms=[
ScaleAmplitudeConfig(scale="db"),
SpectralMeanSubstractionConfig(),
]
),
)
config = PeakEnergyBBoxMapperConfig(
loading_buffer=0.99,

View File

@ -1,10 +1,10 @@
from batdetect2.configs import load_config
from batdetect2.train import FullTrainingConfig
from batdetect2.config import BatDetect2Config
from batdetect2.core import load_config
def test_example_config_is_valid(example_data_dir):
conf = load_config(
example_data_dir / "config.yaml",
schema=FullTrainingConfig,
schema=BatDetect2Config,
)
assert isinstance(conf, FullTrainingConfig)
assert isinstance(conf, BatDetect2Config)

View File

@ -40,7 +40,7 @@ def test_can_save_checkpoint(
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
output1 = module(spec1.unsqueeze(0))
output2 = recovered(spec2.unsqueeze(0))
output1 = module.model(wav.unsqueeze(0))
output2 = recovered.model(wav.unsqueeze(0))
torch.testing.assert_close(output1, output2, rtol=0, atol=0)

View File

@ -1,6 +1,6 @@
import torch
from batdetect2.utils.arrays import adjust_width, extend_width
from batdetect2.core.arrays import adjust_width, extend_width
def test_extend_width():