Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
5d8d577399
9 changed files with 303 additions and 147 deletions
  1. +2
    -0
      fastNLP/core/controllers/evaluator.py
  2. +276
    -64
      fastNLP/core/controllers/trainer.py
  3. +5
    -3
      fastNLP/core/drivers/paddle_driver/fleet.py
  4. +7
    -22
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  5. +3
    -1
      fastNLP/core/drivers/paddle_driver/single_device.py
  6. +1
    -1
      tests/core/controllers/_test_trainer_fleet.py
  7. +1
    -1
      tests/core/controllers/_test_trainer_fleet_outside.py
  8. +2
    -3
      tests/core/controllers/test_trainer_paddle.py
  9. +6
    -52
      tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py

+ 2
- 0
fastNLP/core/controllers/evaluator.py View File

@@ -56,6 +56,8 @@ class Evaluator:
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
注意如果 model_device 为 None,那么 data_device 不会起作用;
* *model_use_eval_mode* (``bool``) -- * *model_use_eval_mode* (``bool``) --
是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的
dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论


+ 276
- 64
fastNLP/core/controllers/trainer.py View File

@@ -1,4 +1,10 @@
from typing import Union, Optional, List, Callable, Dict, Sequence, BinaryIO, IO
"""
``Trainer`` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产
的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现
模型部分,而将训练层面的逻辑完全地交给 fastNLP;
"""

from typing import Union, Optional, List, Callable, Dict, BinaryIO
from functools import partial from functools import partial
from collections import defaultdict from collections import defaultdict
import copy import copy
@@ -7,7 +13,6 @@ from dataclasses import is_dataclass
import os import os
from pathlib import Path from pathlib import Path
import io import io
import inspect


__all__ = [ __all__ = [
'Trainer', 'Trainer',
@@ -62,12 +67,20 @@ class Trainer(TrainerEventTrigger):
**kwargs **kwargs
): ):
r""" r"""
`Trainer` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产
的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需
要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP;
:param model: 训练所需要的模型,例如 ``torch.nn.Module``;

.. note::

当使用 pytorch 时,注意参数 ``model`` 在大多数情况下为 ``nn.Module``。但是您仍能够通过使用一些特定的组合来使用情况,如下所示:


:param model: 训练所需要的模型,目前支持 pytorch;
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch",],之后我们会加入 jittor、paddle 等
1. 当希望使用 ``DataParallel`` 时,您应当使用 ``TorchSingleDriver``,意味着您在初始化 ``Trainer`` 时参数 ``device`` 不应当为
一个 ``List``;

2. 当您选择自己初始化 ``init_process_group`` 时(这种情况要求您传入的 ``model`` 参数一定为 ``DistributedDataParallel``),
您应当使用 ``TorchDDPDriver``,意味着您需要通过 ``python -m torch.distributed.launch`` 的方式来启动训练,此时参数 ``device``
应当设置为 None(此时我们会忽略该参数),具体见下面对于参数 ``device`` 的更详细的解释。

:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch"],之后我们会加入 jittor、paddle 等
国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device`` 国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device``
的设置; 的设置;
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict;
@@ -80,79 +93,248 @@ class Trainer(TrainerEventTrigger):
device 的可选输入如下所示: device 的可选输入如下所示:


