* 修复 tests/core/controllers/_test_trainer_jittor.py,使其可以正常运行 Trainer 并不接收 validate_dataloaders 参数,改为 evaluate_dataloaders 即可。 * jittor single driver 支持 cpu 和 gpu 的切换tags/v1.0.0alpha
@@ -8,7 +8,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor | |||||
import jittor as jt | |||||
__all__ = [ | __all__ = [ | ||||
"JittorSingleDriver", | "JittorSingleDriver", | ||||
@@ -105,6 +105,9 @@ class JittorSingleDriver(JittorDriver): | |||||
def setup(self): | def setup(self): | ||||
""" | """ | ||||
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 | |||||
支持 cpu 和 gpu 的切换 | |||||
""" | """ | ||||
pass | |||||
if self.model_device in ["cpu", None]: | |||||
jt.flags.use_cuda = 0 # 使用 cpu | |||||
else: | |||||
jt.flags.use_cuda = 1 # 使用 cuda |
@@ -225,7 +225,7 @@ if __name__ == "__main__": | |||||
device=[0,1,2,3,4], | device=[0,1,2,3,4], | ||||
optimizers=optimizer, | optimizers=optimizer, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
validate_dataloaders=val_dataloader, | |||||
evaluate_dataloaders=val_dataloader, | |||||
validate_every=-1, | validate_every=-1, | ||||
input_mapping=None, | input_mapping=None, | ||||
output_mapping=None, | output_mapping=None, | ||||
@@ -69,7 +69,8 @@ class TrainJittorConfig: | |||||
shuffle: bool = True | shuffle: bool = True | ||||
@pytest.mark.parametrize("driver,device", [("jittor", None)]) | |||||
@pytest.mark.parametrize("driver", ["jittor"]) | |||||
@pytest.mark.parametrize("device", ["cpu", 1]) | |||||
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | @pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | ||||
@pytest.mark.jittor | @pytest.mark.jittor | ||||
def test_trainer_jittor( | def test_trainer_jittor( | ||||
@@ -134,4 +135,5 @@ def test_trainer_jittor( | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
# test_trainer_jittor("jittor", None, [RichCallback(100)]) | # test_trainer_jittor("jittor", None, [RichCallback(100)]) | ||||
# test_trainer_jittor("jittor", 1, [RichCallback(100)]) | |||||
pytest.main(['test_trainer_jittor.py']) # 只运行此模块 | pytest.main(['test_trainer_jittor.py']) # 只运行此模块 |