Browse Source

删除了driver的replace_sampler替换为set_dist_repro_dataloader; 同时修改 driver.load/driver.save 函数

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
8e4abf2aa5
15 changed files with 148 additions and 132 deletions
  1. +1
    -5
      fastNLP/core/controllers/evaluator.py
  2. +17
    -55
      fastNLP/core/controllers/trainer.py
  3. +47
    -34
      fastNLP/core/drivers/driver.py
  4. +2
    -1
      fastNLP/core/drivers/jittor_driver/mpi.py
  5. +5
    -4
      fastNLP/core/drivers/jittor_driver/single_device.py
  6. +7
    -6
      fastNLP/core/drivers/paddle_driver/fleet.py
  7. +6
    -5
      fastNLP/core/drivers/paddle_driver/single_device.py
  8. +8
    -7
      fastNLP/core/drivers/torch_driver/ddp.py
  9. +6
    -6
      fastNLP/core/drivers/torch_driver/single_device.py
  10. +17
    -0
      fastNLP/core/samplers/reproducible_sampler.py
  11. +1
    -1
      fastNLP/envs/set_env_on_import.py
  12. +1
    -1
      requirements.txt
  13. +8
    -5
      tests/core/drivers/paddle_driver/test_fleet.py
  14. +1
    -1
      tests/core/drivers/paddle_driver/test_single_device.py
  15. +21
    -1
      tests/core/drivers/torch_driver/test_torch_replace_sampler.py

+ 1
- 5
fastNLP/core/controllers/evaluator.py View File

@@ -124,11 +124,7 @@ class Evaluator:

self.dataloaders = {}
for name, dl in dataloaders.items(): # 替换为正确的 sampler
dl = self.driver.replace_sampler(
dataloader=dl,
dist_sampler=self._dist_sampler,
reproducible=False
)
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False)
self.dataloaders[name] = dl

self.progress_bar = kwargs.get('progress_bar', 'auto')


+ 17
- 55
fastNLP/core/controllers/trainer.py View File

@@ -250,11 +250,8 @@ class Trainer(TrainerEventTrigger):
self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.replace_sampler(
dataloader=self.train_dataloader,
dist_sampler=_dist_sampler,
reproducible=self.callback_manager.has_trainer_chechpoint
)
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager.has_trainer_chechpoint)

self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
self.on_after_trainer_initialized(self.driver)
@@ -578,22 +575,6 @@ class Trainer(TrainerEventTrigger):
else:
states["val_filter_state"] = None