* *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等; * *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等;
* *torch.device*: 将模型装载到 ``torch.device`` 上
* *torch.device*: 例如 'torch.device("cuda:0")'
* *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`;
* *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值;
* *None*: 为None则不对模型进行任何处理;
* *None*: 仅当用户自己通过训练框架提供的并行训练启动脚本开启 ddp 进程时为 None;

.. note::

如果希望使用 ``TorchDDPDriver``,在初始化 ``Trainer`` 时您应当使用::

Trainer(driver="torch", device=[0, 1])

注意如果这时 ``device=[0]``,我们仍旧会使用 ``TorchDDPDriver``。

如果希望使用 ``TorchSingleDriver``,则在初始化 ``Trainer`` 时您应当使用::

Trainer(driver="torch", device=0)


.. node::
.. warning::


如果希望使用 ``TorchDDPDriver``
注意参数 ``device`` 仅当您通过 pytorch 或者其它训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``!


例如,当您使用::

python -m torch.distributed.launch --nproc_per_node 2 train.py

来使用 ``TorchDDPDriver`` 时,此时参数 ``device`` 不再有效(不管您是否自己初始化 ``init_process_group``),我们将直接
通过 ``torch.device(f"cuda:{local_rank}")`` 来获取当前进程所使用的的具体的 gpu 设备。因此此时您需要使用 ``os.environ["CUDA_VISIBLE_DEVICES"]``
来指定要使用的具体的 gpu 设备。

另一点需要注意的是,当您没有选择自己初始化 ``init_process_group`` 时,我们仍旧会帮助您把模型和数据迁移到当前进程所使用的
具体的 gpu 设备上。但是如果您选择自己在 ``Trainer`` 初始化前(意味着在 ``driver`` 的 ``setup`` 前)初始化 ``init_process_group``,
那么对于模型的迁移应当完全由您自己来完成。此时对于数据的迁移,如果您在 ``Trainer`` 初始化时指定了参数 ``data_device``,那么
我们会将数据迁移到 ``data_device`` 上;如果其为 None,那么将数据迁移到正确的设备上应当由您自己来完成。

对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`fastNLP.core.drivers.torch_driver.TorchDDPDriver`。


:param n_epochs: 训练总共的 epoch 的数量,默认为 20; :param n_epochs: 训练总共的 epoch 的数量,默认为 20;
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
为 None;
:param batch_step_fn: 定制每次 train batch 执行的函数。该函数应接受两个参数为 `trainer` 和`batch`,不需要要返回值;可以
参考 fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop中的batch_step_fn函数。
:param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`,
不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。
:param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`;
默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,
则使用模型默认的前向传播函数。
:param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似;
注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None);如果该值为 None ,将首先尝试寻找模型中是否有
evaluate_step 这个函数,如果没有则使用 forward 函数。
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类;
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()};
:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch evaluate 一次;为正数则表示每隔几个 batch evaluate 一次;
为函数时表示用户自己传入的用于控制 Trainer 中的 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它
类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里;
注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时);
如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型,
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`;
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
:param fp16: 是否开启混合精度训练;默认为 False;
:param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。
:param larger_better: monitor 的值是否是越大越好。
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
:param kwargs: 一些其它的可能需要的参数,见下方的说明
为 None;
:param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``,
不需要要返回值;更详细的使用位置和说明请见 :meth:`fastNLP.core.controllers.TrainBatchLoop.batch_step_fn`;
:param evaluate_batch_step_fn: 定制每次验证时前向运行一个 batch 的数据所执行的函数。该函数应接受的两个参数为 ``evaluator`` 和 ``batch``,
不需要有返回值;可以参考 :meth:`fastNLP.core.controllers.EvaluateBatchLoop.batch_step_fn`;
:param train_fn: 用来控制 ``Trainer`` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 ``train_step`` 还是 ``forward``;
默认为 ``None``,如果该值是 ``None``,那么我们会默认使用 ``train_step`` 当做前向传播的函数,如果在模型的定义类中没有找到该方法,
则使用模型默认的前向传播函数,例如对于 pytorch 来说就是 ``forward``。

.. note::
在 fastNLP 中,对于训练时使用的前向传播函数的查找逻辑如下所示:

1. 如果 ``train_fn`` 为 None,那么在 model 的类 Model 中寻找方法 ``Model.train_step``;如果没有找到,那么默认使用 ``Model.forward``;
2. 如果 ``train_fn`` 为一个字符串,例如 'my_step_fn',那么我们首先会在 model 的类 Model 中寻找方法 ``Model.my_step_fn``,
如果没有找到,那么会直接报错;

:param evaluate_fn: 用来控制 ``Trainer`` 中内置的 ``Evaluator`` 在验证的前向传播过程中是调用模型的哪一个函数,应当为 ``None``
或者一个字符串;其使用方式和 train_fn 类似;具体可见 :class:`fastNLP.core.controllers.Evaluator`;
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 ``Callback`` 类;具体可见
:class:`fastNLP.core.callbacks.Callback`;
:param metrics: 用于传给 ``Trainer`` 内部的 ``Evaluator`` 实例来进行训练过程中的验证。其应当为一个字典,其中 key 表示 monitor,
例如 {"acc1": AccMetric(), "acc2": AccMetric()};

目前我们支持的 ``metric`` 的种类有以下几种:

1. fastNLP 自己的 ``metric``:详见 :class:`fastNLP.core.metrics.Metric`;
2. torchmetrics;
3. allennlp.training.metrics;
4. paddle.metric;

