Browse Source

little change

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
02738b84bf
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      fastNLP/core/controllers/evaluator.py
  2. +2
    -2
      fastNLP/core/controllers/trainer.py

+ 1
- 1
fastNLP/core/controllers/evaluator.py View File

@@ -54,7 +54,7 @@ class Evaluator:
:kwargs: :kwargs:
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数:
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
* *model_use_eval_mode* (``bool``) -- * *model_use_eval_mode* (``bool``) --
是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的


+ 2
- 2
fastNLP/core/controllers/trainer.py View File

@@ -87,7 +87,7 @@ class Trainer(TrainerEventTrigger):


.. node:: .. node::


如果希望使用 ``TorchDDPDriver``
如果希望使用 ``TorchDDPDriver``




:param n_epochs: 训练总共的 epoch 的数量,默认为 20; :param n_epochs: 训练总共的 epoch 的数量,默认为 20;
@@ -134,7 +134,7 @@ class Trainer(TrainerEventTrigger):
:kwargs: :kwargs:
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数:
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
* set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; * set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; * *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;


Loading…
Cancel
Save