@@ -19,6 +19,7 @@ from fastNLP.core.utils import ( | |||||
check_user_specific_params, | check_user_specific_params, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
is_in_paddle_dist, | is_in_paddle_dist, | ||||
rank_zero_rm | |||||
) | ) | ||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
RandomBatchSampler, | RandomBatchSampler, | ||||
@@ -55,20 +56,134 @@ class PaddleFleetDriver(PaddleDriver): | |||||
fp16: bool = False, | fp16: bool = False, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
""" | |||||
采用fleet接口进行并行paddle训练的driver | |||||
PaddleFleetDriver 目前考虑支持的三种启动方式: | |||||
1. 用户自己不进行 fleet 的任何操作,直接使用我们的 Trainer,并且只运行一个 main 脚本,这时是由我们自己使用 open_subprocesses 拉起 | |||||
多个进程,然后由 Driver 自己进行初始化 | |||||
2. 其它情况同 1,但是用户自己使用 python -m paddle.distributed.launch 拉起; | |||||
3. 用户自己在外面初始化 Fleet,并且通过 python -m paddle.distributed.launch 拉起; | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动; | |||||
如果用户自己在外面初始化了 fleet,那么 | |||||
parallel_device 为 None; | |||||
data_device 为 表示单卡的一个参数; | |||||
dist.is_initialized 为 true; | |||||
r""" | |||||
通过使用 PaddlePaddle 的 Fleet 框架启动多卡进程的 Driver。 | |||||
需要注意的一点是,由于 PaddlePaddle 框架的特性,如果直接使用在 rank0 拉起其它进程的方法的话,如果不加以任何限制,PaddlePaddle会出现 | |||||
第一次前向传播后卡住或占用所有显卡的现象;为了解决这一问题,我们在引入 FastNLP 时,会使用 `CUDA_VISIBLE_DEVICES` 将设备限制在卡0上, | |||||
而用户如果使用了这一环境变量,我们会将其储存在 `USER_CUDA_VISIBLE_DEVICES` 中,并且通过一定的手段实现了转换(详细的设置请参见: | |||||
`fastNLP/envs/set_backend.py`)。在拉起其它进程的时候,我们会如法炮制,将环境限制在对应的设备上。 | |||||
`PaddleFleetDriver` 目前支持的三种启动方式: | |||||
1. 用户自己不进行分布式的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `FleetLauncher` 拉起多个进程, | |||||
然后 `PaddleFleetDriver` 自己通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 A) | |||||
2. 用户同样不在 Trainer 之外初始化分布式训练,但是用户自己使用 python -m paddle.distributed.launch 拉起来创建多个进程,这时我们仍旧 | |||||
会通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 B) | |||||
3. 用户自己在外面初始化分布式,并且通过 python -m paddle.distributed.launch 拉起,这时无论是多个进程的拉起和通信组的建立 | |||||
都由用户自己操作,我们只会在 driver.setup 的时候对 `PaddleFleetDriver` 设置一些必要的属性值;(情况 C) | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动;因此我们不会在 `PaddleFleetDriver` 中保存 | |||||
任何当前有多少台机器的信息; | |||||
Part 1:三种启动方式的具体分析: | |||||
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, | |||||
`PaddleFleetDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: | |||||
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DataParallel` 包裹的model), | |||||
因为 `Parallel` 的使用一定要求 fleet.init 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 | |||||
用户需要使用 2 张以上的显卡,那么其必然需要使用 paddle.distributed.launch 来启动,意味着就不是情况 A 了; | |||||
这时我们首先会调用 `FleetLauncher.launch` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu | |||||
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); | |||||
接着我们会调用 `fleet.init` 来初始化各个进程之间的通信组; | |||||
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 | |||||
才会去真正地运行 `FleetLauncher.launch`;进程 0 运行到 `fleet.init`,paddle 会阻塞进程 0 继续 | |||||
向前运行,直到其它进程也运行到这里; | |||||
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DataParallel` 将模型包裹; | |||||
至此,paddle 分布式的环境配置过程全部完成; | |||||
-> 情况 B:注意这种情况我们直接限定了用户是通过 paddle.distributed.launch 拉起,并且没有自己建立分布式的通信组。这时在 | |||||
`PaddleFleetDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, | |||||
这时每个进程所使用的 gpu 是我们直接通过 `CUDA_VISIBLE_DEVICE` 来配置的;因此,如果用户想要实现使用特定 gpu | |||||
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现,我们会通过一定的手段将其保存起来); | |||||
剩下的操作和情况 A 类似; | |||||
-> 情况 C:注意这种情况我们限定了用户是通过 paddle.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 | |||||
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DataParallel` 包裹等。 | |||||
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: | |||||
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `PaddleFleetDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 | |||||
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 | |||||
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 | |||||
启动方式来实现这一点的: | |||||
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 | |||||
使用 '情况 A' 来启动 `PaddleFleetDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 | |||||
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `PaddleFleetDriver` 的初始化和 setup 过程中, | |||||
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 | |||||
Part 2:对应的代码细节: | |||||
1. 如何判断当前的各进程之间的通信组已经被建立(fleet 已经被初始化); | |||||
parallel_helper._is_parallel_ctx_initialized(); | |||||
2. 如何判断不同的进程是否是由 `python -m paddle.distributed.launch` 拉起还是由我们的 `FleetLauncher.launch()` | |||||
函数拉起; | |||||
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'PADDLE_RANK_IN_NODE'、'PADDLE_TRAINER_ID' | |||||
以及没有 `FASTNLP_DISTRIBUTED_CHECK`, | |||||
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m paddle.distributed.launch` | |||||
来拉起多个进程; | |||||
3. 整体的处理判断流程: | |||||
___________________________________ | |||||
|进入 PaddleFleetDriver 的 __init__ 函数| | |||||
——————————————————————————————————— | |||||
↓ | |||||
___________________________________________________ | |||||
| 判断不同的进程是否是由 paddle.distributed.launch 拉起 | | |||||
|(或者我们自己的 FleetLauncher 函数拉起) | --------------> | |||||
——————————————————————————————————————————————————— | | |||||
↓ 是由 paddle.distributed.launch 拉起 | 我们自己的 FleetLauncher 函数拉起多个进程 | |||||
_____________________________ | | |||||
←←←←← | 检测用户是否自己初始化了 fleet | | | |||||
↓ ————————————————————————————— ↓ | |||||
↓ ↓ 是 ________ | |||||
↓ ______ | 情况 A | | |||||
↓ 否 |情况 C| ————————— | |||||
↓ ——————— | |||||
↓ | |||||
↓ ______ | |||||
↓ -----------> |情况 B| | |||||
——————— | |||||
4. 为了完成全部的建立分布式所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
配置 fleet 所 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch | |||||
需要的环境变量 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
开启多个进程 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
调用 fleet.init函数 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | 用户自己调用 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
设置 PaddleFleetDriver | | | | |||||
的 world_size 和 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | PaddleFleetDriver.setup | |||||
global_rank 属性 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
Part 3:其它的处理细节: | |||||
1. 环境变量; | |||||
fastNLP 的 `PaddleFleetDriver` 运行时所需要的环境变量分为两种,一种是 paddle fleet 运行所需要的环境变量;另一种是 fastNLP 自己 | |||||
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; | |||||
2. parallel_device, model_device 和 data_device 的关系; | |||||
parallel_device 为 `PaddleFleetDriver` 的参数,model_device 和 data_device 都为 driver 的属性; | |||||
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; | |||||
model_device 永远都为单独的一个 torch.device; | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
parallel_device | 由用户传入trainer的参数 | | | |||||
| device 决定,必须是一个list, | 为 CUDA_VISIBLE_DEVICES | 为 CUDA_VISIBLE_DEVICES | |||||
| 其中每一个对象都是 int | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
model_device | parallel_device[local_rank] | parallel_device | None | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
data_device | model_device | model_device | 由用户传入 trainer 的参数 | |||||
| | | data_device 决定 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
3. _DDPWrappingModel 的作用; | |||||
因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DataParallel` 的forward 函数来帮助 | |||||
我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DataParallel` 的 forward 方法, | |||||
然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 forward 函数,还是 | |||||
`train_step`、`evaluate_step`、`test_step` 方法。 | |||||
4. 当某一个进程出现 exception 后,`PaddleFleetDriver` 的处理; | |||||
不管是什么情况,`PaddleFleetDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||||
""" | """ | ||||
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) | super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
@@ -78,6 +193,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | "when your value of parameter `device` is `None` in your `Trainer` instance.") | ||||
# 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 | # 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 | ||||
# 这个参数会在 initialize_paddle_drvier 中设置。 | |||||
self.is_pull_by_paddle_run = is_pull_by_paddle_run | self.is_pull_by_paddle_run = is_pull_by_paddle_run | ||||
self.parallel_device = parallel_device | self.parallel_device = parallel_device | ||||
# 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu | # 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu | ||||
@@ -98,7 +214,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.outside_fleet = True | self.outside_fleet = True | ||||
# 用户只有将模型上传到对应机器上后才能用 DataParallel 包裹,因此如果用户在外面初始化了 Fleet,那么在 PaddleFleetDriver 中 | # 用户只有将模型上传到对应机器上后才能用 DataParallel 包裹,因此如果用户在外面初始化了 Fleet,那么在 PaddleFleetDriver 中 | ||||
# 我们就直接将 model_device 置为 None; | |||||
# 我们就直接将 model_device 置为 None; | |||||
self._model_device = None | self._model_device = None | ||||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | ||||
@@ -119,9 +235,12 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.world_size = None | self.world_size = None | ||||
self.global_rank = 0 | self.global_rank = 0 | ||||
self.gloo_rendezvous_dir = None | |||||
# 分布式环境的其它参数设置 | |||||
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | ||||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | ||||
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | |||||
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | ||||
self.is_collective = self._fleet_kwargs.get("is_collective", True) | self.is_collective = self._fleet_kwargs.get("is_collective", True) | ||||
if not self.is_collective: | if not self.is_collective: | ||||
@@ -145,7 +264,10 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def setup(self): | def setup(self): | ||||
""" | """ | ||||
在主进程拉起其它子进程,将主进程作为rank 0 | |||||
根据不同的情况进行不同的设置。 | |||||
1、如果是通过 paddle.distributed.launch 方法启动时,则根据已经设置好的环境获取 | |||||
分布式的属性。 | |||||
2、否则,调用 FleetLauncher 类启动子进程 | |||||
""" | """ | ||||
if self._has_setup: | if self._has_setup: | ||||
return | return | ||||
@@ -174,7 +296,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False | # 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False | ||||
# parallel_device 是 list, | # parallel_device 是 list, | ||||
if not parallel_helper._is_parallel_ctx_initialized(): | if not parallel_helper._is_parallel_ctx_initialized(): | ||||
# 没有初始化分布式环境,且是主进程 | |||||
# 拉起子进程并设置相应的属性 | |||||
self.init_fleet_and_set() | self.init_fleet_and_set() | ||||
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; | # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; | ||||
else: | else: | ||||
@@ -216,12 +338,13 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 是 rank0 的话,则拉起其它子进程 | # 是 rank0 的话,则拉起其它子进程 | ||||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | ||||
launcher.launch() | launcher.launch() | ||||
self.gloo_rendezvous_dir = launcher.gloo_rendezvous_dir | |||||
# 设置参数和初始化分布式环境 | # 设置参数和初始化分布式环境 | ||||
fleet.init(self.role_maker, self.is_collective, self.strategy) | fleet.init(self.role_maker, self.is_collective, self.strategy) | ||||
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) | self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) | ||||
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) | self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) | ||||
# 正常情况下不会Assert出问题,但还是保险一下 | |||||
# 正常情况下不会 Assert 出问题,但还是保险一下 | |||||
assert self.global_rank is not None | assert self.global_rank is not None | ||||
assert self.world_size is not None | assert self.world_size is not None | ||||
assert self.world_size == len(self.parallel_device) | assert self.world_size == len(self.parallel_device) | ||||
@@ -235,10 +358,19 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.global_rank = paddledist.get_rank() | self.global_rank = paddledist.get_rank() | ||||
def barrier(self): | def barrier(self): | ||||
r""" | |||||
用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; | |||||
仅在多分布式训练场景中有使用。 | |||||
注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 | |||||
""" | |||||
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 | ||||
paddledist.barrier() | paddledist.barrier() | ||||
def configure_fleet(self): | def configure_fleet(self): | ||||
""" | |||||
将模型用 DataParallel 和自定义的类型包裹起来 | |||||
""" | |||||
if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): | if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): | ||||
self.model = DataParallel( | self.model = DataParallel( | ||||
_FleetWrappingModel(self.model), | _FleetWrappingModel(self.model), | ||||
@@ -247,8 +379,14 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self._has_fleetwrapped = True | self._has_fleetwrapped = True | ||||
def on_exception(self): | def on_exception(self): | ||||
if os.path.exists(self.gloo_rendezvous_dir): | |||||
shutil.rmtree(self.gloo_rendezvous_dir) | |||||
""" | |||||
该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 | |||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | |||||
因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 | |||||
pid 的信息; | |||||
""" | |||||
rank_zero_rm(self.gloo_rendezvous_dir) | |||||
super().on_exception() | super().on_exception() | ||||
@property | @property | ||||
@@ -282,6 +420,17 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return self.model_device | return self.model_device | ||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
""" | |||||
通过调用 `fn` 来实现训练时的前向传播过程; | |||||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||||
函数; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:param fn: 调用该函数进行一次计算。 | |||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | |||||
if self._has_fleetwrapped: | if self._has_fleetwrapped: | ||||
return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, | return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, | ||||
wo_auto_param_call=self.wo_auto_param_call) | wo_auto_param_call=self.wo_auto_param_call) | ||||
@@ -292,6 +441,27 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return fn(batch) | return fn(batch) | ||||
def get_model_call_fn(self, fn: str) -> Tuple: | def get_model_call_fn(self, fn: str) -> Tuple: | ||||
""" | |||||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||||
函数,然后给出 warning; | |||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||||
可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||||
""" | |||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if self._has_fleetwrapped: | if self._has_fleetwrapped: | ||||
if hasattr(model, fn): | if hasattr(model, fn): | ||||
@@ -316,7 +486,25 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return self.model, model.forward | return self.model, model.forward | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | ||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
reproducible: bool = False): | |||||
r""" | |||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||||
可以可以加载。 | |||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||||
""" | |||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
@@ -429,10 +617,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
@staticmethod | @staticmethod | ||||
def _check_optimizer_legality(optimizers): | def _check_optimizer_legality(optimizers): | ||||
""" | |||||
paddle存在设置分布式optimizers的函数,返回值为fleet.meta_optimizers.HybridParallelOptimizer | |||||
重写是为了防止单卡下也传入了分布式的优化器 | |||||
""" | |||||
# paddle 存在设置分布式 optimizers 的函数,返回值为 fleet.meta_optimizers.HybridParallelOptimizer | |||||
DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer | DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer | ||||
for each_optimizer in optimizers: | for each_optimizer in optimizers: | ||||
if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): | if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): | ||||
@@ -20,7 +20,7 @@ from .utils import ( | |||||
# 记录各个进程信息 | # 记录各个进程信息 | ||||
class SubTrainer(object): | class SubTrainer(object): | ||||
""" | """ | ||||
和fastnlp的Triainer没有关系,仅用于统计节点内不同训练的一些信息 | |||||
用于统计节点内不同训练进程的信息,和 fastnlp 的 Triainer 没有关系 | |||||
""" | """ | ||||
def __init__(self, endpoint=None, rank=None): | def __init__(self, endpoint=None, rank=None): | ||||
self.devices = [] | self.devices = [] | ||||
@@ -30,8 +30,8 @@ class SubTrainer(object): | |||||
class FleetLauncher: | class FleetLauncher: | ||||
""" | """ | ||||
复原了 paddle 的 launch_collective 函数,将其集成到一个类里 | |||||
仅支持单机多卡的启动 | |||||
复原了 paddle 的 launch_collective 函数,将其简化后集成到一个类里 | |||||
仅支持每个机器单卡的情况。 | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
@@ -45,17 +45,26 @@ class FleetLauncher: | |||||
self.setup() | self.setup() | ||||
def setup(self): | def setup(self): | ||||
""" | |||||
进行初始化设置的函数,根据传入的设备找到分布式训练使用的端口号 | |||||
""" | |||||
self.set_endpoints() | self.set_endpoints() | ||||
self.sub_trainers = self.get_process_info() | self.sub_trainers = self.get_process_info() | ||||
def launch(self) -> int: | |||||
def launch(self): | |||||
""" | |||||
用于启动分布式进程。 | |||||
首先设置 PaddlePaddle 分布式训练需要设置的环境变量,然后建立新的子进程 | |||||
""" | |||||
# 设置环境变量 | # 设置环境变量 | ||||
self.global_envs = self.get_global_env() | self.global_envs = self.get_global_env() | ||||
self.open_subprocess() | self.open_subprocess() | ||||
reset_seed() | reset_seed() | ||||
def open_subprocess(self): | def open_subprocess(self): | ||||
""" | |||||
从 sub_trainers 中获取各个 rank 的信息,并且使用 subprocess.Popen 建立新的子进程。 | |||||
""" | |||||
if __main__.__spec__ is None: | if __main__.__spec__ is None: | ||||
# Script called as `python a/b/c.py` | # Script called as `python a/b/c.py` | ||||
@@ -77,6 +86,7 @@ class FleetLauncher: | |||||
current_env = copy.copy(self.global_envs) | current_env = copy.copy(self.global_envs) | ||||
for idx, t in enumerate(self.sub_trainers): | for idx, t in enumerate(self.sub_trainers): | ||||
# 根据不同的 rank 设置环境变量 | |||||
proc_env = { | proc_env = { | ||||
# global_rank | # global_rank | ||||
"PADDLE_TRAINER_ID": f"{t.rank}", | "PADDLE_TRAINER_ID": f"{t.rank}", | ||||
@@ -108,6 +118,14 @@ class FleetLauncher: | |||||
os.environ.update(current_env) | os.environ.update(current_env) | ||||
def get_global_env(self): | def get_global_env(self): | ||||
""" | |||||
设置分布式训练需要的全局变量,包括: | |||||
1、GLOO 相关的设置 | |||||
2、`PADDLE_TRAINERS_NUM` :所有的进程数目 | |||||
3、`PADDLE_TRAINER_ENDPOINTS` :使用的所有地址及其端口 | |||||
4、`PADDLE_WORLD_DEVICE_IDS` :使用的所有设备 | |||||
5、FASTNLP_DISTRIBUTED_CHECK:通过 fastNLP 建立子进程的标志,保存分布式训练使用的设备 | |||||
""" | |||||
global_envs = copy.copy(os.environ.copy()) | global_envs = copy.copy(os.environ.copy()) | ||||
self.gloo_rendezvous_dir = tempfile.mkdtemp() | self.gloo_rendezvous_dir = tempfile.mkdtemp() | ||||
@@ -137,7 +155,7 @@ class FleetLauncher: | |||||
def set_endpoints(self): | def set_endpoints(self): | ||||
""" | """ | ||||
Reference to `get_cluster_from_args` | |||||
寻找用户设置的端口或是空闲端口用于分布式训练,参考了 PaddlePaddle 中的 `get_cluster_from_args` 函数 | |||||
""" | """ | ||||
self.node_ip = "127.0.0.1" | self.node_ip = "127.0.0.1" | ||||
@@ -157,7 +175,7 @@ class FleetLauncher: | |||||
def get_process_info(self): | def get_process_info(self): | ||||
""" | """ | ||||
Reference to `get_cluster` | |||||
获取各个训练进程的设备、rank 和端口信息,参考 PaddlePaddle 的 `get_cluster` 函数。 | |||||
""" | """ | ||||
sub_trainers = [] | sub_trainers = [] | ||||
assert len(self.endpoints) >= len( | assert len(self.endpoints) >= len( | ||||
@@ -17,14 +17,16 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
model: paddle.nn.Layer, **kwargs) -> PaddleDriver: | model: paddle.nn.Layer, **kwargs) -> PaddleDriver: | ||||
r""" | r""" | ||||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | ||||
注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; | |||||
1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将 | |||||
设备自动设置为用户指定的设备(由于我们在引入 fastNLP 进行了特殊的设置,因此可以通过 `CUDA_VISIBLE_DEVICES` 获取) | |||||
2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver | |||||
3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver | |||||
:param driver: 该参数的值应为以下之一:["paddle", "fleet"]; | :param driver: 该参数的值应为以下之一:["paddle", "fleet"]; | ||||
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | :param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | ||||
:param model: 训练或者评测的具体的模型; | :param model: 训练或者评测的具体的模型; | ||||
:return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | |||||
先后 driver 的次序的正确问题); | |||||
:return: 返回构造的 `Driver` 实例。 | |||||
""" | """ | ||||
if is_in_paddle_launch_dist(): | if is_in_paddle_launch_dist(): | ||||
if device is not None: | if device is not None: | ||||
@@ -47,9 +49,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
if device >= _could_use_device_num: | if device >= _could_use_device_num: | ||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | raise ValueError("The gpu device that parameter `device` specifies is not existed.") | ||||
if device != -1: | |||||
device = f"gpu:{device}" | |||||
else: | |||||
if device == -1: | |||||
device = list(range(_could_use_device_num)) | device = list(range(_could_use_device_num)) | ||||
elif isinstance(device, Sequence) and not isinstance(device, str): | elif isinstance(device, Sequence) and not isinstance(device, str): | ||||
device = list(set(device)) | device = list(set(device)) | ||||
@@ -61,9 +61,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
elif each >= _could_use_device_num: | elif each >= _could_use_device_num: | ||||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" | ||||
" the available gpu number.") | " the available gpu number.") | ||||
if len(device) == 1: | |||||
# 传入了 [1] 这样的,视为单卡。 | |||||
device = device[0] | |||||
elif device is not None and not isinstance(device, str): | elif device is not None and not isinstance(device, str): | ||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | ||||
@@ -82,6 +79,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | ||||
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | ||||
"choose `paddle` driver.") | "choose `paddle` driver.") | ||||
return PaddleFleetDriver(model, device, **kwargs) | |||||
return PaddleFleetDriver(model, [device], **kwargs) | |||||
else: | else: | ||||
return PaddleFleetDriver(model, device, **kwargs) | return PaddleFleetDriver(model, device, **kwargs) |
@@ -19,7 +19,12 @@ from fastNLP.envs import ( | |||||
rank_zero_call, | rank_zero_call, | ||||
) | ) | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
from fastNLP.core.samplers import ( | |||||
ReproducibleBatchSampler, | |||||
ReproducibleSampler, | |||||
RandomBatchSampler, | |||||
RandomSampler, | |||||
) | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -29,7 +34,7 @@ if _NEED_IMPORT_PADDLE: | |||||
Dataset, | Dataset, | ||||
Sampler, | Sampler, | ||||
BatchSampler, | BatchSampler, | ||||
RandomSampler, | |||||
RandomSampler as PaddleRandomSampler, | |||||
) | ) | ||||
from paddle.optimizer import Optimizer | from paddle.optimizer import Optimizer | ||||
@@ -333,6 +338,9 @@ class PaddleDriver(Driver): | |||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | elif isinstance(dataloader_args.sampler, ReproducibleSampler): | ||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
elif isinstance(dataloader_args.sampler, PaddleRandomSampler): | |||||
sampler = RandomSampler(dataloader_args.sampler.data_source) | |||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
"`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
@@ -464,7 +472,7 @@ class PaddleDriver(Driver): | |||||
res.sampler = dataloader.batch_sampler.sampler | res.sampler = dataloader.batch_sampler.sampler | ||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | ||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | res.shuffle = dataloader.batch_sampler.sampler.shuffle | ||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
elif isinstance(dataloader.batch_sampler.sampler, PaddleRandomSampler): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
@@ -474,7 +482,7 @@ class PaddleDriver(Driver): | |||||
res.sampler = batch_sampler.sampler | res.sampler = batch_sampler.sampler | ||||
if hasattr(batch_sampler.sampler, "shuffle"): | if hasattr(batch_sampler.sampler, "shuffle"): | ||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | res.shuffle = dataloader.batch_sampler.sampler.shuffle | ||||
elif isinstance(batch_sampler.sampler, RandomSampler): | |||||
elif isinstance(batch_sampler.sampler, PaddleRandomSampler): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
@@ -31,6 +31,9 @@ __all__ = [ | |||||
] | ] | ||||
class PaddleSingleDriver(PaddleDriver): | class PaddleSingleDriver(PaddleDriver): | ||||
""" | |||||
支持 paddle cpu 或单卡 gpu 训练的 driver | |||||
""" | |||||
def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): | def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): | ||||
if isinstance(model, DataParallel): | if isinstance(model, DataParallel): | ||||
raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | ||||
@@ -59,18 +62,53 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self.world_size = 1 | self.world_size = 1 | ||||
def setup(self): | def setup(self): | ||||
r""" | |||||
该函数用来初始化训练环境,用于设置当前训练的设备,并将模型迁移到对应设备上。 | |||||
""" | |||||
device = self.model_device | device = self.model_device | ||||
device = get_device_from_visible(device, output_type=str) | device = get_device_from_visible(device, output_type=str) | ||||
paddle.device.set_device(device) | paddle.device.set_device(device) | ||||
self.model.to(device) | self.model.to(device) | ||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
""" | |||||
通过调用 `fn` 来实现训练时的前向传播过程; | |||||
注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||||
函数; | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||||
:param fn: 调用该函数进行一次计算。 | |||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||||
""" | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | if isinstance(batch, Dict) and not self.wo_auto_param_call: | ||||
return auto_param_call(fn, batch, signature_fn=signature_fn) | return auto_param_call(fn, batch, signature_fn=signature_fn) | ||||
else: | else: | ||||
return fn(batch) | return fn(batch) | ||||
def get_model_call_fn(self, fn: str) -> Tuple: | def get_model_call_fn(self, fn: str) -> Tuple: | ||||
""" | |||||
该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||||
该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||||
之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||||
这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||||
`evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||||
`evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||||
函数,然后给出 warning; | |||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||||
可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||||
:param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||||
:return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||||
""" | |||||
if hasattr(self.model, fn): | if hasattr(self.model, fn): | ||||
fn = getattr(self.model, fn) | fn = getattr(self.model, fn) | ||||
if not callable(fn): | if not callable(fn): | ||||
@@ -95,6 +133,24 @@ class PaddleSingleDriver(PaddleDriver): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | ||||
reproducible: bool = False): | reproducible: bool = False): | ||||
r""" | |||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||||
可以可以加载。 | |||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||||
""" | |||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
@@ -69,7 +69,6 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||||
os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" | os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" | ||||
return seed | return seed | ||||
def reset_seed() -> None: | def reset_seed() -> None: | ||||
""" | """ | ||||
fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | ||||
@@ -80,16 +79,10 @@ def reset_seed() -> None: | |||||
if seed is not None: | if seed is not None: | ||||
paddle_seed_everything(int(seed), workers=bool(int(workers))) | paddle_seed_everything(int(seed), workers=bool(int(workers))) | ||||
class ForwardState(IntEnum): | |||||
TRAIN = 0 | |||||
VALIDATE = 1 | |||||
TEST = 2 | |||||
PREDICT = 3 | |||||
class _FleetWrappingModel(Layer): | class _FleetWrappingModel(Layer): | ||||
""" | """ | ||||
参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和 | |||||
pytorch相似的处理方式 | |||||
参考 _DDPWrappingModel , paddle 的分布式训练也需要用 paddle.nn.DataParallel 进行包装,采用和 | |||||
pytorch 相似的处理方式 | |||||
""" | """ | ||||
def __init__(self, model: 'nn.Layer'): | def __init__(self, model: 'nn.Layer'): | ||||
super(_FleetWrappingModel, self).__init__() | super(_FleetWrappingModel, self).__init__() | ||||
@@ -109,7 +102,6 @@ class _FleetWrappingModel(Layer): | |||||
class DummyGradScaler: | class DummyGradScaler: | ||||
""" | """ | ||||
用于仿造的GradScaler对象,防止重复写大量的if判断 | 用于仿造的GradScaler对象,防止重复写大量的if判断 | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
pass | pass | ||||
@@ -152,6 +144,9 @@ def _build_fp16_env(dummy=False): | |||||
return auto_cast, GradScaler | return auto_cast, GradScaler | ||||
def find_free_ports(num): | def find_free_ports(num): | ||||
""" | |||||
在空闲的端口中找到 num 个端口 | |||||
""" | |||||
def __free_port(): | def __free_port(): | ||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | ||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, | s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, | ||||
@@ -178,18 +173,11 @@ def find_free_ports(num): | |||||
return None | return None | ||||
def get_host_name_ip(): | |||||
try: | |||||
host_name = socket.gethostname() | |||||
host_ip = socket.gethostbyname(host_name) | |||||
return host_name, host_ip | |||||
except: | |||||
return None | |||||
def get_device_from_visible(device: Union[str, int], output_type=int): | def get_device_from_visible(device: Union[str, int], output_type=int): | ||||
""" | """ | ||||
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | ||||
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | ||||
:param device: 未转化的设备名 | :param device: 未转化的设备名 | ||||
:param output_type: 返回值的类型 | :param output_type: 返回值的类型 | ||||
:return: 转化后的设备id | :return: 转化后的设备id | ||||
@@ -76,7 +76,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||||
logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " | logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " | ||||
"still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " | "still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " | ||||
"choose `torch` driver.") | "choose `torch` driver.") | ||||
return TorchDDPDriver(model, device, **kwargs) | |||||
return TorchDDPDriver(model, [device], **kwargs) | |||||
else: | else: | ||||
return TorchDDPDriver(model, device, **kwargs) | return TorchDDPDriver(model, device, **kwargs) | ||||
elif driver == "fairscale": | elif driver == "fairscale": | ||||
@@ -218,6 +218,8 @@ class TorchDriver(Driver): | |||||
# 2. 保存模型的状态; | # 2. 保存模型的状态; | ||||
if should_save_model: | if should_save_model: | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if not os.path.exists(folder): | |||||
os.mkdir(folder) | |||||
if only_state_dict: | if only_state_dict: | ||||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | ||||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | # 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | ||||
@@ -401,7 +403,17 @@ class TorchDriver(Driver): | |||||
res.sampler = dataloader.batch_sampler.sampler | res.sampler = dataloader.batch_sampler.sampler | ||||
if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | if hasattr(dataloader.batch_sampler.sampler, "shuffle"): | ||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | res.shuffle = dataloader.batch_sampler.sampler.shuffle | ||||
elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): | |||||
elif isinstance(dataloader.batch_sampler.sampler, TorchRandomSampler): | |||||
res.shuffle = True | |||||
else: | |||||
res.shuffle = False | |||||
# RandomBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | |||||
batch_sampler = dataloader.batch_sampler.batch_sampler | |||||
res.sampler = batch_sampler.sampler | |||||
if hasattr(batch_sampler.sampler, "shuffle"): | |||||
res.shuffle = dataloader.batch_sampler.sampler.shuffle | |||||
elif isinstance(batch_sampler.sampler, TorchRandomSampler): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
@@ -416,7 +416,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
if self.drop_last: | if self.drop_last: | ||||
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size | |||||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
else: | else: | ||||
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \ | |||||
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | |||||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size |
@@ -22,6 +22,13 @@ from .utils import apply_to_collection | |||||
def paddle_to(data, device: Union[str, int]): | def paddle_to(data, device: Union[str, int]): | ||||
""" | |||||
将 `data` 迁移到指定的 `device` 上 | |||||
:param data: 要迁移的张量 | |||||
:param device: 目标设备,可以是 `str` 或 `int` | |||||
:return: 迁移后的张量 | |||||
""" | |||||
if device == "cpu": | if device == "cpu": | ||||
return data.cpu() | return data.cpu() | ||||
@@ -31,6 +38,9 @@ def paddle_to(data, device: Union[str, int]): | |||||
def get_paddle_gpu_str(device: Union[str, int]): | def get_paddle_gpu_str(device: Union[str, int]): | ||||
""" | """ | ||||
获得 `gpu:x` 类型的设备名 | 获得 `gpu:x` 类型的设备名 | ||||
:param device: 设备编号或设备名 | |||||
:return: 返回对应的 `gpu:x` 格式的设备名 | |||||
""" | """ | ||||
if isinstance(device, str): | if isinstance(device, str): | ||||
return device.replace("cuda", "gpu") | return device.replace("cuda", "gpu") | ||||
@@ -38,7 +48,10 @@ def get_paddle_gpu_str(device: Union[str, int]): | |||||
def get_paddle_device_id(device: Union[str, int]): | def get_paddle_device_id(device: Union[str, int]): | ||||
""" | """ | ||||
获得 gpu 的设备id,注意不要传入 `cpu` 。 | |||||
获得 gpu 的设备id | |||||
:param: device: 设备编号或设备名 | |||||
:return: 设备对应的编号 | |||||
""" | """ | ||||
if isinstance(device, int): | if isinstance(device, int): | ||||
return device | return device | ||||
@@ -16,7 +16,7 @@ from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -53,7 +53,7 @@ def model_and_optimizers(request): | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | feature_dimension=ArgMaxDatasetConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | ||||
dataset = TorchArgMaxDatset( | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | feature_dimension=ArgMaxDatasetConfig.feature_dimension, | ||||
data_num=ArgMaxDatasetConfig.data_num, | data_num=ArgMaxDatasetConfig.data_num, | ||||
seed=ArgMaxDatasetConfig.seed | seed=ArgMaxDatasetConfig.seed | ||||
@@ -19,7 +19,7 @@ from fastNLP.core import Evaluator | |||||
from fastNLP.core.utils.utils import safe_rm | from fastNLP.core.utils.utils import safe_rm | ||||
from fastNLP.core.drivers.torch_driver import TorchSingleDriver | from fastNLP.core.drivers.torch_driver import TorchSingleDriver | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@@ -55,7 +55,7 @@ def model_and_optimizers(request): | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | feature_dimension=ArgMaxDatasetConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = optim.SGD(trainer_params.model.parameters(), lr=0.01) | trainer_params.optimizers = optim.SGD(trainer_params.model.parameters(), lr=0.01) | ||||
dataset = TorchArgMaxDatset( | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | feature_dimension=ArgMaxDatasetConfig.feature_dimension, | ||||
data_num=ArgMaxDatasetConfig.data_num, | data_num=ArgMaxDatasetConfig.data_num, | ||||
seed=ArgMaxDatasetConfig.seed | seed=ArgMaxDatasetConfig.seed | ||||
@@ -24,7 +24,7 @@ from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDatset | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -64,7 +64,7 @@ def model_and_optimizers(request): | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | feature_dimension=ArgMaxDatasetConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | ||||
dataset = TorchArgMaxDatset( | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | feature_dimension=ArgMaxDatasetConfig.feature_dimension, | ||||
data_num=ArgMaxDatasetConfig.data_num, | data_num=ArgMaxDatasetConfig.data_num, | ||||
seed=ArgMaxDatasetConfig.seed | seed=ArgMaxDatasetConfig.seed | ||||
@@ -11,7 +11,7 @@ from torchmetrics import Accuracy | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDatset | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | |||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@@ -80,7 +80,7 @@ def model_and_optimizers(request): | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | feature_dimension=ArgMaxDatasetConfig.feature_dimension | ||||
) | ) | ||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | ||||
dataset = TorchArgMaxDatset( | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | feature_dimension=ArgMaxDatasetConfig.feature_dimension, | ||||
data_num=ArgMaxDatasetConfig.data_num, | data_num=ArgMaxDatasetConfig.data_num, | ||||
seed=ArgMaxDatasetConfig.seed | seed=ArgMaxDatasetConfig.seed | ||||
@@ -527,7 +527,7 @@ class TestSaveLoad: | |||||
@classmethod | @classmethod | ||||
def setup_class(cls): | def setup_class(cls): | ||||
# 不在这里 setup 的话会报错 | # 不在这里 setup 的话会报错 | ||||
cls.driver = generate_driver(10, 10) | |||||
cls.driver = generate_driver(10, 10, device=[0,1]) | |||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = PaddleRandomMaxDataset(20, 10) | self.dataset = PaddleRandomMaxDataset(20, 10) | ||||
@@ -633,7 +633,7 @@ class TestSaveLoad: | |||||
batch_sampler=BucketedBatchSampler( | batch_sampler=BucketedBatchSampler( | ||||
self.dataset, | self.dataset, | ||||
length=[10 for i in range(len(self.dataset))], | length=[10 for i in range(len(self.dataset))], | ||||
batch_size=4, | |||||
batch_size=2, | |||||
) | ) | ||||
) | ) | ||||
dataloader.batch_sampler.set_distributed( | dataloader.batch_sampler.set_distributed( | ||||
@@ -19,7 +19,7 @@ def test_incorrect_driver(): | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "gpu:0", 0, [1]] | |||||
["cpu", "gpu:0", 0] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
@@ -27,7 +27,7 @@ def test_incorrect_driver(): | |||||
) | ) | ||||
def test_get_single_device(driver, device): | def test_get_single_device(driver, device): | ||||
""" | """ | ||||
测试正常情况下初始化PaddleSingleDriver的情况 | |||||
测试正常情况下初始化 PaddleSingleDriver 的情况 | |||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(2, 100) | model = PaddleNormalModel_Classification_1(2, 100) | ||||
@@ -36,7 +36,7 @@ def test_get_single_device(driver, device): | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[0, 1] | |||||
[0, 1, [1]] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
@@ -45,7 +45,7 @@ def test_get_single_device(driver, device): | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_get_fleet_2(driver, device): | def test_get_fleet_2(driver, device): | ||||
""" | """ | ||||
测试 fleet 多卡的初始化情况 | |||||
测试 fleet 多卡的初始化情况,但传入了单个 gpu | |||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(64, 10) | model = PaddleNormalModel_Classification_1(64, 10) | ||||
@@ -34,7 +34,7 @@ class TestPaddleDriverFunctions: | |||||
def test_check_single_optimizer_legality(self): | def test_check_single_optimizer_legality(self): | ||||
""" | """ | ||||
测试传入单个optimizer时的表现 | |||||
测试传入单个 optimizer 时的表现 | |||||
""" | """ | ||||
optimizer = paddle.optimizer.Adam( | optimizer = paddle.optimizer.Adam( | ||||
parameters=self.driver.model.parameters(), | parameters=self.driver.model.parameters(), | ||||
@@ -50,7 +50,7 @@ class TestPaddleDriverFunctions: | |||||
def test_check_optimizers_legality(self): | def test_check_optimizers_legality(self): | ||||
""" | """ | ||||
测试传入optimizer list的表现 | |||||
测试传入 optimizer list 的表现 | |||||
""" | """ | ||||
optimizers = [ | optimizers = [ | ||||
paddle.optimizer.Adam( | paddle.optimizer.Adam( | ||||
@@ -70,13 +70,13 @@ class TestPaddleDriverFunctions: | |||||
def test_check_dataloader_legality_in_train(self): | def test_check_dataloader_legality_in_train(self): | ||||
""" | """ | ||||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
""" | """ | ||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | ||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | ||||
@@ -90,29 +90,29 @@ class TestPaddleDriverFunctions: | |||||
def test_check_dataloader_legality_in_test(self): | def test_check_dataloader_legality_in_test(self): | ||||
""" | """ | ||||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
""" | """ | ||||
# 此时传入的应该是dict | # 此时传入的应该是dict | ||||
dataloader = { | dataloader = { | ||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset()) | |||||
"train": DataLoader(PaddleNormalDataset()), | |||||
"test":DataLoader(PaddleNormalDataset()) | |||||
} | } | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
# batch_size 和 batch_sampler 均为 None 的情形 | # batch_size 和 batch_sampler 均为 None 的情形 | ||||
dataloader = { | dataloader = { | ||||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||||
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
"train": DataLoader(PaddleNormalDataset()), | |||||
"test":DataLoader(PaddleNormalDataset(), batch_size=None) | |||||
} | } | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
# 传入的不是dict,应该报错 | |||||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||||
# 传入的不是 dict ,应该报错 | |||||
dataloader = DataLoader(PaddleNormalDataset()) | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
# 创建torch的dataloader | |||||
# 创建 torch 的 dataloader | |||||
train_loader = torch.utils.data.DataLoader( | train_loader = torch.utils.data.DataLoader( | ||||
TorchNormalDataset(), | TorchNormalDataset(), | ||||
batch_size=32, shuffle=True | batch_size=32, shuffle=True | ||||
@@ -127,7 +127,7 @@ class TestPaddleDriverFunctions: | |||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
""" | """ | ||||
测试tensor_to_numeric函数 | |||||
测试 tensor_to_numeric 函数 | |||||
""" | """ | ||||
# 单个张量 | # 单个张量 | ||||
tensor = paddle.to_tensor(3) | tensor = paddle.to_tensor(3) | ||||
@@ -180,7 +180,7 @@ class TestPaddleDriverFunctions: | |||||
def test_set_model_mode(self): | def test_set_model_mode(self): | ||||
""" | """ | ||||
测试set_model_mode函数 | |||||
测试 set_model_mode 函数 | |||||
""" | """ | ||||
self.driver.set_model_mode("train") | self.driver.set_model_mode("train") | ||||
assert self.driver.model.training | assert self.driver.model.training | ||||
@@ -192,14 +192,14 @@ class TestPaddleDriverFunctions: | |||||
def test_move_model_to_device_cpu(self): | def test_move_model_to_device_cpu(self): | ||||
""" | """ | ||||
测试move_model_to_device函数 | |||||
测试 move_model_to_device 函数 | |||||
""" | """ | ||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | ||||
assert self.driver.model.linear1.weight.place.is_cpu_place() | assert self.driver.model.linear1.weight.place.is_cpu_place() | ||||
def test_move_model_to_device_gpu(self): | def test_move_model_to_device_gpu(self): | ||||
""" | """ | ||||
测试move_model_to_device函数 | |||||
测试 move_model_to_device 函数 | |||||
""" | """ | ||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | ||||
assert self.driver.model.linear1.weight.place.is_gpu_place() | assert self.driver.model.linear1.weight.place.is_gpu_place() | ||||
@@ -207,7 +207,7 @@ class TestPaddleDriverFunctions: | |||||
def test_worker_init_function(self): | def test_worker_init_function(self): | ||||
""" | """ | ||||
测试worker_init_function | |||||
测试 worker_init_function | |||||
""" | """ | ||||
# 先确保不影响运行 | # 先确保不影响运行 | ||||
# TODO:正确性 | # TODO:正确性 | ||||
@@ -215,7 +215,7 @@ class TestPaddleDriverFunctions: | |||||
def test_set_deterministic_dataloader(self): | def test_set_deterministic_dataloader(self): | ||||
""" | """ | ||||
测试set_deterministic_dataloader | |||||
测试 set_deterministic_dataloader | |||||
""" | """ | ||||
# 先确保不影响运行 | # 先确保不影响运行 | ||||
# TODO:正确性 | # TODO:正确性 | ||||
@@ -224,7 +224,7 @@ class TestPaddleDriverFunctions: | |||||
def test_set_sampler_epoch(self): | def test_set_sampler_epoch(self): | ||||
""" | """ | ||||
测试set_sampler_epoch | |||||
测试 set_sampler_epoch | |||||
""" | """ | ||||
# 先确保不影响运行 | # 先确保不影响运行 | ||||
# TODO:正确性 | # TODO:正确性 | ||||
@@ -336,7 +336,7 @@ class TestSingleDeviceFunction: | |||||
def test_move_data_to_device(self): | def test_move_data_to_device(self): | ||||
""" | """ | ||||
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||||
这个函数仅调用了 paddle_move_data_to_device ,测试例在 tests/core/utils/test_paddle_utils.py 中 | |||||
就不重复测试了 | 就不重复测试了 | ||||
""" | """ | ||||
self.driver.move_data_to_device(paddle.rand((32, 64))) | self.driver.move_data_to_device(paddle.rand((32, 64))) | ||||
@@ -490,9 +490,6 @@ class TestSetDistReproDataloader: | |||||
else: | else: | ||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | ||||
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | ||||
left_idxes = set() | left_idxes = set() | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | ||||
@@ -510,7 +507,6 @@ class TestSetDistReproDataloader: | |||||
new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
else: | else: | ||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
num_consumed_samples = num_consumed_batches * batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
# 重新构造 dataloader | # 重新构造 dataloader | ||||
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) | batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) | ||||
@@ -0,0 +1,788 @@ | |||||
import pytest | |||||
import os | |||||
from pathlib import Path | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from fastNLP.core.samplers import ( | |||||
RandomSampler, | |||||
UnrepeatedSampler, | |||||
BucketedBatchSampler, | |||||
UnrepeatedRandomSampler, | |||||
UnrepeatedSequentialSampler, | |||||
) | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import rank_zero_rm | |||||
import torch | |||||
import torch.distributed as dist | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||||
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | |||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
device = [torch.device(i) for i in device] | |||||
driver = TorchDDPDriver( | |||||
model=torch_model, | |||||
parallel_device=device, | |||||
fp16=fp16, | |||||
output_from_new_proc=output_from_new_proc | |||||
) | |||||
driver.set_optimizers(torch_opt) | |||||
driver.setup() | |||||
return driver | |||||
def dataloader_with_bucketedbatchsampler(dataset, length, batch_size, shuffle, drop_last): | |||||
""" | |||||
建立一个 batch_sampler 为 BucketedBatchSampler 的 dataloader | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=BucketedBatchSampler( | |||||
dataset, | |||||
length, | |||||
batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, | |||||
), | |||||
) | |||||
return dataloader | |||||
def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0, unrepeated=False): | |||||
""" | |||||
建立一个 sampler 为 RandomSampler 的 dataloader | |||||
""" | |||||
if unrepeated: | |||||
sampler = UnrepeatedRandomSampler(dataset, shuffle, seed) | |||||
else: | |||||
sampler = RandomSampler(dataset, shuffle, seed=seed) | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=sampler, | |||||
drop_last=drop_last, | |||||
batch_size=batch_size | |||||
) | |||||
return dataloader | |||||
############################################################################ | |||||
# | |||||
# 测试 TorchDDPDriver 的一些函数 | |||||
# | |||||
############################################################################ | |||||
class TestDDPDriverFunction: | |||||
""" | |||||
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | |||||
""" | |||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.driver = generate_driver(10, 10) | |||||
@magic_argv_env_context | |||||
def test_multi_drivers(self): | |||||
""" | |||||
测试使用了多个 TorchDDPDriver 的情况。 | |||||
""" | |||||
driver2 = generate_driver(20, 10) | |||||
with pytest.raises(RuntimeError): | |||||
# 设备设置不同,应该报错 | |||||
driver3 = generate_driver(20, 3, device=[0,1,2]) | |||||
assert False | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_move_data_to_device(self): | |||||
""" | |||||
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 | |||||
就不重复测试了 | |||||
""" | |||||
self.driver.move_data_to_device(torch.rand((32, 64))) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_is_distributed(self): | |||||
""" | |||||
测试 is_distributed 函数 | |||||
""" | |||||
assert self.driver.is_distributed() == True | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_get_no_sync_context(self): | |||||
""" | |||||
测试 get_no_sync_context 函数 | |||||
""" | |||||
res = self.driver.get_model_no_sync_context() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_is_global_zero(self): | |||||
""" | |||||
测试 is_global_zero 函数 | |||||
""" | |||||
self.driver.is_global_zero() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_unwrap_model(self): | |||||
""" | |||||
测试 unwrap_model 函数 | |||||
""" | |||||
self.driver.unwrap_model() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_get_local_rank(self): | |||||
""" | |||||
测试 get_local_rank 函数 | |||||
""" | |||||
self.driver.get_local_rank() | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
def test_all_gather(self): | |||||
""" | |||||
测试 all_gather 函数 | |||||
详细的测试在 test_dist_utils.py 中完成 | |||||
""" | |||||
obj = { | |||||
"rank": self.driver.global_rank | |||||
} | |||||
obj_list = self.driver.all_gather(obj, group=None) | |||||
for i, res in enumerate(obj_list): | |||||
assert res["rank"] == i | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("src_rank", ([0, 1])) | |||||
def test_broadcast_object(self, src_rank): | |||||
""" | |||||
测试 broadcast_object 函数 | |||||
详细的函数在 test_dist_utils.py 中完成 | |||||
""" | |||||
if self.driver.global_rank == src_rank: | |||||
obj = { | |||||
"rank": self.driver.global_rank | |||||
} | |||||
else: | |||||
obj = None | |||||
res = self.driver.broadcast_object(obj, src=src_rank) | |||||
assert res["rank"] == src_rank | |||||
############################################################################ | |||||
# | |||||
# 测试 set_dist_repro_dataloader 函数 | |||||
# | |||||
############################################################################ | |||||
class TestSetDistReproDataloader: | |||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.device = [0, 1] | |||||
cls.driver = generate_driver(10, 10, device=cls.device) | |||||
def setup_method(self): | |||||
self.dataset = TorchNormalDataset(40) | |||||
""" | |||||
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 | |||||
此时对应 driver.load 中的情况 | |||||
""" | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | |||||
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | |||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler is batch_sampler | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | |||||
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | |||||
sampler = RandomSampler(self.dataset, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is sampler | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | |||||
参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 | |||||
当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 | |||||
是否重新实例化 dataloader | |||||
""" | |||||
@magic_argv_env_context | |||||
def test_with_dist_none_reproducible_true(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | |||||
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | |||||
with pytest.raises(RuntimeError): | |||||
# 应当抛出 RuntimeError | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
# @pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler | |||||
时的表现 | |||||
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler | |||||
和原 dataloader 相同 | |||||
""" | |||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
pad=True | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
self.check_distributed_sampler(dataloader.batch_sampler) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 | |||||
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 | |||||
batch_sampler.sampler 和原 dataloader 相同 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.drop_last == False | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | |||||
此时直接返回原来的 dataloader,不做任何处理。 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert replaced_loader is dataloader | |||||
dist.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader | |||||
""" | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset=self.dataset, | |||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | |||||
) | |||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_dist_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 | |||||
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
""" | |||||
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | |||||
为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader | |||||
""" | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler | |||||
的表现 | |||||
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_unrepeat_dataloader_normal(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 | |||||
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 | |||||
的属性 | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | |||||
dist.barrier() | |||||
def check_distributed_sampler(self, sampler): | |||||
""" | |||||
测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 | |||||
""" | |||||
assert sampler.num_replicas == dist.get_world_size() | |||||
assert sampler.rank == dist.get_rank() | |||||
if not isinstance(sampler, UnrepeatedSampler): | |||||
assert sampler.pad == True | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): | |||||
""" | |||||
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
num_replicas = len(self.device) | |||||
num_consumed_batches = 2 | |||||
already_seen_idx = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_idx.update(batch) | |||||
dist.barrier() | |||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||||
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||||
# 重新改造 dataloader | |||||
new_loader = dataloader_with_bucketedbatchsampler( | |||||
replaced_loader.dataset, | |||||
length=replaced_loader.dataset._data, | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=False, | |||||
) | |||||
new_loader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
pad=True | |||||
) | |||||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||||
else: | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas | |||||
# 重新构造 dataloader | |||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) | |||||
new_loader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
) | |||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||||
for idx, batch in enumerate(new_loader): | |||||
left_idxes.update(batch) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas | |||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas | |||||
############################################################################ | |||||
# | |||||
# 测试 save 和 load 相关的功能 | |||||
# | |||||
############################################################################ | |||||
class TestSaveLoad: | |||||
""" | |||||
测试多卡情况下 save 和 load 相关函数的表现 | |||||
""" | |||||
@classmethod | |||||
def setup_class(cls): | |||||
# 不在这里 setup 的话会报错 | |||||
cls.driver = generate_driver(10, 10) | |||||
def setup_method(self): | |||||
self.dataset = TorchArgMaxDataset(10, 20) | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_model(self, only_state_dict): | |||||
""" | |||||
测试 save_model 和 load_model 函数 | |||||
""" | |||||
try: | |||||
path = "model" | |||||
dataloader = DataLoader(self.dataset, batch_size=2) | |||||
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
self.driver1.save_model(path, only_state_dict) | |||||
# 同步 | |||||
dist.barrier() | |||||
self.driver2.load_model(path, only_state_dict) | |||||
for idx, batch in enumerate(dataloader): | |||||
batch = self.driver1.move_data_to_device(batch) | |||||
res1 = self.driver1.model( | |||||
batch, | |||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
res2 = self.driver2.model( | |||||
batch, | |||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
@pytest.mark.parametrize("device", ([[0,1]])) | |||||
def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
num_replicas = len(device) | |||||
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
generate_driver(10, 10, device=device, fp16=False) | |||||
dataloader = dataloader_with_bucketedbatchsampler( | |||||
self.dataset, | |||||
length=[10 for i in range(len(self.dataset))], | |||||
batch_size=4, | |||||
shuffle=True, | |||||
drop_last=False | |||||
) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver1.world_size, | |||||
rank=self.driver1.global_rank, | |||||
pad=True | |||||
) | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
# 同步 | |||||
dist.barrier() | |||||
# 保存状态 | |||||
sampler_states = dataloader.batch_sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_bucketedbatchsampler( | |||||
self.dataset, | |||||
length=[10 for i in range(len(self.dataset))], | |||||
batch_size=2, | |||||
shuffle=True, | |||||
drop_last=False | |||||
) | |||||
dataloader.batch_sampler.set_distributed( | |||||
num_replicas=self.driver2.world_size, | |||||
rank=self.driver2.global_rank, | |||||
pad=True | |||||
) | |||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | |||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | |||||
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas | |||||
# 3. 检查 fp16 是否被加载 | |||||
if fp16: | |||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
res1 = self.driver1.model( | |||||
batch, | |||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
res2 = self.driver2.model( | |||||
batch, | |||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
@pytest.mark.parametrize("device", ([[0,1]])) | |||||
def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
num_replicas = len(device) | |||||
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
self.driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver1.world_size, | |||||
rank=self.driver1.global_rank, | |||||
pad=True | |||||
) | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
# 同步 | |||||
dist.barrier() | |||||
# 保存状态 | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
if only_state_dict: | |||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | |||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | |||||
dataloader.batch_sampler.sampler.set_distributed( | |||||
num_replicas=self.driver2.world_size, | |||||
rank=self.driver2.global_rank, | |||||
pad=True | |||||
) | |||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
# 3. 检查 fp16 是否被加载 | |||||
if fp16: | |||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
res1 = self.driver1.model( | |||||
batch, | |||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
res2 = self.driver2.model( | |||||
batch, | |||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | |||||
wo_auto_param_call=False, | |||||
) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | |||||
finally: | |||||
rank_zero_rm(path) |
@@ -0,0 +1,103 @@ | |||||
import os | |||||
import pytest | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | |||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | |||||
from fastNLP.envs import get_gpu_count | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
import torch | |||||
def test_incorrect_driver(): | |||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_torch_driver("paddle", 0, model) | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
["cpu", "cuda:0", 0, torch.device("cuda:0")] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch"] | |||||
) | |||||
def test_get_single_device(driver, device): | |||||
""" | |||||
测试正常情况下初始化TorchSingleDriver的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
assert isinstance(driver, TorchSingleDriver) | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[0, 1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch_ddp"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp_2(driver, device): | |||||
""" | |||||
测试 ddp 多卡的初始化情况,但传入了单个 gpu | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
assert isinstance(driver, TorchDDPDriver) | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[[0, 2, 3], -1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch", "torch_ddp"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp(driver, device): | |||||
""" | |||||
测试 ddp 多卡的初始化情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
assert isinstance(driver, TorchDDPDriver) | |||||
@pytest.mark.parametrize( | |||||
("driver", "device"), | |||||
[("torch_ddp", "cpu")] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp_cpu(driver, device): | |||||
""" | |||||
测试试图在 cpu 上初始化分布式训练的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch", "torch_ddp"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_device_out_of_range(driver, device): | |||||
""" | |||||
测试传入的device超过范围的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(2, 100) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_torch_driver(driver, device, model) |
@@ -0,0 +1,697 @@ | |||||
import os | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
import pytest | |||||
from pathlib import Path | |||||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | |||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from fastNLP.core import rank_zero_rm | |||||
import torch | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
import paddle | |||||
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | |||||
""" | |||||
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader | |||||
""" | |||||
if shuffle: | |||||
sampler = torch.utils.data.RandomSampler(dataset) | |||||
else: | |||||
sampler = torch.utils.data.SequentialSampler(dataset) | |||||
dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_sampler=RandomBatchSampler( | |||||
BatchSampler( | |||||
sampler, batch_size=batch_size, drop_last=drop_last | |||||
), | |||||
batch_size=batch_size, | |||||
drop_last=drop_last, | |||||
), | |||||
) | |||||
return dataloader | |||||
def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0): | |||||
""" | |||||
建立一个 sampler 为 RandomSampler 的 dataloader | |||||
""" | |||||
dataloader = DataLoader( | |||||
dataset, | |||||
sampler=RandomSampler(dataset, shuffle, seed=seed), | |||||
drop_last=drop_last, | |||||
batch_size=batch_size | |||||
) | |||||
return dataloader | |||||
############################################################################ | |||||
# | |||||
# 测试基类 TorchDrvier 中的一些简单函数 | |||||
# | |||||
############################################################################ | |||||
class TestTorchDriverFunctions: | |||||
""" | |||||
使用 TorchSingleDriver 测试基类的函数 | |||||
""" | |||||
@classmethod | |||||
def setup_class(self): | |||||
model = TorchNormalModel_Classification_1(10, 32) | |||||
self.driver = TorchSingleDriver(model, device="cpu") | |||||
def test_check_single_optimizer_legality(self): | |||||
""" | |||||
测试传入单个 optimizer 时的表现 | |||||
""" | |||||
optimizer = torch.optim.Adam( | |||||
params=self.driver.model.parameters(), | |||||
lr=0.01 | |||||
) | |||||
self.driver.set_optimizers(optimizer) | |||||
optimizer = paddle.optimizer.Adam( | |||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||||
learning_rate=0.01, | |||||
) | |||||
# 传入 torch 的 optimize r时,应该报错 ValueError | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizer) | |||||
def test_check_optimizers_legality(self): | |||||
""" | |||||
测试传入 optimizer list 的表现 | |||||
""" | |||||
optimizers = [ | |||||
torch.optim.Adam( | |||||
params=self.driver.model.parameters(), | |||||
lr=0.01 | |||||
) for i in range(10) | |||||
] | |||||
self.driver.set_optimizers(optimizers) | |||||
optimizers += [ | |||||
paddle.optimizer.Adam( | |||||
parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), | |||||
learning_rate=0.01, | |||||
) | |||||
] | |||||
with pytest.raises(ValueError): | |||||
self.driver.set_optimizers(optimizers) | |||||
def test_check_dataloader_legality_in_train(self): | |||||
""" | |||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | |||||
""" | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
# 创建 paddle 的 dataloader | |||||
dataloader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | |||||
def test_check_dataloader_legality_in_test(self): | |||||
""" | |||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | |||||
""" | |||||
# 此时传入的应该是dict | |||||
dataloader = { | |||||
"train": DataLoader(TorchNormalDataset()), | |||||
"test": DataLoader(TorchNormalDataset()) | |||||
} | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 传入的不是 dict,应该报错 | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
# 创建 paddle 的 dataloader | |||||
train_loader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
test_loader = paddle.io.DataLoader( | |||||
PaddleNormalDataset(), | |||||
batch_size=32, shuffle=True | |||||
) | |||||
dataloader = {"train": train_loader, "test": test_loader} | |||||
with pytest.raises(ValueError): | |||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | |||||
def test_tensor_to_numeric(self): | |||||
""" | |||||
测试 tensor_to_numeric 函数 | |||||
""" | |||||
# 单个张量 | |||||
tensor = torch.tensor(3) | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == 3 | |||||
tensor = torch.rand((3, 4)) | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor) | |||||
assert res == tensor.tolist() | |||||
# 张量list | |||||
tensor_list = [torch.rand((6, 4, 2)) for i in range(10)] | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor_list) | |||||
assert isinstance(res, list) | |||||
tensor_list = [t.tolist() for t in tensor_list] | |||||
assert res == tensor_list | |||||
# 张量tuple | |||||
tensor_tuple = tuple([torch.rand((6, 4, 2)) for i in range(10)]) | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor_tuple) | |||||
assert isinstance(res, tuple) | |||||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||||
assert res == tensor_tuple | |||||
# 张量dict | |||||
tensor_dict = { | |||||
"tensor": torch.rand((3, 4)), | |||||
"list": [torch.rand((6, 4, 2)) for i in range(10)], | |||||
"dict":{ | |||||
"list": [torch.rand((6, 4, 2)) for i in range(10)], | |||||
"tensor": torch.rand((3, 4)) | |||||
}, | |||||
"int": 2, | |||||
"string": "test string" | |||||
} | |||||
res = TorchSingleDriver.tensor_to_numeric(tensor_dict) | |||||
assert isinstance(res, dict) | |||||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||||
assert isinstance(res["list"], list) | |||||
for r, d in zip(res["list"], tensor_dict["list"]): | |||||
assert r == d.tolist() | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||||
assert r == d.tolist() | |||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||||
def test_set_model_mode(self): | |||||
""" | |||||
测试set_model_mode函数 | |||||
""" | |||||
self.driver.set_model_mode("train") | |||||
assert self.driver.model.training | |||||
self.driver.set_model_mode("eval") | |||||
assert not self.driver.model.training | |||||
# 应该报错 | |||||
with pytest.raises(AssertionError): | |||||
self.driver.set_model_mode("test") | |||||
def test_move_model_to_device_cpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
TorchSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||||
assert self.driver.model.linear1.weight.device.type == "cpu" | |||||
def test_move_model_to_device_gpu(self): | |||||
""" | |||||
测试move_model_to_device函数 | |||||
""" | |||||
TorchSingleDriver.move_model_to_device(self.driver.model, "cuda") | |||||
assert self.driver.model.linear1.weight.device.type == "cuda" | |||||
assert self.driver.model.linear1.weight.device.index == 0 | |||||
def test_worker_init_function(self): | |||||
""" | |||||
测试worker_init_function | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
TorchSingleDriver.worker_init_function(0) | |||||
def test_set_deterministic_dataloader(self): | |||||
""" | |||||
测试set_deterministic_dataloader | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
self.driver.set_deterministic_dataloader(dataloader) | |||||
def test_set_sampler_epoch(self): | |||||
""" | |||||
测试set_sampler_epoch | |||||
""" | |||||
# 先确保不影响运行 | |||||
# TODO:正确性 | |||||
dataloader = DataLoader(TorchNormalDataset()) | |||||
self.driver.set_sampler_epoch(dataloader, 0) | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试正常情况下 get_dataloader_args 的表现 | |||||
""" | |||||
dataloader = DataLoader( | |||||
TorchNormalDataset(), | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last, | |||||
) | |||||
res = TorchSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, TorchNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, torch.utils.data.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, torch.utils.data.SequentialSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 batch_sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = TorchNormalDataset() | |||||
dataloader = dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last) | |||||
res = TorchSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, TorchNormalDataset) | |||||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||||
if shuffle: | |||||
assert isinstance(res.sampler, torch.utils.data.RandomSampler) | |||||
else: | |||||
assert isinstance(res.sampler, torch.utils.data.SequentialSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
@pytest.mark.parametrize("batch_size", [16]) | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
@pytest.mark.parametrize("drop_last", [True, False]) | |||||
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): | |||||
""" | |||||
测试替换了 sampler 后 get_dataloader_args 的表现 | |||||
""" | |||||
dataset = TorchNormalDataset() | |||||
dataloader = dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last) | |||||
res = TorchSingleDriver.get_dataloader_args(dataloader) | |||||
assert isinstance(res.dataset, TorchNormalDataset) | |||||
assert isinstance(res.batch_sampler, BatchSampler) | |||||
assert isinstance(res.sampler, RandomSampler) | |||||
assert res.shuffle == shuffle | |||||
assert res.batch_size == batch_size | |||||
assert res.drop_last == drop_last | |||||
############################################################################ | |||||
# | |||||
# 测试 TorchSingleDrvier 中的一些简单函数 | |||||
# | |||||
############################################################################ | |||||
class TestSingleDeviceFunction: | |||||
""" | |||||
测试其它函数的测试例 | |||||
""" | |||||
@classmethod | |||||
def setup_class(cls): | |||||
model = TorchNormalModel_Classification_1(10, 784) | |||||
cls.driver = TorchSingleDriver(model, device="cpu") | |||||
def test_unwrap_model(self): | |||||
""" | |||||
测试能否运行 | |||||
""" | |||||
res = self.driver.unwrap_model() | |||||
assert res is self.driver.model | |||||
def test_is_distributed(self): | |||||
assert self.driver.is_distributed() == False | |||||
def test_move_data_to_device(self): | |||||
""" | |||||
这个函数仅调用了 torch_move_data_to_device ,测试例在 tests/core/utils/test_torch_utils.py 中 | |||||
就不重复测试了 | |||||
""" | |||||
self.driver.move_data_to_device(torch.rand((32, 64))) | |||||
############################################################################ | |||||
# | |||||
# 测试 set_dist_repro_dataloader 函数 | |||||
# | |||||
############################################################################ | |||||
class TestSetDistReproDataloader: | |||||
""" | |||||
专门测试 set_dist_repro_dataloader 函数的类 | |||||
""" | |||||
def setup_method(self): | |||||
self.dataset = TorchNormalDataset(20) | |||||
model = TorchNormalModel_Classification_1(10, 32) | |||||
self.driver = TorchSingleDriver(model, device="cpu") | |||||
def test_with_reproducible_false(self): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 | |||||
当dist为字符串时,此时应该返回原来的 dataloader | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert replaced_loader is dataloader | |||||
@pytest.mark.parametrize("shuffle", [True, False]) | |||||
def test_with_reproducible_true(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | |||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | |||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | |||||
assert not (replaced_loader is dataloader) | |||||
if shuffle: | |||||
# 此时会替换 sampler | |||||
assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
else: | |||||
# 此时会替换 batch_sampler | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | |||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dist_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 | |||||
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler | |||||
""" | |||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) | |||||
dist = RandomSampler(self.dataset, shuffle=shuffle) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.sampler is dist | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dataloader_reproducible_batch_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = dataloader_with_randombatchsampler(self.dataset, 4, shuffle, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | |||||
assert replaced_loader.drop_last == dataloader.drop_last | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | |||||
def test_with_dataloader_reproducible_sampler(self, shuffle): | |||||
""" | |||||
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 | |||||
应该返回新的 dataloader,且其余各项设置和原来相同 | |||||
""" | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, shuffle, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | |||||
assert not (replaced_loader is dataloader) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) | |||||
assert replaced_loader.batch_sampler.batch_size == 2 | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): | |||||
""" | |||||
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | |||||
""" | |||||
# 迭代两个 batch | |||||
num_consumed_batches = 2 | |||||
already_seen_idx = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_idx.update(batch) | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | |||||
else: | |||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | |||||
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range | |||||
left_idxes = set() | |||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
# 重新改造 dataloader | |||||
new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||||
new_loader.batch_sampler.load_state_dict(sampler_states) | |||||
else: | |||||
batch_size = replaced_loader.batch_sampler.batch_size | |||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | |||||
# 重新构造 dataloader | |||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) | |||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | |||||
for idx, batch in enumerate(new_loader): | |||||
left_idxes.update(batch) | |||||
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) | |||||
assert len(left_idxes | already_seen_idx) == len(self.dataset) | |||||
############################################################################ | |||||
# | |||||
# 测试 save 和 load 相关的功能 | |||||
# | |||||
############################################################################ | |||||
def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
""" | |||||
生成driver | |||||
""" | |||||
model = TorchNormalModel_Classification_1(labels, features) | |||||
opt = torch.optim.Adam(params=model.parameters(), lr=0.01) | |||||
driver = TorchSingleDriver(model, device=device, fp16=fp16) | |||||
driver.set_optimizers(opt) | |||||
driver.setup() | |||||
return driver | |||||
@pytest.fixture | |||||
def prepare_test_save_load(): | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataloader = DataLoader(dataset, batch_size=4) | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
return driver1, driver2, dataloader | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
""" | |||||
测试 save_model 和 load_model 函数 | |||||
""" | |||||
try: | |||||
path = "model" | |||||
driver1, driver2, dataloader = prepare_test_save_load | |||||
driver1.save_model(path, only_state_dict) | |||||
driver2.load_model(path, only_state_dict) | |||||
for batch in dataloader: | |||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) | |||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
sampler_states = dataloader.batch_sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | |||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | |||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | |||||
# 3. 检查 fp16 是否被加载 | |||||
if fp16: | |||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
batch = driver2.move_data_to_device(batch) | |||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) | |||||
assert len(left_x_batches | already_seen_x_set) == len(dataset) | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | |||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | |||||
@pytest.mark.parametrize("fp16", ([True, False])) | |||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
""" | |||||
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 | |||||
""" | |||||
try: | |||||
path = "model.ckp" | |||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataloader = dataloader_with_randomsampler(dataset, 4, True, False) | |||||
num_consumed_batches = 2 | |||||
already_seen_x_set = set() | |||||
already_seen_y_set = set() | |||||
for idx, batch in enumerate(dataloader): | |||||
if idx >= num_consumed_batches: | |||||
break | |||||
already_seen_x_set.update(batch["x"]) | |||||
already_seen_y_set.update(batch["y"]) | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | |||||
save_states = {"num_consumed_batches": num_consumed_batches} | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | |||||
# 更改 batch_size | |||||
dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | |||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | |||||
# 1. 检查 optimizer 的状态 | |||||
# TODO optimizer 的 state_dict 总是为空 | |||||
# 2. 检查 sampler 是否被正确地加载和替换 | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] | |||||
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] | |||||
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches | |||||
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | |||||
# 3. 检查 fp16 是否被加载 | |||||
if fp16: | |||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | |||||
# 5. 检查 batch_idx | |||||
start_batch = load_states.pop('batch_idx_in_epoch') | |||||
assert start_batch == 2 * num_consumed_batches | |||||
left_x_batches = set() | |||||
left_y_batches = set() | |||||
for idx, batch in enumerate(replaced_loader): | |||||
batch = driver2.move_data_to_device(batch) | |||||
left_x_batches.update(batch["x"]) | |||||
left_y_batches.update(batch["y"]) | |||||
res1 = driver1.model.evaluate_step(**batch) | |||||
res2 = driver2.model.evaluate_step(**batch) | |||||
assert torch.equal(res1["preds"], res2["preds"]) | |||||
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) | |||||
assert len(left_x_batches | already_seen_x_set) == len(dataset) | |||||
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) | |||||
assert len(left_y_batches | already_seen_y_set) == len(dataset) | |||||
finally: | |||||
rank_zero_rm(path) |
@@ -1,35 +1,36 @@ | |||||
from torch.utils.data.sampler import SequentialSampler, RandomSampler | |||||
from fastNLP.core.samplers.sampler import ReproduceSampler | |||||
from tests.helpers.datasets.normal_data import NormalIterator | |||||
class TestReproduceSampler: | |||||
def test_sequentialsampler(self): | |||||
normal_iterator = NormalIterator(num_of_data=20) | |||||
sequential_sampler = SequentialSampler(normal_iterator) | |||||
reproduce_sampler = ReproduceSampler(sequential_sampler) | |||||
# iter_seq_sampler = iter(sequential_sampler) | |||||
# for each in iter_seq_sampler: | |||||
# print(each) | |||||
iter_reproduce_sampler = iter(reproduce_sampler) | |||||
forward_step = 3 | |||||
for _ in range(forward_step): | |||||
next(iter_reproduce_sampler) | |||||
state = reproduce_sampler.save_state() | |||||
assert state["current_batch_idx"] == forward_step | |||||
new_repro_sampler = ReproduceSampler(sequential_sampler) | |||||
assert new_repro_sampler.save_state()["current_batch_idx"] == 0 | |||||
new_repro_sampler.load_state(state) | |||||
iter_new_repro_sampler = iter(new_repro_sampler) | |||||
new_index_list = [] | |||||
for each in iter_new_repro_sampler: | |||||
new_index_list.append(each) | |||||
assert new_index_list == list(range(3, 20)) | |||||
import os | |||||
import pytest | |||||
os.environ["FASTNLP_BACKEND"] = "torch" | |||||
from fastNLP.core.drivers.torch_driver.utils import ( | |||||
replace_batch_sampler, | |||||
replace_sampler, | |||||
) | |||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
def test_replace_batch_sampler(): | |||||
dataset = TorchNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, TorchNormalDataset) | |||||
assert len(replaced_loader.dataset) == len(dataset) | |||||
assert replaced_loader.batch_sampler.batch_size == 16 | |||||
def test_replace_sampler(): | |||||
dataset = TorchNormalDataset(10) | |||||
dataloader = DataLoader(dataset, batch_size=32) | |||||
sampler = RandomSampler(dataset) | |||||
replaced_loader = replace_sampler(dataloader, sampler) | |||||
assert not (replaced_loader is dataloader) | |||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
@@ -38,7 +38,7 @@ class TorchNormalDataset_Classification(Dataset): | |||||
return {"x": self.x[item], "y": self.y[item]} | return {"x": self.x[item], "y": self.y[item]} | ||||
class TorchArgMaxDatset(Dataset): | |||||
class TorchArgMaxDataset(Dataset): | |||||
def __init__(self, feature_dimension=10, data_num=1000, seed=0): | def __init__(self, feature_dimension=10, data_num=1000, seed=0): | ||||
self.num_labels = feature_dimension | self.num_labels = feature_dimension | ||||
self.feature_dimension = feature_dimension | self.feature_dimension = feature_dimension | ||||