# 4. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
dataloader_args = self.driver.get_dataloader_args(self.dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler
elif dataloader_args.sampler:
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")

if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
states['sampler_states'] = sampler.state_dict()
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')
if isinstance(folder, str):
folder = Path(folder)

@@ -601,9 +582,9 @@ class Trainer(TrainerEventTrigger):
if not callable(model_save_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
rank_zero_call(model_save_fn)(folder)
self.driver.save(folder=folder, states=states, should_save_model=False, **kwargs)
self.driver.save(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs)
else:
self.driver.save(folder=folder, states=states,
self.driver.save(folder=folder, dataloader=self.dataloader, states=states,
only_state_dict=only_state_dict, should_save_model=True, **kwargs)

self.driver.barrier()
@@ -616,9 +597,6 @@ class Trainer(TrainerEventTrigger):
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator;

注意我们目前不支持单卡到多卡的断点重训;
TODO:注意我们目前不支持 RandomSampler、BucketedSampler 或者 SortedSampler 之间的断点重训;
因此如果用户自己需要使用 BucketedSampler,那么其需要自己在 Trainer 之前初始化 BucketedSampler,然后替换原始 Dataloader 中的
sampler,不管其是第一次断点重训,还是之后的加载的重新训练;

:param folder: 保存断点重训 states 的文件地址;
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们
@@ -627,33 +605,23 @@ class Trainer(TrainerEventTrigger):
self.driver.barrier()
if isinstance(folder, str):
folder = Path(folder)

dataloader = self.dataloader
if not resume_training:
dataloader = None

if model_load_fn is not None:
if not callable(model_load_fn):
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
raise ValueError("Parameter `model_save_fn` should be `Callable`.")
rank_zero_call(model_load_fn)(folder)
states = self.driver.load(folder=folder, should_load_model=False, **kwargs)
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
else:
states = self.driver.load(folder=folder, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)

if not resume_training:
return

# 1. 恢复 sampler 的状态;
dataloader_args = self.driver.get_dataloader_args(self.dataloader)

sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了
if self.driver.is_distributed():
raise RuntimeError("It is not allowed to use single device checkpoint retraining before but ddp now.")
sampler = ReproducibleBatchSampler(
batch_sampler=sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last
)
sampler.load_state_dict(states['sampler_states'])

self.driver.replace_sampler(self.dataloader, sampler)
self.dataloader = states.pop('dataloader')

# 2. validate filter state;
if self.evaluator is not None:
@@ -668,22 +636,16 @@ class Trainer(TrainerEventTrigger):

# 4. 修改 trainer_state.batch_idx_in_epoch
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
if not isinstance(sampler, ReproducibleBatchSampler):
if dataloader_args.drop_last:
self.trainer_state.batch_idx_in_epoch = len(sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size
else:
self.trainer_state.batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size
# sampler 是 batch_sampler;
else:
self.trainer_state.batch_idx_in_epoch = sampler.batch_idx_in_epoch
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')

# 5. 恢复所有 callback 的状态;
self.on_load_checkpoint(states["callback_states"])

self.driver.barrier()

""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 step 函数) 的 """
""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 batch_step_fn 函数) 的 """

def train_step(self, batch):
with self.driver.auto_cast():


+ 47
- 34
fastNLP/core/drivers/driver.py View File

@@ -2,7 +2,7 @@ import os
import signal
import sys
from typing import Any, Sequence, List, Optional, Callable, Dict, Union
from abc import ABC
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from io import BytesIO
@@ -14,7 +14,6 @@ __all__ = [
from fastNLP.core.utils import nullcontext


# todo 航总 check 一下哪一些方法需要 @abstractmethod;
class Driver(ABC):
r"""
用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类;
@@ -32,29 +31,33 @@ class Driver(ABC):
# self._consensus_file: Optional[Union[str, Path]] = None
self._pids: Optional[List[int]] = None

@abstractmethod
def setup(self):
r"""
该函数用来初始化训练环境,例如将模型迁移到对应的设备上等;
多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值;
"""

def replace_sampler(self, dataloader, dist_sampler: Optional[str], reproducible: bool = False):
def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False):
r"""
因为一些特殊的情况需要替换 dataloader 的 sampler,而每一个 driver 中的该函数会提供该功能;例如在多卡训练的中,我们
需要将 sampler 替换为 distributed sampler;以及如果用户在 Trainer 中加入了断点重训的 callback,那么我们就需要将 sampler 替换
为 reproducible sampler;

:param dataloader: 由 trainer 中传入的原始的 dataloader;
:param dist_sampler: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];用于指定使用怎样的 sampler;
目前该参数被定制为分布式训练服务,其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist",否则为 None;
evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
:param reproducible: 用于在 `Trainer` 中指定是否替换为断点重训的 sampler(多卡) 或者 batch_sampler(单卡);如果是单卡的 Driver,
并且该参数为 True,表示当前正在断点重训,那么我们就会使用我们的 `ReproducibleBatchSampler` 来替换 dataloader 原本的 batch_sampler;
如果是多卡的 Driver,那么我们就会用 `RandomSampler` 替换 dataloader 原本的 sampler;

:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;
"""
raise NotImplementedError("Each specific driver should implemented its own `replace_sampler` function.")
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。

:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
可以可以加载。
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
"""
if dist is None and reproducible is False:
return dataloader
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `set_dist_repro_dataloader` "
f"function.")

def set_deterministic_dataloader(self, dataloader):
r"""
@@ -68,7 +71,7 @@ class Driver(ABC):

:param cur_epoch_idx: 当前是第几个 epoch;
"""
@abstractmethod
def train_step(self, batch):
"""
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程;
@@ -103,7 +106,7 @@ class Driver(ABC):
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么
我们应当提醒用户这一行为;
"""
raise NotImplementedError("Each specific driver should implemented its own `predict_step` function.")
raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.")

@property
def model(self):
@@ -234,6 +237,7 @@ class Driver(ABC):
"""
self.optimizers = optimizers

@abstractmethod
def backward(self, loss):
"""
实现深度学习中的反向传播过程;
@@ -242,12 +246,14 @@ class Driver(ABC):
"""
raise NotImplementedError("Each specific driver should implemented its own `backward` function.")

@abstractmethod
def step(self):
r"""
实现深度学习中的参数的优化更新过程,应当直接通过优化器 optimizers 来更新参数;
"""
raise NotImplementedError("Each specific driver should implemented its own `step` function.")

@abstractmethod
def zero_grad(self, set_to_none: bool = False):
r"""
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零;
@@ -286,6 +292,7 @@ class Driver(ABC):
def auto_cast(self, auto_cast):
self._auto_cast = auto_cast

@abstractmethod
def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs):
r"""
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数;
@@ -296,6 +303,7 @@ class Driver(ABC):
"""
raise NotImplementedError("Each specific driver should implemented its own `save_model` function.")

@abstractmethod
def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs):
r"""
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。
@@ -307,7 +315,8 @@ class Driver(ABC):
"""
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.")

