|
|
@@ -18,7 +18,7 @@ class TestBasicModel(object): |
|
|
|
@pytest.mark.parametrize("cls", [LeNet5, SymbolNet]) |
|
|
|
@pytest.mark.parametrize("criterion", [nn.CrossEntropyLoss]) |
|
|
|
@pytest.mark.parametrize("optimizer", [torch.optim.RMSprop]) |
|
|
|
@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda:0")]) |
|
|
|
@pytest.mark.parametrize("device", [torch.device("cpu")]) |
|
|
|
def test_models(self, num_classes, image_size, cls, criterion, optimizer, device): |
|
|
|
cls = cls(num_classes=num_classes, image_size=image_size) |
|
|
|
criterion = criterion() |
|
|
|