mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
23 lines
597 B
Python
23 lines
597 B
Python
from torch import nn
|
|
from torch.optim import SGD, Adam
|
|
|
|
from batdetect2.train.optimizers import OptimizerImportConfig, build_optimizer
|
|
|
|
|
|
def test_build_optimizer_defaults_to_adam():
|
|
model = nn.Linear(4, 2)
|
|
optimizer = build_optimizer(model.parameters())
|
|
|
|
assert isinstance(optimizer, Adam)
|
|
|
|
|
|
def test_build_optimizer_supports_import_config():
|
|
model = nn.Linear(4, 2)
|
|
config = OptimizerImportConfig(
|
|
target="torch.optim.SGD",
|
|
arguments={"lr": 1e-3},
|
|
)
|
|
|
|
optimizer = build_optimizer(model.parameters(), config=config)
|
|
assert isinstance(optimizer, SGD)
|