mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
2 Commits
7336638fa9
...
a4a5a10da1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4a5a10da1 | ||
|
|
202c6cbab0 |
@ -27,8 +27,8 @@ dependencies = [
|
|||||||
"seaborn>=0.13.2",
|
"seaborn>=0.13.2",
|
||||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||||
"tensorboard>=2.16.2",
|
"tensorboard>=2.16.2",
|
||||||
"torch>=1.13.1,<2.5.0",
|
"torch>=1.13.1",
|
||||||
"torchaudio>=1.13.1,<2.5.0",
|
"torchaudio>=1.13.1",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -120,7 +120,7 @@ def get_sound_event_tags(
|
|||||||
if annotation.event:
|
if annotation.event:
|
||||||
tags.append(data.Tag(key=event_key, value=annotation.event))
|
tags.append(data.Tag(key=event_key, value=annotation.event))
|
||||||
|
|
||||||
if annotation.individual:
|
if annotation.individual is not None:
|
||||||
tags.append(
|
tags.append(
|
||||||
data.Tag(key=individual_key, value=str(annotation.individual))
|
data.Tag(key=individual_key, value=str(annotation.individual))
|
||||||
)
|
)
|
||||||
|
|||||||
@ -331,6 +331,7 @@ _scalers = {
|
|||||||
|
|
||||||
class ScaleAmplitude(torch.nn.Module):
|
class ScaleAmplitude(torch.nn.Module):
|
||||||
def __init__(self, scale: Literal["power", "db"]):
|
def __init__(self, scale: Literal["power", "db"]):
|
||||||
|
super().__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.scaler = _scalers[scale]()
|
self.scaler = _scalers[scale]()
|
||||||
|
|
||||||
|
|||||||
166
tests/test_data/test_predictions/test_parquet.py
Normal file
166
tests/test_data/test_predictions/test_parquet.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.predictions import ParquetOutputConfig, build_output_formatter
|
||||||
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_formatter(sample_targets: TargetProtocol):
|
||||||
|
return build_output_formatter(
|
||||||
|
config=ParquetOutputConfig(),
|
||||||
|
targets=sample_targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_roundtrip(
|
||||||
|
sample_formatter,
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
detections = [
|
||||||
|
RawPrediction(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=list(np.random.uniform(size=[4]))
|
||||||
|
),
|
||||||
|
detection_score=0.5,
|
||||||
|
class_scores=np.random.uniform(
|
||||||
|
size=len(sample_targets.class_names)
|
||||||
|
),
|
||||||
|
features=np.random.uniform(size=32),
|
||||||
|
)
|
||||||
|
for _ in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||||
|
|
||||||
|
path = tmp_path / "predictions.parquet"
|
||||||
|
|
||||||
|
sample_formatter.save(predictions=[prediction], path=path)
|
||||||
|
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
recovered = sample_formatter.load(path=path)
|
||||||
|
|
||||||
|
assert len(recovered) == 1
|
||||||
|
assert recovered[0].clip == prediction.clip
|
||||||
|
|
||||||
|
for recovered_prediction, detection in zip(
|
||||||
|
recovered[0].predictions, detections
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
recovered_prediction.detection_score == detection.detection_score
|
||||||
|
)
|
||||||
|
# Note: floating point comparison might need tolerance, but parquet should preserve float64
|
||||||
|
assert np.allclose(
|
||||||
|
recovered_prediction.class_scores, detection.class_scores
|
||||||
|
)
|
||||||
|
assert np.allclose(recovered_prediction.features, detection.features)
|
||||||
|
assert recovered_prediction.geometry == detection.geometry
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_clips(
|
||||||
|
sample_formatter,
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
# Create a second clip
|
||||||
|
clip2 = clip.model_copy(update={"uuid": uuid4()})
|
||||||
|
|
||||||
|
detections1 = [
|
||||||
|
RawPrediction(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=list(np.random.uniform(size=[4]))
|
||||||
|
),
|
||||||
|
detection_score=0.8,
|
||||||
|
class_scores=np.random.uniform(
|
||||||
|
size=len(sample_targets.class_names)
|
||||||
|
),
|
||||||
|
features=np.random.uniform(size=32),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
detections2 = [
|
||||||
|
RawPrediction(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=list(np.random.uniform(size=[4]))
|
||||||
|
),
|
||||||
|
detection_score=0.9,
|
||||||
|
class_scores=np.random.uniform(
|
||||||
|
size=len(sample_targets.class_names)
|
||||||
|
),
|
||||||
|
features=np.random.uniform(size=32),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
predictions = [
|
||||||
|
BatDetect2Prediction(clip=clip, predictions=detections1),
|
||||||
|
BatDetect2Prediction(clip=clip2, predictions=detections2),
|
||||||
|
]
|
||||||
|
|
||||||
|
path = tmp_path / "multi_predictions.parquet"
|
||||||
|
sample_formatter.save(predictions=predictions, path=path)
|
||||||
|
|
||||||
|
recovered = sample_formatter.load(path=path)
|
||||||
|
|
||||||
|
assert len(recovered) == 2
|
||||||
|
# Order might not be preserved if we don't sort, but implementation appends so it should be
|
||||||
|
# However, let's sort by clip uuid to be safe if needed, or just check existence
|
||||||
|
|
||||||
|
recovered_uuids = {p.clip.uuid for p in recovered}
|
||||||
|
expected_uuids = {clip.uuid, clip2.uuid}
|
||||||
|
assert recovered_uuids == expected_uuids
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_geometry(
|
||||||
|
sample_formatter,
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
# Create a polygon geometry
|
||||||
|
polygon = data.Polygon(
|
||||||
|
coordinates=[[
|
||||||
|
[0.0, 10000.0],
|
||||||
|
[0.1, 20000.0],
|
||||||
|
[0.2, 10000.0],
|
||||||
|
[0.0, 10000.0],
|
||||||
|
]]
|
||||||
|
)
|
||||||
|
|
||||||
|
detections = [
|
||||||
|
RawPrediction(
|
||||||
|
geometry=polygon,
|
||||||
|
detection_score=0.95,
|
||||||
|
class_scores=np.random.uniform(
|
||||||
|
size=len(sample_targets.class_names)
|
||||||
|
),
|
||||||
|
features=np.random.uniform(size=32),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||||
|
|
||||||
|
path = tmp_path / "complex_geometry.parquet"
|
||||||
|
sample_formatter.save(predictions=[prediction], path=path)
|
||||||
|
|
||||||
|
recovered = sample_formatter.load(path=path)
|
||||||
|
|
||||||
|
assert len(recovered) == 1
|
||||||
|
assert len(recovered[0].predictions) == 1
|
||||||
|
|
||||||
|
recovered_pred = recovered[0].predictions[0]
|
||||||
|
|
||||||
|
# Check if geometry is recovered correctly as a Polygon
|
||||||
|
assert isinstance(recovered_pred.geometry, data.Polygon)
|
||||||
|
assert recovered_pred.geometry == polygon
|
||||||
@ -1,6 +1,5 @@
|
|||||||
"""Test suite for model functions."""
|
"""Test suite for model functions."""
|
||||||
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -12,12 +11,6 @@ from batdetect2 import api
|
|||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
|
|
||||||
def test_can_import_model_without_warnings():
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("error")
|
|
||||||
api.load_model()
|
|
||||||
|
|
||||||
|
|
||||||
@settings(deadline=None, max_examples=5)
|
@settings(deadline=None, max_examples=5)
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
def test_can_import_model_without_pickle(duration: float):
|
def test_can_import_model_without_pickle(duration: float):
|
||||||
|
|||||||
@ -40,12 +40,14 @@ def dummy_targets() -> TargetProtocol:
|
|||||||
|
|
||||||
dimension_names = ["width", "height"]
|
dimension_names = ["width", "height"]
|
||||||
|
|
||||||
generic_class_tags = [
|
detection_class_tags = [
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=data.term_from_key(key="detector"), value="batdetect2"
|
term=data.term_from_key(key="detector"), value="batdetect2"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
detection_class_name = "bat"
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation):
|
def filter(self, sound_event: data.SoundEventAnnotation):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -80,7 +82,8 @@ def dummy_targets() -> TargetProtocol:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return DummyTargets()
|
t: TargetProtocol = DummyTargets()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -278,9 +281,9 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_basic(
|
def test_convert_raw_to_sound_event_basic(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions: List[RawPrediction],
|
||||||
sample_recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
"""Test basic conversion, default threshold, multi-label."""
|
"""Test basic conversion, default threshold, multi-label."""
|
||||||
|
|
||||||
@ -308,7 +311,7 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
)
|
)
|
||||||
assert feat_dict["batdetect2:f0"] == 7.0
|
assert feat_dict["batdetect2:f0"] == 7.0
|
||||||
|
|
||||||
generic_tags = dummy_targets.generic_class_tags
|
generic_tags = dummy_targets.detection_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||||
("category", "noise", 0.85),
|
("category", "noise", 0.85),
|
||||||
@ -321,7 +324,9 @@ def test_convert_raw_to_sound_event_basic(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_thresholding(
|
def test_convert_raw_to_sound_event_thresholding(
|
||||||
sample_raw_predictions, sample_recording, dummy_targets
|
sample_raw_predictions: List[RawPrediction],
|
||||||
|
sample_recording: data.Recording,
|
||||||
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
"""Test effect of classification threshold."""
|
"""Test effect of classification threshold."""
|
||||||
raw_pred = sample_raw_predictions[0]
|
raw_pred = sample_raw_predictions[0]
|
||||||
@ -335,7 +340,7 @@ def test_convert_raw_to_sound_event_thresholding(
|
|||||||
top_class_only=False,
|
top_class_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
generic_tags = dummy_targets.generic_class_tags
|
generic_tags = dummy_targets.detection_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||||
("category", "noise", 0.85),
|
("category", "noise", 0.85),
|
||||||
@ -347,9 +352,9 @@ def test_convert_raw_to_sound_event_thresholding(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_no_threshold(
|
def test_convert_raw_to_sound_event_no_threshold(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions: List[RawPrediction],
|
||||||
sample_recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
"""Test when classification_threshold is None."""
|
"""Test when classification_threshold is None."""
|
||||||
raw_pred = sample_raw_predictions[2]
|
raw_pred = sample_raw_predictions[2]
|
||||||
@ -362,7 +367,7 @@ def test_convert_raw_to_sound_event_no_threshold(
|
|||||||
top_class_only=False,
|
top_class_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
generic_tags = dummy_targets.generic_class_tags
|
generic_tags = dummy_targets.detection_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||||
("dwc:scientificName", "Myotis", 0.05),
|
("dwc:scientificName", "Myotis", 0.05),
|
||||||
@ -375,9 +380,9 @@ def test_convert_raw_to_sound_event_no_threshold(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_top_class(
|
def test_convert_raw_to_sound_event_top_class(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions: List[RawPrediction],
|
||||||
sample_recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
"""Test top_class_only=True behavior."""
|
"""Test top_class_only=True behavior."""
|
||||||
raw_pred = sample_raw_predictions[0]
|
raw_pred = sample_raw_predictions[0]
|
||||||
@ -390,7 +395,7 @@ def test_convert_raw_to_sound_event_top_class(
|
|||||||
top_class_only=True,
|
top_class_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
generic_tags = dummy_targets.generic_class_tags
|
generic_tags = dummy_targets.detection_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||||
("category", "noise", 0.85),
|
("category", "noise", 0.85),
|
||||||
@ -402,9 +407,9 @@ def test_convert_raw_to_sound_event_top_class(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions: List[RawPrediction],
|
||||||
sample_recording,
|
sample_recording: data.Recording,
|
||||||
dummy_targets,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
"""Test when all class scores are below the default threshold."""
|
"""Test when all class scores are below the default threshold."""
|
||||||
raw_pred = sample_raw_predictions[2]
|
raw_pred = sample_raw_predictions[2]
|
||||||
@ -417,7 +422,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
|||||||
top_class_only=False,
|
top_class_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
generic_tags = dummy_targets.generic_class_tags
|
generic_tags = dummy_targets.detection_class_tags
|
||||||
expected_tags = {
|
expected_tags = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||||
}
|
}
|
||||||
@ -428,9 +433,9 @@ def test_convert_raw_to_sound_event_all_below_threshold(
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_list_to_clip_basic(
|
def test_convert_raw_list_to_clip_basic(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions: List[RawPrediction],
|
||||||
sample_clip,
|
sample_clip: data.Clip,
|
||||||
dummy_targets,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
"""Test converting a list of RawPredictions to a ClipPrediction."""
|
"""Test converting a list of RawPredictions to a ClipPrediction."""
|
||||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||||
@ -459,7 +464,7 @@ def test_convert_raw_list_to_clip_basic(
|
|||||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||||
for pt in clip_pred.sound_events[2].tags
|
for pt in clip_pred.sound_events[2].tags
|
||||||
}
|
}
|
||||||
generic_tags = dummy_targets.generic_class_tags
|
generic_tags = dummy_targets.detection_class_tags
|
||||||
expected_tags3 = {
|
expected_tags3 = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
|
||||||
}
|
}
|
||||||
@ -480,9 +485,9 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
|
|||||||
|
|
||||||
|
|
||||||
def test_convert_raw_list_to_clip_passes_args(
|
def test_convert_raw_list_to_clip_passes_args(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions: List[RawPrediction],
|
||||||
sample_clip,
|
sample_clip: data.Clip,
|
||||||
dummy_targets,
|
dummy_targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
"""Test that arguments like top_class_only are passed through."""
|
"""Test that arguments like top_class_only are passed through."""
|
||||||
|
|
||||||
@ -500,7 +505,7 @@ def test_convert_raw_list_to_clip_passes_args(
|
|||||||
(pt.tag.term.name, pt.tag.value, pt.score)
|
(pt.tag.term.name, pt.tag.value, pt.score)
|
||||||
for pt in clip_pred.sound_events[0].tags
|
for pt in clip_pred.sound_events[0].tags
|
||||||
}
|
}
|
||||||
generic_tags = dummy_targets.generic_class_tags
|
generic_tags = dummy_targets.detection_class_tags
|
||||||
expected_tags1 = {
|
expected_tags1 = {
|
||||||
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
|
||||||
("category", "noise", 0.85),
|
("category", "noise", 0.85),
|
||||||
@ -508,10 +513,10 @@ def test_convert_raw_list_to_clip_passes_args(
|
|||||||
assert se_pred1_tags == expected_tags1
|
assert se_pred1_tags == expected_tags1
|
||||||
|
|
||||||
|
|
||||||
def test_get_generic_tags_basic(dummy_targets):
|
def test_get_generic_tags_basic(dummy_targets: TargetProtocol):
|
||||||
"""Test creation of generic tags with score."""
|
"""Test creation of generic tags with score."""
|
||||||
detection_score = 0.75
|
detection_score = 0.75
|
||||||
generic_tags = dummy_targets.generic_class_tags
|
generic_tags = dummy_targets.detection_class_tags
|
||||||
predicted_tags = get_generic_tags(
|
predicted_tags = get_generic_tags(
|
||||||
detection_score=detection_score, generic_class_tags=generic_tags
|
detection_score=detection_score, generic_class_tags=generic_tags
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import pytest
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.preprocess import audio
|
from batdetect2.audio import AudioConfig
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_wave(
|
def create_dummy_wave(
|
||||||
@ -56,5 +56,5 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_audio_config() -> audio.AudioConfig:
|
def default_audio_config() -> AudioConfig:
|
||||||
return audio.AudioConfig()
|
return AudioConfig()
|
||||||
|
|||||||
@ -3,15 +3,14 @@ import pytest
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import (
|
||||||
PreprocessingConfig,
|
PreprocessingConfig,
|
||||||
build_preprocessor,
|
build_preprocessor,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
ScaleAmplitudeConfig,
|
ScaleAmplitudeConfig,
|
||||||
SpectralMeanSubstractionConfig,
|
SpectralMeanSubstractionConfig,
|
||||||
SpectrogramConfig,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
DEFAULT_ANCHOR,
|
DEFAULT_ANCHOR,
|
||||||
@ -457,7 +456,7 @@ def test_peak_energy_bbox_mapper_encode(generate_whistle):
|
|||||||
|
|
||||||
# Instantiate the mapper with a preprocessor
|
# Instantiate the mapper with a preprocessor
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
PreprocessingConfig.model_validate({"spectrogram": {"transforms": []}})
|
PreprocessingConfig(spectrogram_transforms=[])
|
||||||
)
|
)
|
||||||
mapper = PeakEnergyBBoxMapper(
|
mapper = PeakEnergyBBoxMapper(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
@ -553,7 +552,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
|||||||
|
|
||||||
# Instantiate the mapper.
|
# Instantiate the mapper.
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
PreprocessingConfig(spectrogram=SpectrogramConfig(transforms=[]))
|
PreprocessingConfig(spectrogram_transforms=[])
|
||||||
)
|
)
|
||||||
audio_loader = build_audio_loader()
|
audio_loader = build_audio_loader()
|
||||||
mapper = PeakEnergyBBoxMapper(
|
mapper = PeakEnergyBBoxMapper(
|
||||||
@ -596,12 +595,10 @@ def test_build_roi_mapper_for_anchor_bbox():
|
|||||||
def test_build_roi_mapper_for_peak_energy_bbox():
|
def test_build_roi_mapper_for_peak_energy_bbox():
|
||||||
# Given
|
# Given
|
||||||
preproc_config = PreprocessingConfig(
|
preproc_config = PreprocessingConfig(
|
||||||
spectrogram=SpectrogramConfig(
|
spectrogram_transforms=[
|
||||||
transforms=[
|
ScaleAmplitudeConfig(scale="db"),
|
||||||
ScaleAmplitudeConfig(scale="db"),
|
SpectralMeanSubstractionConfig(),
|
||||||
SpectralMeanSubstractionConfig(),
|
]
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
config = PeakEnergyBBoxMapperConfig(
|
config = PeakEnergyBBoxMapperConfig(
|
||||||
loading_buffer=0.99,
|
loading_buffer=0.99,
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
from batdetect2.configs import load_config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.train import FullTrainingConfig
|
from batdetect2.core import load_config
|
||||||
|
|
||||||
|
|
||||||
def test_example_config_is_valid(example_data_dir):
|
def test_example_config_is_valid(example_data_dir):
|
||||||
conf = load_config(
|
conf = load_config(
|
||||||
example_data_dir / "config.yaml",
|
example_data_dir / "config.yaml",
|
||||||
schema=FullTrainingConfig,
|
schema=BatDetect2Config,
|
||||||
)
|
)
|
||||||
assert isinstance(conf, FullTrainingConfig)
|
assert isinstance(conf, BatDetect2Config)
|
||||||
|
|||||||
@ -40,7 +40,7 @@ def test_can_save_checkpoint(
|
|||||||
|
|
||||||
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
||||||
|
|
||||||
output1 = module(spec1.unsqueeze(0))
|
output1 = module.model(wav.unsqueeze(0))
|
||||||
output2 = recovered(spec2.unsqueeze(0))
|
output2 = recovered.model(wav.unsqueeze(0))
|
||||||
|
|
||||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.utils.arrays import adjust_width, extend_width
|
from batdetect2.core.arrays import adjust_width, extend_width
|
||||||
|
|
||||||
|
|
||||||
def test_extend_width():
|
def test_extend_width():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user