|
|
@@ -68,13 +68,13 @@ class Trainer(TrainerEventTrigger): |
|
|
|
|
|
|
|
:param model: 训练所需要的模型,目前支持 pytorch; |
|
|
|
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle 等 |
|
|
|
国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 |
|
|
|
国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 |
|
|
|
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; |
|
|
|
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; |
|
|
|
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你 |
|
|
|
可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也 |
|
|
|
可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前 |
|
|
|
自己构造 DDP 的多进程场景); |
|
|
|
可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也 |
|
|
|
可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前 |
|
|
|
自己构造 DDP 的多进程场景); |
|
|
|
device 的可选输入如下所示: |
|
|
|
1. 可选输入:str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中; |
|
|
|
2. torch.device:将模型装载到torch.device上; |
|
|
|