def save(self, folder, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
@abstractmethod
def save(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
r"""
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True)

@@ -317,12 +326,14 @@ class Driver(ABC):
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的
传入的值保持一致。
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。
"""
raise NotImplementedError("Each specific driver should implemented its own `save` function.")

def load(self, folder: Union[str, Path], only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict:
@abstractmethod
def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict:
r"""
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和;
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。
@@ -331,11 +342,22 @@ class Driver(ABC):

:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME
(如果 should_load_model 为True)。
:param dataloader: 当前给定 dataloader,需要根据 save 的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader'
以及 'batch_idx_in_epoch' 这两个值。
:param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为
False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有
找到对应的模型状态,则报错。
:return: 需要返回 save 函数输入的 states 内容;
:return: 需要返回 save 函数输入的 states 内容
'dataloader',返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。
在保存与当前传入 data sample 数目不一致时报错。
'batch_idx_in_epoch': int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的
数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此
只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size);
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。
"""
raise NotImplementedError("Each specific driver should implemented its own `load` function.")

@@ -352,6 +374,7 @@ class Driver(ABC):
"""
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.")

@abstractmethod
def set_model_mode(self, mode: str):
r"""
设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式;
@@ -378,6 +401,7 @@ class Driver(ABC):
中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数;
"""

@abstractmethod
def move_data_to_device(self, batch):
r"""
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
@@ -399,17 +423,6 @@ class Driver(ABC):
仅在多分布式训练场景中有使用。
"""

@staticmethod
def get_dataloader_args(dataloader):
"""
用于从 dataloader 中抽取一些属性的值,返回的dataclass中必须包含以下的key:
sampler, batch_sampler, batch_size, drop_last;

:param dataloader:
:return: 返回一个 dataclass,其实例属性应当包括以上的各个属性,并且其名字也应当与这些属性相同,从而方便 trainer 或者其它对象调用;
"""
raise NotImplementedError("Each specific driver should implemented its own `get_dataloader_args` function.")

def is_distributed(self) -> bool:
"""
当前的 driver 实例是否是分布式的;


+ 2
- 1
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -70,7 +70,8 @@ class JittorMPIDriver(JittorDriver):
def test_step(self, batch):
return self._test_step(batch)

def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False):
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
pass

def backward(self, loss):


+ 5
- 4
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -99,14 +99,15 @@ class JittorSingleDriver(JittorDriver):
def is_distributed(self):
return False

def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False):
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
reproducible: bool = False, sampler_or_batch_sampler=None):
# reproducible 的相关功能暂时没有实现
if isinstance(dist_sampler, ReproducibleBatchSampler):
if isinstance(dist, ReproducibleBatchSampler):
raise NotImplementedError
dataloader.batch_sampler = dist_sample
if isinstance(dist_sampler, ReproducibleIterator):
if isinstance(dist, ReproducibleIterator):
raise NotImplementedError
dataloader.batch_sampler.sampler = dist_sampler
dataloader.batch_sampler.sampler = dist

