diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index 94d85f22..0fc0ff3c 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -15,8 +15,6 @@ if _NEED_IMPORT_JITTOR: else: from fastNLP.core.utils.dummy_class import DummyClass as Module from fastNLP.core.utils.dummy_class import DummyClass as Dataset -jt.flags.use_cuda=1 - class JittorNormalModel_Classification(Module): """ @@ -73,6 +71,7 @@ class TrainJittorConfig: @pytest.mark.parametrize("driver", ["jittor"]) @pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) +@pytest.mark.jittor def test_trainer_jittor( driver, device,