:param evaluate_every: 用来控制 ``Trainer`` 内部的 ``Evaluator`` 验证的频率,其可以为负数、正数或者函数:

1. 为负数时表示每隔几个 ``epoch`` evaluate 一次;
2. 为正数则表示每隔几个 ``batch`` evaluate 一次;
3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate;

.. note::

如果参数 ``evaluate_every`` 为函数,其应当类似:

>>> def my_evaluate_every(trainer) -> bool:
... if (trainer.global_forward_batches+1) % 1000 == 0:
... return True
... else:
... return False

该函数表示当每经过 1000 个 batch,``Trainer`` 中内置的 ``Evaluator`` 就会验证一次;

另一个需要注意的事情在于该函数会在每一次 batch 的结尾进行调用,当该函数返回 ``True`` 时,``Evaluator`` 才会进行验证;

:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理:

1. 如果 ``input_mapping`` 是一个字典:

1. 如果此时 batch 也是一个 ``Dict``,那么我们会把 batch 中同样在 ``input_mapping`` 中的 key 修改为 ``input_mapping`` 的对应 ``key`` 的 ``value``;
2. 如果此时 batch 是一个 ``dataclass``,那么我们会先将其转换为一个 ``Dict``,然后再进行上述转换;
3. 如果此时 batch 此时是其它类型,那么我们将会直接报错;
2. 如果 ``input_mapping`` 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里;

注意该参数会被传进 ``Evaluator`` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时);
如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``input_mapping``, 请使用 ``train_input_mapping`` 与 ``evaluate_input_mapping`` 分别进行设置。

:param output_mapping: 应当为一个字典或者函数。作用和 ``input_mapping`` 类似,区别在于其用于转换输出:

1. 如果 ``output_mapping`` 是一个 ``Dict``,那么我们需要模型的输出必须是 ``Dict`` 或者 ``dataclass`` 类型:

1. 如果此时模型的输出是一个 ``Dict``,那么我们会把输出中同样在 ``output_mapping`` 中的 key 修改为 ``output_mapping`` 的对应 key 的 value;
2. 如果此时模型的输出是一个 ``dataclass``,那么我们会先将其转换为一个 Dict,然后再进行上述转换;
2. 如果 ``output_mapping`` 是一个函数,那么我们将会直接将模型的输出传给该函数;

如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``output_mapping``, 请使用 ``train_output_mapping`` 与 ``evaluate_output_mapping`` 分别进行设置;

.. note::

``input_mapping`` 和 ``output_mapping`` 与 fastNLP 的一个特殊的概念 **'参数绑定'** 高度相关,它们的存在也是为了 fastNLP
中的参数匹配能够正确地运行;

.. todo::
之后链接上 参数匹配 的文档;

.. warning::

如果 ``Trainer`` 的参数 ``output_mapping`` 不为 ``None``,请保证其返回的一定是一个字典,并且其中含有关键字 **'loss'**;

:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为;

1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据**前向函数**所需要的参数从 batch 中提取对应的对象,然后传入到**前向函数**中;
2. 如果该值为 ``True``,那么我们会将 batch 直接透传给模型;

.. todo::
之后链接上 参数匹配 的文档;

函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`;

:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1;
:param fp16: 是否开启混合精度训练,默认为 False;
:param monitor: 对于一些特殊的 ``Callback``,例如 :class:`fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor``
来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 ``CheckpointCallback`` 而言,如果我们
想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置:

.. code-block::

trainer = Trainer(
...,
metrics={'acc': accMetric()},
callbacks=[CheckpointCallback(
...,
monitor='acc',
topk=1
)]
)

这意味着对于 ``CheckpointCallback`` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。

``Trainer`` 中的参数 ``monitor`` 的作用在于为没有设置 ``monitor`` 参数但是需要该参数的 *callback* 实例设置该值。关于 ``monitor``
参数更详细的说明,请见 :class:`fastNLP.core.callbacks.CheckpointCallback`;

注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;

:param larger_better: 对于需要参数 ``monitor`` 的 *callback* 来说,``monitor`` 的值是否是越大越好;类似于 ``monitor``,其作用
在于为没有设置 ``larger_better`` 参数但是需要该参数的 *callback* 实例设置该值;

注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;

:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None;

.. note::

marker 的使用场景主要在于如果一个脚本中含有多个 ``Trainer`` 实例,并且含有多个使用 ``Trainer.on`` 修饰的函数时,不同的函数属于
不同的 ``Trainer`` 实例;

