diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index bcebc6d0..b751354d 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -33,11 +33,12 @@ class JittorDriver(Driver): f"`jittor.Module` type.") super(JittorDriver, self).__init__(model) - self.model = model - self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) self.grad_scaler = _grad_scaler() + # 用来设置是否关闭 auto_param_call 中的参数匹配问题; + self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) + @staticmethod def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): # 在fastnlp中实现了JittorDataLoader @@ -152,4 +153,4 @@ class JittorDriver(Driver): # def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): # # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; # if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): - # dataloader.batch_sampler.set_epoch(cur_epoch_idx) \ No newline at end of file + # dataloader.batch_sampler.set_epoch(cur_epoch_idx) diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 695e6ec9..ab1e8595 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver): logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') return fn, None elif fn in {"train_step", "evaluate_step"}: - logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') - return self.model, self.model.forward + logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') + return self.model, self.model.execute else: raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") @@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver): return dataloader else: return dataloader + + def setup(self): + """ + 使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 + """ + pass diff --git a/tests/core/controllers/test_trainer_jittor.py b/tests/core/controllers/test_trainer_jittor.py new file mode 100644 index 00000000..d0eac8cd --- /dev/null +++ b/tests/core/controllers/test_trainer_jittor.py @@ -0,0 +1,133 @@ +import pytest + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.controllers.trainer import Evaluator +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR + +if _NEED_IMPORT_JITTOR: + import jittor as jt + from jittor import nn, Module + from jittor.dataset import Dataset + + +class JittorNormalModel_Classification(Module): + """ + 基础的 Jittor 分类模型 + """ + + def __init__(self, num_labels, feature_dimension): + super(JittorNormalModel_Classification, self).__init__() + self.num_labels = num_labels + + self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) + self.ac1 = nn.ReLU() + self.linear2 = nn.Linear(in_features=64, out_features=32) + self.ac2 = nn.ReLU() + self.output = nn.Linear(in_features=32, out_features=num_labels) + self.loss_fn = nn.CrossEntropyLoss() + + def execute(self, x): + # It's similar to forward function in Pytorch + x = self.ac1(self.linear1(x)) + x = self.ac2(self.linear2(x)) + x = self.output(x) + return x + + def train_step(self, x, y): + x = self(x) + return {"loss": self.loss_fn(x, y)} + + def evaluate_step(self, x, y): + x = self(x) + return {"pred": x, "target": y.reshape((-1,))} + + +class JittorRandomMaxDataset(Dataset): + def __init__(self, num_samples, num_features): + super(JittorRandomMaxDataset, self).__init__() + self.x = jt.randn((num_samples, num_features)) + self.y = self.x.argmax(dim=1)[0] + + def __len__(self): + return len(self.y) + + def __getitem__(self, item): + return {"x": self.x[item], "y": self.y[item]} + + +class TrainJittorConfig: + num_labels: int = 5 + feature_dimension: int = 5 + lr = 1e-1 + batch_size: int = 4 + shuffle: bool = True + + +@pytest.mark.parametrize("driver,device", [("jittor", None)]) +@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) +def test_trainer_jittor( + driver, + device, + callbacks, + n_epochs=3, +): + model = JittorNormalModel_Classification( + num_labels=TrainJittorConfig.num_labels, + feature_dimension=TrainJittorConfig.feature_dimension + ) + optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr) + train_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + val_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + test_dataloader = JittorDataLoader( + dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), + batch_size=TrainJittorConfig.batch_size, + shuffle=True, + # num_workers=4, + ) + metrics = {"acc": Accuracy()} + + trainer = Trainer( + model=model, + driver=driver, + device=device, + optimizers=optimizer, + train_dataloader=train_dataloader, + evaluate_dataloaders=val_dataloader, + validate_every=-1, + evaluate_fn="evaluate_step", + input_mapping=None, + output_mapping=None, + metrics=metrics, + n_epochs=n_epochs, + callbacks=callbacks, + # progress_bar="rich" + ) + trainer.run() + + evaluator = Evaluator( + model=model, + driver=driver, + dataloaders=test_dataloader, + evaluate_fn="evaluate_step", + metrics=metrics, + ) + metric_results = evaluator.run() + assert metric_results["acc#acc"] > 0.80 + + +if __name__ == "__main__": + # test_trainer_jittor("jittor", None, [RichCallback(100)]) + pytest.main(['test_trainer_jittor.py']) # 只运行此模块