batdetect2/tests/test_utils/test_arrays.py
2025-11-28 18:10:37 +00:00

60 lines
1.6 KiB
Python

import torch
from batdetect2.core.arrays import adjust_width, extend_width
def test_extend_width():
array = torch.rand([1, 1, 128, 100])
extended = extend_width(array, 100)
assert extended.shape == (1, 1, 128, 200)
extended = extend_width(array, 100, axis=0)
assert extended.shape == (101, 1, 128, 100)
extended = extend_width(array, 100, axis=1)
assert extended.shape == (1, 101, 128, 100)
extended = extend_width(array, 100, axis=2)
assert extended.shape == (1, 1, 228, 100)
extended = extend_width(array, 100, axis=3)
assert extended.shape == (1, 1, 128, 200)
extended = extend_width(array, 100, axis=-2)
assert extended.shape == (1, 1, 228, 100)
def test_extends_with_value():
array = torch.rand([1, 1, 128, 100])
extended = extend_width(array, 100, value=-1)
torch.testing.assert_close(
extended[:, :, :, 100:],
torch.ones_like(array) * -1,
rtol=0,
atol=0,
)
def test_can_adjust_short_width():
array = torch.rand([1, 1, 128, 100])
extended = adjust_width(array, 512)
assert extended.shape == (1, 1, 128, 512)
extended = adjust_width(array, 512, axis=0)
assert extended.shape == (512, 1, 128, 100)
extended = adjust_width(array, 512, axis=1)
assert extended.shape == (1, 512, 128, 100)
extended = adjust_width(array, 512, axis=2)
assert extended.shape == (1, 1, 512, 100)
extended = adjust_width(array, 512, axis=3)
assert extended.shape == (1, 1, 128, 512)
def test_can_adjust_long_width():
array = torch.rand([1, 1, 128, 512])
extended = adjust_width(array, 256)
assert extended.shape == (1, 1, 128, 256)