此时,通过将修饰器 ``Trainer.on`` 的参数 ``marker`` 和 ``Trainer`` 的参数 ``marker`` 置为相同,就可以使得该函数只会在这一
``Trainer`` 实例中被调用;例如,

.. code-block::

@Trainer.on(Event.on_train_begin(), marker='trainer1')
def fn(trainer):
...

trainer = Trainer(
...,
marker='trainer1'
)

另一点需要说明的是,如果一个被 ``Trainer.on`` 修饰的函数,其修饰时没有指明 ``marker``,那么会将该函数传给代码位于其之后的
第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 None;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on`

:kwargs: :kwargs:
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数:
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
* set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; * set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
注意如果 model_device 为 None,那么 data_device 不会起作用;
* *use_dist_sampler* -- 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
* *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
* *data_device* -- 一个具体的 driver 实例中,有 ``model_device`` 和 ``data_device``,前者表示模型所在的设备,后者表示
当 ``model_device`` 为 None 时应当将数据迁移到哪个设备;

.. note::

注意您在绝大部分情况下不会用到该参数!

1. 当 driver 实例的 ``model_device`` 不为 None 时,该参数无效;
2. 对于 pytorch,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时,
driver 实例的 ``model_device`` 才会为 None;

* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;默认为 ``True``;
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";

注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``;
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Evaluator`` 中。与 output_mapping 互斥。

.. note::
``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证;
``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入:

必须的参数:1. ``metrics``;2. ``evaluate_dataloaders``;

可选的其它参数:1. ``evaluate_batch_step_fn;2. ``evaluate_fn``;3. ``evaluate_every``;4. ``input_mapping``;
5. ``output_mapping``; 6. ``model_wo_auto_param_call``;7. ``fp16``;8. ``monitor``;9. ``larger_better``;

.. warning::


如果 ``Trainer`` 中内置的 ``Evaluator`` 实例不为 ``None``,那么需要注意 ``Trainer`` 中的一些参数是与 ``Evaluator`` 一致的,它们分别为:

1. ``Evaluator`` 在初始化时的 ``driver`` 参数是 ``Trainer`` 中已经实例化过的 driver;这一点使得一些参数对于 ``Trainer`` 内部的
``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等;
2. ``input_mapping``,``output_mapping``,``model_wo_auto_param_call`` 和 ``fp16`` 是 ``Trainer`` 和其内部默认的
``Evaluator`` 是一致的;

当然,对于 ``input_mapping`` 和 ``output_mapping``,您可以通过添加 ``kwargs`` 中的参数 ``evaluate_input_mapping`` 和
``evaluate_output_mapping`` 来单独为 ``Evaluator`` 进行更细致的订制。

另一方面,注意一些专门独属于 ``Evaluator`` 的参数仅当 ``Evaluator`` 不为 None 时才会生效。


""" """
self.model = model self.model = model
@@ -174,7 +356,7 @@ class Trainer(TrainerEventTrigger):
evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None) evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None)
evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None) evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None)


train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \
train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \
_get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
evaluate_input_mapping, evaluate_output_mapping) evaluate_input_mapping, evaluate_output_mapping)


@@ -273,7 +455,7 @@ class Trainer(TrainerEventTrigger):
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
@@ -302,7 +484,7 @@ class Trainer(TrainerEventTrigger):
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True,
catch_KeyboardInterrupt=None): catch_KeyboardInterrupt=None):
"""
r"""
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint
去保存断点重训的文件; 去保存断点重训的文件;
:param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。 :param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。
@@ -491,6 +673,36 @@ class Trainer(TrainerEventTrigger):
# do something # do something
# 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。 # 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。


.. note::


例如:

.. code-block::

@Trainer.on(Event.on_train_begin())
def fn1(trainer):
...

@Trainer.on(Event.on_train_epoch_begin())
def fn2(trainer):
...

trainer1 = Trainer(
...,
marker='trainer1'
)

@Trainer.on(Event.on_fetch_data_begin())
def fn3(trainer):
...

trainer2 = Trainer(
...,
marker='trainer2'
)


注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前;


:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含 :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含


+ 5
- 3
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -238,14 +238,16 @@ class PaddleFleetDriver(PaddleDriver):
self.gloo_rendezvous_dir = None self.gloo_rendezvous_dir = None


# 分布式环境的其它参数设置 # 分布式环境的其它参数设置
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {})
paddle_kwargs = kwargs.get("paddle_kwargs", {})
self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {})
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) check_user_specific_params(self._fleet_kwargs, DataParallel.__init__)
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy())
self.is_collective = self._fleet_kwargs.get("is_collective", True)
self.is_collective = self._fleet_kwargs.pop("is_collective", True)
if not self.is_collective: if not self.is_collective:
raise NotImplementedError("FastNLP only support `collective` for distributed training now.") raise NotImplementedError("FastNLP only support `collective` for distributed training now.")
self.role_maker = self._fleet_kwargs.get("role_maker", None)
self.role_maker = self._fleet_kwargs.pop("role_maker", None)


if self.local_rank == 0 and not is_in_paddle_dist(): if self.local_rank == 0 and not is_in_paddle_dist():
# 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 # 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有


+ 7
- 22
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -22,12 +22,14 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver 2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver
3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver 3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver


:param driver: 该参数的值应为以下之一:["paddle", "fleet"];
:param driver: 使用的 ``driver`` 类型,在这个函数中仅支持 ``paddle``
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; :param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致;
:param model: 训练或者评测的具体的模型; :param model: 训练或者评测的具体的模型;


:return: 返回构造的 `Driver` 实例。 :return: 返回构造的 `Driver` 实例。
""" """
if driver != "paddle":
raise ValueError("When initialize PaddleDriver, parameter `driver` must be 'paddle'.")
if is_in_paddle_launch_dist(): if is_in_paddle_launch_dist():
if device is not None: if device is not None:
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull "
@@ -37,9 +39,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
# TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int # TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int
return PaddleFleetDriver(model, device[0], True, **kwargs) return PaddleFleetDriver(model, device[0], True, **kwargs)


if driver not in {"paddle", "fleet"}:
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].")

user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES")
if user_visible_devices is None: if user_visible_devices is None:
raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set " raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set "
@@ -64,22 +63,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
" the available gpu number.") " the available gpu number.")
elif device is not None and not isinstance(device, str): elif device is not None and not isinstance(device, str):
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
if isinstance(device, List):
return PaddleFleetDriver(model, device, **kwargs)
else:
return PaddleSingleDriver(model, device, **kwargs)


if driver == "paddle":
if not isinstance(device, List):
return PaddleSingleDriver(model, device, **kwargs)
else:
logger.rank_zero_warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use"
"`Fleetriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter"
"`driver` as `PaddleFleetDriver`.")
return PaddleFleetDriver(model, device, **kwargs)
elif driver == "fleet":
if not isinstance(device, List):
if device == "cpu":
raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.")
logger.rank_zero_warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will"
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should "
"choose `paddle` driver.")
return PaddleFleetDriver(model, [device], **kwargs)
else:
return PaddleFleetDriver(model, device, **kwargs)

