Browse Source

small

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
f7b3fe6a06
1 changed files with 1 additions and 2 deletions
  1. +1
    -2
      tests/core/controllers/test_trainer_jittor.py

+ 1
- 2
tests/core/controllers/test_trainer_jittor.py View File

@@ -15,8 +15,6 @@ if _NEED_IMPORT_JITTOR:
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as Module from fastNLP.core.utils.dummy_class import DummyClass as Module
from fastNLP.core.utils.dummy_class import DummyClass as Dataset from fastNLP.core.utils.dummy_class import DummyClass as Dataset
jt.flags.use_cuda=1



class JittorNormalModel_Classification(Module): class JittorNormalModel_Classification(Module):
""" """
@@ -73,6 +71,7 @@ class TrainJittorConfig:
@pytest.mark.parametrize("driver", ["jittor"]) @pytest.mark.parametrize("driver", ["jittor"])
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) @pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
@pytest.mark.jittor
def test_trainer_jittor( def test_trainer_jittor(
driver, driver,
device, device,


Loading…
Cancel
Save