diff --git a/batdetect2/detector/parameters.py b/batdetect2/detector/parameters.py index fcc7a8c..3e4f091 100644 --- a/batdetect2/detector/parameters.py +++ b/batdetect2/detector/parameters.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union from pydantic import BaseModel, Field, computed_field -from batdetect2.train.train_utils import ( +from batdetect2.train.legacy.train_utils import ( get_genus_mapping, get_short_class_names, ) diff --git a/tests/test_models/test_inputs.py b/tests/test_models/test_inputs.py index cac3548..af3e34d 100644 --- a/tests/test_models/test_inputs.py +++ b/tests/test_models/test_inputs.py @@ -2,7 +2,7 @@ import torch from hypothesis import given from hypothesis import strategies as st -from batdetect2.models import ModelConfig, ModelType, get_backbone +from batdetect2.models import ModelConfig, ModelType, build_architecture @given( @@ -20,7 +20,7 @@ def test_model_can_process_spectrograms_of_any_width( input = torch.rand([1, 1, input_height, input_width]) - model = get_backbone( + model = build_architecture( config=ModelConfig( name=model_type, # type: ignore input_height=input_height,