|
@@ -15,8 +15,6 @@ if _NEED_IMPORT_JITTOR: |
|
|
else: |
|
|
else: |
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Module |
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Module |
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Dataset |
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Dataset |
|
|
jt.flags.use_cuda=1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JittorNormalModel_Classification(Module): |
|
|
class JittorNormalModel_Classification(Module): |
|
|
""" |
|
|
""" |
|
@@ -73,6 +71,7 @@ class TrainJittorConfig: |
|
|
@pytest.mark.parametrize("driver", ["jittor"]) |
|
|
@pytest.mark.parametrize("driver", ["jittor"]) |
|
|
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) |
|
|
@pytest.mark.parametrize("device", ["cpu", "gpu", "cuda", None]) |
|
|
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) |
|
|
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) |
|
|
|
|
|
@pytest.mark.jittor |
|
|
def test_trainer_jittor( |
|
|
def test_trainer_jittor( |
|
|
driver, |
|
|
driver, |
|
|
device, |
|
|
device, |
|
|