mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
35 lines
927 B
Python
35 lines
927 B
Python
import torch
|
|
from hypothesis import given
|
|
from hypothesis import strategies as st
|
|
|
|
from batdetect2.models import ModelConfig, ModelType, build_architecture
|
|
|
|
|
|
@given(
|
|
input_width=st.integers(min_value=50, max_value=1500),
|
|
input_height=st.integers(min_value=1, max_value=16),
|
|
model_type=st.sampled_from(ModelType),
|
|
)
|
|
def test_model_can_process_spectrograms_of_any_width(
|
|
input_width,
|
|
input_height,
|
|
model_type,
|
|
):
|
|
# Input height must be divisible by 8
|
|
input_height = 8 * input_height
|
|
|
|
input = torch.rand([1, 1, input_height, input_width])
|
|
|
|
model = build_architecture(
|
|
config=ModelConfig(
|
|
name=model_type, # type: ignore
|
|
input_height=input_height,
|
|
),
|
|
)
|
|
|
|
output = model(input)
|
|
assert output.shape[0] == 1
|
|
assert output.shape[1] == model.out_channels
|
|
assert output.shape[2] == input_height
|
|
assert output.shape[3] == input_width
|