diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index a47ab998..fc5d9d5b 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -10,7 +10,8 @@ __all__ = [ 'ProgressCallback', 'RichCallback', "LRSchedCallback", - 'LoadBestModelCallback' + 'LoadBestModelCallback', + "EarlyStopCallback" ] @@ -21,4 +22,5 @@ from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallb from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback +from .early_stop_callback import EarlyStopCallback diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index b2d99b51..4b553a1f 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -1,11 +1,15 @@ -from typing import Union, Callable, Dict, Optional +from typing import Union, Callable, Dict, Optional, Any +from abc import ABC __all__ = [ 'Callback', ] from .callback_events import Events, EventsList, Filter +from .utils import _get_monitor_value from fastNLP.core.callbacks.callback_events import _SingleEventState +from fastNLP.core.log import logger +from fastNLP.core.utils import apply_to_collection class Callback: @@ -150,4 +154,82 @@ class _CallbackWrapper(Callback): return self.fn.__name__ +class CanItemDataType(ABC): + """ + 检测可以进行传输的对象。 + + """ + + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is CanItemDataType: + item = getattr(subclass, 'item', None) + return callable(item) + return NotImplemented + + +class HasMonitorCallback(Callback): + def __init__(self, monitor, larger_better, must_have_monitor=False): + self.set_monitor(monitor, larger_better) + self.must_have_moinitor = must_have_monitor + + def set_monitor(self, monitor, larger_better): + self.monitor = str(monitor) if monitor is not None else None + self.larger_better = bool(larger_better) + if larger_better: + self.monitor_value = float('-inf') + else: + self.monitor_value = float('inf') + self._real_monitor = self.monitor + + def on_after_trainer_initialized(self, trainer, driver): + """ + 如果本身的 monitor 没有设置,则根据 Trainer 中的 monitor 设置 monitor 。 + 同时对于必须要有 monitor 设置的 callback ,该函数会进行检查。 + + :param trainer: + :param driver: + :return: + """ + if self.monitor is None and trainer.monitor is not None: + self.set_monitor(monitor=trainer.monitor, larger_better=trainer.larger_better) + if self.must_have_moinitor and self.monitor is None: + raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " + f"You can set it in the initialization or through Trainer.") + + def get_monitor_value(self, results:Dict)->float: + """ + 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 + :param results: + :return: + """ + if len(results)==0: + return 0 + # 保证所有的 tensor 都被转换为了 python 特定的类型 + results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) + use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, + real_monitor=self._real_monitor, + res=results) + if self._real_monitor != use_monitor: # 发生了替换需要打印 + logger.warning( + f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " + f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") + self._real_monitor = use_monitor + return monitor_value + + def is_better_monitor_value(self, monitor_value: float, keep_if_better=True): + """ + 检测 monitor_value 是否是更好的 + + :param monitor_value: + :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 + :return: + """ + better = False + if (self.larger_better and monitor_value > self.monitor_value) or \ + (not self.larger_better and monitor_value < self.monitor_value): + better = True + if keep_if_better: + self.monitor_value = monitor_value + return better \ No newline at end of file diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index d3a3b52d..839a9522 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -5,12 +5,12 @@ __all__ = [ import os from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping from pathlib import Path -from abc import ABC import sys +from copy import deepcopy import fastNLP -from .callback import Callback, Filter +from .callback import Callback, HasMonitorCallback from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME @@ -18,22 +18,7 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir from fastNLP.core.utils import apply_to_collection -class CanItemDataType(ABC): - """ - 检测可以进行传输的对象。 - - """ - - @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: - if cls is CanItemDataType: - item = getattr(subclass, 'item', None) - return callable(item) - return NotImplemented - - - -class CheckpointCallback(Callback): +class CheckpointCallback(HasMonitorCallback): def __init__( self, monitor, @@ -48,12 +33,8 @@ class CheckpointCallback(Callback): model_save_fn: Optional[Callable] = None, **kwargs, ): - if monitor is None and save_topk is not None: - raise ValueError("Parameter `monitor` must be set when you want to use 'save_topk'.") - - if monitor is not None and not isinstance(monitor, str): - raise ValueError("Parameter `monitor` should be of 'str' type.") - + super().__init__(monitor=monitor, larger_better=larger_better, + must_have_monitor=save_topk is not None) if save_folder is None: logger.warning( "Parameter `path` is None, and we will use the current work directory to find and load your model.") @@ -91,13 +72,12 @@ class CheckpointCallback(Callback): "`BaseException` type.") else: save_on_exception = [] - self.monitor = monitor + self.save_folder = Path(save_folder) self.save_every_n_epochs = save_every_n_epochs self.save_every_n_batches = save_every_n_batches self.save_last = save_last self.save_topk = save_topk - self.larger_better = larger_better self.only_state_dict = only_state_dict self.model_save_fn = model_save_fn self.save_on_exception = save_on_exception @@ -107,20 +87,22 @@ class CheckpointCallback(Callback): self._topk_model = {} self._topn = 0 # 表示目前已经保存了几个最好的模型; - # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 - # key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 - # 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; - # 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; - self._real_monitor = self.monitor - # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; synchronize_mkdir(self.timestamp_path) - def on_validate_end(self, trainer, validate_res): - self._save_topk(trainer, validate_res) + def on_after_trainer_initialized(self, trainer, driver): + if self.save_topk is not None: + super().on_after_trainer_initialized(trainer, driver) + if self.save_topk is not None and trainer.evaluator is None: + logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") + + def on_validate_end(self, trainer, results): + if len(results) == 0: + return + self._save_topk(trainer, results) def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: @@ -143,7 +125,7 @@ class CheckpointCallback(Callback): def on_sanity_check_end(self, trainer, sanity_check_res): # 主要核对一下 monitor 是否存在。 - self._get_validate_metric(sanity_check_res) + self.get_monitor_value(results=sanity_check_res) def on_save_checkpoint(self, trainer) -> Dict: """ @@ -154,8 +136,7 @@ class CheckpointCallback(Callback): states = {} states['timestamp_path'] = str(self.timestamp_path.absolute()) - states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, - function=lambda x:x.item()) + states['_topk_model'] = deepcopy(self._topk_model) states['save_topk'] = 0 if self.save_topk is None else self.save_topk states['_real_monitor'] = self._real_monitor return states @@ -176,30 +157,30 @@ class CheckpointCallback(Callback): self._topk_model.update(self._topk_model) self._real_monitor = states["real_monitor"] - def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): + def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): """ 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 :param trainer: - :param validate_res: + :param results: :return: """ if self.save_topk is not None: - _metric_value = self._get_validate_metric(validate_res) + monitor_value = self.get_monitor_value(results=results) folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ - f"-{self._real_monitor}_{_metric_value}" + f"-{self._real_monitor}_{monitor_value}" _should_save = False if self._topn < self.save_topk: - self._topk_model[folder_name] = _metric_value + self._topk_model[folder_name] = monitor_value self._topn += 1 _should_save = True else: _least_valuable_model = (min if self.larger_better else max)(self._topk_model, key=lambda x: self._topk_model[x]) - if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ - (self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): - self._topk_model[folder_name] = _metric_value + if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ + (self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): + self._topk_model[folder_name] = monitor_value _should_save = True self._topk_model.pop(_least_valuable_model) synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) @@ -235,7 +216,11 @@ class CheckpointCallback(Callback): :return: """ use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) + if self._real_monitor != use_monitor: + logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " + f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") self._real_monitor = use_monitor + return value @property @@ -263,7 +248,7 @@ class ModelCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -310,7 +295,7 @@ class TrainerCheckpointCallback(CheckpointCallback): 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。 + 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py new file mode 100644 index 00000000..602236f7 --- /dev/null +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -0,0 +1,61 @@ +__all__ = [ + 'EarlyStopCallback' +] + +from typing import Dict + +from .callback import HasMonitorCallback +from fastNLP.core.utils.exceptions import EarlyStopException + + +class EarlyStopCallback(HasMonitorCallback): + def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10): + """ + + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 + :param larger_better: monitor 的值是否是越大越好。 + :param patience: 多少次 validate 不没有提升就停止。 + """ + super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) + self.wait = 0 + self.patience = patience + + def on_validate_end(self, trainer, results): + if len(results)==0: + return + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): + self.wait = 0 + else: + self.wait += 1 + + def on_fetch_data_begin(self, trainer): + # 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 + if self.wait >= self.patience: + raise EarlyStopException(f"After {self.wait} validations, no improvement for " + f"metric `{self._real_monitor}`") + + def on_train_epoch_begin(self, trainer): + # 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 + if self.wait >= self.patience: + raise EarlyStopException(f"After {self.wait} validations, no improvement for " + f"metric `{self._real_monitor}`(best value: {self.monitor_value})") + + def on_save_checkpoint(self, trainer) -> Dict: + states = { + 'patience': self.patience, + 'wait': self.wait, + 'monitor': self.monitor, + 'monitor_value': self.monitor_value + } + return states + + def on_load_checkpoint(self, trainer, states): + self.patience = states['patience'] + self.wait = states['wait'] + self.monitor = states['monitor'] + self.monitor_value = float(states['monitor_value']) + + def callback_name(self): + return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' + diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index e7b94f8c..9a4bb65f 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -4,8 +4,7 @@ __all__ = [ import os from typing import Optional, Callable -from .callback import Callback -from .utils import _get_monitor_value +from .callback import HasMonitorCallback from io import BytesIO import shutil @@ -14,15 +13,15 @@ from fastNLP.core.log import logger from fastNLP.envs import all_rank_call -class LoadBestModelCallback(Callback): - def __init__(self, monitor:str, larger_better:bool = True, only_state_dict:bool = True, +class LoadBestModelCallback(HasMonitorCallback): + def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True, save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, model_load_fn:Optional[Callable] = None, delete_after_train:bool = True): """ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 - :param str monitor: 监控的 metric 值。 + :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。 :param larger_better: 该 metric 值是否是越大越好。 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 @@ -33,6 +32,7 @@ class LoadBestModelCallback(Callback): 请在函数内完成对模型的加载。 :param delete_after_train: 在训练结束后是否删掉模型。 """ + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) if model_load_fn is not None: assert callable(model_load_fn), "`model_load_fn` must be a callable object." assert model_save_fn is not None, "`model_load_fn` and `model_save_fn` must be passed at the same time." @@ -56,15 +56,11 @@ class LoadBestModelCallback(Callback): self.real_save_folder = None self.buffer = BytesIO() - self.monitor = monitor - self.larger_better = larger_better self.save_folder = save_folder self.only_state_dict = only_state_dict self.model_save_fn = model_save_fn self.model_load_fn = model_load_fn self.delete_after_after = delete_after_train - self._real_monitor = None - self.monitor_value = float('-inf') if larger_better else float('inf') def on_after_trainer_initialized(self, trainer, driver): if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: @@ -76,13 +72,16 @@ class LoadBestModelCallback(Callback): raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " f"save best model when launch using script.") + super().on_after_trainer_initialized(trainer, driver) + + def on_sanity_check_end(self, trainer, sanity_check_res): + self.get_monitor_value(sanity_check_res) + def on_validate_end(self, trainer, results): - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (monitor_value < self.monitor_value and self.larger_better is False) or \ - (monitor_value > self.monitor_value and self.larger_better): - self.monitor_value = monitor_value + if len(results)==0: + return + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, model_save_fn=self.model_save_fn) diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 633fbb09..756d236b 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -8,7 +8,7 @@ __all__ = [ 'RichCallback' ] -from .callback import Callback +from .callback import HasMonitorCallback from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger @@ -28,15 +28,13 @@ def choose_progress_callback(progress_bar:str): return None -class ProgressCallback(Callback): +class ProgressCallback(HasMonitorCallback): def on_train_end(self, trainer): f_rich_progress.stop() def on_sanity_check_end(self, trainer, sanity_check_res): if len(sanity_check_res) and getattr(self, 'monitor', None) is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=sanity_check_res) + self.get_monitor_value(sanity_check_res) class RichCallback(ProgressCallback): @@ -46,28 +44,22 @@ class RichCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ - super().__init__() + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every self.progress_bar = f_rich_progress self.task2id = {} self.loss = 0 self.loss_round_ndigit = loss_round_ndigit - self.monitor = monitor - self.larger_better = larger_better - if larger_better: - self.monitor_value = float('-inf') - else: - self.monitor_value = float('inf') - self._real_monitor = monitor self.format_json = format_json def on_after_trainer_initialized(self, trainer, driver): if not self.progress_bar.disable: self.progress_bar.set_disable(flag=trainer.driver.get_local_rank() != 0) + super(RichCallback, self).on_after_trainer_initialized(trainer, driver) def on_train_begin(self, trainer): self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, @@ -109,16 +101,12 @@ class RichCallback(ProgressCallback): text_style = '' characters = '-' if self.monitor is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (self.larger_better and monitor_value > self.monitor_value) or \ - (not self.larger_better and monitor_value < self.monitor_value): + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if abs(self.monitor_value) != float('inf'): rule_style = 'spring_green3' text_style = '[bold]' characters = '+' - self.monitor_value = monitor_value self.progress_bar.print() self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " f"Batch:{trainer.batch_idx_in_epoch}", @@ -151,18 +139,12 @@ class RawTextCallback(ProgressCallback): :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ - super().__init__() + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every self.task2id = {} self.loss = 0 self.loss_round_ndigit = loss_round_ndigit - self.monitor = monitor - self.larger_better = larger_better - if larger_better: - self.monitor_value = float('-inf') - else: - self.monitor_value = float('inf') - self._real_monitor = monitor + self.set_monitor(monitor, larger_better) self.format_json = format_json self.num_signs = 10 @@ -189,14 +171,10 @@ class RawTextCallback(ProgressCallback): base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' text = '' if self.monitor is not None: - self._real_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, - real_monitor=self._real_monitor, - res=results) - if (self.larger_better and monitor_value > self.monitor_value) or \ - (not self.larger_better and monitor_value < self.monitor_value): + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): if abs(self.monitor_value) != float('inf'): text = '+'*self.num_signs + base_text + '+'*self.num_signs - self.monitor_value = monitor_value if len(text) == 0: text = '-'*self.num_signs + base_text + '-'*self.num_signs diff --git a/fastNLP/core/callbacks/utils.py b/fastNLP/core/callbacks/utils.py index 900aebf6..2720ba3f 100644 --- a/fastNLP/core/callbacks/utils.py +++ b/fastNLP/core/callbacks/utils.py @@ -19,23 +19,31 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->( if monitor in res: return monitor, res[monitor] + if real_monitor in res: + return real_monitor, res[real_monitor] + pairs = [] for idx, (key, value) in enumerate(res.items()): - match = SequenceMatcher(None, key, monitor).find_longest_match(0, len(key), 0, len(monitor)) - pairs.append((key, value, match.size, idx)) + match_size = _match_length(monitor, key) + pairs.append((key, value, match_size, idx)) pairs.sort(key=lambda pair: (pair[2], -pair[3]), reverse=True) key, value, match_size = pairs[0][:3] - if real_monitor is not None and real_monitor in res and real_monitor != key: - # 如果 real_monitor 比新找的更长就继续用之前的。 - match = SequenceMatcher(None, real_monitor, monitor).find_longest_match(0, len(real_monitor), 0, len(monitor)) - if match.size > match_size: - return real_monitor, res[real_monitor] + return key, value + - logger.warning(f"We can not find `{monitor}` in the evaluation result (with keys as {list(res.keys())}), " - f"we use the `{key}` as the monitor.") - real_monitor = key - return real_monitor, value +def _match_length(a:str, b:str)->int: + """ + 需要把长度短的放在前面 + + :param a: + :param b: + :return: + """ + short = a if len(a) < len(b) else b + long = a if len(a)>=len(b) else b + match = SequenceMatcher(None, short, long).find_longest_match(0, len(short), 0, len(long)) + return match.size diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index bd66d0a0..865acc89 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -219,6 +219,7 @@ class Evaluator: def remove_progress_bar(self, dataloader_name): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): f_rich_progress.destroy_task(self._rich_task_id) + f_rich_progress.refresh() # 使得最终的bar可以消失 delattr(self, '_rich_task_id') elif self.progress_bar == 'raw': desc = 'Evaluation ends' @@ -229,6 +230,7 @@ class Evaluator: def finally_progress_bar(self): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): f_rich_progress.destroy_task(self._rich_task_id) + f_rich_progress.refresh() delattr(self, '_rich_task_id') @property diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index b7456b61..b360c6a0 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -23,9 +23,9 @@ from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.envs import rank_zero_call -from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME +from fastNLP.core.utils.exceptions import EarlyStopException class Trainer(TrainerEventTrigger): @@ -50,6 +50,8 @@ class Trainer(TrainerEventTrigger): output_mapping: Optional[Union[Callable, Dict]] = None, accumulation_steps: int = 1, fp16: bool = False, + monitor: str = None, + larger_better: bool = True, marker: Optional[str] = None, **kwargs ): @@ -103,6 +105,10 @@ class Trainer(TrainerEventTrigger): 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; + :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 + 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。 + :param larger_better: monitor 的值是否是越大越好。 :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; :param kwargs: 一些其它的可能需要的参数; torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; @@ -211,6 +217,8 @@ class Trainer(TrainerEventTrigger): self.evaluator = None self.epoch_validate = lambda *args, **kwargs: ... self.step_validate = lambda *args, **kwargs: ... + self.monitor = monitor + self.larger_better = larger_better if metrics is not None and validate_dataloaders is not None: if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") @@ -240,6 +248,7 @@ class Trainer(TrainerEventTrigger): else: # validate_every > 0 self._step_validate_filter = Filter(every=validate_every) + self.metrics = metrics self.validate_every = validate_every @@ -321,6 +330,10 @@ class Trainer(TrainerEventTrigger): self.driver.barrier() self.on_train_end() self.driver.barrier() + + except EarlyStopException as e: + logger.info(f"Catch early stop exception: {e.msg}.") + self.on_exception(e) except KeyboardInterrupt as e: self.driver.on_exception() self.on_exception(e) @@ -610,7 +623,7 @@ class Trainer(TrainerEventTrigger): r""" 用于断点重训的加载函数; 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 - 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; + 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; 注意我们目前不支持单卡到多卡的断点重训; diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index d9d66970..019e6fad 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -49,13 +49,13 @@ class Driver(ABC): 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; - 注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; + 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 可以可以加载。 :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, - 如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 + 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 """ if dist is None and reproducible is False: diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index 596148bc..c467b868 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -3,7 +3,7 @@ from typing import Optional, Union from .jittor_driver import JittorDriver from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import ReproducibleIterator +from fastNLP.core.samplers import ReproducibleSampler if _NEED_IMPORT_JITTOR: import jittor @@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver): def test_step(self, batch): return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): pass diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index f39053d3..4c99a2f5 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -3,7 +3,7 @@ from typing import Dict, Union from .jittor_driver import JittorDriver from fastNLP.core.utils import auto_param_call from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler if _NEED_IMPORT_JITTOR: import jittor @@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver): def is_distributed(self): return False - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], reproducible: bool = False, sampler_or_batch_sampler=None): # reproducible 的相关功能暂时没有实现 if isinstance(dist, ReproducibleBatchSampler): raise NotImplementedError dataloader.batch_sampler = dist_sample - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): raise NotImplementedError dataloader.batch_sampler.sampler = dist if reproducible: raise NotImplementedError - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): + if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): return dataloader - elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): + elif isinstance(dataloader.batch_sampler, RandomBatchSampler): return dataloader else: # TODO - batch_sampler = ReproducibleBatchSampler( + batch_sampler = RandomBatchSampler( batch_sampler=dataloader.batch_sampler, batch_size=dataloader.batch_sampler.batch_size, drop_last=dataloader.drop_last diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index d2d548f5..65af48a1 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -19,7 +19,7 @@ from fastNLP.core.utils import ( paddle_move_data_to_device, is_in_paddle_dist, ) -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler +from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -312,13 +312,13 @@ class PaddleFleetDriver(PaddleDriver): def test_step(self, batch): return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): dataloader.batch_sampler.sampler = dist return dataloader @@ -340,7 +340,7 @@ class PaddleFleetDriver(PaddleDriver): # trainer elif dist == "dist": # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): + if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): dataloader.batch_sampler.sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank, @@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver): return dataloader # evaluator elif dist == "unrepeatdist": - sampler = UnrepeatedSampler( + sampler = UnrepeatedRandomSampler( dataset=dataloader.dataset, shuffle=shuffle, seed=int(os.environ.get("FASTNLP_SEED", 0)) diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 97f14bb6..c57ba14d 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -10,7 +10,7 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -139,7 +139,7 @@ class PaddleSingleDriver(PaddleDriver): """ return paddle_move_data_to_device(batch, "gpu:0") - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], reproducible: bool = False, sampler_or_batch_sampler=None): # 暂时不支持IteratorDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ @@ -147,12 +147,12 @@ class PaddleSingleDriver(PaddleDriver): if isinstance(dist, ReproducibleBatchSampler): dataloader.batch_sampler = dist return dataloader - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): dataloader.batch_sampler.sampler = dist return dataloader if reproducible: - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): + if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): return dataloader elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): return dataloader diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 9e5e16fd..44cabcf4 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -28,11 +28,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( ) from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.utils import auto_param_call, check_user_specific_params -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler +from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ + re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object -from fastNLP.core.samplers import re_instantiate_sampler class TorchDDPDriver(TorchDriver): @@ -446,13 +446,23 @@ class TorchDDPDriver(TorchDriver): # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, reproducible: bool = False): - # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; if isinstance(dist, ReproducibleBatchSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) return replace_batch_sampler(dataloader, dist) - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) return replace_sampler(dataloader, dist) # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; @@ -465,7 +475,7 @@ class TorchDDPDriver(TorchDriver): if isinstance(dist, ReproducibleBatchSampler): dist = re_instantiate_sampler(dist) return replace_batch_sampler(dataloader, dist) - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): dist = re_instantiate_sampler(dist) return replace_sampler(dataloader, dist) return dataloader @@ -481,7 +491,7 @@ class TorchDDPDriver(TorchDriver): pad=True ) return replace_batch_sampler(dataloader, batch_sampler) - elif isinstance(args.sampler, ReproducibleIterator): + elif isinstance(args.sampler, ReproducibleSampler): sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, @@ -503,14 +513,15 @@ class TorchDDPDriver(TorchDriver): return replace_sampler(dataloader, sampler) # evaluator elif dist == "unrepeatdist": - # todo @yh,补充 unrepeatdist 相关内容; args = self.get_dataloader_args(dataloader) - - # todo 判断 batch_sampler; - sampler = UnrepeatedSampler( - dataset=args.dataset, - shuffle=args.shuffle, - ) + if isinstance(args.sampler, ReproducibleSampler): + sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) + elif not isinstance(args.sampler, UnrepeatedSampler): + sampler = UnrepeatedSequentialSampler( + dataset=args.dataset + ) + else: + sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank @@ -588,7 +599,7 @@ class TorchDDPDriver(TorchDriver): :param group: :return: """ - return fastnlp_torch_all_gather(obj, device=self.data_device, group=group) + return fastnlp_torch_all_gather(obj, group=group) def find_free_network_port() -> str: diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 37717f54..5e3819e7 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -1,11 +1,8 @@ import io import pickle -from typing import Mapping _pickler = pickle.Pickler _unpickler = pickle.Unpickler -from abc import ABC -from typing import Any, Union, List -import numpy as np +from typing import Any, List from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 @@ -13,103 +10,25 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist + try: + from torch._C._distributed_c10d import ProcessGroupMPI + except ImportError: + _MPI_AVAILABLE = False + + try: + from torch._C._distributed_c10d import ProcessGroupNCCL + except ImportError: + _NCCL_AVAILABLE = False + + try: + from torch._C._distributed_c10d import ProcessGroupGloo + from torch._C._distributed_c10d import _ProcessGroupWrapper + except ImportError: + _GLOO_AVAILABLE = False from fastNLP.core.utils import apply_to_collection - -def all_gather_object(object_list, obj, group=None): - """ - Gathers picklable objects from the whole group into a list. Similar to - :func:`all_gather`, but Python objects can be passed in. Note that the object - must be picklable in order to be gathered. - - Args: - object_list (list[Any]): Output list. It should be correctly sized as the - size of the group for this collective and will contain the output. - object (Any): Pickable Python object to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - - Returns: - None. If the calling rank is part of this group, the output of the - collective will be populated into the input ``object_list``. If the - calling rank is not part of the group, the passed in ``object_list`` will - be unmodified. - - .. note:: Note that this API differs slightly from the :func:`all_gather` - collective since it does not provide an ``async_op`` handle and thus - will be a blocking call. - - .. note:: For NCCL-based processed groups, internal tensor representations - of objects must be moved to the GPU device before communication takes - place. In this case, the device used is given by - ``torch.cuda.current_device()`` and it is the user's responsiblity to - ensure that this is set so that each rank has an individual GPU, via - ``torch.cuda.set_device()``. - - .. warning:: - :func:`all_gather_object` uses ``pickle`` module implicitly, which is - known to be insecure. It is possible to construct malicious pickle data - which will execute arbitrary code during unpickling. Only call this - function with data you trust. - - Example:: - >>> # Note: Process group initialization omitted on each rank. - >>> import torch.distributed as dist - >>> # Assumes world_size of 3. - >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object - >>> output = [None for _ in gather_objects] - >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) - >>> output - ['foo', 12, {1: 2}] - """ - if dist.distributed_c10d._rank_not_in_group(group): - return - - input_tensor, local_size = _object_to_tensor(obj) - current_device = torch.device("cpu") - if dist.is_nccl_available() and isinstance( - group or dist.distributed_c10d._get_default_group(), dist.ProcessGroupNCCL - ): - # See note about using torch.cuda.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device("cuda", torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) - # Gather all local sizes. This is so that we can find the max size, and index - # until the correct size when deserializing the tensors. - group_size = dist.get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device - ) - object_size_list = [ - object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) - ] - # Allgather tensor sizes - dist.all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] - # Resize tensor to max size across all ranks. - input_tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=current_device - ) - # Output tensors are nonoverlapping views of coalesced_output_tensor - output_tensors = [ - coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] - for i in range(group_size) - ] - dist.all_gather(output_tensors, input_tensor, group=group) - # Deserialize outputs back to object. - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - if tensor.device != torch.device("cpu"): - tensor = tensor.cpu() - tensor_size = object_size_list[i] - object_list[i] = _tensor_to_object(tensor, tensor_size) - - def _validate_output_list_for_rank(my_rank, dst, gather_list): if dst == my_rank: if not gather_list: @@ -123,8 +42,10 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): ) -def gather_object(obj, object_gather_list=None, dst=0, group=None): +def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): """ + 从其它 rank gather 东西到 dst rank 。 + Gathers picklable objects from the whole group in a single process. Similar to :func:`gather`, but Python objects can be passed in. Note that the object must be picklable in order to be gathered. @@ -176,6 +97,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): # Ensure object_gather_list is specified appopriately. my_rank = dist.get_rank() _validate_output_list_for_rank(my_rank, dst, object_gather_list) + # 防止 unpickle 的时候出现在了发送的 gpu 上。 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) input_tensor, local_size = _object_to_tensor(obj) group_backend = dist.get_backend(group) current_device = torch.device("cpu") @@ -266,113 +189,11 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): return _tensor_to_object(tensor.cpu(), size) -def _all_gather(obj, **kwargs): - group = kwargs.get('group', None) - if isinstance(obj, torch.Tensor): - gathered_tensor = [torch.zeros_like(obj) for _ in - range(torch.distributed.get_world_size(group=group))] - - torch.distributed.all_gather(gathered_tensor, obj, group=group) - - return gathered_tensor - - elif isinstance(obj, tuple) and isinstance(obj[1], torch.Tensor): - tensor, size = obj - # 首先需要同步 size 吧? - group_size = dist.get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=tensor.device - ) - object_size_list = [ - object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) - ] - dist.all_gather(object_size_list, size, group=group) - max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] - # Resize tensor to max size across all ranks. - tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=tensor.device - ) - - # Output tensors are nonoverlapping views of coalesced_output_tensor - output_tensors = [ - coalesced_output_tensor[max_object_size * i: max_object_size * (i + 1)] - for i in range(group_size) - ] - dist.all_gather(output_tensors, tensor, group=group) - object_list = [] - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - tensor_size = object_size_list[i] - object_list.append(_tensor_to_object(tensor, tensor_size)) - return object_list - elif isinstance(obj, tuple) and len(obj) == 2: - obj, _type = obj - gathered_tensor = [torch.zeros_like(obj) for _ in - range(torch.distributed.get_world_size(group=group))] - - torch.distributed.all_gather(gathered_tensor, obj, group=group) - - if _type == np.ndarray: - gathered_tensor = [t.detach().cpu().numpy() for t in gathered_tensor] - else: - gathered_tensor = [_type(t.item()) for t in gathered_tensor] - - return gathered_tensor - else: - raise RuntimeError("Unsupported types to implement all_gather.") - - -class CanTransferDataType(ABC): - """ - 检测可以进行传输的对象。 - - """ - - @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: - if cls is CanTransferDataType: - if issubclass(subclass, Mapping): - return False - if subclass in (torch.Tensor, tuple, list, str, int, float, bool, np.ndarray): - return True - return False - return NotImplemented - - -def _tensorize(obj, device=None): - if isinstance(obj, torch.Tensor): - return obj - if isinstance(obj, bool): - return torch.tensor(obj, dtype=torch.uint8, device=device), bool - if isinstance(obj, float): - return torch.tensor(obj, dtype=torch.float, device=device), float - if isinstance(obj, int): - return torch.tensor(obj, dtype=torch.int, device=device), int - if isinstance(obj, np.ndarray): - return torch.from_numpy(obj), np.ndarray - return _object_to_tensor(obj, device) - - def _to_device(tensor, device): return tensor.contiguous().to(device) -def convert_to_tensors(data: Any, device=None) -> Any: - data = apply_to_collection(data, CanTransferDataType, _tensorize) - def _move_to_device_and_make_contiguous(t: Union[torch.Tensor, tuple], device: Union[str, torch.device]): - if isinstance(t, tuple): - if isinstance(t[1], torch.Tensor): # 说明是 object 转的 - return t[0].to(device).contiguous(), t[1].to(device) - else: # 说明第二个元素是type,见 to_dtype_tensor 函数 - return t[0].to(device).contiguous(), t[1] - return t.to(device).contiguous() - - data = apply_to_collection(data, (torch.Tensor, tuple), _move_to_device_and_make_contiguous, device=device) - return data - - -def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: +def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: """ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 @@ -390,36 +211,28 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: {'a': 1, 'b':[1, 2], 'c':{'d': 2}} ] - :param obj: 任意结构的数据,所有的 value 都会变成 list ,其长度为 world_size ,依次为每个 rank 上的对象值 - :param device: 当前 rank 使用的 device 是哪个。为 None 的话默认使用 torch.cuda.current_device() 获取。 + :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 + 序列化之后进行传输。 + :param device: 当前该参数无意义。 :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 - # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) - if device is None: - device = torch.cuda.current_device() - if _TORCH_GREATER_EQUAL_1_8: + if isinstance(obj, torch.Tensor): + objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] + dist.all_gather(objs, obj, group=group) + else: objs = [None for _ in range(dist.get_world_size(group))] - dist.all_gather_object(objs, obj) - objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 - return objs - group = group if group is not None else torch.distributed.group.WORLD - data = convert_to_tensors(obj, device=device) - data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) - - objs = [] - - def _get_obj_on_idx(obj, idx): - return obj[idx] - - for i in range(dist.get_world_size(group)): - objs.append(apply_to_collection(data, dtype=list, function=_get_obj_on_idx, idx=i)) - + # 防止 unpickle 的时候弄到发送的 gpu 上了 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if _TORCH_GREATER_EQUAL_1_8: + dist.all_gather_object(objs, obj, group=group) + else: + objs = all_gather_object(objs, obj, group=group) return objs -def fastnlp_torch_broadcast_object(obj, src, device, group=None): +def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): """ 将 src 上的 obj 对象广播到其它 rank 上。 @@ -430,10 +243,9 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): :return: """ cur_rank = dist.get_rank(group) - # if cur_rank == src: - # # 如果有 tensor 全部移动到 cpu 上,方便 pickle - # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) - + if cur_rank == src: + # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) if _TORCH_GREATER_EQUAL_1_8: if cur_rank!=src: get_obj = [None] @@ -442,6 +254,8 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): else: dist.broadcast_object_list([obj], src=src, group=group) return obj + if device is None: + device = torch.cuda.current_device() if cur_rank == src: tensor, size = _object_to_tensor(obj, device=device) @@ -460,3 +274,107 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): return _tensor_to_object(tensor, tensor_size=size.item()) +def _check_for_nccl_backend(group): + pg = group or dist.distributed_c10d._get_default_group() + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, _ProcessGroupWrapper): + pg = pg.wrapped_pg + + return ( + dist.is_nccl_available() and + isinstance(pg, dist.ProcessGroupNCCL) + ) + + +def all_gather_object(object_list, obj, group=None): + """ + 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 + + Gathers picklable objects from the whole group into a list. Similar to + :func:`all_gather`, but Python objects can be passed in. Note that the object + must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + object (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if dist._rank_not_in_group(group): + return + + input_tensor, local_size = _object_to_tensor(obj) + current_device = torch.device("cpu") + is_nccl_backend = _check_for_nccl_backend(group) + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device("cuda", torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist.get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + dist.all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + if tensor.device != torch.device("cpu"): + tensor = tensor.cpu() + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 14a135ee..19e687b8 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -13,9 +13,8 @@ __all__ = [ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler from fastNLP.core.log import logger -from fastNLP.core.samplers import re_instantiate_sampler class TorchSingleDriver(TorchDriver): @@ -130,13 +129,13 @@ class TorchSingleDriver(TorchDriver): else: return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, reproducible: bool = False): # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; if isinstance(dist, ReproducibleBatchSampler): return replace_batch_sampler(dataloader, dist) - elif isinstance(dist, ReproducibleIterator): + elif isinstance(dist, ReproducibleSampler): return replace_sampler(dataloader, dist) # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; @@ -144,7 +143,7 @@ class TorchSingleDriver(TorchDriver): if isinstance(args.batch_sampler, ReproducibleBatchSampler): batch_sampler = re_instantiate_sampler(args.batch_sampler) return replace_batch_sampler(dataloader, batch_sampler) - elif isinstance(args.sampler, ReproducibleIterator): + elif isinstance(args.sampler, ReproducibleSampler): sampler = re_instantiate_sampler(args.sampler) return replace_sampler(dataloader, sampler) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index ce1bff14..b200f1fd 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler class TorchDriver(Driver): @@ -182,8 +182,8 @@ class TorchDriver(Driver): # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; - # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 - # sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; + # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 + # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; dataloader_args = self.get_dataloader_args(dataloader) if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): sampler = dataloader_args.batch_sampler @@ -247,11 +247,10 @@ class TorchDriver(Driver): dataloader_args = self.get_dataloader_args(dataloader) if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): sampler = dataloader_args.batch_sampler - elif isinstance(dataloader_args.sampler, ReproducibleIterator): + elif isinstance(dataloader_args.sampler, ReproducibleSampler): sampler = dataloader_args.sampler elif self.is_distributed(): - raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " - "`ReproducibleBatchSampler` or `ReproducibleIterator`.") + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") else: sampler = ReproducibleBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, @@ -291,7 +290,7 @@ class TorchDriver(Driver): @staticmethod def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover - """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed + """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with ``seed_everything(seed, workers=True)``. See also the PyTorch documentation on diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index bb2ee661..c3cc2d39 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -9,18 +9,28 @@ __all__ = [ 'MixSequentialSampler', 'PollingSampler', - 'ReproducibleIterator', + 'ReproducibleSampler', 'RandomSampler', - - 're_instantiate_sampler', + "SequentialSampler", + "SortedSampler", 'UnrepeatedSampler', - "UnrepeatedSortedSampler" + 'UnrepeatedRandomSampler', + "UnrepeatedSortedSampler", + "UnrepeatedSequentialSampler", + + "RandomBatchSampler", + "BucketedBatchSampler", + "ReproducibleBatchSampler", + + "re_instantiate_sampler", + "conversion_between_reproducible_and_unrepeated_sampler" ] from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler -from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler +from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler -from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler -from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler +from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler +from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler +from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 3e39aca5..c4116e24 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -1,6 +1,6 @@ __all__ = [ 'BucketedBatchSampler', - "ReproducibleBatchSampler" + "RandomBatchSampler" ] import math @@ -16,7 +16,10 @@ from fastNLP.core.log import logger from abc import abstractmethod -class ReproducibleBatchIterator: +class ReproducibleBatchSampler: + def __init__(self, **kwargs): + pass + @abstractmethod def set_distributed(self, num_replicas, rank, pad=True): raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") @@ -41,19 +44,25 @@ class ReproducibleBatchIterator: def set_epoch(self, epoch): pass + @property + def batch_idx_in_epoch(self): + raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") -class ReproducibleBatchSampler(ReproducibleBatchIterator): + +class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): """ 可以使得 batch_sampler 对象状态恢复的 wrapper 。 - :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 + :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 :param batch_size: 每个 batch 的大小是多少。 :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 :param kwargs: fastNLP 内部使用。 """ + super().__init__() + self.batch_sampler = batch_sampler self.batch_size = batch_size self.drop_last = drop_last @@ -138,7 +147,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator): (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size -class BucketedBatchSampler(ReproducibleBatchIterator): +class BucketedBatchSampler(ReproducibleBatchSampler): def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): """ diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 6d2c8246..1dc226a5 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -1,24 +1,21 @@ -from typing import Dict, List +from typing import Dict, List, Union import math import numpy as np from fastNLP.core.log import logger +from fastNLP.core.dataset import DataSet __all__ = [ - 'ReproducibleIterator', + 'ReproducibleSampler', 'RandomSampler', - 're_instantiate_sampler' + "SortedSampler", + "SequentialSampler" ] -def re_instantiate_sampler(sampler): - all_attributes = vars(sampler) - return type(sampler)(**all_attributes) - - -class ReproducibleIterator: +class ReproducibleSampler: """ - 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler + 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 """ @@ -46,7 +43,7 @@ class ReproducibleIterator: pass -class RandomSampler(ReproducibleIterator): +class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): """ @@ -156,8 +153,8 @@ class RandomSampler(ReproducibleIterator): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ - "and current dataset." + assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ + f"and current dataset({len(self.dataset)})." self.seed = states['seed'] self.epoch = states['epoch'] self.num_consumed_samples = states['num_consumed_samples'] @@ -214,9 +211,132 @@ class RandomSampler(ReproducibleIterator): self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) +class SequentialSampler(RandomSampler): + def __init__(self, dataset, dist_mode:str='interval', **kwargs): + """ + 按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param kwargs: + """ + super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) + + def __iter__(self): + if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 + self.num_consumed_samples = 0 + self.during_iter = True + indices = self.generate_indices() + + if self.pad: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + + assert len(indices) == self.total_size + + # subsample + indices = indices[self.num_consumed_samples:] + indices = indices[self.rank:len(indices):self.num_replicas] + assert len(indices) == self.num_left_samples + for index in indices: + self.num_consumed_samples += self.num_replicas + yield index + self.during_iter = False + self.num_consumed_samples = 0 + def generate_indices(self) -> List[int]: + """ + 生成随机序列 + :return: + """ + return list(range(len(self.dataset))) + def state_dict(self) -> Dict: + states = { + 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; + 'sampler_type': self.__class__.__name__, + 'length': len(self.dataset), + } + return states + def load_state_dict(self, states: Dict): + # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ + "during an unfinished iteration." + + assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ + f"we cannot use {self.__class__.__name__} to load it." + + length = states['length'] + assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ + f"and current dataset({len(self.dataset)})." + self.num_consumed_samples = states['num_consumed_samples'] + if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 + self.num_consumed_samples = 0 + + +class SortedSampler(SequentialSampler): + def __init__(self, dataset, length:Union[str, List], **kwargs): + """ + 将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 + DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 + :param seed: 设置的随机数种子 + :param kwargs: fastNLP 保留使用 + """ + super().__init__(dataset=dataset, **kwargs) + if isinstance(dataset, DataSet): + length = dataset.get_field(length) + if not isinstance(length[0], int): + length = list(map(len, length)) + else: + assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ + "the length parameter can only be List[int]" + + assert len(length) == len(dataset), "The length of `data` and `length` should be equal." + + self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 + self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 + + def generate_indices(self) -> List[int]: + return self.sorted_indices + + def __iter__(self): + if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 + self.num_consumed_samples = 0 + self.during_iter = True + indices = self.generate_indices() + + if self.pad: + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + + assert len(indices) == self.total_size + + # subsample + indices = indices[self.num_consumed_samples:] + indices = indices[self.rank:len(indices):self.num_replicas] + assert len(indices) == self.num_left_samples + + for index in indices: + self.num_consumed_samples += self.num_replicas + yield index + self.during_iter = False + self.num_consumed_samples = 0 diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index 18ae16db..d7913d20 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -1,6 +1,8 @@ __all__ = [ + 'UnrepeatedSampler', 'UnrepeatedSortedSampler', - 'UnrepeatedSampler' + 'UnrepeatedRandomSampler', + "UnrepeatedSequentialSampler" ] from typing import List, Union @@ -10,13 +12,21 @@ import numpy as np class UnrepeatedSampler: + """ + 在多卡场景下保证 indice 不重复的 sampler + """ + pass + + +class UnrepeatedRandomSampler(UnrepeatedSampler): def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): """ 考虑在多卡evaluate的场景下,不能重复sample。 - :param dataset: - :param shuffle: - :param seed: + :param dataset: 实现了 __len__ 方法的数据容器。 + :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 + :param seed: 设置的随机数种子 + :param kwargs: fastNLP 保留使用 """ self.dataset = dataset self.shuffle = shuffle @@ -33,8 +43,8 @@ class UnrepeatedSampler: :return: """ num_common = len(self.dataset)//self.num_replicas - self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) - return self.num_samples + num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) + return num_samples def __iter__(self): indices = self.generate_indices() @@ -83,8 +93,8 @@ class UnrepeatedSampler: return self -class UnrepeatedSortedSampler(UnrepeatedSampler): - def __init__(self, dataset, length:Union[str, List], seed: int = 0): +class UnrepeatedSortedSampler(UnrepeatedRandomSampler): + def __init__(self, dataset, length:Union[str, List], **kwargs): """ 将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 batch 数量不完全一致。 @@ -92,11 +102,9 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): :param dataset: 实现了 __len__ 方法的数据容器。 :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 - :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 - :param seed: 设置的随机数种子 :param kwargs: fastNLP 保留使用 """ - super().__init__(dataset=dataset, shuffle=False, seed=seed) + super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) if isinstance(dataset, DataSet): length = dataset.get_field(length) if not isinstance(length[0], int): @@ -107,8 +115,29 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): assert len(length) == len(dataset), "The length of `data` and `length` should be equal." - self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 - self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 + length = np.array(length, dtype=int) # 按照长到短排列的序号。 + self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的 def generate_indices(self) -> List[int]: return self.sorted_indices + + +class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): + def __init__(self, dataset, **kwargs): + """ + 按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param kwargs: + """ + super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) + + def __iter__(self): + indices = self.generate_indices() + indices = indices[self.rank:len(indices):self.num_replicas] + for index in indices: + yield index + + def generate_indices(self) -> List[int]: + return list(range(len(self.dataset))) + diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py new file mode 100644 index 00000000..dd90fe7c --- /dev/null +++ b/fastNLP/core/samplers/utils.py @@ -0,0 +1,42 @@ +__all__ = [ + 're_instantiate_sampler', + 'conversion_between_reproducible_and_unrepeated_sampler' +] + +from fastNLP.core.samplers.unrepeated_sampler import * +from fastNLP.core.samplers.reproducible_sampler import * + + +def conversion_between_reproducible_and_unrepeated_sampler(sampler): + """ + 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 + ReproducibleSampler, + + :param sampler: + :return: + """ + assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ + "The sampler must be UnrepeatedSampler or ReproducibleSampler" + if isinstance(sampler, UnrepeatedSampler): + if isinstance(sampler, UnrepeatedRandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) + elif isinstance(sampler, UnrepeatedSequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) + elif isinstance(sampler, UnrepeatedSortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) + raise TypeError(f"{sampler.__class__} has no unrepeated version.") + else: + if isinstance(sampler, RandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) + elif isinstance(sampler, SequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) + elif isinstance(sampler, SortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) + raise TypeError(f"{sampler.__class__} has no reproducible version.") + + +def re_instantiate_sampler(sampler, new_sampler_class=None): + all_attributes = vars(sampler) + if new_sampler_class is not None: + return new_sampler_class(**all_attributes) + return type(sampler)(**all_attributes) \ No newline at end of file diff --git a/fastNLP/core/utils/exceptions.py b/fastNLP/core/utils/exceptions.py new file mode 100644 index 00000000..afedbcba --- /dev/null +++ b/fastNLP/core/utils/exceptions.py @@ -0,0 +1,10 @@ + +class EarlyStopException(BaseException): + r""" + 用于EarlyStop时从Trainer训练循环中跳出。 + + """ + + def __init__(self, msg): + super(EarlyStopException, self).__init__(msg) + self.msg = msg diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 256cc906..a865f4c1 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -94,9 +94,6 @@ class FRichProgress(Progress, metaclass=Singleton): self.print = self.console.print self.log = self.console.log - # start new - self.start() - self.console.show_cursor(show=True) return self def set_transient(self, transient: bool = True): @@ -154,6 +151,7 @@ class FRichProgress(Progress, metaclass=Singleton): super().start() self.console.show_cursor(show=True) + if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: f_rich_progress = FRichProgress().new_progess( "[progress.description]{task.description}", diff --git a/tests/core/callbacks/test_utils.py b/tests/core/callbacks/test_utils.py index 10aba0e0..fdec93e0 100644 --- a/tests/core/callbacks/test_utils.py +++ b/tests/core/callbacks/test_utils.py @@ -12,32 +12,27 @@ def test_get_monitor_value(): with Capturing() as output: monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) assert monitor == 'f1' and value==0.2 - assert 'We can not find' not in output[0] # 测试可以匹配,且选择更靠前的 res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='f1', real_monitor=None, res=res) assert monitor=='acc#f1' and value==0.2 - assert 'We can not find' in output[0] # 测试monitor匹配不上,使用real_monitor res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: - monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#rec', res=res) + monitor, value = _get_monitor_value(monitor='acc', real_monitor='acc#rec', res=res) assert monitor=='acc#rec' and value==0.3 - assert 'We can not find' not in output[0] # 测试monitor/real_monitor匹配不上, 重新选择 res = {'acc#f1': 0.2, 'acc#rec': 0.3, 'add#f':0.4} with Capturing() as output: monitor, value = _get_monitor_value(monitor='acc#f', real_monitor='acc#r', res=res) assert monitor=='acc#f1' and value==0.2 - assert 'We can not find' in output[0] # 测试partial的位置 res = {"acc#acc": 0.52, "loss#loss": 2} with Capturing() as output: monitor, value = _get_monitor_value(monitor='-loss', real_monitor=None, res=res) assert monitor=='loss#loss' and value==2 - assert 'We can not find' in output[0] diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 33662d7f..b2f5864b 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -10,7 +10,7 @@ from paddle.io import DataLoader, BatchSampler from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.samplers import ReproducibleBatchSampler +from fastNLP.core.samplers import RandomBatchSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset from fastNLP.core import synchronize_safe_rm @@ -153,7 +153,7 @@ class TestSingleDeviceFunction: @pytest.mark.parametrize( "dist_sampler", - ["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] + ["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] ) @pytest.mark.parametrize( "reproducible", diff --git a/tests/core/drivers/torch_driver/test_dist_utils.py b/tests/core/drivers/torch_driver/test_dist_utils.py index 8fb7eb34..2d2145c8 100644 --- a/tests/core/drivers/torch_driver/test_dist_utils.py +++ b/tests/core/drivers/torch_driver/test_dist_utils.py @@ -7,38 +7,10 @@ import numpy as np # print(isinstance((1,), tuple)) # exit() -from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, convert_to_tensors, fastnlp_torch_broadcast_object +from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context - -def test_convert_to_tensors(): - local_rank = 0 - obj = { - 'tensor': torch.full(size=(2,), fill_value=local_rank), - 'numpy': np.full(shape=(1,), fill_value=local_rank), - 'bool': local_rank % 2 == 0, - 'float': local_rank + 0.1, - 'int': local_rank, - 'dict': { - 'rank': local_rank - }, - 'list': [local_rank] * 2, - 'str': 'xxx' - } - data = convert_to_tensors(obj) - assert len(data) == len(obj) - assert (data['tensor'] == obj['tensor']).sum() == 2 - for name in ['list', 'str']: - assert len(data[name])==2 and isinstance(data[name][0], torch.Tensor) and \ - isinstance(data[name][1], torch.Tensor) and data[name][1].ndim==1 - - for name in ['numpy', 'bool', 'float', 'int']: - assert isinstance(data[name][0], torch.Tensor) and data[name][0].numel()==1 - - assert isinstance(data['dict']['rank'][0], torch.Tensor) and data[name][0].numel() == 1 - - @magic_argv_env_context def test_fastnlp_torch_all_gather(): os.environ['MASTER_ADDR'] = '127.0.0.1' @@ -66,7 +38,7 @@ def test_fastnlp_torch_all_gather(): 'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), torch.full(size=(2,), fill_value=local_rank).cuda()] } - data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) + data = fastnlp_torch_all_gather(obj) world_size = int(os.environ['WORLD_SIZE']) assert len(data) == world_size for i in range(world_size): @@ -81,10 +53,12 @@ def test_fastnlp_torch_all_gather(): assert data[i]['tensors'][0][0] == i for obj in [1, True, 'xxx']: - data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) + data = fastnlp_torch_all_gather(obj) assert len(data)==world_size assert data[0]==data[1] + dist.destroy_process_group() + @magic_argv_env_context def test_fastnlp_torch_broadcast_object(): os.environ['MASTER_ADDR'] = '127.0.0.1' @@ -130,3 +104,4 @@ def test_fastnlp_torch_broadcast_object(): for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) assert int(data)==0 + dist.destroy_process_group() diff --git a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py index 81d693fc..161bbfe8 100644 --- a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py +++ b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py @@ -30,7 +30,7 @@ class SequenceDataSet: def check_replace_sampler(driver): - # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler + # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler # reproducible 是 True 和 False # 需要 check 返回的 sampler 和 dataloader 都不同了 diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index edc7b86b..d51dd912 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -4,7 +4,7 @@ import numpy as np import pytest from itertools import chain -from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler +from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -18,7 +18,7 @@ class TestReproducibleBatchSampler: before_batch_size = 7 dataset = TorchNormalDataset(num_of_data=100) dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) dataloader = replace_batch_sampler(dataloader, re_batchsampler) forward_steps = 3 @@ -28,15 +28,15 @@ class TestReproducibleBatchSampler: # 1. 保存状态 _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) + assert isinstance(_get_re_batchsampler, RandomBatchSampler) state = _get_re_batchsampler.state_dict() assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, - "sampler_type": "ReproducibleBatchSampler"} + "sampler_type": "RandomBatchSampler"} # 2. 断点重训,重新生成一个 dataloader; # 不改变 batch_size; dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) @@ -53,7 +53,7 @@ class TestReproducibleBatchSampler: # 改变 batch_size; after_batch_size = 3 dataloader = DataLoader(dataset, batch_size=after_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) @@ -99,7 +99,7 @@ class TestReproducibleBatchSampler: dataset = TorchNormalDataset(num_of_data=100) # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) dataloader = replace_batch_sampler(dataloader, re_batchsampler) # 将一轮的所有数据保存下来,看是否恢复的是正确的; @@ -111,13 +111,13 @@ class TestReproducibleBatchSampler: # 1. 保存状态 _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) + assert isinstance(_get_re_batchsampler, RandomBatchSampler) state = _get_re_batchsampler.state_dict() # 2. 断点重训,重新生成一个 dataloader; # 不改变 batch_size; dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) diff --git a/tests/core/samplers/test_reproducible_sampler.py b/tests/core/samplers/test_reproducible_sampler.py index 0a3697d3..981d6a03 100644 --- a/tests/core/samplers/test_reproducible_sampler.py +++ b/tests/core/samplers/test_reproducible_sampler.py @@ -1,18 +1,14 @@ -import unittest - -from itertools import product import numpy as np +import pytest from functools import partial -from array import array +from itertools import chain -from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler +from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler from tests.helpers.datasets.torch_data import TorchNormalDataset - -class TestRandomSamplerYh(unittest.TestCase): +class TestRandomSamplerYh: def test_init(self): # 测试能否正确初始化 dataset = TorchNormalDataset(num_of_data=100) @@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase): dataset = TorchNormalDataset(num_of_data=100) sampler = RandomSampler(dataset) for i in sampler: - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): sampler.set_distributed(1, 0) break @@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase): dataset = TorchNormalDataset(num_of_data=100) sampler = RandomSampler(dataset, shuffle=False) sampler.set_distributed(num_replicas=2, rank=0, pad=False) - self.assertEqual(len(sampler), 50) + assert len(sampler)==50 count = 0 for i in sampler: - self.assertEqual(i%2, 0) + assert i%2==0 count += 1 - self.assertEqual(count, 50) + assert count == 50 sampler.set_distributed(num_replicas=2, rank=1, pad=False) - self.assertEqual(len(sampler), 50) + assert len(sampler)==50 count = 0 for i in sampler: - self.assertEqual(i%2, 1) + assert i%2==1 count += 1 - self.assertEqual(count, 50) + assert count==50 dataset = TorchNormalDataset(num_of_data=101) sampler = RandomSampler(dataset, shuffle=False) sampler.set_distributed(num_replicas=2, rank=0, pad=True) - self.assertEqual(len(sampler), 51) + assert len(sampler)==51 count = 0 for i in sampler: - self.assertEqual(i%2, 0) + assert i%2==0 count += 1 - self.assertEqual(count, 51) + assert count == 51 sampler.set_distributed(num_replicas=2, rank=1, pad=True) - self.assertEqual(len(sampler), 51) + assert len(sampler) == 51 count = 0 for i in sampler: if i!=0: - self.assertEqual(i%2, 1) + assert i%2==1 count += 1 - self.assertEqual(count, 51) + assert count == 51 def test_state_dict_check_length(self): dataset = TorchNormalDataset(num_of_data=100) @@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase): states = sampler.state_dict() new_ds = TorchNormalDataset(num_of_data=10) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): new_sampler = RandomSampler(new_ds) new_sampler.load_state_dict(states) @@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase): new_sampler = RandomSampler(new_ds) new_sampler.load_state_dict(states) - def test_state_dict(self): + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('pre_shuffle', [True, False]) + @pytest.mark.parametrize('post_shuffle', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): num_samples = 100 dataset = TorchNormalDataset(num_of_data=num_samples) # 测试使用 前后shuffle不一致的load操作 - lst = [0]+np.random.randint(1, num_samples, size=3).tolist() - for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], - lst): - with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): - sampler = RandomSampler(dataset, shuffle=pre_shuffle) - sampler.set_epoch(0) - already_numbers = set() - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - self.assertEqual(len(already_numbers), num_consumed_samples) - - states = sampler.state_dict() - - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - for i in new_sampler: - self.assertNotIn(i, already_numbers) - - # 测试切换成多卡也没有问题 - other_rank_number = set() - for rank in range(3): - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) - new_sampler.set_epoch(0) - count = 0 - for i in new_sampler: - self.assertNotIn(i, other_rank_number) - other_rank_number.add(i) - self.assertNotIn(i, already_numbers) - count += 1 - - def test_state_dict_2(self): + sampler = RandomSampler(dataset, shuffle=pre_shuffle) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples + + states = sampler.state_dict() + + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + new_sampler.set_epoch(0) + count = 0 + seen = 0 + seen_in_other_rank = 0 + for i in new_sampler: + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('pre_shuffle', [True, False]) + @pytest.mark.parametrize('post_shuffle', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 num_samples = 100 dataset = TorchNormalDataset(num_of_data=num_samples) # 测试使用 前后shuffle不一致的load操作 - lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist() # lst = [30] - for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], - lst): - with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): - already_numbers = set() - sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) - sampler.set_distributed(num_replicas=2, rank=0) - sampler.set_epoch(0) - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) - sampler.set_epoch(0) - sampler.set_distributed(num_replicas=2, rank=1) - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - self.assertEqual(len(already_numbers), num_consumed_samples*2) - - states = sampler.state_dict() - - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - for i in new_sampler: - self.assertNotIn(i, already_numbers) - - # 测试切换成多卡也没有问题 - other_rank_number = set() - for rank in range(3): - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) - count = 0 - for i in new_sampler: - self.assertNotIn(i, other_rank_number) - other_rank_number.add(i) - self.assertNotIn(i, already_numbers) - count += 1 - - -class TestRandomSampler(unittest.TestCase): + already_numbers = set() + sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + for i in new_sampler: + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + + +class TestRandomSampler: # 测试单卡; def test_seed_work_when_shuffle_is_true(self): data_length = 100 @@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase): ... +class DatasetWithVaryLength: + def __init__(self, num_of_data=100, reverse=False): + self.data = np.arange(num_of_data) + if reverse: + self.data = self.data[::-1] + + def __getitem__(self, item): + return self.data[item] + + def __len__(self): + return len(self.data) + + +class TestSortedSampler: + def test_single(self): + num_of_data = 100 + data = DatasetWithVaryLength(num_of_data) + sampler = SortedSampler(data, length=data.data) + indexes = list(sampler) + assert indexes==list(range(num_of_data-1, -1, -1)) + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) + def test_multi(self, pad, num_replica, num_of_data): + data = DatasetWithVaryLength(num_of_data=num_of_data) + samplers = [] + for i in range(num_replica): + sampler = SortedSampler(dataset=data, length=data.data) + sampler.set_distributed(num_replica, rank=i, pad=pad) + samplers.append(sampler) + + # 保证顺序是没乱的 + already_seen_index = set() + for sampler in samplers: + larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。 + prev_index = float('inf') + cur_set = set() + seen_in_other_rank = 0 + for index in sampler: + seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 + cur_set.add(index) + larger_count += int(index <= prev_index) + prev_index = index + assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 + assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0 + already_seen_index.update(cur_set) + + indexes = list(chain(*samplers)) + indexes = set(indexes) + if pad: + assert indexes == set(range(num_of_data)) + else: + assert len(indexes) <= num_of_data + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, num_consumed_samples): + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j= max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + assert smaller<=1 if pad else smaller==0 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, num_consumed_samples): + # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + # lst = [30] + already_numbers = set() + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j<=max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = SortedSampler(dataset, length=dataset.data) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i < max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SortedSampler(dataset, length=dataset.data) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i>=max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + assert smaller <= 1 if pad else smaller == 0 + + +class TestSequentialSampler: + def test_single(self): + num_of_data = 100 + data = DatasetWithVaryLength(num_of_data) + sampler = SequentialSampler(data) + indexes = list(sampler) + assert indexes==list(range(num_of_data)) + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) + def test_multi(self, pad, num_replica, num_of_data): + data = DatasetWithVaryLength(num_of_data=num_of_data) + samplers = [] + for i in range(num_replica): + sampler = SequentialSampler(dataset=data) + sampler.set_distributed(num_replica, rank=i, pad=pad) + samplers.append(sampler) + + # 保证顺序是没乱的 + already_seen_index = set() + for idx, sampler in enumerate(samplers): + larger_count = 1 + prev_index = float('inf') + cur_set = set() + seen_in_other_rank = 0 + for index in sampler: + seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 + cur_set.add(index) + larger_count += int(index >= prev_index) + prev_index = index + assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 + assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0 + already_seen_index.update(cur_set) + + indexes = list(chain(*samplers)) + indexes = set(indexes) + if pad: + assert indexes == set(range(num_of_data)) + else: + assert len(indexes) <= num_of_data + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, num_consumed_samples): + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + sampler = SequentialSampler(dataset=dataset) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j>max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples + + states = sampler.state_dict() + + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i > max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + new_sampler.set_epoch(0) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i <= max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=rank # 因为pad可能重复 + assert smaller<=1 if pad else smaller==0 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, num_consumed_samples): + # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + # lst = [30] + already_numbers = set() + sampler = SequentialSampler(dataset=dataset) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j>max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = SequentialSampler(dataset=dataset) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i > max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i=prev_index + prev_index = index + + indexes = list(chain(*samplers)) + assert len(indexes) == num_of_data + indexes = set(indexes) + assert indexes == set(range(num_of_data))