+ 3
- 1
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -1,4 +1,5 @@
import os import os
import contextlib
from typing import Optional, Dict, Union, Callable, Tuple from typing import Optional, Dict, Union, Callable, Tuple


from .paddle_driver import PaddleDriver from .paddle_driver import PaddleDriver
@@ -70,7 +71,8 @@ class PaddleSingleDriver(PaddleDriver):
""" """
device = get_device_from_visible(self.model_device, output_type=str) device = get_device_from_visible(self.model_device, output_type=str)
paddle.device.set_device(device) paddle.device.set_device(device)
self.model.to(device)
with contextlib.redirect_stdout(None):
self.model.to(device)


def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call: if isinstance(batch, Dict) and not self.wo_auto_param_call:


+ 1
- 1
tests/core/controllers/_test_trainer_fleet.py View File

@@ -76,7 +76,7 @@ def test_trainer_fleet(
trainer.run() trainer.run()


if __name__ == "__main__": if __name__ == "__main__":
driver = "fleet"
driver = "paddle"
device = [0,2,3] device = [0,2,3]
# driver = "paddle" # driver = "paddle"
# device = 2 # device = 2


+ 1
- 1
tests/core/controllers/_test_trainer_fleet_outside.py View File

@@ -83,7 +83,7 @@ def test_trainer_fleet(
trainer.run() trainer.run()


if __name__ == "__main__": if __name__ == "__main__":
driver = "fleet"
driver = "paddle"
device = [0,2,3] device = [0,2,3]
callbacks = [ callbacks = [
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),


+ 2
- 3
tests/core/controllers/test_trainer_paddle.py View File

@@ -24,13 +24,12 @@ class TrainPaddleConfig:
shuffle: bool = True shuffle: bool = True
evaluate_every = 2 evaluate_every = 2


@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])])
@pytest.mark.parametrize("device", ["cpu", 1, [0, 1]])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) @pytest.mark.parametrize("callbacks", [[RichCallback(5)]])
@pytest.mark.paddledist @pytest.mark.paddledist
@magic_argv_env_context @magic_argv_env_context
def test_trainer_paddle( def test_trainer_paddle(
driver,
device, device,
callbacks, callbacks,
n_epochs=2, n_epochs=2,
@@ -56,7 +55,7 @@ def test_trainer_paddle(
metrics = {"acc": Accuracy(backend="paddle")} metrics = {"acc": Accuracy(backend="paddle")}
trainer = Trainer( trainer = Trainer(
model=model, model=model,
driver=driver,
driver="paddle",
device=device, device=device,
optimizers=optimizers, optimizers=optimizers,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,


+ 6
- 52
tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py View File

@@ -21,87 +21,41 @@ def test_incorrect_driver():
"device", "device",
["cpu", "gpu:0", 0] ["cpu", "gpu:0", 0]
) )
@pytest.mark.parametrize(
"driver",
["paddle"]
)
def test_get_single_device(driver, device):
def test_get_single_device(device):
""" """
测试正常情况下初始化 PaddleSingleDriver 的情况 测试正常情况下初始化 PaddleSingleDriver 的情况
""" """


model = PaddleNormalModel_Classification_1(2, 100) model = PaddleNormalModel_Classification_1(2, 100)
driver = initialize_paddle_driver(driver, device, model)
driver = initialize_paddle_driver("paddle", device, model)
assert isinstance(driver, PaddleSingleDriver) assert isinstance(driver, PaddleSingleDriver)


@pytest.mark.paddle
@pytest.mark.parametrize(
"device",
[0, 1, [1]]
)
@pytest.mark.parametrize(
"driver",
["fleet"]
)
@magic_argv_env_context
def test_get_fleet_2(driver, device):
"""
测试 fleet 多卡的初始化情况,但传入了单个 gpu
"""

model = PaddleNormalModel_Classification_1(64, 10)
driver = initialize_paddle_driver(driver, device, model)

assert isinstance(driver, PaddleFleetDriver)

@pytest.mark.paddle @pytest.mark.paddle
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
[[0, 2, 3], -1] [[0, 2, 3], -1]
) )
@pytest.mark.parametrize(
"driver",
["paddle", "fleet"]
)
@magic_argv_env_context @magic_argv_env_context
def test_get_fleet(driver, device):
def test_get_fleet(device):
""" """
测试 fleet 多卡的初始化情况 测试 fleet 多卡的初始化情况
""" """


model = PaddleNormalModel_Classification_1(64, 10) model = PaddleNormalModel_Classification_1(64, 10)
driver = initialize_paddle_driver(driver, device, model)
driver = initialize_paddle_driver("paddle", device, model)


assert isinstance(driver, PaddleFleetDriver) assert isinstance(driver, PaddleFleetDriver)


@pytest.mark.paddle
@pytest.mark.parametrize(
("driver", "device"),
[("fleet", "cpu")]
)
@magic_argv_env_context
def test_get_fleet_cpu(driver, device):
"""
测试试图在 cpu 上初始化分布式训练的情况
"""
model = PaddleNormalModel_Classification_1(64, 10)
with pytest.raises(ValueError):
driver = initialize_paddle_driver(driver, device, model)

@pytest.mark.paddle @pytest.mark.paddle
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] [-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1]
) )
@pytest.mark.parametrize(
"driver",
["paddle", "fleet"]
)
@magic_argv_env_context @magic_argv_env_context
def test_device_out_of_range(driver, device):
def test_device_out_of_range(device):
""" """
测试传入的device超过范围的情况 测试传入的device超过范围的情况
""" """
model = PaddleNormalModel_Classification_1(2, 100) model = PaddleNormalModel_Classification_1(2, 100)
with pytest.raises(ValueError): with pytest.raises(ValueError):
driver = initialize_paddle_driver(driver, device, model)
driver = initialize_paddle_driver("paddle", device, model)

Loading…
Cancel
Save