@@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback): | |||||
real_monitor=self._real_monitor, | real_monitor=self._real_monitor, | ||||
res=results) | res=results) | ||||
if (monitor_value < self.monitor_value and self.larger_better is False) or \ | if (monitor_value < self.monitor_value and self.larger_better is False) or \ | ||||
(monitor_value > self.monitor_value and self.larger_better): | |||||
(monitor_value > self.monitor_value and self.larger_better): | |||||
self.monitor_value = monitor_value | self.monitor_value = monitor_value | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
@@ -124,11 +124,7 @@ class Evaluator: | |||||
self.dataloaders = {} | self.dataloaders = {} | ||||
for name, dl in dataloaders.items(): # 替换为正确的 sampler | for name, dl in dataloaders.items(): # 替换为正确的 sampler | ||||
dl = self.driver.replace_sampler( | |||||
dataloader=dl, | |||||
dist_sampler=self._dist_sampler, | |||||
reproducible=False | |||||
) | |||||
dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False) | |||||
self.dataloaders[name] = dl | self.dataloaders[name] = dl | ||||
self.progress_bar = kwargs.get('progress_bar', 'auto') | self.progress_bar = kwargs.get('progress_bar', 'auto') | ||||
@@ -250,11 +250,8 @@ class Trainer(TrainerEventTrigger): | |||||
self.dataloader = self.train_dataloader | self.dataloader = self.train_dataloader | ||||
self.driver.set_deterministic_dataloader(self.dataloader) | self.driver.set_deterministic_dataloader(self.dataloader) | ||||
self.dataloader = self.driver.replace_sampler( | |||||
dataloader=self.train_dataloader, | |||||
dist_sampler=_dist_sampler, | |||||
reproducible=self.callback_manager.has_trainer_chechpoint | |||||
) | |||||
self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | |||||
reproducible=self.callback_manager.has_trainer_chechpoint) | |||||
self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | ||||
self.on_after_trainer_initialized(self.driver) | self.on_after_trainer_initialized(self.driver) | ||||
@@ -263,7 +260,7 @@ class Trainer(TrainerEventTrigger): | |||||
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, | ||||
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, | ||||
catch_KeyboardInterrupt=True): | |||||
catch_KeyboardInterrupt=None): | |||||
""" | """ | ||||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint | 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint | ||||
去保存断点重训的文件; | 去保存断点重训的文件; | ||||
@@ -273,15 +270,17 @@ class Trainer(TrainerEventTrigger): | |||||
:param resume_from: 从哪个路径下恢复 trainer 的状态 | :param resume_from: 从哪个路径下恢复 trainer 的状态 | ||||
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 | :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 | ||||
:param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 | :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 | ||||
行。 | |||||
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.driver.is_distributed(): | |||||
if catch_KeyboardInterrupt: | |||||
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||||
"driver. And we are gonna to set it to False.") | |||||
catch_KeyboardInterrupt = False | |||||
if catch_KeyboardInterrupt is None: | |||||
catch_KeyboardInterrupt = not self.driver.is_distributed() | |||||
else: | |||||
if self.driver.is_distributed(): | |||||
if catch_KeyboardInterrupt: | |||||
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||||
"driver. And we are gonna to set it to False.") | |||||
catch_KeyboardInterrupt = False | |||||
self._set_num_eval_batch_per_dl(num_eval_batch_per_dl) | self._set_num_eval_batch_per_dl(num_eval_batch_per_dl) | ||||
@@ -576,22 +575,6 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
states["val_filter_state"] = None | states["val_filter_state"] = None | ||||
# 4. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||||
dataloader_args = self.driver.get_dataloader_args(self.dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
states['sampler_states'] = sampler.state_dict() | |||||
else: | |||||
raise RuntimeError( | |||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
if isinstance(folder, str): | if isinstance(folder, str): | ||||
folder = Path(folder) | folder = Path(folder) | ||||
@@ -599,9 +582,9 @@ class Trainer(TrainerEventTrigger): | |||||
if not callable(model_save_fn): | if not callable(model_save_fn): | ||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | ||||
rank_zero_call(model_save_fn)(folder) | rank_zero_call(model_save_fn)(folder) | ||||
self.driver.save(folder=folder, states=states, should_save_model=False, **kwargs) | |||||
self.driver.save(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs) | |||||
else: | else: | ||||
self.driver.save(folder=folder, states=states, | |||||
self.driver.save(folder=folder, dataloader=self.dataloader, states=states, | |||||
only_state_dict=only_state_dict, should_save_model=True, **kwargs) | only_state_dict=only_state_dict, should_save_model=True, **kwargs) | ||||
self.driver.barrier() | self.driver.barrier() | ||||
@@ -614,9 +597,6 @@ class Trainer(TrainerEventTrigger): | |||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; | 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; | ||||
注意我们目前不支持单卡到多卡的断点重训; | 注意我们目前不支持单卡到多卡的断点重训; | ||||
TODO:注意我们目前不支持 RandomSampler、BucketedSampler 或者 SortedSampler 之间的断点重训; | |||||
因此如果用户自己需要使用 BucketedSampler,那么其需要自己在 Trainer 之前初始化 BucketedSampler,然后替换原始 Dataloader 中的 | |||||
sampler,不管其是第一次断点重训,还是之后的加载的重新训练; | |||||
:param folder: 保存断点重训 states 的文件地址; | :param folder: 保存断点重训 states 的文件地址; | ||||
:param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 | ||||
@@ -625,33 +605,23 @@ class Trainer(TrainerEventTrigger): | |||||
self.driver.barrier() | self.driver.barrier() | ||||
if isinstance(folder, str): | if isinstance(folder, str): | ||||
folder = Path(folder) | folder = Path(folder) | ||||
dataloader = self.dataloader | |||||
if not resume_training: | |||||
dataloader = None | |||||
if model_load_fn is not None: | if model_load_fn is not None: | ||||
if not callable(model_load_fn): | if not callable(model_load_fn): | ||||
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") | |||||
raise ValueError("Parameter `model_save_fn` should be `Callable`.") | |||||
rank_zero_call(model_load_fn)(folder) | rank_zero_call(model_load_fn)(folder) | ||||
states = self.driver.load(folder=folder, should_load_model=False, **kwargs) | |||||
states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) | |||||
else: | else: | ||||
states = self.driver.load(folder=folder, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||||
states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) | |||||
if not resume_training: | if not resume_training: | ||||
return | return | ||||
# 1. 恢复 sampler 的状态; | |||||
dataloader_args = self.driver.get_dataloader_args(self.dataloader) | |||||
sampler = dataloader_args.sampler | |||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||||
if self.driver.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use single device checkpoint retraining before but ddp now.") | |||||
sampler = ReproducibleBatchSampler( | |||||
batch_sampler=sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(states['sampler_states']) | |||||
self.driver.replace_sampler(self.dataloader, sampler) | |||||
self.dataloader = states.pop('dataloader') | |||||
# 2. validate filter state; | # 2. validate filter state; | ||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
@@ -666,22 +636,16 @@ class Trainer(TrainerEventTrigger): | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | # 4. 修改 trainer_state.batch_idx_in_epoch | ||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | ||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
self.trainer_state.batch_idx_in_epoch = len(sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
self.trainer_state.batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
self.trainer_state.batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||||
# 5. 恢复所有 callback 的状态; | # 5. 恢复所有 callback 的状态; | ||||
self.on_load_checkpoint(states["callback_states"]) | self.on_load_checkpoint(states["callback_states"]) | ||||
self.driver.barrier() | self.driver.barrier() | ||||
""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 step 函数) 的 """ | |||||
""" 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 batch_step_fn 函数) 的 """ | |||||
def train_step(self, batch): | def train_step(self, batch): | ||||
with self.driver.auto_cast(): | with self.driver.auto_cast(): | ||||
@@ -2,7 +2,7 @@ import os | |||||
import signal | import signal | ||||
import sys | import sys | ||||
from typing import Any, Sequence, List, Optional, Callable, Dict, Union | from typing import Any, Sequence, List, Optional, Callable, Dict, Union | ||||
from abc import ABC | |||||
from abc import ABC, abstractmethod | |||||
from datetime import datetime | from datetime import datetime | ||||
from pathlib import Path | from pathlib import Path | ||||
from io import BytesIO | from io import BytesIO | ||||
@@ -14,7 +14,6 @@ __all__ = [ | |||||
from fastNLP.core.utils import nullcontext | from fastNLP.core.utils import nullcontext | ||||
# todo 航总 check 一下哪一些方法需要 @abstractmethod; | |||||
class Driver(ABC): | class Driver(ABC): | ||||
r""" | r""" | ||||
用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类; | 用来初始化 `Driver` 的基类,所有定制的 `driver` 都需要继承此类; | ||||
@@ -32,29 +31,33 @@ class Driver(ABC): | |||||
# self._consensus_file: Optional[Union[str, Path]] = None | # self._consensus_file: Optional[Union[str, Path]] = None | ||||
self._pids: Optional[List[int]] = None | self._pids: Optional[List[int]] = None | ||||
@abstractmethod | |||||
def setup(self): | def setup(self): | ||||
r""" | r""" | ||||
该函数用来初始化训练环境,例如将模型迁移到对应的设备上等; | 该函数用来初始化训练环境,例如将模型迁移到对应的设备上等; | ||||
多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; | 多卡的 driver 的该函数要更为复杂一些,例如其可能需要开启多进程之间的通信环境,以及设置一些环境变量和其余所需要的变量值; | ||||
""" | """ | ||||
def replace_sampler(self, dataloader, dist_sampler: Optional[str], reproducible: bool = False): | |||||
def set_dist_repro_dataloader(self, dataloader, dist=None, reproducible: bool = False): | |||||
r""" | r""" | ||||
因为一些特殊的情况需要替换 dataloader 的 sampler,而每一个 driver 中的该函数会提供该功能;例如在多卡训练的中,我们 | |||||
需要将 sampler 替换为 distributed sampler;以及如果用户在 Trainer 中加入了断点重训的 callback,那么我们就需要将 sampler 替换 | |||||
为 reproducible sampler; | |||||
:param dataloader: 由 trainer 中传入的原始的 dataloader; | |||||
:param dist_sampler: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];用于指定使用怎样的 sampler; | |||||
目前该参数被定制为分布式训练服务,其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist",否则为 None; | |||||
evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||||
:param reproducible: 用于在 `Trainer` 中指定是否替换为断点重训的 sampler(多卡) 或者 batch_sampler(单卡);如果是单卡的 Driver, | |||||
并且该参数为 True,表示当前正在断点重训,那么我们就会使用我们的 `ReproducibleBatchSampler` 来替换 dataloader 原本的 batch_sampler; | |||||
如果是多卡的 Driver,那么我们就会用 `RandomSampler` 替换 dataloader 原本的 sampler; | |||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `replace_sampler` function.") | |||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||||
:param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||||
:param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||||
切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||||
可以可以加载。 | |||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||||
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||||
""" | |||||
if dist is None and reproducible is False: | |||||
return dataloader | |||||
raise NotImplementedError(f"Driver:{self.__class__.__name__} does not support `set_dist_repro_dataloader` " | |||||
f"function.") | |||||
def set_deterministic_dataloader(self, dataloader): | def set_deterministic_dataloader(self, dataloader): | ||||
r""" | r""" | ||||
@@ -68,7 +71,7 @@ class Driver(ABC): | |||||
:param cur_epoch_idx: 当前是第几个 epoch; | :param cur_epoch_idx: 当前是第几个 epoch; | ||||
""" | """ | ||||
@abstractmethod | |||||
def train_step(self, batch): | def train_step(self, batch): | ||||
""" | """ | ||||
通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程; | 通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程; | ||||
@@ -103,7 +106,7 @@ class Driver(ABC): | |||||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | 因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | ||||
我们应当提醒用户这一行为; | 我们应当提醒用户这一行为; | ||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `predict_step` function.") | |||||
raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.") | |||||
@property | @property | ||||
def model(self): | def model(self): | ||||
@@ -234,6 +237,7 @@ class Driver(ABC): | |||||
""" | """ | ||||
self.optimizers = optimizers | self.optimizers = optimizers | ||||
@abstractmethod | |||||
def backward(self, loss): | def backward(self, loss): | ||||
""" | """ | ||||
实现深度学习中的反向传播过程; | 实现深度学习中的反向传播过程; | ||||
@@ -242,12 +246,14 @@ class Driver(ABC): | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `backward` function.") | raise NotImplementedError("Each specific driver should implemented its own `backward` function.") | ||||
@abstractmethod | |||||
def step(self): | def step(self): | ||||
r""" | r""" | ||||
实现深度学习中的参数的优化更新过程,应当直接通过优化器 optimizers 来更新参数; | 实现深度学习中的参数的优化更新过程,应当直接通过优化器 optimizers 来更新参数; | ||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `step` function.") | raise NotImplementedError("Each specific driver should implemented its own `step` function.") | ||||
@abstractmethod | |||||
def zero_grad(self, set_to_none: bool = False): | def zero_grad(self, set_to_none: bool = False): | ||||
r""" | r""" | ||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | ||||
@@ -286,6 +292,7 @@ class Driver(ABC): | |||||
def auto_cast(self, auto_cast): | def auto_cast(self, auto_cast): | ||||
self._auto_cast = auto_cast | self._auto_cast = auto_cast | ||||
@abstractmethod | |||||
def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs): | def save_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = True, **kwargs): | ||||
r""" | r""" | ||||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | 保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | ||||
@@ -296,6 +303,7 @@ class Driver(ABC): | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `save_model` function.") | raise NotImplementedError("Each specific driver should implemented its own `save_model` function.") | ||||
@abstractmethod | |||||
def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs): | def load_model(self, filepath: Union[str, Path, BytesIO], only_state_dict: bool = False, **kwargs): | ||||
r""" | r""" | ||||
加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | 加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | ||||
@@ -307,7 +315,8 @@ class Driver(ABC): | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | raise NotImplementedError("Each specific driver should implemented its own `load_model` function.") | ||||
def save(self, folder, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
@abstractmethod | |||||
def save(self, folder, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
r""" | r""" | ||||
断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | 断点重训的保存函数,该函数会负责保存模型和 optimizers, fp16 的 state_dict;以及模型的保存(若 should_save_model 为 True) | ||||
@@ -317,12 +326,14 @@ class Driver(ABC): | |||||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | ||||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的 | 该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load() 返回的值与这里的 | ||||
传入的值保持一致。 | 传入的值保持一致。 | ||||
:param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||||
:param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | :param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | ||||
:param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | :param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | ||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `save` function.") | raise NotImplementedError("Each specific driver should implemented its own `save` function.") | ||||
def load(self, folder: Union[str, Path], only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
@abstractmethod | |||||
def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
r""" | r""" | ||||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | 断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | ||||
其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 | 其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 | ||||
@@ -331,11 +342,22 @@ class Driver(ABC): | |||||
:param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME | :param folder: 读取该 folder 下的 FASTNLP_CHECKPOINT_FILENAME 文件与 FASTNLP_MODEL_FILENAME | ||||
(如果 should_load_model 为True)。 | (如果 should_load_model 为True)。 | ||||
:param dataloader: 当前给定 dataloader,需要根据 save 的 dataloader 状态合理设置。若该值为 None ,是不需要返回 'dataloader' | |||||
以及 'batch_idx_in_epoch' 这两个值。 | |||||
:param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为 | :param only_state_dict: 读取的,当 should_save_model 为 False ,该参数无效。如果为 True ,说明保存的内容为权重;如果为 | ||||
False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。 | False 说明保存的是模型,但也是通过当前 Driver 的模型去加载保存的模型的权重,而不是使用保存的模型替换当前模型。 | ||||
:param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 | :param should_load_model: 是否应该加载模型,如果为False,Driver 将不负责加载模型。若该参数为 True ,但在保存的状态中没有 | ||||
找到对应的模型状态,则报错。 | 找到对应的模型状态,则报错。 | ||||
:return: 需要返回 save 函数输入的 states 内容; | |||||
:return: 需要返回 save 函数输入的 states 内容 | |||||
'dataloader',返回的是根据传入的 dataloader 与 保存的状态一起设置为合理的状态,可以返回的对象与传入的dataloader是同一个。 | |||||
在保存与当前传入 data sample 数目不一致时报错。 | |||||
'batch_idx_in_epoch': int 类型的数据,表明当前 epoch 进行到了进行到了第几个 batch 了。 请注意,该值不能是只能通过保存的 | |||||
数据中读取的,因为前后两次运行 batch_size 可能由变化。该数字的原则应该符合以下等式 | |||||
'返回 dataloader 还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数' 。 | |||||
由于 '返回 dataloader 还会产生的batch数量' 这个数量在 batch_size 与 drop_last 参数给定的情况下,无法改变,因此 | |||||
只能通过调整 batch_idx_in_epoch 这个值来使等式成立。一个简单的计算原则如下 | |||||
当drop_last为True,等同于 floor(sample_in_this_rank/batch_size) - floor(num_left_samples/batch_size); | |||||
当drop_last为False,等同于 ceil(sample_in_this_rank/batch_size) - ceil(num_left_samples/batch_size)。 | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `load` function.") | raise NotImplementedError("Each specific driver should implemented its own `load` function.") | ||||
@@ -352,6 +374,7 @@ class Driver(ABC): | |||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") | raise NotImplementedError("Each specific driver should implemented its own `tensor_to_numeric` function.") | ||||
@abstractmethod | |||||
def set_model_mode(self, mode: str): | def set_model_mode(self, mode: str): | ||||
r""" | r""" | ||||
设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式; | 设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式; | ||||
@@ -378,6 +401,7 @@ class Driver(ABC): | |||||
中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数; | 中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数; | ||||
""" | """ | ||||
@abstractmethod | |||||
def move_data_to_device(self, batch): | def move_data_to_device(self, batch): | ||||
r""" | r""" | ||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | ||||
@@ -399,17 +423,6 @@ class Driver(ABC): | |||||
仅在多分布式训练场景中有使用。 | 仅在多分布式训练场景中有使用。 | ||||
""" | """ | ||||
@staticmethod | |||||
def get_dataloader_args(dataloader): | |||||
""" | |||||
用于从 dataloader 中抽取一些属性的值,返回的dataclass中必须包含以下的key: | |||||
sampler, batch_sampler, batch_size, drop_last; | |||||
:param dataloader: | |||||
:return: 返回一个 dataclass,其实例属性应当包括以上的各个属性,并且其名字也应当与这些属性相同,从而方便 trainer 或者其它对象调用; | |||||
""" | |||||
raise NotImplementedError("Each specific driver should implemented its own `get_dataloader_args` function.") | |||||
def is_distributed(self) -> bool: | def is_distributed(self) -> bool: | ||||
""" | """ | ||||
当前的 driver 实例是否是分布式的; | 当前的 driver 实例是否是分布式的; | ||||
@@ -70,7 +70,8 @@ class JittorMPIDriver(JittorDriver): | |||||
def test_step(self, batch): | def test_step(self, batch): | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
pass | pass | ||||
def backward(self, loss): | def backward(self, loss): | ||||
@@ -99,14 +99,15 @@ class JittorSingleDriver(JittorDriver): | |||||
def is_distributed(self): | def is_distributed(self): | ||||
return False | return False | ||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
# reproducible 的相关功能暂时没有实现 | # reproducible 的相关功能暂时没有实现 | ||||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
dataloader.batch_sampler = dist_sample | dataloader.batch_sampler = dist_sample | ||||
if isinstance(dist_sampler, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleIterator): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
dataloader.batch_sampler.sampler = dist_sampler | |||||
dataloader.batch_sampler.sampler = dist | |||||
if reproducible: | if reproducible: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -8,7 +8,6 @@ from .utils import ( | |||||
_FleetWrappingModel, | _FleetWrappingModel, | ||||
ForwardState, | ForwardState, | ||||
_MODE_PARAMETER, | _MODE_PARAMETER, | ||||
get_host_name_ip, | |||||
get_device_from_visible, | get_device_from_visible, | ||||
reset_seed, | reset_seed, | ||||
) | ) | ||||
@@ -81,9 +80,9 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 如果用户自己在外面初始化了并行模型; | # 如果用户自己在外面初始化了并行模型; | ||||
self.outside_fleet = False | self.outside_fleet = False | ||||
# 检测 paddle 分布式的环境变量 | |||||
if parallel_helper._is_parallel_ctx_initialized(): | |||||
# 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | |||||
if parallel_helper._is_parallel_ctx_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ | |||||
"fastnlp_paddle_launch_not_fleet" not in os.environ: | |||||
# 如果用户自己在外面初始化了 Fleet,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | |||||
if not isinstance(model, DataParallel): | if not isinstance(model, DataParallel): | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"It is not allowed to input a normal model instead of `paddle.DataParallel` when" | "It is not allowed to input a normal model instead of `paddle.DataParallel` when" | ||||
@@ -125,11 +124,11 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | ||||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | ||||
self._data_device = kwargs.get("_data_device", None) | |||||
self._data_device = kwargs.get("data_device", None) | |||||
if self._data_device is not None: | if self._data_device is not None: | ||||
if isinstance(self._data_device, int): | if isinstance(self._data_device, int): | ||||
if self._data_device < 0: | if self._data_device < 0: | ||||
raise ValueError("Parameter `_data_device` can not be smaller than 0.") | |||||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | |||||
_could_use_device_num = paddle.device.cuda.device_count() | _could_use_device_num = paddle.device.cuda.device_count() | ||||
if self._data_device >= _could_use_device_num: | if self._data_device >= _could_use_device_num: | ||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | raise ValueError("The gpu device that parameter `device` specifies is not existed.") | ||||
@@ -140,18 +139,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
logger.warning("`Parameter data_device` is not equal to paddle.deivce.get_device(), " | logger.warning("`Parameter data_device` is not equal to paddle.deivce.get_device(), " | ||||
"please keep them equal to avoid some potential bugs.") | "please keep them equal to avoid some potential bugs.") | ||||
if not self.outside_fleet and parallel_device is None: | |||||
raise ValueError("Parameter `parallel_device` can not be None when using `PaddleFleetDriver`. This error is caused " | |||||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||||
# 可能需要放在参数里 | |||||
self.strategy = kwargs.get("strategy", fleet.DistributedStrategy()) | |||||
self.is_collective = kwargs.get("is_collective", True) | |||||
if not self.is_collective: | |||||
raise NotImplementedError("FastNLP dose not support `parameters server` for distributed training now.") | |||||
self.role_maker = kwargs.get("role_maker", None) | |||||
self._master_port = None | |||||
self.world_size = None | self.world_size = None | ||||
self.global_rank = 0 | self.global_rank = 0 | ||||
self._configured = False # 防止重复调用 configure_ddp() 函数使用 | self._configured = False # 防止重复调用 configure_ddp() 函数使用 | ||||
@@ -159,7 +146,11 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | ||||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | ||||
# TODO 对这些参数的检查 | |||||
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | |||||
self.is_collective = self._fleet_kwargs.get("is_collective", True) | |||||
if not self.is_collective: | |||||
raise NotImplementedError("FastNLP only support `collective` for distributed training now.") | |||||
self.role_maker = self._fleet_kwargs.get("role_maker", None) | |||||
if self.local_rank == 0 and not is_in_paddle_dist(): | if self.local_rank == 0 and not is_in_paddle_dist(): | ||||
# 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 | # 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 | ||||
@@ -193,14 +184,16 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM")) | self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM")) | ||||
self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID")) | self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID")) | ||||
reset_seed() | reset_seed() | ||||
logger.warning(f"\nworld size, global rank: {self.world_size}, {self.global_rank}\n") | |||||
fleet.init(self.role_maker, self.is_collective, self.strategy) | |||||
logger.info(f"\nworld size, global rank: {self.world_size}, {self.global_rank}\n") | |||||
if not parallel_helper._is_parallel_ctx_initialized(): | |||||
fleet.init(self.role_maker, self.is_collective, self.strategy) | |||||
os.environ["fastnlp_paddle_launch_not_fleet"] = "yes" | |||||
else: | else: | ||||
# 在用户只使用了一个分布式 trainer 的情况下 | # 在用户只使用了一个分布式 trainer 的情况下 | ||||
# 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False | # 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False | ||||
# parallel_device 是 list, | # parallel_device 是 list, | ||||
# if self.local_rank == 0 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
if not parallel_helper._is_parallel_ctx_initialized(): | if not parallel_helper._is_parallel_ctx_initialized(): | ||||
# 没有初始化分布式环境,且是主进程 | # 没有初始化分布式环境,且是主进程 | ||||
self.init_fleet_and_set() | self.init_fleet_and_set() | ||||
@@ -212,11 +205,15 @@ class PaddleFleetDriver(PaddleDriver): | |||||
if sorted(pre_gpus) != sorted(self.parallel_device): | if sorted(pre_gpus) != sorted(self.parallel_device): | ||||
raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not" | raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not" | ||||
"allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.") | "allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.") | ||||
self.world_size = dist.get_world_size() | |||||
self.global_rank = dist.get_rank() | |||||
if not self.outside_fleet: | if not self.outside_fleet: | ||||
# self.model.to(self.model_device) | # self.model.to(self.model_device) | ||||
self.configure_fleet() | self.configure_fleet() | ||||
self.barrier() | |||||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | ||||
# TODO 不用.to会怎么样? | # TODO 不用.to会怎么样? | ||||
self._pids = [] | self._pids = [] | ||||
@@ -238,10 +235,10 @@ class PaddleFleetDriver(PaddleDriver): | |||||
""" | """ | ||||
if self.local_rank == 0: | if self.local_rank == 0: | ||||
# 是 rank0 的话,则拉起其它子进程 | # 是 rank0 的话,则拉起其它子进程 | ||||
print("in launcher") | |||||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | ||||
launcher.launch() | launcher.launch() | ||||
# 设置参数和初始化分布式环境 | # 设置参数和初始化分布式环境 | ||||
reset_seed() | |||||
fleet.init(self.role_maker, self.is_collective, self.strategy) | fleet.init(self.role_maker, self.is_collective, self.strategy) | ||||
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) | self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) | ||||
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) | self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) | ||||
@@ -256,6 +253,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | ||||
根据 paddle 设置的环境变量来获得各种属性 | 根据 paddle 设置的环境变量来获得各种属性 | ||||
""" | """ | ||||
print("set_from_env") | |||||
self.world_size = dist.get_world_size() | self.world_size = dist.get_world_size() | ||||
self.global_rank = dist.get_rank() | self.global_rank = dist.get_rank() | ||||
@@ -297,8 +295,6 @@ class PaddleFleetDriver(PaddleDriver): | |||||
@property | @property | ||||
def model_device(self): | def model_device(self): | ||||
# 我认为这里的两个 device 应该返回真实值,对 CUDA_VISIBLDE_DEIVCES的转换应该在相应的 to 函数完成 | |||||
# 否则会造成用户的困惑 | |||||
return self._model_device | return self._model_device | ||||
@property | @property | ||||
@@ -316,13 +312,14 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def test_step(self, batch): | def test_step(self, batch): | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist_sampler, ReproducibleIterator): | |||||
dataloader.batch_sampler.sampler = dist_sampler | |||||
if isinstance(dist, ReproducibleIterator): | |||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | return dataloader | ||||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 | # paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 | ||||
@@ -334,14 +331,14 @@ class PaddleFleetDriver(PaddleDriver): | |||||
shuffle = dataloader.batch_sampler.shuffle | shuffle = dataloader.batch_sampler.shuffle | ||||
# trainer, evaluator | # trainer, evaluator | ||||
if dist_sampler is None: | |||||
if dist is None: | |||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | ||||
"control.") | "control.") | ||||
else: | else: | ||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist_sampler == "dist": | |||||
elif dist == "dist": | |||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | ||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
@@ -364,7 +361,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
dataloader.batch_sampler.sampler = sampler | dataloader.batch_sampler.sampler = sampler | ||||
return dataloader | return dataloader | ||||
# evaluator | # evaluator | ||||
elif dist_sampler == "unrepeatdist": | |||||
elif dist == "unrepeatdist": | |||||
sampler = UnrepeatedDistributedSampler( | sampler = UnrepeatedDistributedSampler( | ||||
dataset=dataloader.dataset, | dataset=dataloader.dataset, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
@@ -408,9 +405,8 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def move_data_to_device(self, batch: 'paddle.Tensor'): | def move_data_to_device(self, batch: 'paddle.Tensor'): | ||||
device = self.data_device | device = self.data_device | ||||
# 因为设置了CUDA_VISIBLE_DEVICES,在子进程中可能会引起错误 | |||||
if FASTNLP_DISTRIBUTED_CHECK in os.environ: | |||||
device = get_device_from_visible(device) | |||||
# 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 | |||||
device = get_device_from_visible(device) | |||||
return paddle_move_data_to_device(batch, device) | return paddle_move_data_to_device(batch, device) | ||||
@staticmethod | @staticmethod | ||||
@@ -7,7 +7,7 @@ from .single_device import PaddleSingleDriver | |||||
from .fleet import PaddleFleetDriver | from .fleet import PaddleFleetDriver | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK | |||||
from fastNLP.core.utils import is_in_paddle_launch_dist | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -26,13 +26,14 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
:return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | :return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | ||||
先后 driver 的次序的正确问题); | 先后 driver 的次序的正确问题); | ||||
""" | """ | ||||
if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_RANK_IN_NODE" in os.environ and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
if is_in_paddle_launch_dist(): | |||||
if device is not None: | if device is not None: | ||||
logger.warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | logger.warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | ||||
"up your script. And we will directly get the local device via " | "up your script. And we will directly get the local device via " | ||||
"`f'gpu:{os.environ['FLAGS_selected_gpus']}')`.") | |||||
device = [int(g) for g in os.environ["FLAGS_selected_gpus"].split(",")] | |||||
return PaddleFleetDriver(model, f"gpu:{os.environ['PADDLE_RANK_IN_NODE']}", True, **kwargs) | |||||
"and `os.environ['CUDA_VISIBLE_DEVICES']``.") | |||||
device = [int(g) for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] | |||||
# TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int | |||||
return PaddleFleetDriver(model, device[0], True, **kwargs) | |||||
if driver not in {"paddle", "fleet"}: | if driver not in {"paddle", "fleet"}: | ||||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | ||||
@@ -42,7 +43,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
# 优先级 user > cuda | # 优先级 user > cuda | ||||
# 判断单机情况 device 的合法性 | # 判断单机情况 device 的合法性 | ||||
# 分布式情况下通过 world_device 判断 | # 分布式情况下通过 world_device 判断 | ||||
if user_visible_devices is not None: | |||||
if user_visible_devices != "": | |||||
_could_use_device_num = len(user_visible_devices.split(",")) | _could_use_device_num = len(user_visible_devices.split(",")) | ||||
elif cuda_visible_devices is not None: | elif cuda_visible_devices is not None: | ||||
_could_use_device_num = len(cuda_visible_devices.split(",")) | _could_use_device_num = len(cuda_visible_devices.split(",")) | ||||
@@ -51,8 +52,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
if isinstance(device, int): | if isinstance(device, int): | ||||
if device < 0 and device != -1: | if device < 0 and device != -1: | ||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
if device >= _could_use_device_num: | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
# if device >= _could_use_device_num: | |||||
# raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
device = f"gpu:{device}" | device = f"gpu:{device}" | ||||
elif isinstance(device, Sequence) and not isinstance(device, str): | elif isinstance(device, Sequence) and not isinstance(device, str): | ||||
device = list(set(device)) | device = list(set(device)) | ||||
@@ -1,8 +1,15 @@ | |||||
import os | |||||
from typing import Optional, Dict, Union | from typing import Optional, Dict, Union | ||||
from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.utils import auto_param_call, get_paddle_gpu_str | |||||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | |||||
from fastNLP.core.utils import ( | |||||
auto_param_call, | |||||
get_paddle_gpu_str, | |||||
get_paddle_device_id, | |||||
paddle_move_data_to_device, | |||||
) | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -86,8 +93,14 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self._test_signature_fn = model.forward | self._test_signature_fn = model.forward | ||||
def setup(self): | def setup(self): | ||||
paddle.device.set_device(self.model_device) | |||||
self.model.to(self.model_device) | |||||
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] | |||||
device_id = get_paddle_device_id(self.model_device) | |||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||||
device_id = user_visible_devices.split(",")[device_id] | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||||
paddle.device.set_device("gpu:0") | |||||
self.model.to("gpu:0") | |||||
def train_step(self, batch) -> Dict: | def train_step(self, batch) -> Dict: | ||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | ||||
@@ -116,15 +129,26 @@ class PaddleSingleDriver(PaddleDriver): | |||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | |||||
def move_data_to_device(self, batch: 'paddle.Tensor'): | |||||
r""" | |||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||||
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` | |||||
:return: 将移动到指定机器上的 batch 对象返回; | |||||
""" | |||||
return paddle_move_data_to_device(batch, "gpu:0") | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
# 暂时不支持IteratorDataset | # 暂时不支持IteratorDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||||
dataloader.batch_sampler = dist_sampler | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dataloader.batch_sampler = dist | |||||
return dataloader | return dataloader | ||||
if isinstance(dist_sampler, ReproducibleIterator): | |||||
dataloader.batch_sampler.sampler = dist_sampler | |||||
if isinstance(dist, ReproducibleIterator): | |||||
dataloader.batch_sampler.sampler = dist | |||||
return dataloader | return dataloader | ||||
if reproducible: | if reproducible: | ||||
@@ -271,10 +271,10 @@ def get_device_from_visible(device: Union[str, int]): | |||||
return idx | return idx | ||||
else: | else: | ||||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | ||||
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||||
if user_visiblde_devices is None or user_visiblde_devices != "": | |||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | ||||
idx = user_visiblde_devices.split(",")[idx] | |||||
idx = user_visible_devices.split(",")[idx] | |||||
else: | else: | ||||
idx = str(idx) | idx = str(idx) | ||||
@@ -7,7 +7,6 @@ from time import sleep | |||||
from typing import List, Optional, Union, Dict | from typing import List, Optional, Union, Dict | ||||
from functools import partial | from functools import partial | ||||
# todo 这个等大家的 __all__ 都弄完后改为 from fastNLP.env import; | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -44,20 +43,128 @@ class TorchDDPDriver(TorchDriver): | |||||
fp16: bool = False, | fp16: bool = False, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
""" | |||||
DDP 目前考虑支持的三种启动方式: | |||||
1. 用户自己不进行 ddp 的任何操作,直接使用我们的 Trainer,并且只运行一个 main 脚本,这时是由我们自己使用 open_subprocesses 拉起 | |||||
多个进程,然后 TorchDDPDriver 自己 init_process_group; | |||||
2. 其它情况同 1,但是用户自己使用 python -m torch.distributed.launch 拉起; | |||||
3. 用户自己在外面初始化 DDP,并且通过 python -m torch.distributed.launch 拉起; | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m torch.distributed.launch 启动; | |||||
如果用户自己在外面初始化了 ddp,那么 | |||||
parallel_device 为 None; | |||||
data_device 为 表示单卡的一个参数; | |||||
dist.is_initialized 为 true; | |||||
r""" | |||||
`TorchDDPDriver` 目前支持的三种启动方式: | |||||
1. 用户自己不进行 ddp 的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `open_subprocesses` 拉起多个进程, | |||||
然后 `TorchDDPDriver` 自己通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 A) | |||||
2. 用户同样不在 Trainer 之外初始化 ddp,但是用户自己使用 python -m torch.distributed.launch 拉起来创建多个进程,这时我们仍旧 | |||||
会通过调用 `dist.init_process_group` 来初始化 ddp 的通信组;(情况 B) | |||||
3. 用户自己在外面初始化 DDP,并且通过 python -m torch.distributed.launch 拉起,这时无论是多个进程的拉起和 ddp 的通信组的建立 | |||||
都由用户自己操作,我们只会在 driver.setup 的时候对 `TorchDDPDriver` 设置一些必要的属性值;(情况 C) | |||||
注意多机的启动强制要求用户在每一台机器上使用 python -m torch.distributed.launch 启动;因此我们不会在 `TorchDDPDriver` 中保存 | |||||
任何当前有多少台机器的信息(num_nodes,不是 gpu 的数量); | |||||
Part 1:三种启动方式的具体分析: | |||||
(1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, | |||||
`TorchDDPDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: | |||||
-> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DistributedDataParallel` 包裹的model), | |||||
因为 `DistributedDataParallel` 的使用一定要求 init_process_group 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 | |||||
用户需要使用 2 张以上的显卡,那么其必然需要使用 torch.distributed.launch 来启动,意味着就不是情况 A 了; | |||||
这时我们首先会调用 `TorchDDPDriver.open_subprocess` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu | |||||
的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); | |||||
接着我们会调用 `dist.init_process_group` 来初始化各个进程之间的通信组; | |||||
这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 | |||||
才会去真正地运行 `TorchDDPDriver.open_subprocess`;进程 0 运行到 `dist.init_process_group`,pytorch 会阻塞进程 0 继续 | |||||
向前运行,直到其它进程也运行到这里; | |||||
最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DistributedDataParallel` 将模型包裹; | |||||
至此,ddp 的环境配置过程全部完成; | |||||
-> 情况 B:注意这种情况我们直接限定了用户是通过 torch.distributed.launch 拉起,并且没有自己建立 ddp 的通信组。这时在 | |||||
`TorchDDPDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, | |||||
这时每个进程所使用的 gpu 是我们直接通过 `torch.device("cuda:{local_rank}")` 来配置的;因此,如果用户想要实现使用特定 gpu | |||||
设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现);剩下的操作和情况 A 类似; | |||||
-> 情况 C:注意这种情况我们限定了用户是通过 torch.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 | |||||
与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DistributedDataParallel` 包裹等。 | |||||
(2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: | |||||
注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `TorchDDPDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 | |||||
检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 | |||||
我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 | |||||
启动方式来实现这一点的: | |||||
我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 | |||||
使用 '情况 A' 来启动 `TorchDDPDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 | |||||
会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `TorchDDPDriver` 的初始化和 setup 过程中, | |||||
如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 | |||||
Part 2:对应的代码细节: | |||||
1. 如何判断当前的各进程之间的通信组已经被建立(ddp 已经被初始化); | |||||
dist.is_initialized(); | |||||
2. 如何判断不同的进程是否是由 `python -m torch.distributed.launch` 拉起还是由我们的 `TorchDDPDriver.open_subprocess` | |||||
函数拉起; | |||||
我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'LOCAL_RANK'、'WORLD_SIZE' 以及没有 `FASTNLP_DISTRIBUTED_CHECK`, | |||||
如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m torch.distributed.launch` | |||||
来拉起多个进程; | |||||
3. 整体的处理判断流程: | |||||
___________________________________ | |||||
|进入 TorchDDPDriver 的 __init__ 函数| | |||||
——————————————————————————————————— | |||||
↓ | |||||
___________________________________________________ | |||||
| 判断不同的进程是否是由 torch.distributed.launch 拉起 | | |||||
|(或者我们自己的 open_subprocess 函数拉起) | --------------> | |||||
——————————————————————————————————————————————————— | | |||||
↓ 是由 torch.distributed.launch 拉起 | 我们自己的 open_subprocess 函数拉起多个进程 | |||||
___________________________ | | |||||
←←←←← | 检测用户是否自己初始化了 ddp | | | |||||
↓ ——————————————————————————— ↓ | |||||
↓ ↓ 是 ________ | |||||
↓ ______ | 情况 A | | |||||
↓ 否 |情况 C| ————————— | |||||
↓ ——————— | |||||
↓ | |||||
↓ ______ | |||||
↓ -----------> |情况 B| | |||||
——————— | |||||
4. 为了完成全部的建立 ddp 所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
配置 ddp 所 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||||
需要的环境变量 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
开启多个进程 | TorchDDPDriver.open_subprocess | torch.distributed.launch| torch.distributed.launch | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
调用 dist. | | | | |||||
init_process\ | TorchDDPDriver.setup | TorchDDPDriver.setup | 用户自己调用 | |||||
_group 函数 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
设置 TorchDDPDriver | | | | |||||
的 world_size 和 | TorchDDPDriver.setup | TorchDDPDriver.setup | TorchDDPDriver.setup | |||||
global_rank 属性 | | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
Part 3:其它的处理细节: | |||||
1. 环境变量; | |||||
fastNLP 的 `TorchDDPDriver` 运行时所需要的环境变量分为两种,一种是 torch 的 ddp 运行所需要的环境变量;另一种是 fastNLP 自己 | |||||
的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; | |||||
2. parallel_device, model_device 和 data_device 的关系; | |||||
parallel_device 为 `TorchDDPDriver` 的参数,model_device 和 data_device 都为 driver 的属性; | |||||
其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; | |||||
model_device 永远都为单独的一个 torch.device; | |||||
情况 A | 情况 B | 情况 C | |||||
________________________________________________________________________________________________________ | |||||
parallel_device | 由用户传入trainer的参数 | 为 torch.device( | 为 torch.device( | |||||
| device 决定,必须是一个list, | "cuda:{local_rank}") | "cuda:{local_rank}") | |||||
| 其中每一个对象都是 torch.device | | | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
model_device | parallel_device[local_rank] | parallel_device | None | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
data_device | model_device | model_device | 由用户传入 trainer 的参数 | |||||
| | | data_device 决定 | |||||
———————————————————————————————————————————————————————————————————————————————————————————————————————— | |||||
3. _DDPWrappingModel 的作用; | |||||
因为我们即需要调用模型的 `train_step`、`validate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 | |||||
forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` | |||||
的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 | |||||
forward 函数,还是 `train_step`、`validate_step`、`test_step` 方法。 | |||||
4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; | |||||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||||
""" | """ | ||||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
@@ -81,7 +188,8 @@ class TorchDDPDriver(TorchDriver): | |||||
# 如果用户自己在外面初始化了 DDP; | # 如果用户自己在外面初始化了 DDP; | ||||
self.outside_ddp = False | self.outside_ddp = False | ||||
if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and "fastnlp_special" not in os.environ: | |||||
if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ | |||||
"fastnlp_torch_launch_not_ddp" not in os.environ: | |||||
# 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | # 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | ||||
if not isinstance(model, DistributedDataParallel): | if not isinstance(model, DistributedDataParallel): | ||||
raise RuntimeError( | raise RuntimeError( | ||||
@@ -97,7 +205,7 @@ class TorchDDPDriver(TorchDriver): | |||||
if isinstance(batch, Dict): | if isinstance(batch, Dict): | ||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | return auto_param_call(step_fn, batch, signature_fn=signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | |||||
return step_fn(batch) | |||||
model = model.module | model = model.module | ||||
if hasattr(model, "train_step"): | if hasattr(model, "train_step"): | ||||
@@ -185,7 +293,7 @@ class TorchDDPDriver(TorchDriver): | |||||
backend="nccl", rank=self.global_rank, world_size=self.world_size | backend="nccl", rank=self.global_rank, world_size=self.world_size | ||||
) | ) | ||||
os.environ["fastnlp_special"] = "yes" | |||||
os.environ["fastnlp_torch_launch_not_ddp"] = "yes" | |||||
# 进入到这里的情况时: | # 进入到这里的情况时: | ||||
# dist.is_initialized 一定为 False; | # dist.is_initialized 一定为 False; | ||||
@@ -337,21 +445,22 @@ class TorchDDPDriver(TorchDriver): | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||||
if isinstance(dist_sampler, ReproducibleIterator): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
if isinstance(dist, ReproducibleIterator): | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | ||||
dist_sampler = re_instantiate_sampler(dist_sampler) | |||||
return replace_sampler(dataloader, dist_sampler) | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | |||||
# trainer, evaluator | # trainer, evaluator | ||||
if dist_sampler is None: | |||||
if dist is None: | |||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | ||||
"control.") | "control.") | ||||
else: | else: | ||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist_sampler == "dist": | |||||
elif dist == "dist": | |||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(args.sampler, ReproducibleIterator): | if isinstance(args.sampler, ReproducibleIterator): | ||||
@@ -377,7 +486,7 @@ class TorchDDPDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
# evaluator | # evaluator | ||||
elif dist_sampler == "unrepeatdist": | |||||
elif dist == "unrepeatdist": | |||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
sampler = UnrepeatedDistributedSampler( | sampler = UnrepeatedDistributedSampler( | ||||
dataset=args.dataset, | dataset=args.dataset, | ||||
@@ -397,12 +397,13 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||||
""" | """ | ||||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | ||||
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | ||||
if device is None: | |||||
device = torch.cuda.current_device() | |||||
if _TORCH_GREATER_EQUAL_1_8: | if _TORCH_GREATER_EQUAL_1_8: | ||||
objs = [None for _ in range(dist.get_world_size(group))] | objs = [None for _ in range(dist.get_world_size(group))] | ||||
dist.all_gather_object(objs, obj) | dist.all_gather_object(objs, obj) | ||||
apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||||
return objs | return objs | ||||
if device is None: | |||||
device = torch.cuda.current_device() | |||||
group = group if group is not None else torch.distributed.group.WORLD | group = group if group is not None else torch.distributed.group.WORLD | ||||
data = convert_to_tensors(obj, device=device) | data = convert_to_tensors(obj, device=device) | ||||
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) | data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) | ||||
@@ -130,12 +130,12 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
reproducible: bool = False): | |||||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, dist_sampler) | |||||
elif isinstance(dist_sampler, ReproducibleIterator): | |||||
return replace_sampler(dataloader, dist_sampler) | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | |||||
reproducible: bool = False): | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, dist) | |||||
elif isinstance(dist, ReproducibleIterator): | |||||
return replace_sampler(dataloader, dist) | |||||
if reproducible: | if reproducible: | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
@@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -178,8 +179,28 @@ class TorchDriver(Driver): | |||||
model.load_state_dict(res.state_dict()) | model.load_state_dict(res.state_dict()) | ||||
@rank_zero_call | @rank_zero_call | ||||
def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# 1. 保存模型的状态; | |||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
states['sampler_states'] = sampler.state_dict() | |||||
else: | |||||
raise RuntimeError( | |||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | if should_save_model: | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if only_state_dict: | if only_state_dict: | ||||
@@ -191,7 +212,7 @@ class TorchDriver(Driver): | |||||
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) | torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) | ||||
logger.debug("Save model") | logger.debug("Save model") | ||||
# 2. 保存 optimizers 的状态; | |||||
# 3. 保存 optimizers 的状态; | |||||
optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: torch.optim.Optimizer = self.optimizers[i] | optimizer: torch.optim.Optimizer = self.optimizers[i] | ||||
@@ -203,7 +224,7 @@ class TorchDriver(Driver): | |||||
states["optimizers_state_dict"] = optimizers_state_dict | states["optimizers_state_dict"] = optimizers_state_dict | ||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
# 1. 加载 optimizers 的状态; | # 1. 加载 optimizers 的状态; | ||||
@@ -224,6 +245,39 @@ class TorchDriver(Driver): | |||||
model.load_state_dict(res.state_dict()) | model.load_state_dict(res.state_dict()) | ||||
logger.debug("Load model.") | logger.debug("Load model.") | ||||
# 3. 恢复 sampler 的状态; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
sampler = dataloader_args.sampler | |||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||||
if self.is_distributed(): | |||||
raise RuntimeError( | |||||
"It is not allowed to use single device checkpoint retraining before but ddp now.") | |||||
sampler = ReproducibleBatchSampler( | |||||
batch_sampler=sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(states['sampler_states']) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
return states | return states | ||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
@@ -34,20 +34,13 @@ def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> | |||||
def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: | def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: | ||||
"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, | |||||
sets the following environment variables: | |||||
r""" | |||||
为伪随机数生成器设置种子的函数:pytorch、numpy、python.random 另外, | |||||
设置以下环境变量: | |||||
- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). | |||||
- `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. | |||||
Args: | |||||
seed: the integer value seed for global random state in Lightning. | |||||
If `None`, will read seed from `PL_GLOBAL_SEED` env variable | |||||
or select it randomly. | |||||
workers: if set to ``True``, will properly configure all dataloaders passed to the | |||||
Trainer with a ``worker_init_fn``. If the user already provides such a function | |||||
for their dataloaders, setting this argument will have no influence. See also: | |||||
:func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`. | |||||
:param seed: 全局随机状态的整数值种子。如果为“无”,将从 "FASTNLP_GLOBAL_SEED" 环境变量中读取种子或随机选择。 | |||||
:param workers: 如果设置为“True”,将正确配置所有传递给带有“worker_init_fn”的培训师。如果用户已经提供了这样的功能对于他们的数据加载器, | |||||
设置此参数将没有影响; | |||||
""" | """ | ||||
max_seed_value = np.iinfo(np.uint32).max | max_seed_value = np.iinfo(np.uint32).max | ||||
min_seed_value = np.iinfo(np.uint32).min | min_seed_value = np.iinfo(np.uint32).min | ||||
@@ -56,7 +49,6 @@ def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||||
env_seed = os.environ.get(FASTNLP_GLOBAL_SEED) | env_seed = os.environ.get(FASTNLP_GLOBAL_SEED) | ||||
if env_seed is None: | if env_seed is None: | ||||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | seed = _select_seed_randomly(min_seed_value, max_seed_value) | ||||
# rank_zero_warn(f"No seed found, seed set to {seed}") | |||||
else: | else: | ||||
try: | try: | ||||
seed = int(env_seed) | seed = int(env_seed) | ||||
@@ -69,12 +61,8 @@ def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||||
if not (min_seed_value <= seed <= max_seed_value): | if not (min_seed_value <= seed <= max_seed_value): | ||||
logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") | logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") | ||||
# rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") | |||||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | seed = _select_seed_randomly(min_seed_value, max_seed_value) | ||||
# using `log.info` instead of `rank_zero_info`, | |||||
# so users can verify the seed is properly set in distributed training. | |||||
# log.info(f"Global seed set to {seed}") | |||||
random.seed(seed) | random.seed(seed) | ||||
np.random.seed(seed) | np.random.seed(seed) | ||||
torch.manual_seed(seed) | torch.manual_seed(seed) | ||||
@@ -84,11 +72,9 @@ def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||||
def reset_seed() -> None: | def reset_seed() -> None: | ||||
""" | |||||
r""" | |||||
这个函数主要是给 ddp 用的,因为 ddp 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | 这个函数主要是给 ddp 用的,因为 ddp 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | ||||
进行随机数的设置; | 进行随机数的设置; | ||||
If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. | |||||
""" | """ | ||||
seed = os.environ.get(FASTNLP_GLOBAL_SEED, None) | seed = os.environ.get(FASTNLP_GLOBAL_SEED, None) | ||||
workers = os.environ.get(FASTNLP_SEED_WORKERS, "0") | workers = os.environ.get(FASTNLP_SEED_WORKERS, "0") | ||||
@@ -11,11 +11,12 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
import paddle.distributed as dist | |||||
from paddle.fluid.dygraph import parallel_helper | from paddle.fluid.dygraph import parallel_helper | ||||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | ||||
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | ||||
paddle.distributed.all_gather(gathered_result, result, group) | |||||
dist.all_gather(gathered_result, result, group) | |||||
return gathered_result | return gathered_result | ||||
class PaddleBackend(Backend): | class PaddleBackend(Backend): | ||||
@@ -36,13 +37,13 @@ class PaddleBackend(Backend): | |||||
tensor = paddle.stack(tensor) | tensor = paddle.stack(tensor) | ||||
# 第一步, aggregate结果 | # 第一步, aggregate结果 | ||||
if method == 'sum': | if method == 'sum': | ||||
tensor = paddle.sum(tensor, dim=0) | |||||
tensor = paddle.sum(tensor, axis=0) | |||||
elif method == 'mean': | elif method == 'mean': | ||||
tensor = paddle.mean(tensor, dim=0) | |||||
tensor = paddle.mean(tensor, axis=0) | |||||
elif method == 'max': | elif method == 'max': | ||||
tensor, _ = paddle.max(tensor, dim=0) | |||||
tensor, _ = paddle.max(tensor, axis=0) | |||||
elif method == 'min': | elif method == 'min': | ||||
tensor, _ = paddle.min(tensor, dim=0) | |||||
tensor, _ = paddle.min(tensor, axis=0) | |||||
else: | else: | ||||
raise AggregateMethodError(should_have_aggregate_method=False) | raise AggregateMethodError(should_have_aggregate_method=False) | ||||
@@ -80,11 +81,12 @@ class PaddleBackend(Backend): | |||||
聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding | 聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding | ||||
""" | """ | ||||
# TODO check 正确性 | # TODO check 正确性 | ||||
if group is None: | |||||
group = paddle.distributed.get_group(0) | |||||
# 有 paddle 那边的 bug,2.3 版本的时候修复了,到时候改一下 | |||||
# if group is None: | |||||
# group = dist.get_group(0) | |||||
world_size = group.nranks | |||||
paddle.distributed.barrier(group=group) | |||||
world_size = group.nranks if group is not None else dist.get_world_size() | |||||
dist.barrier(group=group) | |||||
# 张量为 标量的情况,简单地gather就好 | # 张量为 标量的情况,简单地gather就好 | ||||
if result.ndim == 0: | if result.ndim == 0: | ||||
@@ -93,10 +95,10 @@ class PaddleBackend(Backend): | |||||
# 获得 result 的 shape | # 获得 result 的 shape | ||||
local_size = paddle.to_tensor(result.shape) | local_size = paddle.to_tensor(result.shape) | ||||
# 将 group 中所有 result 的大小聚合在一起 | # 将 group 中所有 result 的大小聚合在一起 | ||||
local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)] | |||||
paddle.distributed.all_gather(local_sizes, local_size, group=group) | |||||
local_sizes = [] | |||||
dist.all_gather(local_sizes, local_size, group=group) | |||||
# 堆叠后,计算出 shape 每一维度的最大值 | # 堆叠后,计算出 shape 每一维度的最大值 | ||||
max_size = paddle.stack(local_sizes).max(axis=0).values | |||||
max_size = paddle.stack(local_sizes).max(axis=0) | |||||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | ||||
# 如果所有的结果大小相同,那么可以直接聚合 | # 如果所有的结果大小相同,那么可以直接聚合 | ||||
@@ -111,16 +113,15 @@ class PaddleBackend(Backend): | |||||
pad_dims.append(val.item()) | pad_dims.append(val.item()) | ||||
result_padded = paddle.nn.functional.pad(result, pad_dims) | result_padded = paddle.nn.functional.pad(result, pad_dims) | ||||
# 重新进行聚合 | # 重新进行聚合 | ||||
gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)] | |||||
paddle.distributed.all_gather(gathered_result, result_padded, group) | |||||
gathered_result = [] | |||||
dist.all_gather(gathered_result, result_padded, group) | |||||
for idx, item_size in enumerate(local_sizes): | for idx, item_size in enumerate(local_sizes): | ||||
slice_param = [slice(dim_size) for dim_size in item_size] | |||||
slice_param = [slice(dim_size) for dim_size in item_size.tolist()] | |||||
gathered_result[idx] = gathered_result[idx][slice_param] | gathered_result[idx] = gathered_result[idx][slice_param] | ||||
return gathered_result | return gathered_result | ||||
def move_tensor_to_device(self, tensor, device): | def move_tensor_to_device(self, tensor, device): | ||||
# TODO 如果在这里处理的话,会不会在别的地方引起bug? | # TODO 如果在这里处理的话,会不会在别的地方引起bug? | ||||
if is_in_paddle_dist(): | |||||
device = get_device_from_visible(device) | |||||
device = get_device_from_visible(device) | |||||
return paddle_to(tensor, device) | return paddle_to(tensor, device) | ||||
@@ -4,17 +4,18 @@ __all__ = [ | |||||
from typing import Any | from typing import Any | ||||
from functools import wraps | from functools import wraps | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||||
from fastNLP.envs.utils import _module_available | from fastNLP.envs.utils import _module_available | ||||
_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | _IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | ||||
if _IS_TORCHMETRICS_AVAILABLE: | |||||
from torchmetrics import Metric as torchmetrics_Metric | |||||
_IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | _IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | ||||
if _IS_ALLENNLP_AVAILABLE: | if _IS_ALLENNLP_AVAILABLE: | ||||
from allennlp.training.metrics import Metric as allennlp_Metric | from allennlp.training.metrics import Metric as allennlp_Metric | ||||
if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE: | |||||
if _IS_TORCHMETRICS_AVAILABLE: | |||||
from torchmetrics import Metric as torchmetrics_Metric | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.metric import Metric as paddle_Metric | from paddle.metric import Metric as paddle_Metric | ||||
@@ -11,11 +11,11 @@ __all__ = [ | |||||
'PollingSampler', | 'PollingSampler', | ||||
'ReproducibleIterator', | 'ReproducibleIterator', | ||||
'RandomSampler', | 'RandomSampler', | ||||
'ReproducibleBatchSampler', | |||||
're_instantiate_sampler' | 're_instantiate_sampler' | ||||
] | ] | ||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler | from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler | ||||
from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler | from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler | ||||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, ReproducibleBatchSampler, re_instantiate_sampler | |||||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | |||||
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler | |||||
@@ -0,0 +1,397 @@ | |||||
__all__ = [ | |||||
'BucketedBatchSampler', | |||||
"ReproducibleBatchSampler" | |||||
] | |||||
import math | |||||
from array import array | |||||
from copy import deepcopy | |||||
from typing import Dict, Union, List | |||||
from itertools import chain | |||||
import numpy as np | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.log import logger | |||||
from abc import abstractmethod | |||||
class ReproducibleBatchIterator: | |||||
@abstractmethod | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | |||||
@abstractmethod | |||||
def __len__(self): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `__len__` method.") | |||||
@abstractmethod | |||||
def __iter__(self): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `__iter__` method.") | |||||
@abstractmethod | |||||
def state_dict(self): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | |||||
@abstractmethod | |||||
def load_state_dict(self, states): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `load_state_dict` method.") | |||||
@abstractmethod | |||||
def set_epoch(self, epoch): | |||||
pass | |||||
class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | |||||
""" | |||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | |||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | |||||
:param batch_size: 每个 batch 的大小是多少。 | |||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||||
:param kwargs: fastNLP 内部使用。 | |||||
""" | |||||
self.batch_sampler = batch_sampler | |||||
self.batch_size = batch_size | |||||
self.drop_last = drop_last | |||||
self.data_idx = kwargs.get("data_idx", 0) | |||||
self.index_list = kwargs.get("index_list", self._iterate_sampler()) | |||||
self.need_reinitialize = kwargs.get("need_reinitialize", False) | |||||
def _iterate_sampler(self): | |||||
_index_lst = [] | |||||
for idx in self.batch_sampler: | |||||
if isinstance(idx, list): | |||||
_index_lst.extend(idx) | |||||
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | |||||
else: | |||||
_index_lst.append(idx) | |||||
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||||
if len(_index_lst) > 4294967295: | |||||
# 注意 self.index_list 内存放的是全部数据的 index; | |||||
# unsigned long | |||||
_index_lst = array("L", _index_lst) | |||||
else: | |||||
# unsigned int | |||||
_index_lst = array("I", _index_lst) | |||||
return _index_lst | |||||
def __iter__(self): | |||||
if self.need_reinitialize: | |||||
self.index_list = self._iterate_sampler() | |||||
self.data_idx = 0 | |||||
else: | |||||
self.need_reinitialize = True | |||||
batch = [] | |||||
if self.data_idx: | |||||
index_list = self.index_list[self.data_idx:] | |||||
else: | |||||
index_list = self.index_list | |||||
for idx in index_list: | |||||
batch.append(idx) | |||||
self.data_idx += 1 | |||||
if len(batch) == self.batch_size: | |||||
yield batch | |||||
batch = [] | |||||
if len(batch) > 0 and not self.drop_last: | |||||
yield batch | |||||
def __len__(self) -> int: | |||||
if self.drop_last: | |||||
return len(self.index_list) // self.batch_size | |||||
else: | |||||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size | |||||
def state_dict(self) -> Dict: | |||||
return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||||
def load_state_dict(self, states: Dict): | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
_index_list = states["index_list"] | |||||
assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | |||||
"record and current dataset." | |||||
self.index_list = _index_list | |||||
self.data_idx = states["data_idx"] | |||||
self.need_reinitialize = False | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||||
def set_epoch(self, epoch): | |||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | |||||
self.batch_sampler.sampler.set_epoch(epoch) | |||||
@property | |||||
def batch_idx_in_epoch(self): | |||||
if self.drop_last: | |||||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size | |||||
else: | |||||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | |||||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||||
class BucketedBatchSampler(ReproducibleBatchIterator): | |||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | |||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||||
""" | |||||
首先按照 sample 的长度排序,然后按照 batch_size*num_batch_per_bucket 为一个桶的大小,sample 只会在这个桶内进行组合,这样 | |||||
每个 batch 中的 padding 数量会比较少 (因为桶内的数据的长度都接近)。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||||
如果否则使用 len() 函数得到每个 sample 中这个 field 的长度。 | |||||
:param batch_size: 每个 batch 的大小 | |||||
:param num_batch_per_bucket: 多少个 batch 组成一个桶,数据只会在一个桶内进行 shuffle 。 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
super().__init__() | |||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if not isinstance(length[0], int): | |||||
length = list(map(len, length)) | |||||
else: | |||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||||
"the length parameter can only be List[int]" | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||||
self.dataset = dataset | |||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||||
self.batch_size = batch_size | |||||
self.num_batch_per_bucket = num_batch_per_bucket | |||||
self.shuffle = shuffle | |||||
self.drop_last = drop_last | |||||
self.seed = seed | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||||
# 多卡的相关的参数 | |||||
self.num_replicas = kwargs.get("num_replicas", 1) | |||||
self.rank = kwargs.get("rank", 0) | |||||
self.epoch = kwargs.get("epoch", -1) | |||||
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; | |||||
# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | |||||
self.during_iter = kwargs.get("during_iter", False) | |||||
# 以下变量为内部使用恢复状态的变量。 | |||||
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size) | |||||
self.old_num_batch_per_bucket = kwargs.get('old_num_batch_per_bucket', self.num_batch_per_bucket) | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||||
"during an unfinished iteration." | |||||
assert num_replicas > 0 and isinstance(num_replicas, int) | |||||
assert isinstance(rank, int) and 0 <= rank < num_replicas | |||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||||
self.num_replicas = num_replicas | |||||
self.rank = rank | |||||
self.pad = pad | |||||
num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||||
else len(self.dataset) | |||||
if self.drop_last: | |||||
assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||||
"than the number of replicates multiplied " \ | |||||
"with batch_size when drop_last=True." | |||||
return self | |||||
@property | |||||
def total_size(self): | |||||
""" | |||||
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、 | |||||
大于或者小于len(dataset) | |||||
:return: | |||||
""" | |||||
return self.num_consumed_samples + self.num_replicas*self.num_left_samples | |||||
@property | |||||
def num_left_samples(self): | |||||
""" | |||||
返回当前 iteration 还有多少个 sample 结束,表示的是当前 rank 的还剩多少。 | |||||
:return: | |||||
""" | |||||
num_consumed_samples = self.num_consumed_samples | |||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
def __len__(self): | |||||
""" | |||||
返回当前 sampler 还会返回多少个 batch 的数据 | |||||
:return: | |||||
""" | |||||
num_sampler_per_rank = self.total_size//self.num_replicas | |||||
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \ | |||||
(num_sampler_per_rank+self.batch_size-1)//self.batch_size | |||||
return num_batches | |||||
def __iter__(self): | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | |||||
self.during_iter = True | |||||
sorted_indices = deepcopy(self.sorted_indices).tolist() # 按长度从高到低排序的 | |||||
if self.shuffle: | |||||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | |||||
_batches = [] | |||||
for _i in range(self.old_num_replicas): | |||||
_sorted_indices = sorted_indices[_i:len(sorted_indices):self.old_num_replicas] | |||||
__batches = self.bucketerize(_sorted_indices, self.old_batch_size, self.old_num_batch_per_bucket, | |||||
seed=self.seed+self.epoch) | |||||
_batches.append(__batches) | |||||
batches = list(chain(*[_ for _ in zip(*_batches)])) | |||||
sorted_indices = list(chain(*batches)) | |||||
sorted_indices = sorted_indices[self.num_consumed_samples:] | |||||
# 再进行排序 | |||||
sub_length = self.length[sorted_indices] | |||||
sorted_indices = np.array(sorted_indices)[np.argsort(sub_length)[::-1]] # 按长度从高到低排序的 | |||||
# 取出这个 rank , | |||||
sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas] | |||||
batches = self.bucketerize(sorted_indices, self.batch_size, self.num_batch_per_bucket, | |||||
seed=self.seed+self.epoch) | |||||
batches = list(map(list, batches)) | |||||
else: | |||||
sorted_indices = sorted_indices[self.num_consumed_samples:] | |||||
sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas] | |||||
_num_batches = len(sorted_indices) // self.batch_size | |||||
if _num_batches == 0: | |||||
batches = [sorted_indices] | |||||
else: | |||||
batches = list(map(list, np.array_split(sorted_indices[:_num_batches*self.batch_size], _num_batches))) | |||||
if len(sorted_indices)%self.batch_size!=0: | |||||
batches.append(sorted_indices[_num_batches*self.batch_size:]) | |||||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||||
if len(batches) > 0: | |||||
if len(batches[-1])<self.batch_size: | |||||
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。 | |||||
else: | |||||
batches.append([batches[-1][0]]) | |||||
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank: | |||||
if len(batches): | |||||
batches[-1].pop(-1) | |||||
if len(batches[-1])==0: | |||||
batches.pop(-1) | |||||
assert len(list(chain(*batches))) == self.num_left_samples | |||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||||
batches = batches[:-1] | |||||
for batch in batches: | |||||
self.num_consumed_samples += self.num_replicas * len(batch) | |||||
yield list(map(int, batch)) | |||||
self.during_iter = False | |||||
self.num_consumed_samples = 0 | |||||
self.old_batch_size = self.batch_size | |||||
self.old_num_batch_per_bucket = self.num_batch_per_bucket | |||||
self.old_num_replicas = self.num_replicas | |||||
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了 | |||||
self.epoch -= 1 | |||||
def bucketerize(self, sorted_indices, batch_size, num_batch_per_bucket, seed): | |||||
""" | |||||
将 indices 分桶 | |||||
:param sorted_indices: List[int] | |||||
:param batch_size: int | |||||
:param num_batch_per_bucket: int | |||||
:param seed: int | |||||
:return: List[List[int]] | |||||
""" | |||||
# 实际的 bucket 大小 | |||||
bucket_size = min(len(sorted_indices), batch_size * num_batch_per_bucket) | |||||
rng = np.random.default_rng(abs(seed)) | |||||
num_buckets = (len(sorted_indices) + bucket_size - 1) // bucket_size | |||||
batches = [] | |||||
batch_indices = [] | |||||
for i in range(num_buckets): | |||||
bucket = sorted_indices[i * bucket_size:(i + 1) * bucket_size] | |||||
rng.shuffle(bucket) # bucket 内部 shuffle 一下 | |||||
_num_batches = len(bucket) // batch_size | |||||
if _num_batches == 0: | |||||
_batches = [bucket] | |||||
else: | |||||
_batches = np.array_split(bucket[:_num_batches*batch_size], _num_batches) | |||||
if len(bucket) % batch_size != 0: | |||||
_batches.append(bucket[_num_batches*batch_size:]) | |||||
batch_indices.extend(list(range(len(batches), len(batches) + len(_batches)))) | |||||
batches.extend(_batches) | |||||
last_batches = [] | |||||
# 最后一个batch 统一不参与shuffle,因为有的rank最后一个 batch 可能不足一个batch_size (不足的时候 | |||||
# 一定要放在末尾,所以就干脆所有的rank都不对最后一个batch进行shuffle)。 | |||||
if len(batches) >= 1: | |||||
last_batches = [list(batches[-1])] | |||||
batch_indices = list(batch_indices[:-1]) | |||||
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | |||||
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | |||||
batches = (np.array(batches)[batch_indices]).tolist() | |||||
if last_batches: | |||||
batches = batches + last_batches | |||||
return batches | |||||
def state_dict(self) -> Dict: | |||||
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | |||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||||
" consumed. ") | |||||
states = { | |||||
'seed': self.seed, | |||||
'epoch': self.epoch, | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, | |||||
'num_batch_per_bucket': self.num_batch_per_bucket, | |||||
'num_replicas': self.num_replicas | |||||
} | |||||
return states | |||||
def load_state_dict(self, states: Dict): | |||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||||
"during an unfinished iteration." | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
length = states['length'] | |||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | |||||
self.seed = states['seed'] | |||||
self.epoch = states['epoch'] | |||||
self.num_consumed_samples = states['num_consumed_samples'] | |||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
self.num_consumed_samples = 0 | |||||
if self.shuffle != states['shuffle']: | |||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||||
f"we use shuffle={states['shuffle']}") | |||||
self.shuffle = states["shuffle"] | |||||
self.old_batch_size = states['batch_size'] | |||||
self.old_num_batch_per_bucket = states['num_batch_per_bucket'] | |||||
self.old_num_replicas = states['num_replicas'] | |||||
def set_epoch(self, epoch): | |||||
self.epoch = epoch |
@@ -1,14 +1,12 @@ | |||||
from typing import Dict, List | from typing import Dict, List | ||||
import math | import math | ||||
import numpy as np | import numpy as np | ||||
from array import array | |||||
from copy import deepcopy | |||||
from fastNLP.core.log import logger | |||||
__all__ = [ | __all__ = [ | ||||
'ReproducibleIterator', | 'ReproducibleIterator', | ||||
'RandomSampler', | 'RandomSampler', | ||||
'ReproducibleBatchSampler', | |||||
're_instantiate_sampler' | 're_instantiate_sampler' | ||||
] | ] | ||||
@@ -22,7 +20,8 @@ def re_instantiate_sampler(sampler): | |||||
class ReproducibleIterator: | class ReproducibleIterator: | ||||
""" | """ | ||||
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | ||||
或者 batch_sampler; | |||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | |||||
""" | """ | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
@@ -50,6 +49,14 @@ class ReproducibleIterator: | |||||
class RandomSampler(ReproducibleIterator): | class RandomSampler(ReproducibleIterator): | ||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | ||||
""" | |||||
:param dataset: 实现了 __len__ 方法的数据容器 | |||||
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 | |||||
:param seed: 随机数种子。 | |||||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | |||||
""" | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
@@ -64,7 +71,7 @@ class RandomSampler(ReproducibleIterator): | |||||
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; | self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; | ||||
# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | ||||
self._during_iter = kwargs.get("_during_iter", False) | |||||
self.during_iter = kwargs.get("during_iter", False) | |||||
def __len__(self): | def __len__(self): | ||||
""" | """ | ||||
@@ -84,9 +91,9 @@ class RandomSampler(ReproducibleIterator): | |||||
>>> next(iter2) # 当前num_consumed_samples的数量会发生变化 | >>> next(iter2) # 当前num_consumed_samples的数量会发生变化 | ||||
""" | """ | ||||
if self._during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
self._during_iter = True | |||||
self.during_iter = True | |||||
indices = self.generate_indices() | indices = self.generate_indices() | ||||
if self.pad: | if self.pad: | ||||
@@ -110,7 +117,7 @@ class RandomSampler(ReproducibleIterator): | |||||
for index in indices: | for index in indices: | ||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
yield index | yield index | ||||
self._during_iter = False | |||||
self.during_iter = False | |||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
@@ -142,8 +149,8 @@ class RandomSampler(ReproducibleIterator): | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
# 如果 self._during_iter 是 True,那么 data_idx 一定是 0; | |||||
assert self._during_iter is False, "Cannot call load_state_dict() when it is " \ | |||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | ||||
@@ -157,6 +164,9 @@ class RandomSampler(ReproducibleIterator): | |||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
if self.shuffle != states['shuffle']: | |||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||||
f"we use shuffle={states['shuffle']}") | |||||
self.shuffle = states["shuffle"] | self.shuffle = states["shuffle"] | ||||
def set_epoch(self, epoch: int) -> None: | def set_epoch(self, epoch: int) -> None: | ||||
@@ -173,7 +183,7 @@ class RandomSampler(ReproducibleIterator): | |||||
:return: | :return: | ||||
""" | """ | ||||
assert self._during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||||
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
assert num_replicas>0 and isinstance(num_replicas, int) | assert num_replicas>0 and isinstance(num_replicas, int) | ||||
assert isinstance(rank, int) and 0<=rank<num_replicas | assert isinstance(rank, int) and 0<=rank<num_replicas | ||||
@@ -196,7 +206,7 @@ class RandomSampler(ReproducibleIterator): | |||||
@property | @property | ||||
def num_left_samples(self): | def num_left_samples(self): | ||||
""" | """ | ||||
返回当前 iteration 还有多少个 sample 结束 | |||||
返回当前 iteration 还有多少个 sample 结束。表示的是当前 rank 的还剩多少 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -205,110 +215,8 @@ class RandomSampler(ReproducibleIterator): | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | ||||
class ReproducibleBatchSampler: | |||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | |||||
self.batch_sampler = batch_sampler | |||||
self.batch_size = batch_size | |||||
self.drop_last = drop_last | |||||
self.data_idx = kwargs.get("data_idx", 0) | |||||
self._index_list = kwargs.get("_index_list", self._iterate_sampler()) | |||||
self.need_reinitialize = kwargs.get("need_reinitialize", False) | |||||
def _iterate_sampler(self): | |||||
_index_lst = [] | |||||
for idx in self.batch_sampler: | |||||
if isinstance(idx, list): | |||||
_index_lst.extend(idx) | |||||
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | |||||
else: | |||||
_index_lst.append(idx) | |||||
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||||
if len(_index_lst) > 4294967295: | |||||
# 注意 self._index_list 内存放的是全部数据的 index; | |||||
# unsigned long | |||||
_index_lst = array("L", _index_lst) | |||||
else: | |||||
# unsigned int | |||||
_index_lst = array("I", _index_lst) | |||||
return _index_lst | |||||
def __iter__(self): | |||||
if self.need_reinitialize: | |||||
self._index_list = self._iterate_sampler() | |||||
self.data_idx = 0 | |||||
else: | |||||
self.need_reinitialize = True | |||||
batch = [] | |||||
if self.data_idx: | |||||
index_list = self._index_list[self.data_idx:] | |||||
else: | |||||
index_list = self._index_list | |||||
for idx in index_list: | |||||
batch.append(idx) | |||||
self.data_idx += 1 | |||||
if len(batch) == self.batch_size: | |||||
yield batch | |||||
batch = [] | |||||
if len(batch) > 0 and not self.drop_last: | |||||
yield batch | |||||
def __len__(self) -> int: | |||||
if self.drop_last: | |||||
return len(self._index_list) // self.batch_size | |||||
else: | |||||
return (len(self._index_list) + self.batch_size - 1) // self.batch_size | |||||
def state_dict(self) -> Dict: | |||||
return {"index_list": deepcopy(self._index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||||
def load_state_dict(self, states: Dict): | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
_index_list = states["index_list"] | |||||
assert len(_index_list) == len(self._index_list), "The number of samples is different between the checkpoint " \ | |||||
"record and current dataset." | |||||
self._index_list = _index_list | |||||
self.data_idx = states["data_idx"] | |||||
self.need_reinitialize = False | |||||
def set_distributed(self): | |||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||||
def set_epoch(self, epoch): | |||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | |||||
self.batch_sampler.sampler.set_epoch(epoch) | |||||
@property | |||||
def batch_idx_in_epoch(self): | |||||
if self.drop_last: | |||||
return len(self._index_list) // self.batch_size - (len(self._index_list) - self.data_idx) // self.batch_size | |||||
else: | |||||
return (len(self._index_list) + self.batch_size - 1) // self.batch_size - \ | |||||
(len(self._index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||||
# todo | |||||
# class SortedSampler(ReproducibleIterator): | |||||
# def __init__(self, dataset, key): | |||||
# pass | |||||
# | |||||
# | |||||
# class BucketedSampler(ReproducibleIterator): | |||||
# def __init__(self, dataset, key): | |||||
# pass | |||||
if __name__ == "__main__": | |||||
sampler = RandomSampler(1) | |||||
print(vars(sampler)) | |||||
batch_sampler = ReproducibleBatchSampler(list(range(3)), 1, True) | |||||
print(vars(batch_sampler)) | |||||
@@ -9,10 +9,11 @@ __all__ = [ | |||||
] | ] | ||||
import os | import os | ||||
import re | |||||
from typing import Any, Optional, Union | from typing import Any, Optional, Union | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_BACKEND_LAUNCH | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -42,10 +43,19 @@ def get_paddle_device_id(device: Union[str, int]): | |||||
if isinstance(device, int): | if isinstance(device, int): | ||||
return device | return device | ||||
device = device.lower() | |||||
if device == "cpu": | if device == "cpu": | ||||
raise ValueError("Cannot get device id from `cpu`.") | raise ValueError("Cannot get device id from `cpu`.") | ||||
return paddle.device._convert_to_place(device).get_device_id() | |||||
match_res = re.match(r"gpu:\d+", device) | |||||
if not match_res: | |||||
raise ValueError( | |||||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" | |||||
) | |||||
device_id = device.split(':', 1)[1] | |||||
device_id = int(device_id) | |||||
return device_id | |||||
def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | ||||
data_device: Optional[str] = None) -> Any: | data_device: Optional[str] = None) -> Any: | ||||
@@ -84,6 +94,4 @@ def is_in_paddle_launch_dist(): | |||||
""" | """ | ||||
判断是否处于 launch 启动的分布式进程中 | 判断是否处于 launch 启动的分布式进程中 | ||||
""" | """ | ||||
return 'PADDLE_RANK_IN_NODE' in os.environ and \ | |||||
'FLAGS_selected_gpus' in os.environ and \ | |||||
FASTNLP_DISTRIBUTED_CHECK not in os.environ | |||||
return FASTNLP_BACKEND_LAUNCH in os.environ |
@@ -52,21 +52,33 @@ def _set_backend(): | |||||
if backend == 'paddle': | if backend == 'paddle': | ||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." | assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." | ||||
if 'CUDA_VISIBLE_DEVICES' not in os.environ and 'PADDLE_RANK_IN_NODE' not in os.environ \ | |||||
and 'FLAGS_selected_gpus' not in os.environ: | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||||
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | |||||
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 | |||||
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') | |||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 | |||||
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 | |||||
# 我们需要从中找到真正使用的设备编号 | |||||
user_visible_devices = user_visible_devices.split(",") | |||||
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) | |||||
else: | |||||
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = "" | |||||
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题 | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] | |||||
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) | |||||
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) | |||||
elif 'CUDA_VISIBLE_DEVICES' in os.environ: | elif 'CUDA_VISIBLE_DEVICES' in os.environ: | ||||
# 主进程中,用户设置了 CUDA_VISIBLE_DEVICES | |||||
# 将用户设置的 CUDA_VISIBLE_DEVICES hack 掉 | |||||
CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] | CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] | ||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | ||||
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | ||||
elif 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | |||||
# TODO 这里由于fastNLP需要hack CUDA_VISIBLE_DEVICES,因此需要相应滴修改FLAGS等paddle变量 @xsh | |||||
CUDA_VISIBLE_DEVICES = os.environ['FLAGS_selected_gpus'] | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | |||||
os.environ['FLAGS_selected_gpus'] = "0" | |||||
os.environ['FLAGS_selected_accelerators'] = "0" | |||||
else: | |||||
# 没有设置的话限制在单卡上,防止多进程时占用别的卡 | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
@@ -15,7 +15,7 @@ def remove_local_rank_in_argv(): | |||||
""" | """ | ||||
index = -1 | index = -1 | ||||
for i, v in enumerate(sys.argv): | for i, v in enumerate(sys.argv): | ||||
if v.startswith('--rank='): | |||||
if v.startswith('--local_rank='): | |||||
os.environ['LOCAL_RANK'] = v.split('=')[1] | os.environ['LOCAL_RANK'] = v.split('=')[1] | ||||
index = i | index = i | ||||
break | break | ||||
@@ -36,8 +36,14 @@ def set_env_on_import_torch(): | |||||
# TODO paddle may need set this | # TODO paddle may need set this | ||||
def set_env_on_import_paddle(): | def set_env_on_import_paddle(): | ||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | |||||
pass | |||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS | |||||
if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||||
and "PADDLE_RANK_IN_NODE" in os.environ: | |||||
# 检测到了分布式环境的环境变量 | |||||
os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] | |||||
# 如果不是由 fastnlp 启动的 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
os.environ[FASTNLP_BACKEND_LAUNCH] = "1" | |||||
# TODO jittor may need set this | # TODO jittor may need set this | ||||
def set_env_on_import_jittor(): | def set_env_on_import_jittor(): | ||||
@@ -3,4 +3,4 @@ prettytable>=0.7.2 | |||||
requests | requests | ||||
regex!=2019.12.17 | regex!=2019.12.17 | ||||
rich==11.2.0 | rich==11.2.0 | ||||
# fsspec[http]>=2021.05.0, !=2021.06.0 | |||||
packaging |
@@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1( | |||||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | ||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_checkpoint_callback_2( | def test_trainer_checkpoint_callback_2( | ||||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||||
import argparse | import argparse | ||||
import os | import os | ||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||||
import sys | import sys | ||||
path = os.path.abspath(__file__) | path = os.path.abspath(__file__) | ||||
@@ -101,7 +101,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
dist.barrier() | |||||
# dist.barrier() | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||||
import argparse | import argparse | ||||
import os | import os | ||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||||
import sys | import sys | ||||
path = os.path.abspath(__file__) | path = os.path.abspath(__file__) | ||||
@@ -0,0 +1,151 @@ | |||||
import pytest | |||||
import os | |||||
from typing import Any | |||||
from dataclasses import dataclass | |||||
from paddle.optimizer import Adam | |||||
from paddle.io import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST | |||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
@dataclass | |||||
class MNISTTrainPaddleConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 784 | |||||
batch_size: int = 32 | |||||
shuffle: bool = True | |||||
validate_every = -5 | |||||
driver: str = "paddle" | |||||
device = "gpu" | |||||
@dataclass | |||||
class MNISTTrainFleetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 784 | |||||
batch_size: int = 32 | |||||
shuffle: bool = True | |||||
validate_every = -5 | |||||
@dataclass | |||||
class TrainerParameters: | |||||
model: Any = None | |||||
optimizers: Any = None | |||||
train_dataloader: Any = None | |||||
validate_dataloaders: Any = None | |||||
input_mapping: Any = None | |||||
output_mapping: Any = None | |||||
metrics: Any = None | |||||
# @pytest.fixture(params=[0], autouse=True) | |||||
# def model_and_optimizers(request): | |||||
# """ | |||||
# 初始化单卡模式的模型和优化器 | |||||
# """ | |||||
# trainer_params = TrainerParameters() | |||||
# print(paddle.device.get_device()) | |||||
# if request.param == 0: | |||||
# trainer_params.model = PaddleNormalModel_Classification( | |||||
# num_labels=MNISTTrainPaddleConfig.num_labels, | |||||
# feature_dimension=MNISTTrainPaddleConfig.feature_dimension | |||||
# ) | |||||
# trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | |||||
# train_dataloader = DataLoader( | |||||
# dataset=PaddleDataset_MNIST("train"), | |||||
# batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
# shuffle=True | |||||
# ) | |||||
# val_dataloader = DataLoader( | |||||
# dataset=PaddleDataset_MNIST(mode="test"), | |||||
# batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
# shuffle=True | |||||
# ) | |||||
# trainer_params.train_dataloader = train_dataloader | |||||
# trainer_params.validate_dataloaders = val_dataloader | |||||
# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | |||||
# trainer_params.metrics = {"acc": Accuracy()} | |||||
# return trainer_params | |||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)]) | |||||
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | |||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.7, larger_better=True), | |||||
RichCallback(5), RecordLossCallback(loss_threshold=0.3)]]) | |||||
@magic_argv_env_context | |||||
def test_trainer_paddle( | |||||
# model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs=15, | |||||
): | |||||
trainer_params = TrainerParameters() | |||||
trainer_params.model = PaddleNormalModel_Classification( | |||||
num_labels=MNISTTrainPaddleConfig.num_labels, | |||||
feature_dimension=MNISTTrainPaddleConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = Adam(parameters=trainer_params.model.parameters(), learning_rate=0.0001) | |||||
train_dataloader = DataLoader( | |||||
dataset=PaddleDataset_MNIST("train"), | |||||
batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=PaddleDataset_MNIST(mode="test"), | |||||
batch_size=MNISTTrainPaddleConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer_params.train_dataloader = train_dataloader | |||||
trainer_params.validate_dataloaders = val_dataloader | |||||
trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every | |||||
trainer_params.metrics = {"acc": Accuracy(backend="paddle")} | |||||
if not isinstance(device, (int, str)) and len(device) > 1 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
with pytest.raises(SystemExit) as exc: | |||||
trainer = Trainer( | |||||
model=trainer_params.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=trainer_params.optimizers, | |||||
train_dataloader=trainer_params.train_dataloader, | |||||
validate_dataloaders=trainer_params.validate_dataloaders, | |||||
validate_every=trainer_params.validate_every, | |||||
input_mapping=trainer_params.input_mapping, | |||||
output_mapping=trainer_params.output_mapping, | |||||
metrics=trainer_params.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
) | |||||
assert exc.value.code == 0 | |||||
return | |||||
else: | |||||
trainer = Trainer( | |||||
model=trainer_params.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=trainer_params.optimizers, | |||||
train_dataloader=trainer_params.train_dataloader, | |||||
validate_dataloaders=trainer_params.validate_dataloaders, | |||||
validate_every=trainer_params.validate_every, | |||||
input_mapping=trainer_params.input_mapping, | |||||
output_mapping=trainer_params.output_mapping, | |||||
metrics=trainer_params.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
) | |||||
trainer.run() |
@@ -132,7 +132,6 @@ def test_trainer_torch_with_evaluator( | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | ||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)]]) | |||||
@pytest.mark.parametrize("fp16", [True, False]) | @pytest.mark.parametrize("fp16", [True, False]) | ||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -140,12 +139,11 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | |||||
fp16, | fp16, | ||||
accumulation_steps, | accumulation_steps, | ||||
n_epochs=6, | n_epochs=6, | ||||
): | ): | ||||
callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.3, larger_better=True)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -77,15 +77,14 @@ def model_and_optimizers(request): | |||||
# 测试一下 cpu; | # 测试一下 cpu; | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator( | def test_trainer_torch_without_evaluator( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | |||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
callbacks = [RecordLossCallback(loss_threshold=0.1)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -108,8 +107,7 @@ def test_trainer_torch_without_evaluator( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", 4), | |||||
@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [1, 2])]) # ("torch", 4), | |||||
@pytest.mark.parametrize("fp16", [False, True]) | @pytest.mark.parametrize("fp16", [False, True]) | ||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -117,11 +115,11 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | |||||
fp16, | fp16, | ||||
accumulation_steps, | accumulation_steps, | ||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
callbacks = [RecordLossCallback(loss_threshold=0.1)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -148,7 +146,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||||
# 测试 accumulation_steps; | # 测试 accumulation_steps; | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), ("torch", [4, 5])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [1, 2])]) | |||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator_accumulation_steps( | def test_trainer_torch_without_evaluator_accumulation_steps( | ||||
@@ -181,7 +179,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", [6, 7])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||||
@pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) | @pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_output_from_new_proc( | def test_trainer_output_from_new_proc( | ||||
@@ -244,7 +242,7 @@ def test_trainer_output_from_new_proc( | |||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||||
@pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_on_exception( | def test_trainer_on_exception( | ||||
@@ -1,12 +1,9 @@ | |||||
import pytest | import pytest | ||||
import sys | |||||
import os | import os | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.envs.set_backend import set_env | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | ||||
set_env_on_import_paddle() | set_env_on_import_paddle() | ||||
set_env("paddle") | |||||
import paddle | import paddle | ||||
import paddle.distributed as dist | import paddle.distributed as dist | ||||
from paddle.io import DataLoader | from paddle.io import DataLoader | ||||
@@ -54,6 +51,7 @@ def test_move_data_to_device(): | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_is_distributed(): | def test_is_distributed(): | ||||
print(os.getenv("CUDA_VISIBLE_DEVICES")) | print(os.getenv("CUDA_VISIBLE_DEVICES")) | ||||
@@ -64,6 +62,7 @@ def test_is_distributed(): | |||||
driver = PaddleFleetDriver( | driver = PaddleFleetDriver( | ||||
model=paddle_model, | model=paddle_model, | ||||
parallel_device=[0,1], | parallel_device=[0,1], | ||||
output_from_new_proc='all' | |||||
) | ) | ||||
driver.set_optimizers(paddle_opt) | driver.set_optimizers(paddle_opt) | ||||
# 区分launch和子进程setup的时候 | # 区分launch和子进程setup的时候 | ||||
@@ -79,6 +78,7 @@ def test_is_distributed(): | |||||
synchronize_safe_rm("log") | synchronize_safe_rm("log") | ||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_get_no_sync_context(): | def test_get_no_sync_context(): | ||||
""" | """ | ||||
@@ -105,6 +105,7 @@ def test_get_no_sync_context(): | |||||
synchronize_safe_rm("log") | synchronize_safe_rm("log") | ||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_is_global_zero(): | def test_is_global_zero(): | ||||
try: | try: | ||||
@@ -128,6 +129,8 @@ def test_is_global_zero(): | |||||
synchronize_safe_rm("log") | synchronize_safe_rm("log") | ||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_unwrap_model(): | def test_unwrap_model(): | ||||
try: | try: | ||||
@@ -204,7 +207,7 @@ def test_replace_sampler(dist_sampler, reproducible): | |||||
else: | else: | ||||
driver.setup() | driver.setup() | ||||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | ||||
driver.replace_sampler(dataloader, dist_sampler, reproducible) | |||||
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||||
finally: | finally: | ||||
synchronize_safe_rm("log") | synchronize_safe_rm("log") | ||||
dist.barrier() | dist.barrier() | ||||
@@ -243,7 +246,7 @@ class SingleMachineMultiGPUTrainingTestCase: | |||||
parallel_device=gpus, | parallel_device=gpus, | ||||
) | ) | ||||
driver.set_optimizers(paddle_opt) | driver.set_optimizers(paddle_opt) | ||||
dataloader = driver.replace_sampler(dataloader) | |||||
dataloader = driver.set_dist_repro_dataloader(dataloader, ) | |||||
driver.setup() | driver.setup() | ||||
# 检查model_device | # 检查model_device | ||||
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") | self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") | ||||
@@ -1,17 +1,11 @@ | |||||
import unittest | import unittest | ||||
import torch | import torch | ||||
from fastNLP.envs.set_env import set_env | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||||
set_env_on_import_paddle() | |||||
set_env("paddle") | |||||
from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | |||||
import paddle | import paddle | ||||
from paddle.io import Dataset, DataLoader | from paddle.io import Dataset, DataLoader | ||||
from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | |||||
class Net(paddle.nn.Layer): | class Net(paddle.nn.Layer): | ||||
def __init__(self): | def __init__(self): | ||||
super(Net, self).__init__() | super(Net, self).__init__() | ||||
@@ -9,7 +9,8 @@ import paddle | |||||
from paddle.io import DataLoader, BatchSampler | from paddle.io import DataLoader, BatchSampler | ||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | ||||
from fastNLP.core.samplers.reproducible_sampler import ReproducibleBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | ||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | ||||
from fastNLP.core import synchronize_safe_rm | from fastNLP.core import synchronize_safe_rm | ||||
@@ -164,4 +165,4 @@ class TestSingleDeviceFunction: | |||||
""" | """ | ||||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | ||||
res = self.driver.replace_sampler(dataloader, dist_sampler, reproducible) | |||||
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) |
@@ -33,11 +33,15 @@ def check_replace_sampler(driver): | |||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler | # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler | ||||
# reproducible 是 True 和 False | # reproducible 是 True 和 False | ||||
# 需要 check 返回的 sampler 和 dataloader 都不同了 | |||||
assert driver.is_distributed() is False, "This test only for non distributed sampler." | assert driver.is_distributed() is False, "This test only for non distributed sampler." | ||||
ds = SequenceDataSet(10) | ds = SequenceDataSet(10) | ||||
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) | dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) | ||||
dl1 = driver.replace_sampler(dataloader, dist_sampler='dist', reproducible=True) | |||||
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='dist', reproducible=True) | |||||
assert not (dl1.sampler is dataloader.sampler), "The sampler should not the same one." | |||||
assert not (dl1 is dataloader), "The dataloader should not the same one." | |||||
# 迭代两个 batch | # 迭代两个 batch | ||||
already_seen_idx = set() | already_seen_idx = set() | ||||
@@ -68,6 +72,23 @@ def check_replace_sampler(driver): | |||||
assert b not in already_seen_idx | assert b not in already_seen_idx | ||||
assert b in left_idxes | assert b in left_idxes | ||||
# 需要 check 替换为 unrepeatdist 的时候没有问题:(1) 不会多pad;(2)所有卡互相不重复 | |||||
ds = SequenceDataSet(11) | |||||
dataloader = DataLoader(dataset=ds, batch_size=2, collate_fn=lambda x:x, shuffle=True) | |||||
dl1 = driver.set_dist_repro_dataloader(dataloader, dist='unrepeatdist', reproducible=True) | |||||
world_size = 3 | |||||
indices = [] | |||||
for i in range(world_size): | |||||
dl1.sampler.set_distributed(num_replicas=world_size, rank=i) | |||||
for idx, batch in dl1: | |||||
indices.extend(batch) | |||||
assert len(indices)==len(ds) # 应该没有任何重复 | |||||
assert len(set(indices))==len(indices) # 应该全是不一样的indice | |||||
@@ -0,0 +1,300 @@ | |||||
import os | |||||
import tempfile | |||||
import datetime | |||||
from pathlib import Path | |||||
import logging | |||||
import re | |||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import synchronize_safe_rm | |||||
# 测试 TorchDDPDriver; | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_1(): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
path = Path.cwd() | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath, mode="w") | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
synchronize_safe_rm(filepath) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_2(): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
origin_path = Path.cwd() | |||||
try: | |||||
path = origin_path.joinpath("not_existed") | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath) | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_3(): | |||||
""" | |||||
path = None; | |||||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
handler = logger.add_file() | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log") | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
# print(f"\nrank: {driver.get_local_rank()} line, {line}\n") | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
synchronize_safe_rm(file) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_4(): | |||||
""" | |||||
测试 path 是文件夹; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
path = Path.cwd().joinpath("not_existed") | |||||
try: | |||||
handler = logger.add_file(path) | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log") | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
class TestLogger: | |||||
msg = 'some test log msg' | |||||
def test_add_file_1(self): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath) | |||||
logger.info(self.msg) | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
def test_add_file_2(self): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
origin_path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
path = origin_path.joinpath("not_existed") | |||||
path = path.joinpath('log.txt') | |||||
handler = logger.add_file(path) | |||||
logger.info(self.msg) | |||||
with open(path, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(origin_path) | |||||
logger.removeHandler(handler) | |||||
def test_add_file_3(self): | |||||
""" | |||||
测试 path 是 None; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.add_file() | |||||
logger.info(self.msg) | |||||
path = Path.cwd() | |||||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||||
for file in path.iterdir(): | |||||
if file.name.startswith(cur_datetime): | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
file.unlink() | |||||
logger.removeHandler(handler) | |||||
def test_add_file_4(self): | |||||
""" | |||||
测试 path 是文件夹; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
handler = logger.add_file(path) | |||||
logger.info(self.msg) | |||||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||||
for file in path.iterdir(): | |||||
if file.name.startswith(cur_datetime): | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
def test_stdout(self, capsys): | |||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.set_stdout(stdout="raw") | |||||
logger.info(self.msg) | |||||
logger.debug('aabbc') | |||||
captured = capsys.readouterr() | |||||
assert "some test log msg\n" == captured.out | |||||
logger.removeHandler(handler) | |||||
@@ -0,0 +1,438 @@ | |||||
from array import array | |||||
import numpy as np | |||||
import pytest | |||||
from itertools import chain | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
class TestReproducibleBatchSampler: | |||||
# TODO 拆分测试,在这里只测试一个东西 | |||||
def test_torch_dataloader_1(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
next(iter_dataloader) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||||
"sampler_type": "ReproducibleBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 27 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(dataloader) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_torch_dataloader_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(dataloader) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.append(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
def test_3(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
for idx, data in enumerate(dataloader): | |||||
if idx > 3: | |||||
break | |||||
iterator = iter(dataloader) | |||||
for each in iterator: | |||||
pass | |||||
class DatasetWithVaryLength: | |||||
def __init__(self, num_of_data=100): | |||||
self.data = np.arange(num_of_data) | |||||
def __getitem__(self, item): | |||||
return self.data[item] | |||||
def __len__(self): | |||||
return len(self.data) | |||||
class TestBucketedBatchSampler: | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71]) | |||||
def test_single_num_batch(self, shuffle, drop_last, num): | |||||
# 数量不够不报错 | |||||
for num in [2, 7, 14, 15, 70, 71]: | |||||
dataset = DatasetWithVaryLength(num_of_data=num) | |||||
before_batch_size = 7 | |||||
re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
num_batch_per_bucket=10, drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
count = len(list(iter(re_batchsampler))) | |||||
if drop_last: | |||||
assert count==num//before_batch_size, num | |||||
else: | |||||
assert count==(num+before_batch_size-1)//before_batch_size, num | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
def test_single(self, shuffle, drop_last): | |||||
before_batch_size = 7 | |||||
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4 | |||||
dataset = DatasetWithVaryLength(num_of_data=1000) | |||||
re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler.set_epoch(0) | |||||
forward_steps = 10 | |||||
iterator = iter(re_batchsampler) | |||||
already_generate_indices = set() | |||||
for _ in range(forward_steps): | |||||
batch = next(iterator) | |||||
assert max(batch) - min(batch) <= before_batch_size * num_batch_per_bucket | |||||
already_generate_indices.update(batch) | |||||
# 1. 保存状态 | |||||
state = re_batchsampler.state_dict() | |||||
# 2. 断点重训,继续训练 | |||||
re_batchsampler2 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler2.load_state_dict(state) | |||||
re_batchsampler2.set_epoch(0) | |||||
new_already_generate_indices = set() | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices)-before_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i]) | |||||
for batch in re_batchsampler2: | |||||
assert max(batch) - min(batch) <= max_diff | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
new_already_generate_indices.update(batch) | |||||
if drop_last is False: | |||||
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
re_batchsampler3 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler3.load_state_dict(state) | |||||
re_batchsampler3.set_epoch(0) | |||||
count = 0 | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices)-after_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i+after_batch_size * num_batch_per_bucket]-indices[i]) | |||||
for batch in re_batchsampler3: | |||||
assert max(batch) - min(batch) <= max_diff | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
count += 1 | |||||
if count > 5: | |||||
break | |||||
# 再 save ,不允许再上个epoch没结束继续sample | |||||
after_batch_size = 5 | |||||
with pytest.raises(RuntimeError): | |||||
state = re_batchsampler3.state_dict() | |||||
for batch in re_batchsampler3: # consume all, 这样才能save | |||||
pass | |||||
already_generate_indices = set() | |||||
count = 0 | |||||
for batch in re_batchsampler3: # 重新开始 | |||||
assert max(batch) - min(batch) <= max_diff | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
count += 1 | |||||
if count > 5: | |||||
break | |||||
state = re_batchsampler3.state_dict() | |||||
# 这里的 drop_last 为 False,需要最终是所有 sample | |||||
re_batchsampler4 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, drop_last=False, | |||||
shuffle=shuffle) | |||||
re_batchsampler4.load_state_dict(state) | |||||
re_batchsampler4.set_epoch(0) | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices) - after_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i + after_batch_size * num_batch_per_bucket] - indices[i]) | |||||
for batch in re_batchsampler4: | |||||
assert max(batch) - min(batch) <= max_diff | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
assert len(already_generate_indices) == len(dataset) | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
def test_multi(self, shuffle, drop_last, pad): | |||||
# def test_multi(self, shuffle=True, drop_last=False, pad=False): | |||||
# no shuffle | |||||
num_replica = 2 | |||||
dataset = DatasetWithVaryLength(num_of_data=1000) | |||||
batch_size = 5 | |||||
num_batch_per_bucket = 10 | |||||
lengths = [] | |||||
rank0_already_seen_indexes = None | |||||
max_diff = num_batch_per_bucket * batch_size * num_replica | |||||
for rank in range(num_replica): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||||
num_batch_per_bucket = num_batch_per_bucket, | |||||
shuffle = shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||||
lengths.append(len(sampler)) | |||||
already_seen_indexes = set() | |||||
repeat_count = 0 | |||||
for batch in sampler: | |||||
assert max_diff>=max(batch)-min(batch) | |||||
for b in batch: | |||||
repeat_count += int(b in already_seen_indexes) | |||||
if rank0_already_seen_indexes: # 不能交叉出现 | |||||
assert b not in rank0_already_seen_indexes | |||||
already_seen_indexes.update(batch) | |||||
if rank0_already_seen_indexes is None: | |||||
rank0_already_seen_indexes = already_seen_indexes | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count<=1 | |||||
else: | |||||
assert repeat_count==0 | |||||
assert len(set(lengths))==1, lengths # 每个进程的batch数量一致 | |||||
# 多进程的保存 | |||||
already_seen_indexes = set() | |||||
for rank in range(num_replica): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||||
num_batch_per_bucket = num_batch_per_bucket, | |||||
shuffle = shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||||
lengths.append(len(sampler)) | |||||
count = 0 | |||||
for batch in sampler: | |||||
assert max_diff>=max(batch)-min(batch) | |||||
already_seen_indexes.update(batch) | |||||
if count>5: | |||||
break | |||||
count += 1 | |||||
state = sampler.state_dict() | |||||
# 切换成单机 | |||||
new_batch_size = 6 | |||||
num_batch_per_bucket = 3 | |||||
new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
new_sampler.load_state_dict(state) | |||||
repeat_count = 0 | |||||
new_already_seen_indexes = set(list(already_seen_indexes)) | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_seen_indexes)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices)-new_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i+new_batch_size * num_batch_per_bucket]-indices[i]) | |||||
for batch in new_sampler: | |||||
assert max_diff>=max(batch)-min(batch) | |||||
for b in batch: | |||||
repeat_count += int(b in new_already_seen_indexes) | |||||
new_already_seen_indexes.update(batch) | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count <= 1 | |||||
else: | |||||
assert repeat_count == 0 | |||||
if drop_last is False: # 如果没有drop应该相等 | |||||
assert len(new_already_seen_indexes)==len(dataset) | |||||
# 测试替换卡的数量。 | |||||
num_replica = 3 | |||||
new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.load_state_dict(state) | |||||
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad) | |||||
repeat_count = 0 | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_seen_indexes)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices) - new_batch_size * num_batch_per_bucket*num_replica): | |||||
max_diff = max(max_diff, indices[i + new_batch_size * num_batch_per_bucket*num_replica] - indices[i]) | |||||
for batch in new_sampler: | |||||
assert max_diff>=max(batch)-min(batch) | |||||
for b in batch: | |||||
repeat_count += int(b in already_seen_indexes) | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count <= 1 | |||||
else: | |||||
assert repeat_count == 0 | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
batch_size = 6 | |||||
if num_replica*batch_size > num_samples: | |||||
return | |||||
num_batch_per_bucket = 10 | |||||
samplers = [] | |||||
lengths = [] | |||||
for i in range(num_replica): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
sampler.set_epoch(0) | |||||
samplers.append(sampler) | |||||
lengths.append(len(list(iter(sampler)))) | |||||
assert len(set(lengths))==1 | |||||
bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||||
for bs in zip(*samplers): | |||||
diff = max(chain(*bs)) - min(chain(*bs)) | |||||
assert diff <= bucket_diff |
@@ -6,7 +6,7 @@ import numpy as np | |||||
from functools import partial | from functools import partial | ||||
from array import array | from array import array | ||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, ReproducibleBatchSampler | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -361,148 +361,3 @@ class TestRandomSampler(unittest.TestCase): | |||||
class TestReproducibleBatchSampler: | |||||
def test_torch_dataloader_1(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
next(iter_dataloader) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||||
"sampler_type": "ReproducibleBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 27 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(dataloader) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_torch_dataloader_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(dataloader) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.append(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
def test_3(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader, RandomSampler, BatchSampler | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
for idx, data in enumerate(dataloader): | |||||
if idx > 3: | |||||
break | |||||
iterator = iter(dataloader) | |||||
for each in iterator: | |||||
pass |
@@ -10,13 +10,6 @@ from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
class SamplerTest(unittest.TestCase): | class SamplerTest(unittest.TestCase): | ||||
def test_sequentialsampler(self): | def test_sequentialsampler(self): | ||||
@@ -15,40 +15,13 @@ class PaddleNormalDataset(Dataset): | |||||
return self._data[item] | return self._data[item] | ||||
class PaddleRandomDataset(Dataset): | |||||
def __init__(self, num_of_data=1000, features=64, labels=10): | |||||
self.num_of_data = num_of_data | |||||
self.x = [ | |||||
paddle.rand((features,)) | |||||
for i in range(num_of_data) | |||||
] | |||||
self.y = [ | |||||
paddle.rand((labels,)) | |||||
for i in range(num_of_data) | |||||
] | |||||
class PaddleRandomMaxDataset(Dataset): | |||||
def __init__(self, num_samples, num_features): | |||||
self.x = paddle.randn((num_samples, num_features)) | |||||
self.y = self.x.argmax(axis=-1) | |||||
def __len__(self): | def __len__(self): | ||||
return self.num_of_data | |||||
return len(self.x) | |||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return {"x": self.x[item], "y": self.y[item]} | return {"x": self.x[item], "y": self.y[item]} | ||||
class PaddleDataset_MNIST(Dataset): | |||||
def __init__(self, mode="train"): | |||||
self.dataset = [ | |||||
( | |||||
np.array(img).astype('float32').reshape(-1), | |||||
label | |||||
) for img, label in paddle.vision.datasets.MNIST(mode=mode) | |||||
] | |||||
def __getitem__(self, idx): | |||||
return {"x": self.dataset[idx][0], "y": self.dataset[idx][1]} | |||||
def __len__(self): | |||||
return len(self.dataset) | |||||
@@ -1,12 +1,12 @@ | |||||
import paddle | import paddle | ||||
import paddle.nn as nn | import paddle.nn as nn | ||||
class PaddleNormalModel_Classification(paddle.nn.Layer): | |||||
class PaddleNormalModel_Classification_1(paddle.nn.Layer): | |||||
""" | """ | ||||
基础的paddle分类模型 | 基础的paddle分类模型 | ||||
""" | """ | ||||
def __init__(self, num_labels, feature_dimension): | def __init__(self, num_labels, feature_dimension): | ||||
super(PaddleNormalModel_Classification, self).__init__() | |||||
super(PaddleNormalModel_Classification_1, self).__init__() | |||||
self.num_labels = num_labels | self.num_labels = num_labels | ||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | ||||
@@ -30,3 +30,26 @@ class PaddleNormalModel_Classification(paddle.nn.Layer): | |||||
x = self(x) | x = self(x) | ||||
return {"pred": x, "target": y.reshape((-1,))} | return {"pred": x, "target": y.reshape((-1,))} | ||||
class PaddleNormalModel_Classification_2(paddle.nn.Layer): | |||||
""" | |||||
基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(PaddleNormalModel_Classification_2, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def forward(self, x, y): | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
loss = self.loss_fn(x, y) | |||||
return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))} |