diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py index d0eac8cd..30e5e668 100644 --- a/tests/core/controllers/test_trainer_jittor.py +++ b/tests/core/controllers/test_trainer_jittor.py @@ -11,6 +11,9 @@ if _NEED_IMPORT_JITTOR: import jittor as jt from jittor import nn, Module from jittor.dataset import Dataset +else: + from fastNLP.core.utils.dummy_class import DummyClass as Module + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class JittorNormalModel_Classification(Module): @@ -68,6 +71,7 @@ class TrainJittorConfig: @pytest.mark.parametrize("driver,device", [("jittor", None)]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) +@pytest.mark.jittor def test_trainer_jittor( driver, device,