|
|
@@ -282,32 +282,42 @@ class Trainer(TrainerEventTrigger): |
|
|
|
|
|
|
|
:kwargs: |
|
|
|
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: |
|
|
|
|
|
|
|
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 |
|
|
|
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; |
|
|
|
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; |
|
|
|
* set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; |
|
|
|
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; |
|
|
|
* *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数: |
|
|
|
|
|
|
|
* fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: |
|
|
|
|
|
|
|
* is_collective -- 是否使用 paddle 集群式的分布式训练方法,目前仅支持为 True 的情况; |
|
|
|
* role_maker -- 初始化 ``fleet`` 分布式训练 API 时使用的 ``RoleMaker`` |
|
|
|
* 其它用于初始化 ``DataParallel`` 的参数; |
|
|
|
* *data_device* -- 一个具体的 driver 实例中,有 ``model_device`` 和 ``data_device``,前者表示模型所在的设备,后者表示 |
|
|
|
当 ``model_device`` 为 None 时应当将数据迁移到哪个设备; |
|
|
|
当 ``model_device`` 为 None 时应当将数据迁移到哪个设备; |
|
|
|
|
|
|
|
.. note:: |
|
|
|
.. note:: |
|
|
|
|
|
|
|
注意您在绝大部分情况下不会用到该参数! |
|
|
|
|
|
|
|
1. 当 driver 实例的 ``model_device`` 不为 None 时,该参数无效; |
|
|
|
2. 对于 pytorch,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时, |
|
|
|
driver 实例的 ``model_device`` 才会为 None; |
|
|
|
3. 对于 paddle,仅当用户自己通过 ``python -m paddle.distributed.launch`` 并且自己初始化 :func:`~init_parallel_env` 或 |
|
|
|
:meth:`fleet.init` 时,driver 实例的 ``model_device`` 才会为 None; |
|
|
|
|
|
|
|
* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch |
|
|
|
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 |
|
|
|
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 |
|
|
|
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;默认为 ``True``; |
|
|
|
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: |
|
|
|
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 |
|
|
|
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; |
|
|
|
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 |
|
|
|
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; |
|
|
|
|
|
|
|
注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; |
|
|
|
注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; |
|
|
|
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, |
|
|
|
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 |
|
|
|
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 |
|
|
|
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 |
|
|
|
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 |
|
|
|
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 |
|
|
|
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 |
|
|
|
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 |
|
|
|