@@ -124,11 +124,7 @@ class Evaluator: | |||
self.dataloaders = {} | |||
for name, dl in dataloaders.items(): # 替换为正确的 sampler | |||
dl = self.driver.replace_sampler( | |||
dataloader=dl, | |||
dist_sampler=self._dist_sampler, | |||
reproducible=False | |||
) | |||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False) | |||
self.dataloaders[name] = dl | |||
self.progress_bar = kwargs.get('progress_bar', 'auto') | |||
@@ -250,11 +250,8 @@ class Trainer(TrainerEventTrigger): | |||
self.dataloader = self.train_dataloader | |||
self.driver.set_deterministic_dataloader(self.dataloader) | |||
self.dataloader = self.driver.replace_sampler( | |||
dataloader=self.train_dataloader, | |||
dist_sampler=_dist_sampler, | |||
reproducible=self.callback_manager.has_trainer_chechpoint | |||
) | |||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||
reproducible=self.callback_manager.has_trainer_chechpoint) | |||
self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | |||
self.on_after_trainer_initialized(self.driver) | |||
@@ -263,7 +260,7 @@ class Trainer(TrainerEventTrigger): | |||
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | |||
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | |||
catch_KeyboardInterrupt=True): | |||
catch_KeyboardInterrupt=None): | |||
""" | |||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint | |||
去保存断点重训的文件; | |||
@@ -273,15 +270,17 @@ class Trainer(TrainerEventTrigger): | |||
:param resume_from: 从哪个路径下恢复 trainer 的状态 | |||
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 | |||
:param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 | |||
行。 | |||
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) | |||
:return: | |||
""" | |||
if self.driver.is_distributed(): | |||
if catch_KeyboardInterrupt: | |||
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||
"driver. And we are gonna to set it to False.") | |||
catch_KeyboardInterrupt = False | |||
if catch_KeyboardInterrupt is None: | |||
catch_KeyboardInterrupt = not self.driver.is_distributed() | |||
else: | |||
if self.driver.is_distributed(): | |||
if catch_KeyboardInterrupt: | |||
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||
"driver. And we are gonna to set it to False.") | |||
catch_KeyboardInterrupt = False | |||
self._set_num_eval_batch_per_dl(num_eval_batch_per_dl) | |||
@@ -576,22 +575,6 @@ class Trainer(TrainerEventTrigger): | |||
else: | |||
states["val_filter_state"] = None | |||
# 4. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||
dataloader_args = self.driver.get_dataloader_args(self.dataloader) | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif dataloader_args.sampler: | |||
sampler = dataloader_args.sampler | |||
else: | |||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
states['sampler_states'] = sampler.state_dict() | |||
else: | |||
raise RuntimeError( | |||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
if isinstance(folder, str): | |||
folder = Path(folder) | |||
@@ -599,9 +582,9 @@ class Trainer(TrainerEventTrigger): | |||
if not callable(model_save_fn): | |||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||
rank_zero_call(model_save_fn)(folder) | |||
self.driver.save(folder=folder, states=states, should_save_model=False, **kwargs) | |||
self.driver.save(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs) | |||
else: | |||
self.driver.save(folder=folder, states=states, | |||
self.driver.save(folder=folder, dataloader=self.dataloader, states=states, | |||
only_state_dict=only_state_dict, should_save_model=True, **kwargs) | |||
self.driver.barrier() | |||
@@ -614,9 +597,6 @@ class Trainer(TrainerEventTrigger): | |||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; | |||
注意我们目前不支持单卡到多卡的断点重训; | |||
TODO:注意我们目前不支持 RandomSampler、BucketedSampler 或者 SortedSampler 之间的断点重训; | |||
因此如果用户自己需要使用 BucketedSampler,那么其需要自己在 Trainer 之前初始化 BucketedSampler,然后替换原始 Dataloader 中的 | |||
sampler,不管其是第一次断点重训,还是之后的加载的重新训练; | |||
:param folder: 保存断点重训 states 的文件地址; | |||
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | |||
@@ -625,33 +605,23 @@ class Trainer(TrainerEventTrigger): | |||
self.driver.barrier() | |||
if isinstance(folder, str): | |||
folder = Path(folder) | |||
dataloader = self.dataloader | |||
if not resume_training: | |||
dataloader = None | |||
if model_load_fn is not None: | |||
if not callable(model_load_fn): | |||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||
raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||
rank_zero_call(model_load_fn)(folder) | |||
states = self.driver.load(folder=folder, should_load_model=False, **kwargs) | |||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||
else: | |||
states = self.driver.load(folder=folder, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||
if not resume_training: | |||
return | |||
# 1. 恢复 sampler 的状态; | |||
dataloader_args = self.driver.get_dataloader_args(self.dataloader) | |||
sampler = dataloader_args.sampler | |||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||
if self.driver.is_distributed(): | |||
raise RuntimeError("It is not allowed to use single device checkpoint retraining before but ddp now.") | |||
sampler = ReproducibleBatchSampler( | |||
batch_sampler=sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
) | |||
sampler.load_state_dict(states['sampler_states']) | |||
self.driver.replace_sampler(self.dataloader, sampler) | |||
self.dataloader = states.pop('dataloader') | |||
# 2. validate filter state; | |||
if self.evaluator is not None: | |||
@@ -666,22 +636,16 @@ class Trainer(TrainerEventTrigger): | |||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
if not isinstance(sampler, ReproducibleBatchSampler): | |||
if dataloader_args.drop_last: | |||
self.trainer_state.batch_idx_in_epoch = len(sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||
else: | |||
self.trainer_state.batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||
# sampler 是 batch_sampler; | |||
else: | |||
self.trainer_state.batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
# 5. 恢复所有 callback 的状态; | |||
self.on_load_checkpoint(states["callback_states"]) | |||
self.driver.barrier() | |||
""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 step 函数) 的 """ | |||
""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 batch_step_fn 函数) 的 """ | |||
def train_step(self, batch): | |||
with self.driver.auto_cast(): | |||
@@ -2,7 +2,7 @@ import os | |||
import signal | |||
import sys | |||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union | |||
from abc import ABC | |||
from abc import ABC, abstractmethod | |||
from datetime import datetime | |||
from pathlib import Path | |||
from io import BytesIO | |||
@@ -14,7 +14,6 @@ __all__ = [ | |||
from fastNLP.core.utils import nullcontext | |||
# todo 航总 check 一下哪一些方法需要 @abstractmethod; | |||
class Driver(ABC): | |||
r""" | |||
用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类; | |||
@@ -32,29 +31,33 @@ class Driver(ABC): | |||
# self._consensus_file: Optional[Union[str, Path]] = None | |||
self._pids: Optional[List[int]] = None | |||
@abstractmethod | |||
def setup(self): | |||
r""" | |||
该函数用来初始化训练环境,例如将模型迁移到对应的设备上等; | |||
多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; | |||
""" | |||
def replace_sampler(self, dataloader, dist_sampler: Optional[str], reproducible: bool = False): | |||
def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False): | |||
r""" | |||
因为一些特殊的情况需要替换 dataloader 的 sampler,而每一个 driver 中的该函数会提供该功能;例如在多卡训练的中,我们 | |||
需要将 sampler 替换为 distributed sampler;以及如果用户在 Trainer 中加入了断点重训的 callback,那么我们就需要将 sampler 替换 | |||
为 reproducible sampler; | |||
:param dataloader: 由 trainer 中传入的原始的 dataloader; | |||
:param dist_sampler: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];用于指定使用怎样的 sampler; | |||
目前该参数被定制为分布式训练服务,其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist",否则为 None; | |||
evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
:param reproducible: 用于在 `Trainer` 中指定是否替换为断点重训的 sampler(多卡) 或者 batch_sampler(单卡);如果是单卡的 Driver, | |||
并且该参数为 True,表示当前正在断点重训,那么我们就会使用我们的 `ReproducibleBatchSampler` 来替换 dataloader 原本的 batch_sampler; | |||
如果是多卡的 Driver,那么我们就会用 `RandomSampler` 替换 dataloader 原本的 sampler; | |||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `replace_sampler` function.") | |||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
可以可以加载。 | |||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||
""" | |||
if dist is None and reproducible is False: | |||
return dataloader | |||
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `set_dist_repro_dataloader` " | |||
f"function.") | |||
def set_deterministic_dataloader(self, dataloader): | |||
r""" | |||
@@ -68,7 +71,7 @@ class Driver(ABC): | |||
:param cur_epoch_idx: 当前是第几个 epoch; | |||
""" | |||
@abstractmethod | |||
def train_step(self, batch): | |||
""" | |||
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程; | |||
@@ -103,7 +106,7 @@ class Driver(ABC): | |||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | |||
我们应当提醒用户这一行为; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `predict_step` function.") | |||
raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.") | |||
@property | |||
def model(self): | |||
@@ -234,6 +237,7 @@ class Driver(ABC): | |||
""" | |||
self.optimizers = optimizers | |||
@abstractmethod | |||
def backward(self, loss): | |||
""" | |||
实现深度学习中的反向传播过程; | |||
@@ -242,12 +246,14 @@ class Driver(ABC): | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `backward` function.") | |||
@abstractmethod | |||
def step(self): | |||
r""" | |||
实现深度学习中的参数的优化更新过程,应当直接通过优化器 optimizers 来更新参数; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `step` function.") | |||
@abstractmethod | |||
def zero_grad(self, set_to_none: bool = False): | |||
r""" | |||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | |||
@@ -286,6 +292,7 @@ class Driver(ABC): | |||
def auto_cast(self, auto_cast): | |||
self._auto_cast = auto_cast | |||
@abstractmethod | |||
def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs): | |||
r""" | |||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | |||
@@ -296,6 +303,7 @@ class Driver(ABC): | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `save_model` function.") | |||
@abstractmethod | |||
def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs): | |||
r""" | |||
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | |||
@@ -307,7 +315,8 @@ class Driver(ABC): | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | |||
def save(self, folder, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
@abstractmethod | |||
def save(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
r""" | |||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | |||
@@ -317,12 +326,14 @@ class Driver(ABC): | |||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | |||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的 | |||
传入的值保持一致。 | |||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `save` function.") | |||
def load(self, folder: Union[str, Path], only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||
@abstractmethod | |||
def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||
r""" | |||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | |||
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 | |||
@@ -331,11 +342,22 @@ class Driver(ABC): | |||
:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME | |||
(如果 should_load_model 为True)。 | |||
:param dataloader: 当前给定 dataloader,需要根据 save 的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader' | |||
以及 'batch_idx_in_epoch' 这两个值。 | |||
:param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为 | |||
False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。 | |||
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 | |||
找到对应的模型状态,则报错。 | |||
:return: 需要返回 save 函数输入的 states 内容; | |||
:return: 需要返回 save 函数输入的 states 内容 | |||
'dataloader',返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。 | |||
在保存与当前传入 data sample 数目不一致时报错。 | |||
'batch_idx_in_epoch': int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的 | |||
数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式 | |||
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。 | |||
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此 | |||
只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下 | |||
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size); | |||
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。 | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `load` function.") | |||
@@ -352,6 +374,7 @@ class Driver(ABC): | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") | |||
@abstractmethod | |||
def set_model_mode(self, mode: str): | |||
r""" | |||
设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式; | |||
@@ -378,6 +401,7 @@ class Driver(ABC): | |||
中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数; | |||
""" | |||
@abstractmethod | |||
def move_data_to_device(self, batch): | |||
r""" | |||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||
@@ -399,17 +423,6 @@ class Driver(ABC): | |||
仅在多分布式训练场景中有使用。 | |||
""" | |||
@staticmethod | |||
def get_dataloader_args(dataloader): | |||
""" | |||
用于从 dataloader 中抽取一些属性的值,返回的dataclass中必须包含以下的key: | |||
sampler, batch_sampler, batch_size, drop_last; | |||
:param dataloader: | |||
:return: 返回一个 dataclass,其实例属性应当包括以上的各个属性,并且其名字也应当与这些属性相同,从而方便 trainer 或者其它对象调用; | |||
""" | |||
raise NotImplementedError("Each specific driver should implemented its own `get_dataloader_args` function.") | |||
def is_distributed(self) -> bool: | |||
""" | |||
当前的 driver 实例是否是分布式的; | |||
@@ -70,7 +70,8 @@ class JittorMPIDriver(JittorDriver): | |||
def test_step(self, batch): | |||
return self._test_step(batch) | |||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
pass | |||
def backward(self, loss): | |||
@@ -99,14 +99,15 @@ class JittorSingleDriver(JittorDriver): | |||
def is_distributed(self): | |||
return False | |||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# reproducible 的相关功能暂时没有实现 | |||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
raise NotImplementedError | |||
dataloader.batch_sampler = dist_sample | |||
if isinstance(dist_sampler, ReproducibleIterator): | |||
if isinstance(dist, ReproducibleIterator): | |||
raise NotImplementedError | |||
dataloader.batch_sampler.sampler = dist_sampler | |||
dataloader.batch_sampler.sampler = dist | |||
if reproducible: | |||
raise NotImplementedError | |||
@@ -8,7 +8,6 @@ from .utils import ( | |||
_FleetWrappingModel, | |||
ForwardState, | |||
_MODE_PARAMETER, | |||
get_host_name_ip, | |||
get_device_from_visible, | |||
reset_seed, | |||
) | |||
@@ -81,9 +80,9 @@ class PaddleFleetDriver(PaddleDriver): | |||
# 如果用户自己在外面初始化了并行模型; | |||
self.outside_fleet = False | |||
# 检测 paddle 分布式的环境变量 | |||
if parallel_helper._is_parallel_ctx_initialized(): | |||
# 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | |||
if parallel_helper._is_parallel_ctx_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ | |||
"fastnlp_paddle_launch_not_fleet" not in os.environ: | |||
# 如果用户自己在外面初始化了 Fleet,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | |||
if not isinstance(model, DataParallel): | |||
raise RuntimeError( | |||
"It is not allowed to input a normal model instead of `paddle.DataParallel` when" | |||
@@ -125,11 +124,11 @@ class PaddleFleetDriver(PaddleDriver): | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | |||
self._data_device = kwargs.get("_data_device", None) | |||
self._data_device = kwargs.get("data_device", None) | |||
if self._data_device is not None: | |||
if isinstance(self._data_device, int): | |||
if self._data_device < 0: | |||
raise ValueError("Parameter `_data_device` can not be smaller than 0.") | |||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | |||
_could_use_device_num = paddle.device.cuda.device_count() | |||
if self._data_device >= _could_use_device_num: | |||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||
@@ -140,18 +139,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
logger.warning("`Parameter data_device` is not equal to paddle.deivce.get_device(), " | |||
"please keep them equal to avoid some potential bugs.") | |||
if not self.outside_fleet and parallel_device is None: | |||
raise ValueError("Parameter `parallel_device` can not be None when using `PaddleFleetDriver`. This error is caused " | |||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||
# 可能需要放在参数里 | |||
self.strategy = kwargs.get("strategy", fleet.DistributedStrategy()) | |||
self.is_collective = kwargs.get("is_collective", True) | |||
if not self.is_collective: | |||
raise NotImplementedError("FastNLP dose not support `parameters server` for distributed training now.") | |||
self.role_maker = kwargs.get("role_maker", None) | |||
self._master_port = None | |||
self.world_size = None | |||
self.global_rank = 0 | |||
self._configured = False # 防止重复调用 configure_ddp() 函数使用 | |||
@@ -159,7 +146,11 @@ class PaddleFleetDriver(PaddleDriver): | |||
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | |||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | |||
# TODO 对这些参数的检查 | |||
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | |||
self.is_collective = self._fleet_kwargs.get("is_collective", True) | |||
if not self.is_collective: | |||
raise NotImplementedError("FastNLP only support `collective` for distributed training now.") | |||
self.role_maker = self._fleet_kwargs.get("role_maker", None) | |||
if self.local_rank == 0 and not is_in_paddle_dist(): | |||
# 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 | |||
@@ -193,14 +184,16 @@ class PaddleFleetDriver(PaddleDriver): | |||
self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM")) | |||
self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID")) | |||
reset_seed() | |||
logger.warning(f"\nworld size, global rank: {self.world_size}, {self.global_rank}\n") | |||
fleet.init(self.role_maker, self.is_collective, self.strategy) | |||
logger.info(f"\nworld size, global rank: {self.world_size}, {self.global_rank}\n") | |||
if not parallel_helper._is_parallel_ctx_initialized(): | |||
fleet.init(self.role_maker, self.is_collective, self.strategy) | |||
os.environ["fastnlp_paddle_launch_not_fleet"] = "yes" | |||
else: | |||
# 在用户只使用了一个分布式 trainer 的情况下 | |||
# 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False | |||
# parallel_device 是 list, | |||
# if self.local_rank == 0 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
if not parallel_helper._is_parallel_ctx_initialized(): | |||
# 没有初始化分布式环境,且是主进程 | |||
self.init_fleet_and_set() | |||
@@ -212,11 +205,15 @@ class PaddleFleetDriver(PaddleDriver): | |||
if sorted(pre_gpus) != sorted(self.parallel_device): | |||
raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not" | |||
"allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.") | |||
self.world_size = dist.get_world_size() | |||
self.global_rank = dist.get_rank() | |||
if not self.outside_fleet: | |||
# self.model.to(self.model_device) | |||
self.configure_fleet() | |||
self.barrier() | |||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | |||
# TODO 不用.to会怎么样? | |||
self._pids = [] | |||
@@ -238,10 +235,10 @@ class PaddleFleetDriver(PaddleDriver): | |||
""" | |||
if self.local_rank == 0: | |||
# 是 rank0 的话,则拉起其它子进程 | |||
print("in launcher") | |||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | |||
launcher.launch() | |||
# 设置参数和初始化分布式环境 | |||
reset_seed() | |||
fleet.init(self.role_maker, self.is_collective, self.strategy) | |||
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) | |||
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) | |||
@@ -256,6 +253,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | |||
根据 paddle 设置的环境变量来获得各种属性 | |||
""" | |||
print("set_from_env") | |||
self.world_size = dist.get_world_size() | |||
self.global_rank = dist.get_rank() | |||
@@ -297,8 +295,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
@property | |||
def model_device(self): | |||
# 我认为这里的两个 device 应该返回真实值,对 CUDA_VISIBLDE_DEIVCES的转换应该在相应的 to 函数完成 | |||
# 否则会造成用户的困惑 | |||
return self._model_device | |||
@property | |||
@@ -316,13 +312,14 @@ class PaddleFleetDriver(PaddleDriver): | |||
def test_step(self, batch): | |||
return self._test_step(batch) | |||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
if isinstance(dist_sampler, ReproducibleIterator): | |||
dataloader.batch_sampler.sampler = dist_sampler | |||
if isinstance(dist, ReproducibleIterator): | |||
dataloader.batch_sampler.sampler = dist | |||
return dataloader | |||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 | |||
@@ -334,14 +331,14 @@ class PaddleFleetDriver(PaddleDriver): | |||
shuffle = dataloader.batch_sampler.shuffle | |||
# trainer, evaluator | |||
if dist_sampler is None: | |||
if dist is None: | |||
if reproducible: | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||
"control.") | |||
else: | |||
return dataloader | |||
# trainer | |||
elif dist_sampler == "dist": | |||
elif dist == "dist": | |||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
dataloader.batch_sampler.sampler.set_distributed( | |||
@@ -364,7 +361,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
dataloader.batch_sampler.sampler = sampler | |||
return dataloader | |||
# evaluator | |||
elif dist_sampler == "unrepeatdist": | |||
elif dist == "unrepeatdist": | |||
sampler = UnrepeatedDistributedSampler( | |||
dataset=dataloader.dataset, | |||
shuffle=shuffle, | |||
@@ -408,9 +405,8 @@ class PaddleFleetDriver(PaddleDriver): | |||
def move_data_to_device(self, batch: 'paddle.Tensor'): | |||
device = self.data_device | |||
# 因为设置了CUDA_VISIBLE_DEVICES,在子进程中可能会引起错误 | |||
if FASTNLP_DISTRIBUTED_CHECK in os.environ: | |||
device = get_device_from_visible(device) | |||
# 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 | |||
device = get_device_from_visible(device) | |||
return paddle_move_data_to_device(batch, device) | |||
@staticmethod | |||
@@ -7,7 +7,7 @@ from .single_device import PaddleSingleDriver | |||
from .fleet import PaddleFleetDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK | |||
from fastNLP.core.utils import is_in_paddle_launch_dist | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
@@ -26,13 +26,14 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
:return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | |||
先后 driver 的次序的正确问题); | |||
""" | |||
if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_RANK_IN_NODE" in os.environ and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
if is_in_paddle_launch_dist(): | |||
if device is not None: | |||
logger.warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | |||
"up your script. And we will directly get the local device via " | |||
"`f'gpu:{os.environ['FLAGS_selected_gpus']}')`.") | |||
device = [int(g) for g in os.environ["FLAGS_selected_gpus"].split(",")] | |||
return PaddleFleetDriver(model, f"gpu:{os.environ['PADDLE_RANK_IN_NODE']}", True, **kwargs) | |||
"and `os.environ['CUDA_VISIBLE_DEVICES']``.") | |||
device = [int(g) for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] | |||
# TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int | |||
return PaddleFleetDriver(model, device[0], True, **kwargs) | |||
if driver not in {"paddle", "fleet"}: | |||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | |||
@@ -3,6 +3,7 @@ from typing import Optional, Dict, Union | |||
from .paddle_driver import PaddleDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | |||
from fastNLP.core.utils import ( | |||
auto_param_call, | |||
get_paddle_gpu_str, | |||
@@ -92,7 +93,12 @@ class PaddleSingleDriver(PaddleDriver): | |||
self._test_signature_fn = model.forward | |||
def setup(self): | |||
os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device)) | |||
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] | |||
device_id = get_paddle_device_id(self.model_device) | |||
if user_visible_devices is not None and user_visible_devices != "": | |||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||
device_id = user_visible_devices.split(",")[device_id] | |||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||
paddle.device.set_device("gpu:0") | |||
self.model.to("gpu:0") | |||
@@ -133,15 +139,16 @@ class PaddleSingleDriver(PaddleDriver): | |||
""" | |||
return paddle_move_data_to_device(batch, "gpu:0") | |||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# 暂时不支持IteratorDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||
dataloader.batch_sampler = dist_sampler | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dataloader.batch_sampler = dist | |||
return dataloader | |||
if isinstance(dist_sampler, ReproducibleIterator): | |||
dataloader.batch_sampler.sampler = dist_sampler | |||
if isinstance(dist, ReproducibleIterator): | |||
dataloader.batch_sampler.sampler = dist | |||
return dataloader | |||
if reproducible: | |||
@@ -271,10 +271,10 @@ def get_device_from_visible(device: Union[str, int]): | |||
return idx | |||
else: | |||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | |||
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||
if user_visiblde_devices is not None and user_visiblde_devices != "": | |||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||
if user_visible_devices is not None and user_visible_devices != "": | |||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||
idx = user_visiblde_devices.split(",")[idx] | |||
idx = user_visible_devices.split(",")[idx] | |||
else: | |||
idx = str(idx) | |||
@@ -445,21 +445,22 @@ class TorchDDPDriver(TorchDriver): | |||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
return self._test_step(batch) | |||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||
if isinstance(dist_sampler, ReproducibleIterator): | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
if isinstance(dist, ReproducibleIterator): | |||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
dist_sampler = re_instantiate_sampler(dist_sampler) | |||
return replace_sampler(dataloader, dist_sampler) | |||
dist = re_instantiate_sampler(dist) | |||
return replace_sampler(dataloader, dist) | |||
# trainer, evaluator | |||
if dist_sampler is None: | |||
if dist is None: | |||
if reproducible: | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||
"control.") | |||
else: | |||
return dataloader | |||
# trainer | |||
elif dist_sampler == "dist": | |||
elif dist == "dist": | |||
args = self.get_dataloader_args(dataloader) | |||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
if isinstance(args.sampler, ReproducibleIterator): | |||
@@ -485,7 +486,7 @@ class TorchDDPDriver(TorchDriver): | |||
return replace_sampler(dataloader, sampler) | |||
# evaluator | |||
elif dist_sampler == "unrepeatdist": | |||
elif dist == "unrepeatdist": | |||
args = self.get_dataloader_args(dataloader) | |||
sampler = UnrepeatedDistributedSampler( | |||
dataset=args.dataset, | |||
@@ -130,12 +130,12 @@ class TorchSingleDriver(TorchDriver): | |||
else: | |||
return self._test_step(batch) | |||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
reproducible: bool = False): | |||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||
return replace_batch_sampler(dataloader, dist_sampler) | |||
elif isinstance(dist_sampler, ReproducibleIterator): | |||
return replace_sampler(dataloader, dist_sampler) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
return replace_batch_sampler(dataloader, dist) | |||
elif isinstance(dist, ReproducibleIterator): | |||
return replace_sampler(dataloader, dist) | |||
if reproducible: | |||
args = self.get_dataloader_args(dataloader) | |||
@@ -50,6 +50,14 @@ class ReproducibleIterator: | |||
class RandomSampler(ReproducibleIterator): | |||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||
""" | |||
:param dataset: 实现了 __len__ 方法的数据容器 | |||
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 | |||
:param seed: 随机数种子。 | |||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | |||
""" | |||
self.dataset = dataset | |||
self.shuffle = shuffle | |||
@@ -208,6 +216,15 @@ class RandomSampler(ReproducibleIterator): | |||
class ReproducibleBatchSampler: | |||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | |||
""" | |||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | |||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 | |||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | |||
:param batch_size: 每个 batch 的大小是多少。 | |||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||
:param kwargs: fastNLP 内部使用。 | |||
""" | |||
self.batch_sampler = batch_sampler | |||
self.batch_size = batch_size | |||
self.drop_last = drop_last | |||
@@ -13,7 +13,7 @@ import re | |||
from typing import Any, Optional, Union | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_BACKEND_LAUNCH | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
@@ -94,6 +94,4 @@ def is_in_paddle_launch_dist(): | |||
""" | |||
判断是否处于 launch 启动的分布式进程中 | |||
""" | |||
return 'PADDLE_RANK_IN_NODE' in os.environ and \ | |||
'FLAGS_selected_gpus' in os.environ and \ | |||
FASTNLP_DISTRIBUTED_CHECK not in os.environ | |||
return FASTNLP_BACKEND_LAUNCH in os.environ |
@@ -3,4 +3,4 @@ prettytable>=0.7.2 | |||
requests | |||
regex!=2019.12.17 | |||
rich==11.2.0 | |||
# fsspec[http]>=2021.05.0, !=2021.06.0 | |||
packaging |
@@ -1,12 +1,9 @@ | |||
import pytest | |||
import sys | |||
import os | |||
import numpy as np | |||
from fastNLP.envs.set_backend import set_env | |||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||
set_env_on_import_paddle() | |||
set_env("paddle") | |||
import paddle | |||
import paddle.distributed as dist | |||
from paddle.io import DataLoader | |||
@@ -54,6 +51,7 @@ def test_move_data_to_device(): | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_is_distributed(): | |||
print(os.getenv("CUDA_VISIBLE_DEVICES")) | |||
@@ -64,6 +62,7 @@ def test_is_distributed(): | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=[0,1], | |||
output_from_new_proc='all' | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
# 区分launch和子进程setup的时候 | |||
@@ -79,6 +78,7 @@ def test_is_distributed(): | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_get_no_sync_context(): | |||
""" | |||
@@ -105,6 +105,7 @@ def test_get_no_sync_context(): | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_is_global_zero(): | |||
try: | |||
@@ -128,6 +129,8 @@ def test_is_global_zero(): | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_unwrap_model(): | |||
try: | |||
@@ -204,7 +207,7 @@ def test_replace_sampler(dist_sampler, reproducible): | |||
else: | |||
driver.setup() | |||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||
driver.replace_sampler(dataloader, dist_sampler, reproducible) | |||
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||
finally: | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@@ -243,7 +246,7 @@ class SingleMachineMultiGPUTrainingTestCase: | |||
parallel_device=gpus, | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
dataloader = driver.replace_sampler(dataloader) | |||
dataloader = driver.set_dist_repro_dataloader(dataloader, ) | |||
driver.setup() | |||
# 检查model_device | |||
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") | |||
@@ -164,4 +164,4 @@ class TestSingleDeviceFunction: | |||
""" | |||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||
res = self.driver.replace_sampler(dataloader, dist_sampler, reproducible) | |||
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) |
@@ -33,11 +33,15 @@ def check_replace_sampler(driver): | |||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler | |||
# reproducible 是 True 和 False | |||
# 需要 check 返回的 sampler 和 dataloader 都不同了 | |||
assert driver.is_distributed() is False, "This test only for non distributed sampler." | |||
ds = SequenceDataSet(10) | |||
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) | |||
dl1 = driver.replace_sampler(dataloader, dist_sampler='dist', reproducible=True) | |||
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='dist', reproducible=True) | |||
assert not (dl1.sampler is dataloader.sampler), "The sampler should not the same one." | |||
assert not (dl1 is dataloader), "The dataloader should not the same one." | |||
# 迭代两个 batch | |||
already_seen_idx = set() | |||
@@ -68,6 +72,22 @@ def check_replace_sampler(driver): | |||
assert b not in already_seen_idx | |||
assert b in left_idxes | |||
# 需要 check 替换为 unrepeatdist 的时候没有问题:(1) 不会多pad;(2)所有卡互相不重复 | |||
ds = SequenceDataSet(11) | |||
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) | |||
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='unrepeatdist', reproducible=True) | |||
world_size = 3 | |||
indices = [] | |||
for i in range(world_size): | |||
dl1.sampler.set_distributed(num_replicas=world_size, rank=i) | |||
for idx, batch in dl1: | |||
indices.extend(batch) | |||
assert len(indices)==len(ds) # 应该没有任何重复 | |||
assert len(set(indices))==len(indices) # 应该全是不一样的indice | |||
@@ -15,7 +15,7 @@ class PaddleNormalDataset(Dataset): | |||
return self._data[item] | |||
class PaddleRandomDataset(Dataset): | |||
class PaddleRandomMaxDataset(Dataset): | |||
def __init__(self, num_samples, num_features): | |||
self.x = paddle.randn((num_samples, num_features)) | |||
self.y = self.x.argmax(axis=-1) | |||
@@ -25,23 +25,3 @@ class PaddleRandomDataset(Dataset): | |||
def __getitem__(self, item): | |||
return {"x": self.x[item], "y": self.y[item]} | |||
class PaddleDataset_MNIST(Dataset): | |||
def __init__(self, mode="train"): | |||
self.dataset = [ | |||
( | |||
np.array(img).astype('float32').reshape(-1), | |||
label | |||
) for img, label in paddle.vision.datasets.MNIST(mode=mode) | |||
] | |||
def __getitem__(self, idx): | |||
return {"x": self.dataset[idx][0], "y": self.dataset[idx][1]} | |||
def __len__(self): | |||
return len(self.dataset) | |||
@@ -1,12 +1,12 @@ | |||
import paddle | |||
import paddle.nn as nn | |||
class PaddleNormalModel_Classification(paddle.nn.Layer): | |||
class PaddleNormalModel_Classification_1(paddle.nn.Layer): | |||
""" | |||
基础的paddle分类模型 | |||
""" | |||
def __init__(self, num_labels, feature_dimension): | |||
super(PaddleNormalModel_Classification, self).__init__() | |||
super(PaddleNormalModel_Classification_1, self).__init__() | |||
self.num_labels = num_labels | |||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||
@@ -30,3 +30,26 @@ class PaddleNormalModel_Classification(paddle.nn.Layer): | |||
x = self(x) | |||
return {"pred": x, "target": y.reshape((-1,))} | |||
class PaddleNormalModel_Classification_2(paddle.nn.Layer): | |||
""" | |||
基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||
""" | |||
def __init__(self, num_labels, feature_dimension): | |||
super(PaddleNormalModel_Classification_2, self).__init__() | |||
self.num_labels = num_labels | |||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||
self.ac1 = nn.ReLU() | |||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||
self.ac2 = nn.ReLU() | |||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||
self.loss_fn = nn.CrossEntropyLoss() | |||
def forward(self, x, y): | |||
x = self.ac1(self.linear1(x)) | |||
x = self.ac2(self.linear2(x)) | |||
x = self.output(x) | |||
loss = self.loss_fn(x, y) | |||
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} |