Browse Source

[bugfix] 支持 Jittor single driver,并添加测试用例 (#413)

* 支持 Jittor single driver

* 提交对 Jittor single driver 的测试用例
tags/v1.0.0alpha
Letian Li GitHub 3 years ago
parent
commit
a294955e32
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 145 additions and 5 deletions
  1. +4
    -3
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  2. +8
    -2
      fastNLP/core/drivers/jittor_driver/single_device.py
  3. +133
    -0
      tests/core/controllers/test_trainer_jittor.py

+ 4
- 3
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -33,11 +33,12 @@ class JittorDriver(Driver):
f"`jittor.Module` type.")
super(JittorDriver, self).__init__(model)

self.model = model

self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()

# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)

@staticmethod
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
# 在fastnlp中实现了JittorDataLoader
@@ -152,4 +153,4 @@ class JittorDriver(Driver):
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx):
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的;
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
# dataloader.batch_sampler.set_epoch(cur_epoch_idx)
# dataloader.batch_sampler.set_epoch(cur_epoch_idx)

+ 8
- 2
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver):
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...')
return fn, None
elif fn in {"train_step", "evaluate_step"}:
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...')
return self.model, self.model.forward
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...')
return self.model, self.model.execute
else:
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.")

@@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver):
return dataloader
else:
return dataloader

def setup(self):
"""
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作
"""
pass

+ 133
- 0
tests/core/controllers/test_trainer_jittor.py View File

@@ -0,0 +1,133 @@
import pytest

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.controllers.trainer import Evaluator
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.progress_callback import RichCallback
from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR

if _NEED_IMPORT_JITTOR:
import jittor as jt
from jittor import nn, Module
from jittor.dataset import Dataset


class JittorNormalModel_Classification(Module):
"""
基础的 Jittor 分类模型
"""

def __init__(self, num_labels, feature_dimension):
super(JittorNormalModel_Classification, self).__init__()
self.num_labels = num_labels

self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
self.ac1 = nn.ReLU()
self.linear2 = nn.Linear(in_features=64, out_features=32)
self.ac2 = nn.ReLU()
self.output = nn.Linear(in_features=32, out_features=num_labels)
self.loss_fn = nn.CrossEntropyLoss()

def execute(self, x):
# It's similar to forward function in Pytorch
x = self.ac1(self.linear1(x))
x = self.ac2(self.linear2(x))
x = self.output(x)
return x

def train_step(self, x, y):
x = self(x)
return {"loss": self.loss_fn(x, y)}

def evaluate_step(self, x, y):
x = self(x)
return {"pred": x, "target": y.reshape((-1,))}


class JittorRandomMaxDataset(Dataset):
def __init__(self, num_samples, num_features):
super(JittorRandomMaxDataset, self).__init__()
self.x = jt.randn((num_samples, num_features))
self.y = self.x.argmax(dim=1)[0]

def __len__(self):
return len(self.y)

def __getitem__(self, item):
return {"x": self.x[item], "y": self.y[item]}


class TrainJittorConfig:
num_labels: int = 5
feature_dimension: int = 5
lr = 1e-1
batch_size: int = 4
shuffle: bool = True


@pytest.mark.parametrize("driver,device", [("jittor", None)])
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]])
def test_trainer_jittor(
driver,
device,
callbacks,
n_epochs=3,
):
model = JittorNormalModel_Classification(
num_labels=TrainJittorConfig.num_labels,
feature_dimension=TrainJittorConfig.feature_dimension
)
optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr)
train_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
val_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
test_dataloader = JittorDataLoader(
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension),
batch_size=TrainJittorConfig.batch_size,
shuffle=True,
# num_workers=4,
)
metrics = {"acc": Accuracy()}

trainer = Trainer(
model=model,
driver=driver,
device=device,
optimizers=optimizer,
train_dataloader=train_dataloader,
evaluate_dataloaders=val_dataloader,
validate_every=-1,
evaluate_fn="evaluate_step",
input_mapping=None,
output_mapping=None,
metrics=metrics,
n_epochs=n_epochs,
callbacks=callbacks,
# progress_bar="rich"
)
trainer.run()

evaluator = Evaluator(
model=model,
driver=driver,
dataloaders=test_dataloader,
evaluate_fn="evaluate_step",
metrics=metrics,
)
metric_results = evaluator.run()
assert metric_results["acc#acc"] > 0.80


if __name__ == "__main__":
# test_trainer_jittor("jittor", None, [RichCallback(100)])
pytest.main(['test_trainer_jittor.py']) # 只运行此模块

Loading…
Cancel
Save