if reproducible:
raise NotImplementedError


+ 7
- 6
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -316,13 +316,14 @@ class PaddleFleetDriver(PaddleDriver):
def test_step(self, batch):
return self._test_step(batch)

def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False):
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
if isinstance(dist_sampler, ReproducibleIterator):
dataloader.batch_sampler.sampler = dist_sampler
if isinstance(dist, ReproducibleIterator):
dataloader.batch_sampler.sampler = dist
return dataloader

# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断
@@ -334,14 +335,14 @@ class PaddleFleetDriver(PaddleDriver):
shuffle = dataloader.batch_sampler.shuffle

# trainer, evaluator
if dist_sampler is None:
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our "
"control.")
else:
return dataloader
# trainer
elif dist_sampler == "dist":
elif dist == "dist":
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
dataloader.batch_sampler.sampler.set_distributed(
@@ -364,7 +365,7 @@ class PaddleFleetDriver(PaddleDriver):
dataloader.batch_sampler.sampler = sampler
return dataloader
# evaluator
elif dist_sampler == "unrepeatdist":
elif dist == "unrepeatdist":
sampler = UnrepeatedDistributedSampler(
dataset=dataloader.dataset,
shuffle=shuffle,


+ 6
- 5
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -133,15 +133,16 @@ class PaddleSingleDriver(PaddleDriver):
"""
return paddle_move_data_to_device(batch, "gpu:0")

def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False):
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持IteratorDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now."
if isinstance(dist_sampler, ReproducibleBatchSampler):
dataloader.batch_sampler = dist_sampler
if isinstance(dist, ReproducibleBatchSampler):
dataloader.batch_sampler = dist
return dataloader
if isinstance(dist_sampler, ReproducibleIterator):
dataloader.batch_sampler.sampler = dist_sampler
if isinstance(dist, ReproducibleIterator):
dataloader.batch_sampler.sampler = dist
return dataloader

if reproducible:


+ 8
- 7
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -445,21 +445,22 @@ class TorchDDPDriver(TorchDriver):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch)

def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False):
if isinstance(dist_sampler, ReproducibleIterator):
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
if isinstance(dist, ReproducibleIterator):
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
dist_sampler = re_instantiate_sampler(dist_sampler)
return replace_sampler(dataloader, dist_sampler)
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)

# trainer, evaluator
if dist_sampler is None:
if dist is None:
if reproducible:
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our "
"control.")
else:
return dataloader
# trainer
elif dist_sampler == "dist":
elif dist == "dist":
args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
if isinstance(args.sampler, ReproducibleIterator):
@@ -485,7 +486,7 @@ class TorchDDPDriver(TorchDriver):
return replace_sampler(dataloader, sampler)

# evaluator
elif dist_sampler == "unrepeatdist":
elif dist == "unrepeatdist":
args = self.get_dataloader_args(dataloader)
sampler = UnrepeatedDistributedSampler(
dataset=args.dataset,


+ 6
- 6
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -130,12 +130,12 @@ class TorchSingleDriver(TorchDriver):
else:
return self._test_step(batch)

def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
reproducible: bool = False):
if isinstance(dist_sampler, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist_sampler)
elif isinstance(dist_sampler, ReproducibleIterator):
return replace_sampler(dataloader, dist_sampler)
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
reproducible: bool = False, sampler_or_batch_sampler=None):
if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
return replace_sampler(dataloader, dist)

if reproducible:
args = self.get_dataloader_args(dataloader)


+ 17
- 0
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -50,6 +50,14 @@ class ReproducibleIterator:

class RandomSampler(ReproducibleIterator):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
"""


:param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 是否在每次 iterate 的时候打乱顺序。
:param seed: 随机数种子。
:param kwargs: 用户不需要使用,fastNLP 内部使用
"""

self.dataset = dataset
self.shuffle = shuffle
@@ -208,6 +216,15 @@ class RandomSampler(ReproducibleIterator):
class ReproducibleBatchSampler:
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
"""
可以使得 batch_sampler 对象状态恢复的 wrapper 。

:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。
:param batch_size: 每个 batch 的大小是多少。
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
:param kwargs: fastNLP 内部使用。
"""
self.batch_sampler = batch_sampler
self.batch_size = batch_size
self.drop_last = drop_last


+ 1
- 1
fastNLP/envs/set_env_on_import.py View File

@@ -15,7 +15,7 @@ def remove_local_rank_in_argv():
"""
index = -1
for i, v in enumerate(sys.argv):
if v.startswith('--rank='):
if v.startswith('--local_rank='):
os.environ['LOCAL_RANK'] = v.split('=')[1]
index = i
break


+ 1
- 1
requirements.txt View File

@@ -3,4 +3,4 @@ prettytable>=0.7.2
requests
regex!=2019.12.17
rich==11.2.0
# fsspec[http]>=2021.05.0, !=2021.06.0
packaging

+ 8
- 5
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -1,12 +1,9 @@
import pytest
import sys
import os
import numpy as np
from fastNLP.envs.set_backend import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle

set_env_on_import_paddle()
set_env("paddle")
import paddle
import paddle.distributed as dist
from paddle.io import DataLoader
@@ -54,6 +51,7 @@ def test_move_data_to_device():

dist.barrier()


@magic_argv_env_context
def test_is_distributed():
print(os.getenv("CUDA_VISIBLE_DEVICES"))
@@ -64,6 +62,7 @@ def test_is_distributed():
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
output_from_new_proc='all'
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
@@ -79,6 +78,7 @@ def test_is_distributed():
synchronize_safe_rm("log")
dist.barrier()


@magic_argv_env_context
def test_get_no_sync_context():
"""
@@ -105,6 +105,7 @@ def test_get_no_sync_context():
synchronize_safe_rm("log")
dist.barrier()


@magic_argv_env_context
def test_is_global_zero():
try:
@@ -128,6 +129,8 @@ def test_is_global_zero():
synchronize_safe_rm("log")
dist.barrier()



@magic_argv_env_context
def test_unwrap_model():
try:
@@ -204,7 +207,7 @@ def test_replace_sampler(dist_sampler, reproducible):
else:
driver.setup()
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
driver.replace_sampler(dataloader, dist_sampler, reproducible)
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
finally:
synchronize_safe_rm("log")
dist.barrier()
@@ -243,7 +246,7 @@ class SingleMachineMultiGPUTrainingTestCase:
parallel_device=gpus,
)
driver.set_optimizers(paddle_opt)
dataloader = driver.replace_sampler(dataloader)
dataloader = driver.set_dist_repro_dataloader(dataloader, )
driver.setup()
# 检查model_device
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}")


+ 1
- 1
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -164,4 +164,4 @@ class TestSingleDeviceFunction:
"""
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)

res = self.driver.replace_sampler(dataloader, dist_sampler, reproducible)
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)

