mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
60 lines
1.6 KiB
Python
60 lines
1.6 KiB
Python
import torch
|
|
|
|
from batdetect2.core.arrays import adjust_width, extend_width
|
|
|
|
|
|
def test_extend_width():
|
|
array = torch.rand([1, 1, 128, 100])
|
|
extended = extend_width(array, 100)
|
|
assert extended.shape == (1, 1, 128, 200)
|
|
|
|
extended = extend_width(array, 100, axis=0)
|
|
assert extended.shape == (101, 1, 128, 100)
|
|
|
|
extended = extend_width(array, 100, axis=1)
|
|
assert extended.shape == (1, 101, 128, 100)
|
|
|
|
extended = extend_width(array, 100, axis=2)
|
|
assert extended.shape == (1, 1, 228, 100)
|
|
|
|
extended = extend_width(array, 100, axis=3)
|
|
assert extended.shape == (1, 1, 128, 200)
|
|
|
|
extended = extend_width(array, 100, axis=-2)
|
|
assert extended.shape == (1, 1, 228, 100)
|
|
|
|
|
|
def test_extends_with_value():
|
|
array = torch.rand([1, 1, 128, 100])
|
|
extended = extend_width(array, 100, value=-1)
|
|
torch.testing.assert_close(
|
|
extended[:, :, :, 100:],
|
|
torch.ones_like(array) * -1,
|
|
rtol=0,
|
|
atol=0,
|
|
)
|
|
|
|
|
|
def test_can_adjust_short_width():
|
|
array = torch.rand([1, 1, 128, 100])
|
|
extended = adjust_width(array, 512)
|
|
assert extended.shape == (1, 1, 128, 512)
|
|
|
|
extended = adjust_width(array, 512, axis=0)
|
|
assert extended.shape == (512, 1, 128, 100)
|
|
|
|
extended = adjust_width(array, 512, axis=1)
|
|
assert extended.shape == (1, 512, 128, 100)
|
|
|
|
extended = adjust_width(array, 512, axis=2)
|
|
assert extended.shape == (1, 1, 512, 100)
|
|
|
|
extended = adjust_width(array, 512, axis=3)
|
|
assert extended.shape == (1, 1, 128, 512)
|
|
|
|
|
|
def test_can_adjust_long_width():
|
|
array = torch.rand([1, 1, 128, 512])
|
|
extended = adjust_width(array, 256)
|
|
assert extended.shape == (1, 1, 128, 256)
|