|
|
@@ -87,7 +87,7 @@ class Trainer(TrainerEventTrigger): |
|
|
|
|
|
|
|
.. node:: |
|
|
|
|
|
|
|
如果希望使用 ``TorchDDPDriver`` |
|
|
|
如果希望使用 ``TorchDDPDriver`` |
|
|
|
|
|
|
|
|
|
|
|
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; |
|
|
@@ -134,7 +134,7 @@ class Trainer(TrainerEventTrigger): |
|
|
|
:kwargs: |
|
|
|
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: |
|
|
|
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 |
|
|
|
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等; |
|
|
|
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; |
|
|
|
* set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; |
|
|
|
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; |
|
|
|
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; |
|
|
|