+ 21
- 1
tests/core/drivers/torch_driver/test_torch_replace_sampler.py View File

@@ -33,11 +33,15 @@ def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler
# reproducible 是 True 和 False

# 需要 check 返回的 sampler 和 dataloader 都不同了
assert driver.is_distributed() is False, "This test only for non distributed sampler."
ds = SequenceDataSet(10)
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True)

dl1 = driver.replace_sampler(dataloader, dist_sampler='dist', reproducible=True)
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='dist', reproducible=True)

assert not (dl1.sampler is dataloader.sampler), "The sampler should not the same one."
assert not (dl1 is dataloader), "The dataloader should not the same one."

# 迭代两个 batch
already_seen_idx = set()
@@ -68,6 +72,22 @@ def check_replace_sampler(driver):
assert b not in already_seen_idx
assert b in left_idxes

# 需要 check 替换为 unrepeatdist 的时候没有问题:(1) 不会多pad;(2)所有卡互相不重复
ds = SequenceDataSet(11)
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True)
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='unrepeatdist', reproducible=True)
world_size = 3
indices = []
for i in range(world_size):
dl1.sampler.set_distributed(num_replicas=world_size, rank=i)
for idx, batch in dl1:
indices.extend(batch)
assert len(indices)==len(ds) # 应该没有任何重复
assert len(set(indices))==len(indices) # 应该全是不一样的indice









Loading…
Cancel
Save