mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Removing stale tests
This commit is contained in:
parent
0bb0caddea
commit
c80078feee
@ -24,6 +24,9 @@ targets:
|
|||||||
- name: rhifer
|
- name: rhifer
|
||||||
tags:
|
tags:
|
||||||
- value: Rhinolophus ferrumequinum
|
- value: Rhinolophus ferrumequinum
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
generic_class:
|
generic_class:
|
||||||
- key: class
|
- key: class
|
||||||
value: Bat
|
value: Bat
|
||||||
|
|||||||
@ -89,7 +89,7 @@ def build_spectrogram_builder(
|
|||||||
n_fft=n_fft,
|
n_fft=n_fft,
|
||||||
hop_length=hop_length,
|
hop_length=hop_length,
|
||||||
window_fn=get_spectrogram_window(conf.window_fn),
|
window_fn=get_spectrogram_window(conf.window_fn),
|
||||||
center=False,
|
center=True,
|
||||||
power=1,
|
power=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,8 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
from soundevent import data
|
||||||
from soundevent import arrays, data
|
|
||||||
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
add_echo,
|
add_echo,
|
||||||
@ -162,29 +160,10 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
|
|
||||||
subclip = select_subclip(original, start=0, span=0.513)
|
subclip = select_subclip(
|
||||||
assert subclip["spectrogram"].shape[1] == 512
|
original,
|
||||||
|
samplerate=256_000,
|
||||||
|
start=0,
|
||||||
def test_add_echo_after_subclip(
|
duration=0.512,
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
|
||||||
sample_audio_loader: AudioLoader,
|
|
||||||
sample_labeller: ClipLabeller,
|
|
||||||
create_recording: Callable[..., data.Recording],
|
|
||||||
):
|
|
||||||
recording1 = create_recording(duration=2)
|
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
|
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
|
||||||
original = generate_train_example(
|
|
||||||
clip_annotation_1,
|
|
||||||
audio_loader=sample_audio_loader,
|
|
||||||
preprocessor=sample_preprocessor,
|
|
||||||
labeller=sample_labeller,
|
|
||||||
)
|
)
|
||||||
|
assert subclip.spectrogram.shape[1] == 512
|
||||||
assert original.sizes["time"] > 512
|
|
||||||
|
|
||||||
subclip = select_subclip(original, start=0, span=0.513)
|
|
||||||
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
|
|
||||||
|
|
||||||
assert with_echo.sizes["time"] == 512
|
|
||||||
|
|||||||
@ -1,323 +1,3 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
import xarray as xr
|
|
||||||
|
|
||||||
from batdetect2.train.clips import (
|
from batdetect2.train.clips import select_subclip
|
||||||
_compute_expected_width,
|
|
||||||
select_subclip,
|
|
||||||
)
|
|
||||||
|
|
||||||
AUDIO_SAMPLERATE = 48000
|
|
||||||
|
|
||||||
SPEC_SAMPLERATE = 100
|
|
||||||
SPEC_FREQS = 64
|
|
||||||
CLIP_DURATION = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
CLIP_WIDTH_SPEC = int(np.floor(CLIP_DURATION * SPEC_SAMPLERATE))
|
|
||||||
CLIP_WIDTH_AUDIO = int(np.floor(CLIP_DURATION * AUDIO_SAMPLERATE))
|
|
||||||
MAX_EMPTY = 0.2
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_dataset(
|
|
||||||
duration_sec: float,
|
|
||||||
spec_samplerate: int = SPEC_SAMPLERATE,
|
|
||||||
audio_samplerate: int = AUDIO_SAMPLERATE,
|
|
||||||
num_freqs: int = SPEC_FREQS,
|
|
||||||
start_time: float = 0.0,
|
|
||||||
) -> xr.Dataset:
|
|
||||||
"""Creates a sample xr.Dataset for testing."""
|
|
||||||
time_step = 1 / spec_samplerate
|
|
||||||
audio_time_step = 1 / audio_samplerate
|
|
||||||
|
|
||||||
times = np.arange(start_time, start_time + duration_sec, step=time_step)
|
|
||||||
freqs = np.linspace(0, audio_samplerate / 2, num_freqs)
|
|
||||||
audio_times = np.arange(
|
|
||||||
start_time,
|
|
||||||
start_time + duration_sec,
|
|
||||||
step=audio_time_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_time_steps = len(times)
|
|
||||||
num_audio_samples = len(audio_times)
|
|
||||||
spec_shape = (num_freqs, num_time_steps)
|
|
||||||
|
|
||||||
spectrogram_data = np.arange(num_time_steps).reshape(1, -1) * np.ones(
|
|
||||||
(num_freqs, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
spectrogram = xr.DataArray(
|
|
||||||
spectrogram_data.astype(np.float32),
|
|
||||||
coords=[("frequency", freqs), ("time", times)],
|
|
||||||
name="spectrogram",
|
|
||||||
)
|
|
||||||
|
|
||||||
detection = xr.DataArray(
|
|
||||||
np.ones(spec_shape, dtype=np.float32) * 0.5,
|
|
||||||
coords=spectrogram.coords,
|
|
||||||
name="detection",
|
|
||||||
)
|
|
||||||
|
|
||||||
classes = xr.DataArray(
|
|
||||||
np.ones((3, *spec_shape), dtype=np.float32),
|
|
||||||
coords=[
|
|
||||||
("category", ["A", "B", "C"]),
|
|
||||||
("frequency", freqs),
|
|
||||||
("time", times),
|
|
||||||
],
|
|
||||||
name="class",
|
|
||||||
)
|
|
||||||
|
|
||||||
size = xr.DataArray(
|
|
||||||
np.ones((2, *spec_shape), dtype=np.float32),
|
|
||||||
coords=[
|
|
||||||
("dimension", ["height", "width"]),
|
|
||||||
("frequency", freqs),
|
|
||||||
("time", times),
|
|
||||||
],
|
|
||||||
name="size",
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_data = np.arange(num_audio_samples)
|
|
||||||
audio = xr.DataArray(
|
|
||||||
audio_data.astype(np.float32),
|
|
||||||
coords=[("audio_time", audio_times)],
|
|
||||||
name="audio",
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = xr.DataArray([1, 2, 3], dims=["other_dim"], name="metadata")
|
|
||||||
|
|
||||||
return xr.Dataset(
|
|
||||||
{
|
|
||||||
"audio": audio,
|
|
||||||
"spectrogram": spectrogram,
|
|
||||||
"detection": detection,
|
|
||||||
"class": classes,
|
|
||||||
"size": size,
|
|
||||||
"metadata": metadata,
|
|
||||||
}
|
|
||||||
).assign_attrs(
|
|
||||||
samplerate=audio_samplerate,
|
|
||||||
spec_samplerate=spec_samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def long_dataset() -> xr.Dataset:
|
|
||||||
"""Dataset longer than the clip duration."""
|
|
||||||
return create_test_dataset(duration_sec=2.0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def short_dataset() -> xr.Dataset:
|
|
||||||
"""Dataset shorter than the clip duration."""
|
|
||||||
return create_test_dataset(duration_sec=0.3)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def exact_dataset() -> xr.Dataset:
|
|
||||||
"""Dataset exactly the clip duration."""
|
|
||||||
return create_test_dataset(duration_sec=CLIP_DURATION - 1e-9)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def offset_dataset() -> xr.Dataset:
|
|
||||||
"""Dataset starting at a non-zero time."""
|
|
||||||
return create_test_dataset(duration_sec=1.0, start_time=0.5)
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_subclip_within_bounds(long_dataset):
|
|
||||||
start_time = 0.5
|
|
||||||
subclip = select_subclip(
|
|
||||||
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
|
||||||
)
|
|
||||||
expected_width = _compute_expected_width(
|
|
||||||
long_dataset, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "time" in subclip.dims
|
|
||||||
assert subclip.dims["time"] == expected_width
|
|
||||||
assert subclip.spectrogram.dims == ("frequency", "time")
|
|
||||||
assert subclip.spectrogram.shape == (SPEC_FREQS, expected_width)
|
|
||||||
assert subclip.detection.shape == (SPEC_FREQS, expected_width)
|
|
||||||
assert subclip["class"].shape == (3, SPEC_FREQS, expected_width)
|
|
||||||
assert subclip.size.shape == (2, SPEC_FREQS, expected_width)
|
|
||||||
assert subclip.time.min() >= start_time
|
|
||||||
assert (
|
|
||||||
subclip.time.max() <= start_time + CLIP_DURATION + 1 / SPEC_SAMPLERATE
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "metadata" in subclip
|
|
||||||
xr.testing.assert_equal(subclip.metadata, long_dataset.metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_subclip_pad_start(long_dataset):
|
|
||||||
start_time = -0.1
|
|
||||||
subclip = select_subclip(
|
|
||||||
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
|
||||||
)
|
|
||||||
expected_width = _compute_expected_width(
|
|
||||||
long_dataset, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
step = 1 / SPEC_SAMPLERATE
|
|
||||||
expected_pad_samples = int(np.floor(abs(start_time) / step))
|
|
||||||
|
|
||||||
assert subclip.dims["time"] == expected_width
|
|
||||||
assert subclip.spectrogram.shape[1] == expected_width
|
|
||||||
|
|
||||||
assert np.all(
|
|
||||||
subclip.spectrogram.isel(time=slice(0, expected_pad_samples)) == 0
|
|
||||||
)
|
|
||||||
|
|
||||||
assert np.any(
|
|
||||||
subclip.spectrogram.isel(time=slice(expected_pad_samples, None)) != 0
|
|
||||||
)
|
|
||||||
assert subclip.time.min() >= start_time
|
|
||||||
assert subclip.time.max() < start_time + CLIP_DURATION + step
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_subclip_pad_end(long_dataset):
|
|
||||||
original_duration = long_dataset.time.max() - long_dataset.time.min()
|
|
||||||
start_time = original_duration - 0.1
|
|
||||||
subclip = select_subclip(
|
|
||||||
long_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
|
||||||
)
|
|
||||||
expected_width = _compute_expected_width(
|
|
||||||
long_dataset, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
step = 1 / SPEC_SAMPLERATE
|
|
||||||
original_width = long_dataset.dims["time"]
|
|
||||||
expected_pad_samples = expected_width - (
|
|
||||||
original_width - int(np.floor(start_time / step))
|
|
||||||
)
|
|
||||||
|
|
||||||
assert subclip.sizes["time"] == expected_width
|
|
||||||
assert subclip.spectrogram.shape[1] == expected_width
|
|
||||||
|
|
||||||
assert np.all(
|
|
||||||
subclip.spectrogram.isel(
|
|
||||||
time=slice(expected_width - expected_pad_samples, None)
|
|
||||||
)
|
|
||||||
== 0
|
|
||||||
)
|
|
||||||
|
|
||||||
assert np.any(
|
|
||||||
subclip.spectrogram.isel(
|
|
||||||
time=slice(0, expected_width - expected_pad_samples)
|
|
||||||
)
|
|
||||||
!= 0
|
|
||||||
)
|
|
||||||
assert subclip.time.min() >= start_time
|
|
||||||
assert subclip.time.max() < start_time + CLIP_DURATION + step
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_subclip_pad_both_short_dataset(short_dataset):
|
|
||||||
start_time = -0.1
|
|
||||||
subclip = select_subclip(
|
|
||||||
short_dataset, span=CLIP_DURATION, start=start_time, dim="time"
|
|
||||||
)
|
|
||||||
expected_width = _compute_expected_width(
|
|
||||||
short_dataset, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
step = 1 / SPEC_SAMPLERATE
|
|
||||||
|
|
||||||
assert subclip.dims["time"] == expected_width
|
|
||||||
assert subclip.spectrogram.shape[1] == expected_width
|
|
||||||
|
|
||||||
assert subclip.spectrogram.coords["time"][0] == pytest.approx(
|
|
||||||
start_time,
|
|
||||||
abs=step,
|
|
||||||
)
|
|
||||||
assert subclip.spectrogram.coords["time"][-1] == pytest.approx(
|
|
||||||
start_time + CLIP_DURATION - step,
|
|
||||||
abs=2 * step,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_subclip_width_consistency(long_dataset):
|
|
||||||
expected_width = _compute_expected_width(
|
|
||||||
long_dataset, CLIP_DURATION, "time"
|
|
||||||
)
|
|
||||||
step = 1 / SPEC_SAMPLERATE
|
|
||||||
|
|
||||||
subclip_aligned = select_subclip(
|
|
||||||
long_dataset.copy(deep=True),
|
|
||||||
span=CLIP_DURATION,
|
|
||||||
start=5 * step,
|
|
||||||
dim="time",
|
|
||||||
)
|
|
||||||
|
|
||||||
subclip_offset = select_subclip(
|
|
||||||
long_dataset.copy(deep=True),
|
|
||||||
span=CLIP_DURATION,
|
|
||||||
start=5.3 * step,
|
|
||||||
dim="time",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert subclip_aligned.sizes["time"] == expected_width
|
|
||||||
assert subclip_offset.sizes["time"] == expected_width
|
|
||||||
assert subclip_aligned.spectrogram.shape[1] == expected_width
|
|
||||||
assert subclip_offset.spectrogram.shape[1] == expected_width
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_subclip_different_dimension(long_dataset):
|
|
||||||
freq_coords = long_dataset.frequency.values
|
|
||||||
freq_min, freq_max = freq_coords.min(), freq_coords.max()
|
|
||||||
freq_span = (freq_max - freq_min) / 2
|
|
||||||
start_freq = freq_min + freq_span / 2
|
|
||||||
|
|
||||||
subclip = select_subclip(
|
|
||||||
long_dataset, span=freq_span, start=start_freq, dim="frequency"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "frequency" in subclip.dims
|
|
||||||
assert subclip.spectrogram.shape[0] < long_dataset.spectrogram.shape[0]
|
|
||||||
assert subclip.detection.shape[0] < long_dataset.detection.shape[0]
|
|
||||||
assert subclip["class"].shape[1] < long_dataset["class"].shape[1]
|
|
||||||
assert subclip.size.shape[1] < long_dataset.size.shape[1]
|
|
||||||
|
|
||||||
assert subclip.dims["time"] == long_dataset.dims["time"]
|
|
||||||
assert subclip.spectrogram.shape[1] == long_dataset.spectrogram.shape[1]
|
|
||||||
|
|
||||||
xr.testing.assert_equal(subclip.audio, long_dataset.audio)
|
|
||||||
assert subclip.dims["audio_time"] == long_dataset.dims["audio_time"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_subclip_fill_value(short_dataset):
|
|
||||||
fill_value = -999.0
|
|
||||||
subclip = select_subclip(
|
|
||||||
short_dataset,
|
|
||||||
span=CLIP_DURATION,
|
|
||||||
start=0,
|
|
||||||
dim="time",
|
|
||||||
fill_value=fill_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_width = _compute_expected_width(
|
|
||||||
short_dataset,
|
|
||||||
CLIP_DURATION,
|
|
||||||
"time",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert subclip.dims["time"] == expected_width
|
|
||||||
assert np.all(subclip.spectrogram.sel(time=slice(0.3, None)) == fill_value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_subclip_no_overlap_raises_error(long_dataset):
|
|
||||||
original_duration = long_dataset.time.max() - long_dataset.time.min()
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="does not overlap"):
|
|
||||||
select_subclip(
|
|
||||||
long_dataset,
|
|
||||||
span=CLIP_DURATION,
|
|
||||||
start=original_duration + 1.0,
|
|
||||||
dim="time",
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="does not overlap"):
|
|
||||||
select_subclip(
|
|
||||||
long_dataset,
|
|
||||||
span=CLIP_DURATION,
|
|
||||||
start=-1.0 * CLIP_DURATION - 1.0,
|
|
||||||
dim="time",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import torch
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||||
from batdetect2.targets.terms import TagInfo
|
from batdetect2.targets.terms import TagInfo
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
@ -21,63 +20,10 @@ recording = data.Recording(
|
|||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
start_time=0,
|
start_time=0,
|
||||||
end_time=1,
|
end_time=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_generated_heatmaps_have_correct_dimensions(
|
|
||||||
sample_targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
spec = xr.DataArray(
|
|
||||||
data=np.random.rand(100, 100),
|
|
||||||
dims=["time", "frequency"],
|
|
||||||
coords={
|
|
||||||
"time": np.linspace(0, 100, 100, endpoint=False),
|
|
||||||
"frequency": np.linspace(0, 100, 100, endpoint=False),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_annotation = data.ClipAnnotation(
|
|
||||||
clip=clip,
|
|
||||||
sound_events=[
|
|
||||||
data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(
|
|
||||||
coordinates=[10, 10, 20, 20],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
|
||||||
clip_annotation.sound_events,
|
|
||||||
spec,
|
|
||||||
targets=sample_targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(detection_heatmap, xr.DataArray)
|
|
||||||
assert detection_heatmap.shape == (100, 100)
|
|
||||||
assert detection_heatmap.dims == ("time", "frequency")
|
|
||||||
|
|
||||||
assert isinstance(class_heatmap, xr.DataArray)
|
|
||||||
assert class_heatmap.shape == (2, 100, 100)
|
|
||||||
assert class_heatmap.dims == ("category", "time", "frequency")
|
|
||||||
assert class_heatmap.coords["category"].values.tolist() == [
|
|
||||||
"pippip",
|
|
||||||
"myomyo",
|
|
||||||
]
|
|
||||||
|
|
||||||
assert isinstance(size_heatmap, xr.DataArray)
|
|
||||||
assert size_heatmap.shape == (2, 100, 100)
|
|
||||||
assert size_heatmap.dims == ("dimension", "time", "frequency")
|
|
||||||
assert size_heatmap.coords["dimension"].values.tolist() == [
|
|
||||||
"width",
|
|
||||||
"height",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||||
sample_target_config: TargetConfig,
|
sample_target_config: TargetConfig,
|
||||||
pippip_tag: TagInfo,
|
pippip_tag: TagInfo,
|
||||||
@ -93,15 +39,6 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
|
|
||||||
targets = build_targets(config)
|
targets = build_targets(config)
|
||||||
|
|
||||||
spec = xr.DataArray(
|
|
||||||
data=np.random.rand(100, 100),
|
|
||||||
dims=["time", "frequency"],
|
|
||||||
coords={
|
|
||||||
"time": np.linspace(0, 100, 100, endpoint=False),
|
|
||||||
"frequency": np.linspace(0, 100, 100, endpoint=False),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_annotation = data.ClipAnnotation(
|
clip_annotation = data.ClipAnnotation(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
sound_events=[
|
sound_events=[
|
||||||
@ -109,7 +46,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
sound_event=data.SoundEvent(
|
sound_event=data.SoundEvent(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
geometry=data.BoundingBox(
|
geometry=data.BoundingBox(
|
||||||
coordinates=[10, 10, 20, 20],
|
coordinates=[10, 10, 20, 30],
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
||||||
@ -118,12 +55,16 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation.sound_events,
|
clip_annotation,
|
||||||
spec,
|
torch.rand([100, 100]),
|
||||||
|
min_freq=0,
|
||||||
|
max_freq=100,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
|
pippip_index = targets.class_names.index("pippip")
|
||||||
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
|
myomyo_index = targets.class_names.index("myomyo")
|
||||||
assert class_heatmap.sel(time=10, frequency=10, category="pippip") == 1.0
|
assert size_heatmap[0, 10, 10] == 10
|
||||||
assert class_heatmap.sel(time=10, frequency=10, category="myomyo") == 0.0
|
assert size_heatmap[1, 10, 10] == 20
|
||||||
assert detection_heatmap.sel(time=10, frequency=10) == 1.0
|
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
||||||
|
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
||||||
|
assert detection_heatmap[10, 10] == 1.0
|
||||||
|
|||||||
@ -2,11 +2,11 @@ from pathlib import Path
|
|||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
@ -19,7 +19,11 @@ def test_can_initialize_default_module():
|
|||||||
assert isinstance(module, L.LightningModule)
|
assert isinstance(module, L.LightningModule)
|
||||||
|
|
||||||
|
|
||||||
def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
|
def test_can_save_checkpoint(
|
||||||
|
tmp_path: Path,
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
):
|
||||||
module = build_default_module()
|
module = build_default_module()
|
||||||
trainer = L.Trainer()
|
trainer = L.Trainer()
|
||||||
path = tmp_path / "example.ckpt"
|
path = tmp_path / "example.ckpt"
|
||||||
@ -28,15 +32,14 @@ def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
|
|||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
|
||||||
spec1 = module.model.preprocessor.preprocess_clip(clip)
|
wav = torch.tensor(sample_audio_loader.load_clip(clip))
|
||||||
spec2 = recovered.model.preprocessor.preprocess_clip(clip)
|
|
||||||
|
|
||||||
xr.testing.assert_equal(spec1, spec2)
|
spec1 = module.model.preprocessor(wav)
|
||||||
|
spec2 = recovered.model.preprocessor(wav)
|
||||||
|
|
||||||
input1 = torch.tensor([spec1.values]).unsqueeze(0)
|
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
|
||||||
input2 = torch.tensor([spec2.values]).unsqueeze(0)
|
|
||||||
|
|
||||||
output1 = module(input1)
|
output1 = module(spec1.unsqueeze(0).unsqueeze(0))
|
||||||
output2 = recovered(input2)
|
output2 = recovered(spec2.unsqueeze(0).unsqueeze(0))
|
||||||
|
|
||||||
torch.testing.assert_close(output1, output2)
|
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||||
|
|||||||
@ -88,12 +88,12 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
|||||||
)
|
)
|
||||||
predictions = postprocessor.get_predictions(
|
predictions = postprocessor.get_predictions(
|
||||||
ModelOutput(
|
ModelOutput(
|
||||||
detection_probs=encoded["detection_heatmap"]
|
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
||||||
.unsqueeze(0)
|
0
|
||||||
.unsqueeze(0),
|
),
|
||||||
size_preds=encoded["size_heatmap"].unsqueeze(0),
|
size_preds=encoded.size_heatmap.unsqueeze(0),
|
||||||
class_probs=encoded["class_heatmap"].unsqueeze(0),
|
class_probs=encoded.class_heatmap.unsqueeze(0),
|
||||||
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
|
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
||||||
),
|
),
|
||||||
[clip],
|
[clip],
|
||||||
)[0]
|
)[0]
|
||||||
@ -182,12 +182,12 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
|||||||
)
|
)
|
||||||
predictions = postprocessor.get_predictions(
|
predictions = postprocessor.get_predictions(
|
||||||
ModelOutput(
|
ModelOutput(
|
||||||
detection_probs=encoded["detection_heatmap"]
|
detection_probs=encoded.detection_heatmap.unsqueeze(0).unsqueeze(
|
||||||
.unsqueeze(0)
|
0
|
||||||
.unsqueeze(0),
|
),
|
||||||
size_preds=encoded["size_heatmap"].unsqueeze(0),
|
size_preds=encoded.size_heatmap.unsqueeze(0),
|
||||||
class_probs=encoded["class_heatmap"].unsqueeze(0),
|
class_probs=encoded.class_heatmap.unsqueeze(0),
|
||||||
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
|
features=encoded.spectrogram.unsqueeze(0).unsqueeze(0),
|
||||||
),
|
),
|
||||||
[clip],
|
[clip],
|
||||||
)[0]
|
)[0]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user