Browse Source

为 test_trainer_jittor 添加 DummyClass 和 pytest.mark.jittor

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
2a44af2519
1 changed files with 4 additions and 0 deletions
  1. +4
    -0
      tests/core/controllers/test_trainer_jittor.py

+ 4
- 0
tests/core/controllers/test_trainer_jittor.py View File

@@ -11,6 +11,9 @@ if _NEED_IMPORT_JITTOR:
import jittor as jt import jittor as jt
from jittor import nn, Module from jittor import nn, Module
from jittor.dataset import Dataset from jittor.dataset import Dataset
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module
from fastNLP.core.utils.dummy_class import DummyClass as Dataset




class JittorNormalModel_Classification(Module): class JittorNormalModel_Classification(Module):
@@ -68,6 +71,7 @@ class TrainJittorConfig:


@pytest.mark.parametrize("driver,device", [("jittor", None)]) @pytest.mark.parametrize("driver,device", [("jittor", None)])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) @pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
@pytest.mark.jittor
def test_trainer_jittor( def test_trainer_jittor(
driver, driver,
device, device,


Loading…
Cancel
Save