@@ -10,7 +10,8 @@ __all__ = [ | |||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback', | 'RichCallback', | ||||
"LRSchedCallback", | "LRSchedCallback", | ||||
'LoadBestModelCallback' | |||||
'LoadBestModelCallback', | |||||
"EarlyStopCallback" | |||||
] | ] | ||||
@@ -21,4 +22,5 @@ from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallb | |||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | ||||
from .lr_scheduler_callback import LRSchedCallback | from .lr_scheduler_callback import LRSchedCallback | ||||
from .load_best_model_callback import LoadBestModelCallback | from .load_best_model_callback import LoadBestModelCallback | ||||
from .early_stop_callback import EarlyStopCallback | |||||
@@ -1,11 +1,15 @@ | |||||
from typing import Union, Callable, Dict, Optional | |||||
from typing import Union, Callable, Dict, Optional, Any | |||||
from abc import ABC | |||||
__all__ = [ | __all__ = [ | ||||
'Callback', | 'Callback', | ||||
] | ] | ||||
from .callback_events import Events, EventsList, Filter | from .callback_events import Events, EventsList, Filter | ||||
from .utils import _get_monitor_value | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.utils import apply_to_collection | |||||
class Callback: | class Callback: | ||||
@@ -150,4 +154,82 @@ class _CallbackWrapper(Callback): | |||||
return self.fn.__name__ | return self.fn.__name__ | ||||
class CanItemDataType(ABC): | |||||
""" | |||||
检测可以进行传输的对象。 | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is CanItemDataType: | |||||
item = getattr(subclass, 'item', None) | |||||
return callable(item) | |||||
return NotImplemented | |||||
class HasMonitorCallback(Callback): | |||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
self.set_monitor(monitor, larger_better) | |||||
self.must_have_moinitor = must_have_monitor | |||||
def set_monitor(self, monitor, larger_better): | |||||
self.monitor = str(monitor) if monitor is not None else None | |||||
self.larger_better = bool(larger_better) | |||||
if larger_better: | |||||
self.monitor_value = float('-inf') | |||||
else: | |||||
self.monitor_value = float('inf') | |||||
self._real_monitor = self.monitor | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
""" | |||||
如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 | |||||
同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 | |||||
:param trainer: | |||||
:param driver: | |||||
:return: | |||||
""" | |||||
if self.monitor is None and trainer.monitor is not None: | |||||
self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) | |||||
if self.must_have_moinitor and self.monitor is None: | |||||
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " | |||||
f"You can set it in the initialization or through Trainer.") | |||||
def get_monitor_value(self, results:Dict)->float: | |||||
""" | |||||
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 | |||||
:param results: | |||||
:return: | |||||
""" | |||||
if len(results)==0: | |||||
return 0 | |||||
# 保证所有的 tensor 都被转换为了 python 特定的类型 | |||||
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) | |||||
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if self._real_monitor != use_monitor: # 发生了替换需要打印 | |||||
logger.warning( | |||||
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||||
self._real_monitor = use_monitor | |||||
return monitor_value | |||||
def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): | |||||
""" | |||||
检测 monitor_value 是否是更好的 | |||||
:param monitor_value: | |||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||||
:return: | |||||
""" | |||||
better = False | |||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
better = True | |||||
if keep_if_better: | |||||
self.monitor_value = monitor_value | |||||
return better |
@@ -5,12 +5,12 @@ __all__ = [ | |||||
import os | import os | ||||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | ||||
from pathlib import Path | from pathlib import Path | ||||
from abc import ABC | |||||
import sys | import sys | ||||
from copy import deepcopy | |||||
import fastNLP | import fastNLP | ||||
from .callback import Callback, Filter | |||||
from .callback import Callback, HasMonitorCallback | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | from fastNLP.core.callbacks.utils import _get_monitor_value | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
@@ -18,22 +18,7 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
class CanItemDataType(ABC): | |||||
""" | |||||
检测可以进行传输的对象。 | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is CanItemDataType: | |||||
item = getattr(subclass, 'item', None) | |||||
return callable(item) | |||||
return NotImplemented | |||||
class CheckpointCallback(Callback): | |||||
class CheckpointCallback(HasMonitorCallback): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
monitor, | monitor, | ||||
@@ -48,12 +33,8 @@ class CheckpointCallback(Callback): | |||||
model_save_fn: Optional[Callable] = None, | model_save_fn: Optional[Callable] = None, | ||||
**kwargs, | **kwargs, | ||||
): | ): | ||||
if monitor is None and save_topk is not None: | |||||
raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") | |||||
if monitor is not None and not isinstance(monitor, str): | |||||
raise ValueError("Parameter `monitor` should be of 'str' type.") | |||||
super().__init__(monitor=monitor, larger_better=larger_better, | |||||
must_have_monitor=save_topk is not None) | |||||
if save_folder is None: | if save_folder is None: | ||||
logger.warning( | logger.warning( | ||||
"Parameter `path` is None, and we will use the current work directory to find and load your model.") | "Parameter `path` is None, and we will use the current work directory to find and load your model.") | ||||
@@ -91,13 +72,12 @@ class CheckpointCallback(Callback): | |||||
"`BaseException` type.") | "`BaseException` type.") | ||||
else: | else: | ||||
save_on_exception = [] | save_on_exception = [] | ||||
self.monitor = monitor | |||||
self.save_folder = Path(save_folder) | self.save_folder = Path(save_folder) | ||||
self.save_every_n_epochs = save_every_n_epochs | self.save_every_n_epochs = save_every_n_epochs | ||||
self.save_every_n_batches = save_every_n_batches | self.save_every_n_batches = save_every_n_batches | ||||
self.save_last = save_last | self.save_last = save_last | ||||
self.save_topk = save_topk | self.save_topk = save_topk | ||||
self.larger_better = larger_better | |||||
self.only_state_dict = only_state_dict | self.only_state_dict = only_state_dict | ||||
self.model_save_fn = model_save_fn | self.model_save_fn = model_save_fn | ||||
self.save_on_exception = save_on_exception | self.save_on_exception = save_on_exception | ||||
@@ -107,20 +87,22 @@ class CheckpointCallback(Callback): | |||||
self._topk_model = {} | self._topk_model = {} | ||||
self._topn = 0 # 表示目前已经保存了几个最好的模型; | self._topn = 0 # 表示目前已经保存了几个最好的模型; | ||||
# 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 | |||||
# key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 | |||||
# 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; | |||||
# 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; | |||||
self._real_monitor = self.monitor | |||||
# 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, | ||||
# FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; | ||||
self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | ||||
# 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; | ||||
synchronize_mkdir(self.timestamp_path) | synchronize_mkdir(self.timestamp_path) | ||||
def on_validate_end(self, trainer, validate_res): | |||||
self._save_topk(trainer, validate_res) | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
if self.save_topk is not None: | |||||
super().on_after_trainer_initialized(trainer, driver) | |||||
if self.save_topk is not None and trainer.evaluator is None: | |||||
logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") | |||||
def on_validate_end(self, trainer, results): | |||||
if len(results) == 0: | |||||
return | |||||
self._save_topk(trainer, results) | |||||
def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): | ||||
if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: | ||||
@@ -143,7 +125,7 @@ class CheckpointCallback(Callback): | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
# 主要核对一下 monitor 是否存在。 | # 主要核对一下 monitor 是否存在。 | ||||
self._get_validate_metric(sanity_check_res) | |||||
self.get_monitor_value(results=sanity_check_res) | |||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | """ | ||||
@@ -154,8 +136,7 @@ class CheckpointCallback(Callback): | |||||
states = {} | states = {} | ||||
states['timestamp_path'] = str(self.timestamp_path.absolute()) | states['timestamp_path'] = str(self.timestamp_path.absolute()) | ||||
states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, | |||||
function=lambda x:x.item()) | |||||
states['_topk_model'] = deepcopy(self._topk_model) | |||||
states['save_topk'] = 0 if self.save_topk is None else self.save_topk | states['save_topk'] = 0 if self.save_topk is None else self.save_topk | ||||
states['_real_monitor'] = self._real_monitor | states['_real_monitor'] = self._real_monitor | ||||
return states | return states | ||||
@@ -176,30 +157,30 @@ class CheckpointCallback(Callback): | |||||
self._topk_model.update(self._topk_model) | self._topk_model.update(self._topk_model) | ||||
self._real_monitor = states["real_monitor"] | self._real_monitor = states["real_monitor"] | ||||
def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): | |||||
def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): | |||||
""" | """ | ||||
根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 | ||||
:param trainer: | :param trainer: | ||||
:param validate_res: | |||||
:param results: | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.save_topk is not None: | if self.save_topk is not None: | ||||
_metric_value = self._get_validate_metric(validate_res) | |||||
monitor_value = self.get_monitor_value(results=results) | |||||
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ | ||||
f"-{self._real_monitor}_{_metric_value}" | |||||
f"-{self._real_monitor}_{monitor_value}" | |||||
_should_save = False | _should_save = False | ||||
if self._topn < self.save_topk: | if self._topn < self.save_topk: | ||||
self._topk_model[folder_name] = _metric_value | |||||
self._topk_model[folder_name] = monitor_value | |||||
self._topn += 1 | self._topn += 1 | ||||
_should_save = True | _should_save = True | ||||
else: | else: | ||||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | _least_valuable_model = (min if self.larger_better else max)(self._topk_model, | ||||
key=lambda x: self._topk_model[x]) | key=lambda x: self._topk_model[x]) | ||||
if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ | |||||
(self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): | |||||
self._topk_model[folder_name] = _metric_value | |||||
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ | |||||
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): | |||||
self._topk_model[folder_name] = monitor_value | |||||
_should_save = True | _should_save = True | ||||
self._topk_model.pop(_least_valuable_model) | self._topk_model.pop(_least_valuable_model) | ||||
synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) | ||||
@@ -235,7 +216,11 @@ class CheckpointCallback(Callback): | |||||
:return: | :return: | ||||
""" | """ | ||||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | ||||
if self._real_monitor != use_monitor: | |||||
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||||
self._real_monitor = use_monitor | self._real_monitor = use_monitor | ||||
return value | return value | ||||
@property | @property | ||||
@@ -263,7 +248,7 @@ class ModelCheckpointCallback(CheckpointCallback): | |||||
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -310,7 +295,7 @@ class TrainerCheckpointCallback(CheckpointCallback): | |||||
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 | ||||
:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
的那个作为 monitor 。 | |||||
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 | |||||
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param save_every_n_epochs: 多少个 epoch 保存一次。 | :param save_every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -0,0 +1,61 @@ | |||||
__all__ = [ | |||||
'EarlyStopCallback' | |||||
] | |||||
from typing import Dict | |||||
from .callback import HasMonitorCallback | |||||
from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class EarlyStopCallback(HasMonitorCallback): | |||||
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): | |||||
""" | |||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | |||||
:param patience: 多少次 validate 不没有提升就停止。 | |||||
""" | |||||
super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | |||||
self.wait = 0 | |||||
self.patience = patience | |||||
def on_validate_end(self, trainer, results): | |||||
if len(results)==0: | |||||
return | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
self.wait = 0 | |||||
else: | |||||
self.wait += 1 | |||||
def on_fetch_data_begin(self, trainer): | |||||
# 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
if self.wait >= self.patience: | |||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | |||||
f"metric `{self._real_monitor}`") | |||||
def on_train_epoch_begin(self, trainer): | |||||
# 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 | |||||
if self.wait >= self.patience: | |||||
raise EarlyStopException(f"After {self.wait} validations, no improvement for " | |||||
f"metric `{self._real_monitor}`(best value: {self.monitor_value})") | |||||
def on_save_checkpoint(self, trainer) -> Dict: | |||||
states = { | |||||
'patience': self.patience, | |||||
'wait': self.wait, | |||||
'monitor': self.monitor, | |||||
'monitor_value': self.monitor_value | |||||
} | |||||
return states | |||||
def on_load_checkpoint(self, trainer, states): | |||||
self.patience = states['patience'] | |||||
self.wait = states['wait'] | |||||
self.monitor = states['monitor'] | |||||
self.monitor_value = float(states['monitor_value']) | |||||
def callback_name(self): | |||||
return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' | |||||
@@ -4,8 +4,7 @@ __all__ = [ | |||||
import os | import os | ||||
from typing import Optional, Callable | from typing import Optional, Callable | ||||
from .callback import Callback | |||||
from .utils import _get_monitor_value | |||||
from .callback import HasMonitorCallback | |||||
from io import BytesIO | from io import BytesIO | ||||
import shutil | import shutil | ||||
@@ -14,15 +13,15 @@ from fastNLP.core.log import logger | |||||
from fastNLP.envs import all_rank_call | from fastNLP.envs import all_rank_call | ||||
class LoadBestModelCallback(Callback): | |||||
def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True, | |||||
class LoadBestModelCallback(HasMonitorCallback): | |||||
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, | |||||
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, | ||||
model_load_fn:Optional[Callable] = None, | model_load_fn:Optional[Callable] = None, | ||||
delete_after_train:bool = True): | delete_after_train:bool = True): | ||||
""" | """ | ||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | ||||
:param str monitor: 监控的 metric 值。 | |||||
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -33,6 +32,7 @@ class LoadBestModelCallback(Callback): | |||||
请在函数内完成对模型的加载。 | 请在函数内完成对模型的加载。 | ||||
:param delete_after_train: 在训练结束后是否删掉模型。 | :param delete_after_train: 在训练结束后是否删掉模型。 | ||||
""" | """ | ||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) | |||||
if model_load_fn is not None: | if model_load_fn is not None: | ||||
assert callable(model_load_fn), "`model_load_fn` must be a callable object." | assert callable(model_load_fn), "`model_load_fn` must be a callable object." | ||||
assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." | assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." | ||||
@@ -56,15 +56,11 @@ class LoadBestModelCallback(Callback): | |||||
self.real_save_folder = None | self.real_save_folder = None | ||||
self.buffer = BytesIO() | self.buffer = BytesIO() | ||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
self.save_folder = save_folder | self.save_folder = save_folder | ||||
self.only_state_dict = only_state_dict | self.only_state_dict = only_state_dict | ||||
self.model_save_fn = model_save_fn | self.model_save_fn = model_save_fn | ||||
self.model_load_fn = model_load_fn | self.model_load_fn = model_load_fn | ||||
self.delete_after_after = delete_after_train | self.delete_after_after = delete_after_train | ||||
self._real_monitor = None | |||||
self.monitor_value = float('-inf') if larger_better else float('inf') | |||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | ||||
@@ -76,13 +72,16 @@ class LoadBestModelCallback(Callback): | |||||
raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | ||||
f"save best model when launch using script.") | f"save best model when launch using script.") | ||||
super().on_after_trainer_initialized(trainer, driver) | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||||
self.get_monitor_value(sanity_check_res) | |||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (monitor_value < self.monitor_value and self.larger_better is False) or \ | |||||
(monitor_value > self.monitor_value and self.larger_better): | |||||
self.monitor_value = monitor_value | |||||
if len(results)==0: | |||||
return | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if self.real_save_folder: | if self.real_save_folder: | ||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
model_save_fn=self.model_save_fn) | model_save_fn=self.model_save_fn) | ||||
@@ -8,7 +8,7 @@ __all__ = [ | |||||
'RichCallback' | 'RichCallback' | ||||
] | ] | ||||
from .callback import Callback | |||||
from .callback import HasMonitorCallback | |||||
from fastNLP.core.callbacks.utils import _get_monitor_value | from fastNLP.core.callbacks.utils import _get_monitor_value | ||||
from fastNLP.core.utils import f_rich_progress | from fastNLP.core.utils import f_rich_progress | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -28,15 +28,13 @@ def choose_progress_callback(progress_bar:str): | |||||
return None | return None | ||||
class ProgressCallback(Callback): | |||||
class ProgressCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
f_rich_progress.stop() | f_rich_progress.stop() | ||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: | if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: | ||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=sanity_check_res) | |||||
self.get_monitor_value(sanity_check_res) | |||||
class RichCallback(ProgressCallback): | class RichCallback(ProgressCallback): | ||||
@@ -46,28 +44,22 @@ class RichCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
super().__init__() | |||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||||
self.print_every = print_every | self.print_every = print_every | ||||
self.progress_bar = f_rich_progress | self.progress_bar = f_rich_progress | ||||
self.task2id = {} | self.task2id = {} | ||||
self.loss = 0 | self.loss = 0 | ||||
self.loss_round_ndigit = loss_round_ndigit | self.loss_round_ndigit = loss_round_ndigit | ||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
if larger_better: | |||||
self.monitor_value = float('-inf') | |||||
else: | |||||
self.monitor_value = float('inf') | |||||
self._real_monitor = monitor | |||||
self.format_json = format_json | self.format_json = format_json | ||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
if not self.progress_bar.disable: | if not self.progress_bar.disable: | ||||
self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) | self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) | ||||
super(RichCallback, self).on_after_trainer_initialized(trainer, driver) | |||||
def on_train_begin(self, trainer): | def on_train_begin(self, trainer): | ||||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | ||||
@@ -109,16 +101,12 @@ class RichCallback(ProgressCallback): | |||||
text_style = '' | text_style = '' | ||||
characters = '-' | characters = '-' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
rule_style = 'spring_green3' | rule_style = 'spring_green3' | ||||
text_style = '[bold]' | text_style = '[bold]' | ||||
characters = '+' | characters = '+' | ||||
self.monitor_value = monitor_value | |||||
self.progress_bar.print() | self.progress_bar.print() | ||||
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | ||||
f"Batch:{trainer.batch_idx_in_epoch}", | f"Batch:{trainer.batch_idx_in_epoch}", | ||||
@@ -151,18 +139,12 @@ class RawTextCallback(ProgressCallback): | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
super().__init__() | |||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||||
self.print_every = print_every | self.print_every = print_every | ||||
self.task2id = {} | self.task2id = {} | ||||
self.loss = 0 | self.loss = 0 | ||||
self.loss_round_ndigit = loss_round_ndigit | self.loss_round_ndigit = loss_round_ndigit | ||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
if larger_better: | |||||
self.monitor_value = float('-inf') | |||||
else: | |||||
self.monitor_value = float('inf') | |||||
self._real_monitor = monitor | |||||
self.set_monitor(monitor, larger_better) | |||||
self.format_json = format_json | self.format_json = format_json | ||||
self.num_signs = 10 | self.num_signs = 10 | ||||
@@ -189,14 +171,10 @@ class RawTextCallback(ProgressCallback): | |||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
text = '' | text = '' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, | |||||
real_monitor=self._real_monitor, | |||||
res=results) | |||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
self.monitor_value = monitor_value | |||||
if len(text) == 0: | if len(text) == 0: | ||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | text = '-'*self.num_signs + base_text + '-'*self.num_signs | ||||
@@ -19,23 +19,31 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( | |||||
if monitor in res: | if monitor in res: | ||||
return monitor, res[monitor] | return monitor, res[monitor] | ||||
if real_monitor in res: | |||||
return real_monitor, res[real_monitor] | |||||
pairs = [] | pairs = [] | ||||
for idx, (key, value) in enumerate(res.items()): | for idx, (key, value) in enumerate(res.items()): | ||||
match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor)) | |||||
pairs.append((key, value, match.size, idx)) | |||||
match_size = _match_length(monitor, key) | |||||
pairs.append((key, value, match_size, idx)) | |||||
pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) | pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) | ||||
key, value, match_size = pairs[0][:3] | key, value, match_size = pairs[0][:3] | ||||
if real_monitor is not None and real_monitor in res and real_monitor != key: | |||||
# 如果 real_monitor 比新找的更长就继续用之前的。 | |||||
match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor)) | |||||
if match.size > match_size: | |||||
return real_monitor, res[real_monitor] | |||||
return key, value | |||||
logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||||
f"we use the `{key}` as the monitor.") | |||||
real_monitor = key | |||||
return real_monitor, value | |||||
def _match_length(a:str, b:str)->int: | |||||
""" | |||||
需要把长度短的放在前面 | |||||
:param a: | |||||
:param b: | |||||
:return: | |||||
""" | |||||
short = a if len(a) < len(b) else b | |||||
long = a if len(a)>=len(b) else b | |||||
match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) | |||||
return match.size | |||||
@@ -219,6 +219,7 @@ class Evaluator: | |||||
def remove_progress_bar(self, dataloader_name): | def remove_progress_bar(self, dataloader_name): | ||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | ||||
f_rich_progress.destroy_task(self._rich_task_id) | f_rich_progress.destroy_task(self._rich_task_id) | ||||
f_rich_progress.refresh() # 使得最终的bar可以消失 | |||||
delattr(self, '_rich_task_id') | delattr(self, '_rich_task_id') | ||||
elif self.progress_bar == 'raw': | elif self.progress_bar == 'raw': | ||||
desc = 'Evaluation ends' | desc = 'Evaluation ends' | ||||
@@ -229,6 +230,7 @@ class Evaluator: | |||||
def finally_progress_bar(self): | def finally_progress_bar(self): | ||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | ||||
f_rich_progress.destroy_task(self._rich_task_id) | f_rich_progress.destroy_task(self._rich_task_id) | ||||
f_rich_progress.refresh() | |||||
delattr(self, '_rich_task_id') | delattr(self, '_rich_task_id') | ||||
@property | @property | ||||
@@ -23,9 +23,9 @@ from fastNLP.core.drivers import Driver | |||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | from fastNLP.envs import FASTNLP_MODEL_FILENAME | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class Trainer(TrainerEventTrigger): | class Trainer(TrainerEventTrigger): | ||||
@@ -50,6 +50,8 @@ class Trainer(TrainerEventTrigger): | |||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
accumulation_steps: int = 1, | accumulation_steps: int = 1, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
monitor: str = None, | |||||
larger_better: bool = True, | |||||
marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
@@ -103,6 +105,10 @@ class Trainer(TrainerEventTrigger): | |||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | ||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | ||||
:param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||||
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | |||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | ||||
:param kwargs: 一些其它的可能需要的参数; | :param kwargs: 一些其它的可能需要的参数; | ||||
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | ||||
@@ -211,6 +217,8 @@ class Trainer(TrainerEventTrigger): | |||||
self.evaluator = None | self.evaluator = None | ||||
self.epoch_validate = lambda *args, **kwargs: ... | self.epoch_validate = lambda *args, **kwargs: ... | ||||
self.step_validate = lambda *args, **kwargs: ... | self.step_validate = lambda *args, **kwargs: ... | ||||
self.monitor = monitor | |||||
self.larger_better = larger_better | |||||
if metrics is not None and validate_dataloaders is not None: | if metrics is not None and validate_dataloaders is not None: | ||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | ||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
@@ -240,6 +248,7 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
# validate_every > 0 | # validate_every > 0 | ||||
self._step_validate_filter = Filter(every=validate_every) | self._step_validate_filter = Filter(every=validate_every) | ||||
self.metrics = metrics | self.metrics = metrics | ||||
self.validate_every = validate_every | self.validate_every = validate_every | ||||
@@ -321,6 +330,10 @@ class Trainer(TrainerEventTrigger): | |||||
self.driver.barrier() | self.driver.barrier() | ||||
self.on_train_end() | self.on_train_end() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
except EarlyStopException as e: | |||||
logger.info(f"Catch early stop exception: {e.msg}.") | |||||
self.on_exception(e) | |||||
except KeyboardInterrupt as e: | except KeyboardInterrupt as e: | ||||
self.driver.on_exception() | self.driver.on_exception() | ||||
self.on_exception(e) | self.on_exception(e) | ||||
@@ -610,7 +623,7 @@ class Trainer(TrainerEventTrigger): | |||||
r""" | r""" | ||||
用于断点重训的加载函数; | 用于断点重训的加载函数; | ||||
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | ||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; | |||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; | |||||
注意我们目前不支持单卡到多卡的断点重训; | 注意我们目前不支持单卡到多卡的断点重训; | ||||
@@ -49,13 +49,13 @@ class Driver(ABC): | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | ||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | ||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | ||||
注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | ||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | ||||
可以可以加载。 | 可以可以加载。 | ||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | ||||
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | ||||
""" | """ | ||||
if dist is None and reproducible is False: | if dist is None and reproducible is False: | ||||
@@ -3,7 +3,7 @@ from typing import Optional, Union | |||||
from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.samplers import ReproducibleIterator | |||||
from fastNLP.core.samplers import ReproducibleSampler | |||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor | import jittor | ||||
@@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver): | |||||
def test_step(self, batch): | def test_step(self, batch): | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
pass | pass | ||||
@@ -3,7 +3,7 @@ from typing import Dict, Union | |||||
from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor | import jittor | ||||
@@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver): | |||||
def is_distributed(self): | def is_distributed(self): | ||||
return False | return False | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# reproducible 的相关功能暂时没有实现 | # reproducible 的相关功能暂时没有实现 | ||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
dataloader.batch_sampler = dist_sample | dataloader.batch_sampler = dist_sample | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
dataloader.batch_sampler.sampler = dist | dataloader.batch_sampler.sampler = dist | ||||
if reproducible: | if reproducible: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
return dataloader | return dataloader | ||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||||
return dataloader | return dataloader | ||||
else: | else: | ||||
# TODO | # TODO | ||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | batch_sampler=dataloader.batch_sampler, | ||||
batch_size=dataloader.batch_sampler.batch_size, | batch_size=dataloader.batch_sampler.batch_size, | ||||
drop_last=dataloader.drop_last | drop_last=dataloader.drop_last | ||||
@@ -19,7 +19,7 @@ from fastNLP.core.utils import ( | |||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
is_in_paddle_dist, | is_in_paddle_dist, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -312,13 +312,13 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def test_step(self, batch): | def test_step(self, batch): | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler = dist | dataloader.batch_sampler.sampler = dist | ||||
return dataloader | return dataloader | ||||
@@ -340,7 +340,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
@@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return dataloader | return dataloader | ||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
sampler = UnrepeatedSampler( | |||||
sampler = UnrepeatedRandomSampler( | |||||
dataset=dataloader.dataset, | dataset=dataloader.dataset, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | seed=int(os.environ.get("FASTNLP_SEED", 0)) | ||||
@@ -10,7 +10,7 @@ from fastNLP.core.utils import ( | |||||
get_paddle_device_id, | get_paddle_device_id, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -139,7 +139,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
""" | """ | ||||
return paddle_move_data_to_device(batch, "gpu:0") | return paddle_move_data_to_device(batch, "gpu:0") | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# 暂时不支持IteratorDataset | # 暂时不支持IteratorDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
@@ -147,12 +147,12 @@ class PaddleSingleDriver(PaddleDriver): | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dataloader.batch_sampler = dist | dataloader.batch_sampler = dist | ||||
return dataloader | return dataloader | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler = dist | dataloader.batch_sampler.sampler = dist | ||||
return dataloader | return dataloader | ||||
if reproducible: | if reproducible: | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
return dataloader | return dataloader | ||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | ||||
return dataloader | return dataloader | ||||
@@ -28,11 +28,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||||
) | ) | ||||
from fastNLP.core.drivers.utils import distributed_open_proc | from fastNLP.core.drivers.utils import distributed_open_proc | ||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | from fastNLP.core.utils import auto_param_call, check_user_specific_params | ||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | |||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | ||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
class TorchDDPDriver(TorchDriver): | class TorchDDPDriver(TorchDriver): | ||||
@@ -446,13 +446,23 @@ class TorchDDPDriver(TorchDriver): | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | ||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | ||||
@@ -465,7 +475,7 @@ class TorchDDPDriver(TorchDriver): | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dist = re_instantiate_sampler(dist) | dist = re_instantiate_sampler(dist) | ||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist = re_instantiate_sampler(dist) | dist = re_instantiate_sampler(dist) | ||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
return dataloader | return dataloader | ||||
@@ -481,7 +491,7 @@ class TorchDDPDriver(TorchDriver): | |||||
pad=True | pad=True | ||||
) | ) | ||||
return replace_batch_sampler(dataloader, batch_sampler) | return replace_batch_sampler(dataloader, batch_sampler) | ||||
elif isinstance(args.sampler, ReproducibleIterator): | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | sampler = re_instantiate_sampler(args.sampler) | ||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
@@ -503,14 +513,15 @@ class TorchDDPDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
# todo @yh,补充 unrepeatdist 相关内容; | |||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
# todo 判断 batch_sampler; | |||||
sampler = UnrepeatedSampler( | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||||
sampler = UnrepeatedSequentialSampler( | |||||
dataset=args.dataset | |||||
) | |||||
else: | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
@@ -588,7 +599,7 @@ class TorchDDPDriver(TorchDriver): | |||||
:param group: | :param group: | ||||
:return: | :return: | ||||
""" | """ | ||||
return fastnlp_torch_all_gather(obj, device=self.data_device, group=group) | |||||
return fastnlp_torch_all_gather(obj, group=group) | |||||
def find_free_network_port() -> str: | def find_free_network_port() -> str: | ||||
@@ -1,11 +1,8 @@ | |||||
import io | import io | ||||
import pickle | import pickle | ||||
from typing import Mapping | |||||
_pickler = pickle.Pickler | _pickler = pickle.Pickler | ||||
_unpickler = pickle.Unpickler | _unpickler = pickle.Unpickler | ||||
from abc import ABC | |||||
from typing import Any, Union, List | |||||
import numpy as np | |||||
from typing import Any, List | |||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | ||||
@@ -13,103 +10,25 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch import distributed as dist | from torch import distributed as dist | ||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupMPI | |||||
except ImportError: | |||||
_MPI_AVAILABLE = False | |||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupNCCL | |||||
except ImportError: | |||||
_NCCL_AVAILABLE = False | |||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupGloo | |||||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||||
except ImportError: | |||||
_GLOO_AVAILABLE = False | |||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
def all_gather_object(object_list, obj, group=None): | |||||
""" | |||||
Gathers picklable objects from the whole group into a list. Similar to | |||||
:func:`all_gather`, but Python objects can be passed in. Note that the object | |||||
must be picklable in order to be gathered. | |||||
Args: | |||||
object_list (list[Any]): Output list. It should be correctly sized as the | |||||
size of the group for this collective and will contain the output. | |||||
object (Any): Pickable Python object to be broadcast from current process. | |||||
group (ProcessGroup, optional): The process group to work on. If None, | |||||
the default process group will be used. Default is ``None``. | |||||
Returns: | |||||
None. If the calling rank is part of this group, the output of the | |||||
collective will be populated into the input ``object_list``. If the | |||||
calling rank is not part of the group, the passed in ``object_list`` will | |||||
be unmodified. | |||||
.. note:: Note that this API differs slightly from the :func:`all_gather` | |||||
collective since it does not provide an ``async_op`` handle and thus | |||||
will be a blocking call. | |||||
.. note:: For NCCL-based processed groups, internal tensor representations | |||||
of objects must be moved to the GPU device before communication takes | |||||
place. In this case, the device used is given by | |||||
``torch.cuda.current_device()`` and it is the user's responsiblity to | |||||
ensure that this is set so that each rank has an individual GPU, via | |||||
``torch.cuda.set_device()``. | |||||
.. warning:: | |||||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||||
known to be insecure. It is possible to construct malicious pickle data | |||||
which will execute arbitrary code during unpickling. Only call this | |||||
function with data you trust. | |||||
Example:: | |||||
>>> # Note: Process group initialization omitted on each rank. | |||||
>>> import torch.distributed as dist | |||||
>>> # Assumes world_size of 3. | |||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
>>> output = [None for _ in gather_objects] | |||||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||||
>>> output | |||||
['foo', 12, {1: 2}] | |||||
""" | |||||
if dist.distributed_c10d._rank_not_in_group(group): | |||||
return | |||||
input_tensor, local_size = _object_to_tensor(obj) | |||||
current_device = torch.device("cpu") | |||||
if dist.is_nccl_available() and isinstance( | |||||
group or dist.distributed_c10d._get_default_group(), dist.ProcessGroupNCCL | |||||
): | |||||
# See note about using torch.cuda.current_device() here in docstring. | |||||
# We cannot simply use my_rank since rank == device is not necessarily | |||||
# true. | |||||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | |||||
# until the correct size when deserializing the tensors. | |||||
group_size = dist.get_world_size(group=group) | |||||
object_sizes_tensor = torch.zeros( | |||||
group_size, dtype=torch.long, device=current_device | |||||
) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
# Allgather tensor sizes | |||||
dist.all_gather(object_size_list, local_size, group=group) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
input_tensor.resize_(max_object_size) | |||||
coalesced_output_tensor = torch.empty( | |||||
max_object_size * group_size, dtype=torch.uint8, device=current_device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
dist.all_gather(output_tensors, input_tensor, group=group) | |||||
# Deserialize outputs back to object. | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(torch.uint8) | |||||
if tensor.device != torch.device("cpu"): | |||||
tensor = tensor.cpu() | |||||
tensor_size = object_size_list[i] | |||||
object_list[i] = _tensor_to_object(tensor, tensor_size) | |||||
def _validate_output_list_for_rank(my_rank, dst, gather_list): | def _validate_output_list_for_rank(my_rank, dst, gather_list): | ||||
if dst == my_rank: | if dst == my_rank: | ||||
if not gather_list: | if not gather_list: | ||||
@@ -123,8 +42,10 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||||
) | ) | ||||
def gather_object(obj, object_gather_list=None, dst=0, group=None): | |||||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): | |||||
""" | """ | ||||
从其它 rank gather 东西到 dst rank 。 | |||||
Gathers picklable objects from the whole group in a single process. | Gathers picklable objects from the whole group in a single process. | ||||
Similar to :func:`gather`, but Python objects can be passed in. Note that the | Similar to :func:`gather`, but Python objects can be passed in. Note that the | ||||
object must be picklable in order to be gathered. | object must be picklable in order to be gathered. | ||||
@@ -176,6 +97,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): | |||||
# Ensure object_gather_list is specified appopriately. | # Ensure object_gather_list is specified appopriately. | ||||
my_rank = dist.get_rank() | my_rank = dist.get_rank() | ||||
_validate_output_list_for_rank(my_rank, dst, object_gather_list) | _validate_output_list_for_rank(my_rank, dst, object_gather_list) | ||||
# 防止 unpickle 的时候出现在了发送的 gpu 上。 | |||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
input_tensor, local_size = _object_to_tensor(obj) | input_tensor, local_size = _object_to_tensor(obj) | ||||
group_backend = dist.get_backend(group) | group_backend = dist.get_backend(group) | ||||
current_device = torch.device("cpu") | current_device = torch.device("cpu") | ||||
@@ -266,113 +189,11 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): | |||||
return _tensor_to_object(tensor.cpu(), size) | return _tensor_to_object(tensor.cpu(), size) | ||||
def _all_gather(obj, **kwargs): | |||||
group = kwargs.get('group', None) | |||||
if isinstance(obj, torch.Tensor): | |||||
gathered_tensor = [torch.zeros_like(obj) for _ in | |||||
range(torch.distributed.get_world_size(group=group))] | |||||
torch.distributed.all_gather(gathered_tensor, obj, group=group) | |||||
return gathered_tensor | |||||
elif isinstance(obj, tuple) and isinstance(obj[1], torch.Tensor): | |||||
tensor, size = obj | |||||
# 首先需要同步 size 吧? | |||||
group_size = dist.get_world_size(group=group) | |||||
object_sizes_tensor = torch.zeros( | |||||
group_size, dtype=torch.long, device=tensor.device | |||||
) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
dist.all_gather(object_size_list, size, group=group) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
tensor.resize_(max_object_size) | |||||
coalesced_output_tensor = torch.empty( | |||||
max_object_size * group_size, dtype=torch.uint8, device=tensor.device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i: max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
dist.all_gather(output_tensors, tensor, group=group) | |||||
object_list = [] | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(torch.uint8) | |||||
tensor_size = object_size_list[i] | |||||
object_list.append(_tensor_to_object(tensor, tensor_size)) | |||||
return object_list | |||||
elif isinstance(obj, tuple) and len(obj) == 2: | |||||
obj, _type = obj | |||||
gathered_tensor = [torch.zeros_like(obj) for _ in | |||||
range(torch.distributed.get_world_size(group=group))] | |||||
torch.distributed.all_gather(gathered_tensor, obj, group=group) | |||||
if _type == np.ndarray: | |||||
gathered_tensor = [t.detach().cpu().numpy() for t in gathered_tensor] | |||||
else: | |||||
gathered_tensor = [_type(t.item()) for t in gathered_tensor] | |||||
return gathered_tensor | |||||
else: | |||||
raise RuntimeError("Unsupported types to implement all_gather.") | |||||
class CanTransferDataType(ABC): | |||||
""" | |||||
检测可以进行传输的对象。 | |||||
""" | |||||
@classmethod | |||||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||||
if cls is CanTransferDataType: | |||||
if issubclass(subclass, Mapping): | |||||
return False | |||||
if subclass in (torch.Tensor, tuple, list, str, int, float, bool, np.ndarray): | |||||
return True | |||||
return False | |||||
return NotImplemented | |||||
def _tensorize(obj, device=None): | |||||
if isinstance(obj, torch.Tensor): | |||||
return obj | |||||
if isinstance(obj, bool): | |||||
return torch.tensor(obj, dtype=torch.uint8, device=device), bool | |||||
if isinstance(obj, float): | |||||
return torch.tensor(obj, dtype=torch.float, device=device), float | |||||
if isinstance(obj, int): | |||||
return torch.tensor(obj, dtype=torch.int, device=device), int | |||||
if isinstance(obj, np.ndarray): | |||||
return torch.from_numpy(obj), np.ndarray | |||||
return _object_to_tensor(obj, device) | |||||
def _to_device(tensor, device): | def _to_device(tensor, device): | ||||
return tensor.contiguous().to(device) | return tensor.contiguous().to(device) | ||||
def convert_to_tensors(data: Any, device=None) -> Any: | |||||
data = apply_to_collection(data, CanTransferDataType, _tensorize) | |||||
def _move_to_device_and_make_contiguous(t: Union[torch.Tensor, tuple], device: Union[str, torch.device]): | |||||
if isinstance(t, tuple): | |||||
if isinstance(t[1], torch.Tensor): # 说明是 object 转的 | |||||
return t[0].to(device).contiguous(), t[1].to(device) | |||||
else: # 说明第二个元素是type,见 to_dtype_tensor 函数 | |||||
return t[0].to(device).contiguous(), t[1] | |||||
return t.to(device).contiguous() | |||||
data = apply_to_collection(data, (torch.Tensor, tuple), _move_to_device_and_make_contiguous, device=device) | |||||
return data | |||||
def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||||
def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||||
""" | """ | ||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | ||||
@@ -390,36 +211,28 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||||
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} | {'a': 1, 'b':[1, 2], 'c':{'d': 2}} | ||||
] | ] | ||||
:param obj: 任意结构的数据,所有的 value 都会变成 list ,其长度为 world_size ,依次为每个 rank 上的对象值 | |||||
:param device: 当前 rank 使用的 device 是哪个。为 None 的话默认使用 torch.cuda.current_device() 获取。 | |||||
:param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 | |||||
序列化之后进行传输。 | |||||
:param device: 当前该参数无意义。 | |||||
:param group: | :param group: | ||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | ||||
""" | """ | ||||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | ||||
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
if device is None: | |||||
device = torch.cuda.current_device() | |||||
if _TORCH_GREATER_EQUAL_1_8: | |||||
if isinstance(obj, torch.Tensor): | |||||
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | |||||
dist.all_gather(objs, obj, group=group) | |||||
else: | |||||
objs = [None for _ in range(dist.get_world_size(group))] | objs = [None for _ in range(dist.get_world_size(group))] | ||||
dist.all_gather_object(objs, obj) | |||||
objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||||
return objs | |||||
group = group if group is not None else torch.distributed.group.WORLD | |||||
data = convert_to_tensors(obj, device=device) | |||||
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) | |||||
objs = [] | |||||
def _get_obj_on_idx(obj, idx): | |||||
return obj[idx] | |||||
for i in range(dist.get_world_size(group)): | |||||
objs.append(apply_to_collection(data, dtype=list, function=_get_obj_on_idx, idx=i)) | |||||
# 防止 unpickle 的时候弄到发送的 gpu 上了 | |||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
if _TORCH_GREATER_EQUAL_1_8: | |||||
dist.all_gather_object(objs, obj, group=group) | |||||
else: | |||||
objs = all_gather_object(objs, obj, group=group) | |||||
return objs | return objs | ||||
def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||||
""" | """ | ||||
将 src 上的 obj 对象广播到其它 rank 上。 | 将 src 上的 obj 对象广播到其它 rank 上。 | ||||
@@ -430,10 +243,9 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||||
:return: | :return: | ||||
""" | """ | ||||
cur_rank = dist.get_rank(group) | cur_rank = dist.get_rank(group) | ||||
# if cur_rank == src: | |||||
# # 如果有 tensor 全部移动到 cpu 上,方便 pickle | |||||
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
if cur_rank == src: | |||||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||||
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||||
if _TORCH_GREATER_EQUAL_1_8: | if _TORCH_GREATER_EQUAL_1_8: | ||||
if cur_rank!=src: | if cur_rank!=src: | ||||
get_obj = [None] | get_obj = [None] | ||||
@@ -442,6 +254,8 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||||
else: | else: | ||||
dist.broadcast_object_list([obj], src=src, group=group) | dist.broadcast_object_list([obj], src=src, group=group) | ||||
return obj | return obj | ||||
if device is None: | |||||
device = torch.cuda.current_device() | |||||
if cur_rank == src: | if cur_rank == src: | ||||
tensor, size = _object_to_tensor(obj, device=device) | tensor, size = _object_to_tensor(obj, device=device) | ||||
@@ -460,3 +274,107 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||||
return _tensor_to_object(tensor, tensor_size=size.item()) | return _tensor_to_object(tensor, tensor_size=size.item()) | ||||
def _check_for_nccl_backend(group): | |||||
pg = group or dist.distributed_c10d._get_default_group() | |||||
# It is not expected for PG to be wrapped many times, but support it just | |||||
# in case | |||||
while isinstance(pg, _ProcessGroupWrapper): | |||||
pg = pg.wrapped_pg | |||||
return ( | |||||
dist.is_nccl_available() and | |||||
isinstance(pg, dist.ProcessGroupNCCL) | |||||
) | |||||
def all_gather_object(object_list, obj, group=None): | |||||
""" | |||||
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 | |||||
Gathers picklable objects from the whole group into a list. Similar to | |||||
:func:`all_gather`, but Python objects can be passed in. Note that the object | |||||
must be picklable in order to be gathered. | |||||
Args: | |||||
object_list (list[Any]): Output list. It should be correctly sized as the | |||||
size of the group for this collective and will contain the output. | |||||
object (Any): Pickable Python object to be broadcast from current process. | |||||
group (ProcessGroup, optional): The process group to work on. If None, | |||||
the default process group will be used. Default is ``None``. | |||||
Returns: | |||||
None. If the calling rank is part of this group, the output of the | |||||
collective will be populated into the input ``object_list``. If the | |||||
calling rank is not part of the group, the passed in ``object_list`` will | |||||
be unmodified. | |||||
.. note:: Note that this API differs slightly from the :func:`all_gather` | |||||
collective since it does not provide an ``async_op`` handle and thus | |||||
will be a blocking call. | |||||
.. note:: For NCCL-based processed groups, internal tensor representations | |||||
of objects must be moved to the GPU device before communication takes | |||||
place. In this case, the device used is given by | |||||
``torch.cuda.current_device()`` and it is the user's responsiblity to | |||||
ensure that this is set so that each rank has an individual GPU, via | |||||
``torch.cuda.set_device()``. | |||||
.. warning:: | |||||
:func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||||
known to be insecure. It is possible to construct malicious pickle data | |||||
which will execute arbitrary code during unpickling. Only call this | |||||
function with data you trust. | |||||
Example:: | |||||
>>> # Note: Process group initialization omitted on each rank. | |||||
>>> import torch.distributed as dist | |||||
>>> # Assumes world_size of 3. | |||||
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||||
>>> output = [None for _ in gather_objects] | |||||
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||||
>>> output | |||||
['foo', 12, {1: 2}] | |||||
""" | |||||
if dist._rank_not_in_group(group): | |||||
return | |||||
input_tensor, local_size = _object_to_tensor(obj) | |||||
current_device = torch.device("cpu") | |||||
is_nccl_backend = _check_for_nccl_backend(group) | |||||
if is_nccl_backend: | |||||
# See note about using torch.cuda.current_device() here in docstring. | |||||
# We cannot simply use my_rank since rank == device is not necessarily | |||||
# true. | |||||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
# Gather all local sizes. This is so that we can find the max size, and index | |||||
# until the correct size when deserializing the tensors. | |||||
group_size = dist.get_world_size(group=group) | |||||
object_sizes_tensor = torch.zeros( | |||||
group_size, dtype=torch.long, device=current_device | |||||
) | |||||
object_size_list = [ | |||||
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||||
] | |||||
# Allgather tensor sizes | |||||
dist.all_gather(object_size_list, local_size, group=group) | |||||
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||||
# Resize tensor to max size across all ranks. | |||||
input_tensor.resize_(max_object_size) | |||||
coalesced_output_tensor = torch.empty( | |||||
max_object_size * group_size, dtype=torch.uint8, device=current_device | |||||
) | |||||
# Output tensors are nonoverlapping views of coalesced_output_tensor | |||||
output_tensors = [ | |||||
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||||
for i in range(group_size) | |||||
] | |||||
dist.all_gather(output_tensors, input_tensor, group=group) | |||||
# Deserialize outputs back to object. | |||||
for i, tensor in enumerate(output_tensors): | |||||
tensor = tensor.type(torch.uint8) | |||||
if tensor.device != torch.device("cpu"): | |||||
tensor = tensor.cpu() | |||||
tensor_size = object_size_list[i] | |||||
object_list[i] = _tensor_to_object(tensor, tensor_size) |
@@ -13,9 +13,8 @@ __all__ = [ | |||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
class TorchSingleDriver(TorchDriver): | class TorchSingleDriver(TorchDriver): | ||||
@@ -130,13 +129,13 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | ||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
elif isinstance(dist, ReproducibleIterator): | |||||
elif isinstance(dist, ReproducibleSampler): | |||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | ||||
@@ -144,7 +143,7 @@ class TorchSingleDriver(TorchDriver): | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | if isinstance(args.batch_sampler, ReproducibleBatchSampler): | ||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | batch_sampler = re_instantiate_sampler(args.batch_sampler) | ||||
return replace_batch_sampler(dataloader, batch_sampler) | return replace_batch_sampler(dataloader, batch_sampler) | ||||
elif isinstance(args.sampler, ReproducibleIterator): | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | sampler = re_instantiate_sampler(args.sampler) | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -182,8 +182,8 @@ class TorchDriver(Driver): | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | ||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | ||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | |||||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | ||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
@@ -247,11 +247,10 @@ class TorchDriver(Driver): | |||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | ||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
elif isinstance(dataloader_args.sampler, ReproducibleIterator): | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | |||||
"`ReproducibleBatchSampler` or `ReproducibleIterator`.") | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||||
else: | else: | ||||
sampler = ReproducibleBatchSampler( | sampler = ReproducibleBatchSampler( | ||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
@@ -291,7 +290,7 @@ class TorchDriver(Driver): | |||||
@staticmethod | @staticmethod | ||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | ||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed | |||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed | |||||
with ``seed_everything(seed, workers=True)``. | with ``seed_everything(seed, workers=True)``. | ||||
See also the PyTorch documentation on | See also the PyTorch documentation on | ||||
@@ -9,18 +9,28 @@ __all__ = [ | |||||
'MixSequentialSampler', | 'MixSequentialSampler', | ||||
'PollingSampler', | 'PollingSampler', | ||||
'ReproducibleIterator', | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | 'RandomSampler', | ||||
're_instantiate_sampler', | |||||
"SequentialSampler", | |||||
"SortedSampler", | |||||
'UnrepeatedSampler', | 'UnrepeatedSampler', | ||||
"UnrepeatedSortedSampler" | |||||
'UnrepeatedRandomSampler', | |||||
"UnrepeatedSortedSampler", | |||||
"UnrepeatedSequentialSampler", | |||||
"RandomBatchSampler", | |||||
"BucketedBatchSampler", | |||||
"ReproducibleBatchSampler", | |||||
"re_instantiate_sampler", | |||||
"conversion_between_reproducible_and_unrepeated_sampler" | |||||
] | ] | ||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | ||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler | |||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | ||||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | |||||
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler | |||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||||
@@ -1,6 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'BucketedBatchSampler', | 'BucketedBatchSampler', | ||||
"ReproducibleBatchSampler" | |||||
"RandomBatchSampler" | |||||
] | ] | ||||
import math | import math | ||||
@@ -16,7 +16,10 @@ from fastNLP.core.log import logger | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
class ReproducibleBatchIterator: | |||||
class ReproducibleBatchSampler: | |||||
def __init__(self, **kwargs): | |||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | ||||
@@ -41,19 +44,25 @@ class ReproducibleBatchIterator: | |||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
pass | pass | ||||
@property | |||||
def batch_idx_in_epoch(self): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | |||||
class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||||
class RandomBatchSampler(ReproducibleBatchSampler): | |||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | ||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | ||||
""" | """ | ||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | 可以使得 batch_sampler 对象状态恢复的 wrapper 。 | ||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | ||||
:param batch_size: 每个 batch 的大小是多少。 | :param batch_size: 每个 batch 的大小是多少。 | ||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | ||||
:param kwargs: fastNLP 内部使用。 | :param kwargs: fastNLP 内部使用。 | ||||
""" | """ | ||||
super().__init__() | |||||
self.batch_sampler = batch_sampler | self.batch_sampler = batch_sampler | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
@@ -138,7 +147,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | ||||
class BucketedBatchSampler(ReproducibleBatchIterator): | |||||
class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
@@ -1,24 +1,21 @@ | |||||
from typing import Dict, List | |||||
from typing import Dict, List, Union | |||||
import math | import math | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.dataset import DataSet | |||||
__all__ = [ | __all__ = [ | ||||
'ReproducibleIterator', | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | 'RandomSampler', | ||||
're_instantiate_sampler' | |||||
"SortedSampler", | |||||
"SequentialSampler" | |||||
] | ] | ||||
def re_instantiate_sampler(sampler): | |||||
all_attributes = vars(sampler) | |||||
return type(sampler)(**all_attributes) | |||||
class ReproducibleIterator: | |||||
class ReproducibleSampler: | |||||
""" | """ | ||||
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||||
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | ||||
""" | """ | ||||
@@ -46,7 +43,7 @@ class ReproducibleIterator: | |||||
pass | pass | ||||
class RandomSampler(ReproducibleIterator): | |||||
class RandomSampler(ReproducibleSampler): | |||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
@@ -156,8 +153,8 @@ class RandomSampler(ReproducibleIterator): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | |||||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({len(self.dataset)})." | |||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
@@ -214,9 +211,132 @@ class RandomSampler(ReproducibleIterator): | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | ||||
class SequentialSampler(RandomSampler): | |||||
def __init__(self, dataset, dist_mode:str='interval', **kwargs): | |||||
""" | |||||
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param kwargs: | |||||
""" | |||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||||
def __iter__(self): | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | |||||
self.during_iter = True | |||||
indices = self.generate_indices() | |||||
if self.pad: | |||||
# add extra samples to make it evenly divisible | |||||
padding_size = self.total_size - len(indices) | |||||
if padding_size <= len(indices): | |||||
indices += indices[:padding_size] | |||||
else: | |||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||||
else: | |||||
# remove tail of data to make it evenly divisible. | |||||
indices = indices[:self.total_size] | |||||
assert len(indices) == self.total_size | |||||
# subsample | |||||
indices = indices[self.num_consumed_samples:] | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
assert len(indices) == self.num_left_samples | |||||
for index in indices: | |||||
self.num_consumed_samples += self.num_replicas | |||||
yield index | |||||
self.during_iter = False | |||||
self.num_consumed_samples = 0 | |||||
def generate_indices(self) -> List[int]: | |||||
""" | |||||
生成随机序列 | |||||
:return: | |||||
""" | |||||
return list(range(len(self.dataset))) | |||||
def state_dict(self) -> Dict: | |||||
states = { | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
} | |||||
return states | |||||
def load_state_dict(self, states: Dict): | |||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||||
"during an unfinished iteration." | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
length = states['length'] | |||||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({len(self.dataset)})." | |||||
self.num_consumed_samples = states['num_consumed_samples'] | |||||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
self.num_consumed_samples = 0 | |||||
class SortedSampler(SequentialSampler): | |||||
def __init__(self, dataset, length:Union[str, List], **kwargs): | |||||
""" | |||||
将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
super().__init__(dataset=dataset, **kwargs) | |||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if not isinstance(length[0], int): | |||||
length = list(map(len, length)) | |||||
else: | |||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||||
"the length parameter can only be List[int]" | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||||
def generate_indices(self) -> List[int]: | |||||
return self.sorted_indices | |||||
def __iter__(self): | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | |||||
self.during_iter = True | |||||
indices = self.generate_indices() | |||||
if self.pad: | |||||
padding_size = self.total_size - len(indices) | |||||
if padding_size <= len(indices): | |||||
indices += indices[:padding_size] | |||||
else: | |||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||||
else: | |||||
# remove tail of data to make it evenly divisible. | |||||
indices = indices[:self.total_size] | |||||
assert len(indices) == self.total_size | |||||
# subsample | |||||
indices = indices[self.num_consumed_samples:] | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
assert len(indices) == self.num_left_samples | |||||
for index in indices: | |||||
self.num_consumed_samples += self.num_replicas | |||||
yield index | |||||
self.during_iter = False | |||||
self.num_consumed_samples = 0 | |||||
@@ -1,6 +1,8 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'UnrepeatedSampler', | |||||
'UnrepeatedSortedSampler', | 'UnrepeatedSortedSampler', | ||||
'UnrepeatedSampler' | |||||
'UnrepeatedRandomSampler', | |||||
"UnrepeatedSequentialSampler" | |||||
] | ] | ||||
from typing import List, Union | from typing import List, Union | ||||
@@ -10,13 +12,21 @@ import numpy as np | |||||
class UnrepeatedSampler: | class UnrepeatedSampler: | ||||
""" | |||||
在多卡场景下保证 indice 不重复的 sampler | |||||
""" | |||||
pass | |||||
class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
考虑在多卡evaluate的场景下,不能重复sample。 | 考虑在多卡evaluate的场景下,不能重复sample。 | ||||
:param dataset: | |||||
:param shuffle: | |||||
:param seed: | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | """ | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
@@ -33,8 +43,8 @@ class UnrepeatedSampler: | |||||
:return: | :return: | ||||
""" | """ | ||||
num_common = len(self.dataset)//self.num_replicas | num_common = len(self.dataset)//self.num_replicas | ||||
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||||
return self.num_samples | |||||
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||||
return num_samples | |||||
def __iter__(self): | def __iter__(self): | ||||
indices = self.generate_indices() | indices = self.generate_indices() | ||||
@@ -83,8 +93,8 @@ class UnrepeatedSampler: | |||||
return self | return self | ||||
class UnrepeatedSortedSampler(UnrepeatedSampler): | |||||
def __init__(self, dataset, length:Union[str, List], seed: int = 0): | |||||
class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
def __init__(self, dataset, length:Union[str, List], **kwargs): | |||||
""" | """ | ||||
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | 将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | ||||
batch 数量不完全一致。 | batch 数量不完全一致。 | ||||
@@ -92,11 +102,9 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | :param dataset: 实现了 __len__ 方法的数据容器。 | ||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | ||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | ||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__(dataset=dataset, shuffle=False, seed=seed) | |||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
length = dataset.get_field(length) | length = dataset.get_field(length) | ||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
@@ -107,8 +115,29 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | ||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||||
length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||||
self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的 | |||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
return self.sorted_indices | return self.sorted_indices | ||||
class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||||
def __init__(self, dataset, **kwargs): | |||||
""" | |||||
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param kwargs: | |||||
""" | |||||
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) | |||||
def __iter__(self): | |||||
indices = self.generate_indices() | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
for index in indices: | |||||
yield index | |||||
def generate_indices(self) -> List[int]: | |||||
return list(range(len(self.dataset))) | |||||
@@ -0,0 +1,42 @@ | |||||
__all__ = [ | |||||
're_instantiate_sampler', | |||||
'conversion_between_reproducible_and_unrepeated_sampler' | |||||
] | |||||
from fastNLP.core.samplers.unrepeated_sampler import * | |||||
from fastNLP.core.samplers.reproducible_sampler import * | |||||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||||
""" | |||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||||
ReproducibleSampler, | |||||
:param sampler: | |||||
:return: | |||||
""" | |||||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||||
if isinstance(sampler, UnrepeatedSampler): | |||||
if isinstance(sampler, UnrepeatedRandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||||
else: | |||||
if isinstance(sampler, RandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||||
elif isinstance(sampler, SequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||||
elif isinstance(sampler, SortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||||
def re_instantiate_sampler(sampler, new_sampler_class=None): | |||||
all_attributes = vars(sampler) | |||||
if new_sampler_class is not None: | |||||
return new_sampler_class(**all_attributes) | |||||
return type(sampler)(**all_attributes) |
@@ -0,0 +1,10 @@ | |||||
class EarlyStopException(BaseException): | |||||
r""" | |||||
用于EarlyStop时从Trainer训练循环中跳出。 | |||||
""" | |||||
def __init__(self, msg): | |||||
super(EarlyStopException, self).__init__(msg) | |||||
self.msg = msg |
@@ -94,9 +94,6 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
self.print = self.console.print | self.print = self.console.print | ||||
self.log = self.console.log | self.log = self.console.log | ||||
# start new | |||||
self.start() | |||||
self.console.show_cursor(show=True) | |||||
return self | return self | ||||
def set_transient(self, transient: bool = True): | def set_transient(self, transient: bool = True): | ||||
@@ -154,6 +151,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
super().start() | super().start() | ||||
self.console.show_cursor(show=True) | self.console.show_cursor(show=True) | ||||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | ||||
f_rich_progress = FRichProgress().new_progess( | f_rich_progress = FRichProgress().new_progess( | ||||
"[progress.description]{task.description}", | "[progress.description]{task.description}", | ||||
@@ -12,32 +12,27 @@ def test_get_monitor_value(): | |||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | ||||
assert monitor == 'f1' and value==0.2 | assert monitor == 'f1' and value==0.2 | ||||
assert 'We can not find' not in output[0] | |||||
# 测试可以匹配,且选择更靠前的 | # 测试可以匹配,且选择更靠前的 | ||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | ||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) | ||||
assert monitor=='acc#f1' and value==0.2 | assert monitor=='acc#f1' and value==0.2 | ||||
assert 'We can not find' in output[0] | |||||
# 测试monitor匹配不上,使用real_monitor | # 测试monitor匹配不上,使用real_monitor | ||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | ||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res) | |||||
monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res) | |||||
assert monitor=='acc#rec' and value==0.3 | assert monitor=='acc#rec' and value==0.3 | ||||
assert 'We can not find' not in output[0] | |||||
# 测试monitor/real_monitor匹配不上, 重新选择 | # 测试monitor/real_monitor匹配不上, 重新选择 | ||||
res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} | ||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) | monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) | ||||
assert monitor=='acc#f1' and value==0.2 | assert monitor=='acc#f1' and value==0.2 | ||||
assert 'We can not find' in output[0] | |||||
# 测试partial的位置 | # 测试partial的位置 | ||||
res = {"acc#acc": 0.52, "loss#loss": 2} | res = {"acc#acc": 0.52, "loss#loss": 2} | ||||
with Capturing() as output: | with Capturing() as output: | ||||
monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) | monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) | ||||
assert monitor=='loss#loss' and value==2 | assert monitor=='loss#loss' and value==2 | ||||
assert 'We can not find' in output[0] |
@@ -10,7 +10,7 @@ from paddle.io import DataLoader, BatchSampler | |||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | ||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | from fastNLP.core.samplers.reproducible_sampler import RandomSampler | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.samplers import RandomBatchSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | ||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | ||||
from fastNLP.core import synchronize_safe_rm | from fastNLP.core import synchronize_safe_rm | ||||
@@ -153,7 +153,7 @@ class TestSingleDeviceFunction: | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"dist_sampler", | "dist_sampler", | ||||
["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||||
["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"reproducible", | "reproducible", | ||||
@@ -7,38 +7,10 @@ import numpy as np | |||||
# print(isinstance((1,), tuple)) | # print(isinstance((1,), tuple)) | ||||
# exit() | # exit() | ||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, convert_to_tensors, fastnlp_torch_broadcast_object | |||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||||
from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | ||||
def test_convert_to_tensors(): | |||||
local_rank = 0 | |||||
obj = { | |||||
'tensor': torch.full(size=(2,), fill_value=local_rank), | |||||
'numpy': np.full(shape=(1,), fill_value=local_rank), | |||||
'bool': local_rank % 2 == 0, | |||||
'float': local_rank + 0.1, | |||||
'int': local_rank, | |||||
'dict': { | |||||
'rank': local_rank | |||||
}, | |||||
'list': [local_rank] * 2, | |||||
'str': 'xxx' | |||||
} | |||||
data = convert_to_tensors(obj) | |||||
assert len(data) == len(obj) | |||||
assert (data['tensor'] == obj['tensor']).sum() == 2 | |||||
for name in ['list', 'str']: | |||||
assert len(data[name])==2 and isinstance(data[name][0], torch.Tensor) and \ | |||||
isinstance(data[name][1], torch.Tensor) and data[name][1].ndim==1 | |||||
for name in ['numpy', 'bool', 'float', 'int']: | |||||
assert isinstance(data[name][0], torch.Tensor) and data[name][0].numel()==1 | |||||
assert isinstance(data['dict']['rank'][0], torch.Tensor) and data[name][0].numel() == 1 | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_fastnlp_torch_all_gather(): | def test_fastnlp_torch_all_gather(): | ||||
os.environ['MASTER_ADDR'] = '127.0.0.1' | os.environ['MASTER_ADDR'] = '127.0.0.1' | ||||
@@ -66,7 +38,7 @@ def test_fastnlp_torch_all_gather(): | |||||
'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), | 'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), | ||||
torch.full(size=(2,), fill_value=local_rank).cuda()] | torch.full(size=(2,), fill_value=local_rank).cuda()] | ||||
} | } | ||||
data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) | |||||
data = fastnlp_torch_all_gather(obj) | |||||
world_size = int(os.environ['WORLD_SIZE']) | world_size = int(os.environ['WORLD_SIZE']) | ||||
assert len(data) == world_size | assert len(data) == world_size | ||||
for i in range(world_size): | for i in range(world_size): | ||||
@@ -81,10 +53,12 @@ def test_fastnlp_torch_all_gather(): | |||||
assert data[i]['tensors'][0][0] == i | assert data[i]['tensors'][0][0] == i | ||||
for obj in [1, True, 'xxx']: | for obj in [1, True, 'xxx']: | ||||
data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) | |||||
data = fastnlp_torch_all_gather(obj) | |||||
assert len(data)==world_size | assert len(data)==world_size | ||||
assert data[0]==data[1] | assert data[0]==data[1] | ||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_fastnlp_torch_broadcast_object(): | def test_fastnlp_torch_broadcast_object(): | ||||
os.environ['MASTER_ADDR'] = '127.0.0.1' | os.environ['MASTER_ADDR'] = '127.0.0.1' | ||||
@@ -130,3 +104,4 @@ def test_fastnlp_torch_broadcast_object(): | |||||
for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: | for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: | ||||
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) | data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) | ||||
assert int(data)==0 | assert int(data)==0 | ||||
dist.destroy_process_group() |
@@ -30,7 +30,7 @@ class SequenceDataSet: | |||||
def check_replace_sampler(driver): | def check_replace_sampler(driver): | ||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler | |||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler | |||||
# reproducible 是 True 和 False | # reproducible 是 True 和 False | ||||
# 需要 check 返回的 sampler 和 dataloader 都不同了 | # 需要 check 返回的 sampler 和 dataloader 都不同了 | ||||
@@ -4,7 +4,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
from itertools import chain | from itertools import chain | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -18,7 +18,7 @@ class TestReproducibleBatchSampler: | |||||
before_batch_size = 7 | before_batch_size = 7 | ||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | dataloader = DataLoader(dataset, batch_size=before_batch_size) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
forward_steps = 3 | forward_steps = 3 | ||||
@@ -28,15 +28,15 @@ class TestReproducibleBatchSampler: | |||||
# 1. 保存状态 | # 1. 保存状态 | ||||
_get_re_batchsampler = dataloader.batch_sampler | _get_re_batchsampler = dataloader.batch_sampler | ||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | state = _get_re_batchsampler.state_dict() | ||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | ||||
"sampler_type": "ReproducibleBatchSampler"} | |||||
"sampler_type": "RandomBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | # 2. 断点重训,重新生成一个 dataloader; | ||||
# 不改变 batch_size; | # 不改变 batch_size; | ||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | dataloader = DataLoader(dataset, batch_size=before_batch_size) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | re_batchsampler.load_state_dict(state) | ||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
@@ -53,7 +53,7 @@ class TestReproducibleBatchSampler: | |||||
# 改变 batch_size; | # 改变 batch_size; | ||||
after_batch_size = 3 | after_batch_size = 3 | ||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | dataloader = DataLoader(dataset, batch_size=after_batch_size) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | re_batchsampler.load_state_dict(state) | ||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
@@ -99,7 +99,7 @@ class TestReproducibleBatchSampler: | |||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | ||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | # 将一轮的所有数据保存下来,看是否恢复的是正确的; | ||||
@@ -111,13 +111,13 @@ class TestReproducibleBatchSampler: | |||||
# 1. 保存状态 | # 1. 保存状态 | ||||
_get_re_batchsampler = dataloader.batch_sampler | _get_re_batchsampler = dataloader.batch_sampler | ||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | state = _get_re_batchsampler.state_dict() | ||||
# 2. 断点重训,重新生成一个 dataloader; | # 2. 断点重训,重新生成一个 dataloader; | ||||
# 不改变 batch_size; | # 不改变 batch_size; | ||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | re_batchsampler.load_state_dict(state) | ||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
@@ -1,18 +1,14 @@ | |||||
import unittest | |||||
from itertools import product | |||||
import numpy as np | import numpy as np | ||||
import pytest | |||||
from functools import partial | from functools import partial | ||||
from array import array | |||||
from itertools import chain | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
class TestRandomSamplerYh(unittest.TestCase): | |||||
class TestRandomSamplerYh: | |||||
def test_init(self): | def test_init(self): | ||||
# 测试能否正确初始化 | # 测试能否正确初始化 | ||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
@@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase): | |||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
sampler = RandomSampler(dataset) | sampler = RandomSampler(dataset) | ||||
for i in sampler: | for i in sampler: | ||||
with self.assertRaises(AssertionError): | |||||
with pytest.raises(AssertionError): | |||||
sampler.set_distributed(1, 0) | sampler.set_distributed(1, 0) | ||||
break | break | ||||
@@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase): | |||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
sampler = RandomSampler(dataset, shuffle=False) | sampler = RandomSampler(dataset, shuffle=False) | ||||
sampler.set_distributed(num_replicas=2, rank=0, pad=False) | sampler.set_distributed(num_replicas=2, rank=0, pad=False) | ||||
self.assertEqual(len(sampler), 50) | |||||
assert len(sampler)==50 | |||||
count = 0 | count = 0 | ||||
for i in sampler: | for i in sampler: | ||||
self.assertEqual(i%2, 0) | |||||
assert i%2==0 | |||||
count += 1 | count += 1 | ||||
self.assertEqual(count, 50) | |||||
assert count == 50 | |||||
sampler.set_distributed(num_replicas=2, rank=1, pad=False) | sampler.set_distributed(num_replicas=2, rank=1, pad=False) | ||||
self.assertEqual(len(sampler), 50) | |||||
assert len(sampler)==50 | |||||
count = 0 | count = 0 | ||||
for i in sampler: | for i in sampler: | ||||
self.assertEqual(i%2, 1) | |||||
assert i%2==1 | |||||
count += 1 | count += 1 | ||||
self.assertEqual(count, 50) | |||||
assert count==50 | |||||
dataset = TorchNormalDataset(num_of_data=101) | dataset = TorchNormalDataset(num_of_data=101) | ||||
sampler = RandomSampler(dataset, shuffle=False) | sampler = RandomSampler(dataset, shuffle=False) | ||||
sampler.set_distributed(num_replicas=2, rank=0, pad=True) | sampler.set_distributed(num_replicas=2, rank=0, pad=True) | ||||
self.assertEqual(len(sampler), 51) | |||||
assert len(sampler)==51 | |||||
count = 0 | count = 0 | ||||
for i in sampler: | for i in sampler: | ||||
self.assertEqual(i%2, 0) | |||||
assert i%2==0 | |||||
count += 1 | count += 1 | ||||
self.assertEqual(count, 51) | |||||
assert count == 51 | |||||
sampler.set_distributed(num_replicas=2, rank=1, pad=True) | sampler.set_distributed(num_replicas=2, rank=1, pad=True) | ||||
self.assertEqual(len(sampler), 51) | |||||
assert len(sampler) == 51 | |||||
count = 0 | count = 0 | ||||
for i in sampler: | for i in sampler: | ||||
if i!=0: | if i!=0: | ||||
self.assertEqual(i%2, 1) | |||||
assert i%2==1 | |||||
count += 1 | count += 1 | ||||
self.assertEqual(count, 51) | |||||
assert count == 51 | |||||
def test_state_dict_check_length(self): | def test_state_dict_check_length(self): | ||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
@@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase): | |||||
states = sampler.state_dict() | states = sampler.state_dict() | ||||
new_ds = TorchNormalDataset(num_of_data=10) | new_ds = TorchNormalDataset(num_of_data=10) | ||||
with self.assertRaises(AssertionError): | |||||
with pytest.raises(AssertionError): | |||||
new_sampler = RandomSampler(new_ds) | new_sampler = RandomSampler(new_ds) | ||||
new_sampler.load_state_dict(states) | new_sampler.load_state_dict(states) | ||||
@@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase): | |||||
new_sampler = RandomSampler(new_ds) | new_sampler = RandomSampler(new_ds) | ||||
new_sampler.load_state_dict(states) | new_sampler.load_state_dict(states) | ||||
def test_state_dict(self): | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('pre_shuffle', [True, False]) | |||||
@pytest.mark.parametrize('post_shuffle', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||||
def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): | |||||
num_samples = 100 | num_samples = 100 | ||||
dataset = TorchNormalDataset(num_of_data=num_samples) | dataset = TorchNormalDataset(num_of_data=num_samples) | ||||
# 测试使用 前后shuffle不一致的load操作 | # 测试使用 前后shuffle不一致的load操作 | ||||
lst = [0]+np.random.randint(1, num_samples, size=3).tolist() | |||||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||||
lst): | |||||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
self.assertEqual(len(already_numbers), num_consumed_samples) | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, already_numbers) | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, other_rank_number) | |||||
other_rank_number.add(i) | |||||
self.assertNotIn(i, already_numbers) | |||||
count += 1 | |||||
def test_state_dict_2(self): | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
for i in new_sampler: | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('pre_shuffle', [True, False]) | |||||
@pytest.mark.parametrize('post_shuffle', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||||
def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): | |||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | ||||
num_samples = 100 | num_samples = 100 | ||||
dataset = TorchNormalDataset(num_of_data=num_samples) | dataset = TorchNormalDataset(num_of_data=num_samples) | ||||
# 测试使用 前后shuffle不一致的load操作 | # 测试使用 前后shuffle不一致的load操作 | ||||
lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist() | |||||
# lst = [30] | # lst = [30] | ||||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||||
lst): | |||||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||||
already_numbers = set() | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
self.assertEqual(len(already_numbers), num_consumed_samples*2) | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, already_numbers) | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||||
count = 0 | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, other_rank_number) | |||||
other_rank_number.add(i) | |||||
self.assertNotIn(i, already_numbers) | |||||
count += 1 | |||||
class TestRandomSampler(unittest.TestCase): | |||||
already_numbers = set() | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples*2 | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
for i in new_sampler: | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
class TestRandomSampler: | |||||
# 测试单卡; | # 测试单卡; | ||||
def test_seed_work_when_shuffle_is_true(self): | def test_seed_work_when_shuffle_is_true(self): | ||||
data_length = 100 | data_length = 100 | ||||
@@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase): | |||||
... | ... | ||||
class DatasetWithVaryLength: | |||||
def __init__(self, num_of_data=100, reverse=False): | |||||
self.data = np.arange(num_of_data) | |||||
if reverse: | |||||
self.data = self.data[::-1] | |||||
def __getitem__(self, item): | |||||
return self.data[item] | |||||
def __len__(self): | |||||
return len(self.data) | |||||
class TestSortedSampler: | |||||
def test_single(self): | |||||
num_of_data = 100 | |||||
data = DatasetWithVaryLength(num_of_data) | |||||
sampler = SortedSampler(data, length=data.data) | |||||
indexes = list(sampler) | |||||
assert indexes==list(range(num_of_data-1, -1, -1)) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||||
def test_multi(self, pad, num_replica, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||||
samplers = [] | |||||
for i in range(num_replica): | |||||
sampler = SortedSampler(dataset=data, length=data.data) | |||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
# 保证顺序是没乱的 | |||||
already_seen_index = set() | |||||
for sampler in samplers: | |||||
larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。 | |||||
prev_index = float('inf') | |||||
cur_set = set() | |||||
seen_in_other_rank = 0 | |||||
for index in sampler: | |||||
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 | |||||
cur_set.add(index) | |||||
larger_count += int(index <= prev_index) | |||||
prev_index = index | |||||
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 | |||||
assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0 | |||||
already_seen_index.update(cur_set) | |||||
indexes = list(chain(*samplers)) | |||||
indexes = set(indexes) | |||||
if pad: | |||||
assert indexes == set(range(num_of_data)) | |||||
else: | |||||
assert len(indexes) <= num_of_data | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||||
def test_state_dict(self, pad, num_consumed_samples): | |||||
num_samples = 100 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
sampler = SortedSampler(dataset, length=dataset.data) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
if already_numbers: | |||||
assert j<max(already_numbers) | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples | |||||
states = sampler.state_dict() | |||||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
assert i < max(already_numbers) | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
smaller = 0 | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
smaller += int(i >= max(already_numbers)) | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
assert smaller<=1 if pad else smaller==0 | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||||
def test_state_dict_2(self, pad, num_consumed_samples): | |||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||||
num_samples = 100 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
# lst = [30] | |||||
already_numbers = set() | |||||
sampler = SortedSampler(dataset, length=dataset.data) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
if already_numbers: | |||||
assert j<=max(already_numbers) | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = SortedSampler(dataset, length=dataset.data) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples*2 | |||||
states = sampler.state_dict() | |||||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
assert i < max(already_numbers) | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
smaller = 0 | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
smaller += int(i>=max(already_numbers)) | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
assert smaller <= 1 if pad else smaller == 0 | |||||
class TestSequentialSampler: | |||||
def test_single(self): | |||||
num_of_data = 100 | |||||
data = DatasetWithVaryLength(num_of_data) | |||||
sampler = SequentialSampler(data) | |||||
indexes = list(sampler) | |||||
assert indexes==list(range(num_of_data)) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||||
def test_multi(self, pad, num_replica, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||||
samplers = [] | |||||
for i in range(num_replica): | |||||
sampler = SequentialSampler(dataset=data) | |||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
# 保证顺序是没乱的 | |||||
already_seen_index = set() | |||||
for idx, sampler in enumerate(samplers): | |||||
larger_count = 1 | |||||
prev_index = float('inf') | |||||
cur_set = set() | |||||
seen_in_other_rank = 0 | |||||
for index in sampler: | |||||
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 | |||||
cur_set.add(index) | |||||
larger_count += int(index >= prev_index) | |||||
prev_index = index | |||||
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 | |||||
assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0 | |||||
already_seen_index.update(cur_set) | |||||
indexes = list(chain(*samplers)) | |||||
indexes = set(indexes) | |||||
if pad: | |||||
assert indexes == set(range(num_of_data)) | |||||
else: | |||||
assert len(indexes) <= num_of_data | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||||
def test_state_dict(self, pad, num_consumed_samples): | |||||
num_samples = 100 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
sampler = SequentialSampler(dataset=dataset) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
if already_numbers: | |||||
assert j>max(already_numbers) | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples | |||||
states = sampler.state_dict() | |||||
new_sampler = SequentialSampler(dataset=dataset) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
assert i > max(already_numbers) | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = SequentialSampler(dataset=dataset) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
smaller = 0 | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
smaller += int(i <= max(already_numbers)) | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=rank # 因为pad可能重复 | |||||
assert smaller<=1 if pad else smaller==0 | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||||
def test_state_dict_2(self, pad, num_consumed_samples): | |||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||||
num_samples = 100 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
# lst = [30] | |||||
already_numbers = set() | |||||
sampler = SequentialSampler(dataset=dataset) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
if already_numbers: | |||||
assert j>max(already_numbers) | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = SequentialSampler(dataset=dataset) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples*2 | |||||
states = sampler.state_dict() | |||||
new_sampler = SequentialSampler(dataset=dataset) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
assert i > max(already_numbers) | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = SequentialSampler(dataset=dataset) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
smaller = 0 | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
smaller += int(i<max(already_numbers)) | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
assert smaller <= rank if pad else smaller == 0 | |||||
@@ -2,7 +2,7 @@ from itertools import chain | |||||
import pytest | import pytest | ||||
from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler | |||||
from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||||
class DatasetWithVaryLength: | class DatasetWithVaryLength: | ||||
@@ -21,7 +21,7 @@ class TestUnrepeatedSampler: | |||||
def test_single(self, shuffle): | def test_single(self, shuffle): | ||||
num_of_data = 100 | num_of_data = 100 | ||||
data = DatasetWithVaryLength(num_of_data) | data = DatasetWithVaryLength(num_of_data) | ||||
sampler = UnrepeatedSampler(data, shuffle) | |||||
sampler = UnrepeatedRandomSampler(data, shuffle) | |||||
indexes = set(sampler) | indexes = set(sampler) | ||||
assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
@@ -32,17 +32,18 @@ class TestUnrepeatedSampler: | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | for i in range(num_replica): | ||||
sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle) | |||||
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) | |||||
sampler.set_distributed(num_replica, rank=i) | sampler.set_distributed(num_replica, rank=i) | ||||
samplers.append(sampler) | samplers.append(sampler) | ||||
indexes = set(chain(*samplers)) | |||||
indexes = list(chain(*samplers)) | |||||
assert len(indexes) == num_of_data | |||||
indexes = set(indexes) | |||||
assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
class TestUnrepeatedSortedSampler: | class TestUnrepeatedSortedSampler: | ||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
def test_single(self, shuffle): | |||||
def test_single(self): | |||||
num_of_data = 100 | num_of_data = 100 | ||||
data = DatasetWithVaryLength(num_of_data) | data = DatasetWithVaryLength(num_of_data) | ||||
sampler = UnrepeatedSortedSampler(data, length=data.data) | sampler = UnrepeatedSortedSampler(data, length=data.data) | ||||
@@ -51,8 +52,7 @@ class TestUnrepeatedSortedSampler: | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | @pytest.mark.parametrize('num_replica', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
@pytest.mark.parametrize('shuffle', [False, True]) | |||||
def test_multi(self, num_replica, num_of_data, shuffle): | |||||
def test_multi(self, num_replica, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | for i in range(num_replica): | ||||
@@ -60,5 +60,45 @@ class TestUnrepeatedSortedSampler: | |||||
sampler.set_distributed(num_replica, rank=i) | sampler.set_distributed(num_replica, rank=i) | ||||
samplers.append(sampler) | samplers.append(sampler) | ||||
indexes = set(chain(*samplers)) | |||||
# 保证顺序是没乱的 | |||||
for sampler in samplers: | |||||
prev_index = float('inf') | |||||
for index in sampler: | |||||
assert index <= prev_index | |||||
prev_index = index | |||||
indexes = list(chain(*samplers)) | |||||
assert len(indexes) == num_of_data # 不同卡之间没有交叉 | |||||
indexes = set(indexes) | |||||
assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
class TestUnrepeatedSequentialSampler: | |||||
def test_single(self): | |||||
num_of_data = 100 | |||||
data = DatasetWithVaryLength(num_of_data) | |||||
sampler = UnrepeatedSequentialSampler(data, length=data.data) | |||||
indexes = list(sampler) | |||||
assert indexes==list(range(num_of_data)) | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||||
def test_multi(self, num_replica, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||||
samplers = [] | |||||
for i in range(num_replica): | |||||
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | |||||
sampler.set_distributed(num_replica, rank=i) | |||||
samplers.append(sampler) | |||||
# 保证顺序是没乱的 | |||||
for sampler in samplers: | |||||
prev_index = float('-inf') | |||||
for index in sampler: | |||||
assert index>=prev_index | |||||
prev_index = index | |||||
indexes = list(chain(*samplers)) | |||||
assert len(indexes) == num_of_data | |||||
indexes = set(indexes) | |||||
assert indexes == set(range(num_of_data)) |