From 9d71170bef82d01344684c4f3d40bf16b5be9e82 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 13 Apr 2022 17:04:33 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3Trainer=E5=9C=A8=E6=96=AD?= =?UTF-8?q?=E7=82=B9=E9=87=8D=E8=AE=AD=E7=9A=84=E6=97=B6=E5=80=99=E6=97=A0?= =?UTF-8?q?=E6=B3=95=E5=AE=9E=E7=8E=B0=E5=87=86=E7=A1=AEload=E5=92=8C?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/evaluator.py | 10 +- fastNLP/core/controllers/trainer.py | 16 +++- fastNLP/core/controllers/utils/state.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 15 ++- .../samplers/reproducible_batch_sampler.py | 80 ++++++++++------ fastNLP/core/samplers/reproducible_sampler.py | 56 +++++++---- fastNLP/core/samplers/utils.py | 57 +++++++++++- fastNLP/envs/env.py | 2 + .../test_reproducible_batch_sampler.py | 93 +++++++++++++++++-- .../samplers/test_reproducible_sampler.py | 62 ++++++++++++- .../core/samplers/test_unrepeated_sampler.py | 6 +- 11 files changed, 325 insertions(+), 74 deletions(-) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 479686e1..2e3678d3 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -364,16 +364,16 @@ class _MetricsWrapper: else: args.append(batch) if not isinstance(outputs, dict): - raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" + raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" f" return a dict from your model or use `output_mapping` to convert it into dict type.") if isinstance(metric, Metric): - auto_param_call(metric.update, batch, *args) + auto_param_call(metric.update, outputs, *args) elif _is_torchmetrics_metric(metric): - auto_param_call(metric.update, batch, *args) + auto_param_call(metric.update, outputs, *args) elif _is_allennlp_metric(metric): - auto_param_call(metric.__call__, batch, *args) + auto_param_call(metric.__call__, outputs, *args) elif _is_paddle_metric(metric): - res = auto_param_call(metric.compute, batch, *args) + res = auto_param_call(metric.compute, outputs, *args) metric.update(res) def reset(self): diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 5daee856..6931ed3c 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -105,8 +105,8 @@ class Trainer(TrainerEventTrigger): 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; - 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 - 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; + 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 + 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 @@ -325,6 +325,8 @@ class Trainer(TrainerEventTrigger): try: while self.cur_epoch_idx < self.n_epochs: + # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save + self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch self.driver.set_model_mode("train") self.on_train_epoch_begin() self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) @@ -598,7 +600,9 @@ class Trainer(TrainerEventTrigger): # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; # 2. trainer_state; states = {"callback_states": self.on_save_checkpoint(), - "trainer_state": self.trainer_state.state_dict()} + "trainer_state": self.trainer_state.state_dict(), + 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) + } # 3. validate filter state; if self.evaluator is not None: @@ -675,9 +679,13 @@ class Trainer(TrainerEventTrigger): # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') + self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ + self.batch_idx_in_epoch + # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save + self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch # 5. 恢复所有 callback 的状态; - self.train_stepeckpoint(states["callback_states"]) + self.on_load_checkpoint(states["callback_states"]) self.driver.barrier() diff --git a/fastNLP/core/controllers/utils/state.py b/fastNLP/core/controllers/utils/state.py index fed9292c..2327c1e5 100644 --- a/fastNLP/core/controllers/utils/state.py +++ b/fastNLP/core/controllers/utils/state.py @@ -60,7 +60,7 @@ class TrainerState: cur_epoch_idx: 当前正在运行第几个 epoch; global_forward_batches: 当前模型总共 forward 了多少个 step; batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; - total_batches: 每一个 epoch 会 forward 多少个 step; + num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; """ n_epochs: Optional[int] = None # 无论如何重新算 diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index d2ffbac1..c79ecd0b 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -194,9 +194,20 @@ class TorchDriver(Driver): sampler = dataloader_args.sampler else: raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") - + num_consumed_batches = states.pop('num_consumed_batches') if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): - states['sampler_states'] = sampler.state_dict() + sampler_states = sampler.state_dict() + # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples + # 会造成多余实际消耗的问题。 + num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) + if num_consumed_samples_array is not None: + if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 + try: + num_consumed_batches = num_consumed_batches * dataloader_args.batch_size + except: # 有可能 batch_size 为 None,就只有损失精度了 + num_consumed_batches = sampler_states['num_consumed_samples'] + sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] + assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." else: raise RuntimeError( 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index d1041f08..d4535bae 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -4,16 +4,18 @@ __all__ = [ ] import math -from array import array from copy import deepcopy from typing import Dict, Union, List from itertools import chain +import os import numpy as np from fastNLP.core.dataset import DataSet from fastNLP.core.log import logger +from .utils import create_array, NumConsumedSamplesArray from abc import abstractmethod +from fastNLP.envs.env import FASTNLP_DEQUE_SIZE class ReproducibleBatchSampler: @@ -34,6 +36,13 @@ class ReproducibleBatchSampler: @abstractmethod def state_dict(self): + """ + 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 + 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward + 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 + + :return: + """ raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") @abstractmethod @@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.batch_size = batch_size self.drop_last = drop_last - self.data_idx = kwargs.get("data_idx", 0) + self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) self.index_list = kwargs.get("index_list", self._iterate_sampler()) self.need_reinitialize = kwargs.get("need_reinitialize", False) @@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler): # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; else: _index_lst.append(idx) - # 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; - if len(_index_lst) > 4294967295: - # 注意 self.index_list 内存放的是全部数据的 index; - # unsigned long - _index_lst = array("L", _index_lst) - else: - # unsigned int - _index_lst = array("I", _index_lst) + _index_lst = create_array(len(_index_lst), _index_lst) return _index_lst def __iter__(self): if self.need_reinitialize: self.index_list = self._iterate_sampler() - self.data_idx = 0 + self.num_consumed_samples = 0 else: self.need_reinitialize = True batch = [] - if self.data_idx: - index_list = self.index_list[self.data_idx:] + if self.num_consumed_samples: + index_list = self.index_list[self.num_consumed_samples:] else: index_list = self.index_list + + # 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 + # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), + num_consumed_samples=self.num_consumed_samples) for idx in index_list: batch.append(idx) - self.data_idx += 1 if len(batch) == self.batch_size: + self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] + self.num_consumed_samples_array.push(self.num_consumed_samples) yield batch batch = [] if len(batch) > 0 and not self.drop_last: + self.num_consumed_samples += len(batch) + self.num_consumed_samples_array.push(self.num_consumed_samples) yield batch + # 需要重置防止边界条件问题 + self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') def __len__(self) -> int: if self.drop_last: @@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler): return (len(self.index_list) + self.batch_size - 1) // self.batch_size def state_dict(self) -> Dict: - return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} + states = { + "index_list": deepcopy(self.index_list), + "num_consumed_samples": self.num_consumed_samples, + 'sampler_type': self.__class__.__name__ + } + states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) + return states def load_state_dict(self, states: Dict): assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ @@ -128,7 +147,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ "record and current dataset." self.index_list = _index_list - self.data_idx = states["data_idx"] + self.num_consumed_samples = states["num_consumed_samples"] self.need_reinitialize = False def set_distributed(self, num_replicas, rank, pad=True): @@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler): @property def batch_idx_in_epoch(self): if self.drop_last: - return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size + return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size else: return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ - (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size + (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size class BucketedBatchSampler(ReproducibleBatchSampler): @@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 - self.batch_size = batch_size self.num_batch_per_bucket = num_batch_per_bucket self.shuffle = shuffle @@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler): self.rank = rank self.pad = pad - num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ - else len(self.dataset) - - if self.drop_last: - assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ - "than the number of replicates multiplied " \ - "with batch_size when drop_last=True." + # num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ + # else len(self.dataset) + # + # if self.drop_last: + # assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ + # "than the number of replicates multiplied " \ + # "with batch_size when drop_last=True." return self @@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) - def __len__(self): + def __len__(self)->int: """ 返回当前 sampler 还会返回多少个 batch 的数据 @@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: batches = batches[:-1] + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), + num_consumed_samples=self.num_consumed_samples) for batch in batches: self.num_consumed_samples += self.num_replicas * len(batch) + self.num_consumed_samples_array.push(self.num_consumed_samples) yield list(map(int, batch)) self.during_iter = False self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') self.old_batch_size = self.batch_size self.old_num_batch_per_bucket = self.num_batch_per_bucket self.old_num_replicas = self.num_replicas @@ -376,10 +398,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): 'num_batch_per_bucket': self.num_batch_per_bucket, 'num_replicas': self.num_replicas } + + states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) return states def load_state_dict(self, states: Dict): - # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ "during an unfinished iteration." diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index f48e2fc6..396e69b2 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -1,9 +1,14 @@ from typing import Dict, List, Union import math +import os + import numpy as np from fastNLP.core.log import logger from fastNLP.core.dataset import DataSet +from fastNLP.envs.env import FASTNLP_DEQUE_SIZE +from .utils import NumConsumedSamplesArray + __all__ = [ 'ReproducibleSampler', @@ -30,6 +35,13 @@ class ReproducibleSampler: raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") def state_dict(self): + """ + 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 + 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward + 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 + + :return: + """ raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") def load_state_dict(self, states): @@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler): indices = indices[self.num_consumed_samples:] indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - - for index in indices: + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), + num_consumed_samples=self.num_consumed_samples) + for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas + self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') def generate_indices(self) -> List[int]: """ @@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler): return indices def state_dict(self) -> Dict: - states = { - 'seed': self.seed, - 'epoch': self.epoch, - 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; - 'sampler_type': self.__class__.__name__, - 'length': len(self.dataset), - 'shuffle': self.shuffle - } + states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, + 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, + 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} return states def load_state_dict(self, states: Dict): - # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ "during an unfinished iteration." @@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler): self.seed = states['seed'] self.epoch = states['epoch'] self.num_consumed_samples = states['num_consumed_samples'] - if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 + if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 self.num_consumed_samples = 0 if self.shuffle != states['shuffle']: logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " @@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler): indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - for index in indices: + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), + num_consumed_samples=self.num_consumed_samples) + for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas + self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') def generate_indices(self) -> List[int]: """ @@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler): return list(range(len(self.dataset))) def state_dict(self) -> Dict: - states = { - 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; - 'sampler_type': self.__class__.__name__, - 'length': len(self.dataset), - } + states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, + 'length': len(self.dataset), + 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} return states def load_state_dict(self, states: Dict): - # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ "during an unfinished iteration." @@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler): indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - for index in indices: + self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), + num_consumed_samples=self.num_consumed_samples) + for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas + self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 + delattr(self, 'num_consumed_samples_array') diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py index dd90fe7c..80af1787 100644 --- a/fastNLP/core/samplers/utils.py +++ b/fastNLP/core/samplers/utils.py @@ -2,6 +2,9 @@ __all__ = [ 're_instantiate_sampler', 'conversion_between_reproducible_and_unrepeated_sampler' ] +from array import array +from typing import Sequence +from collections import deque from fastNLP.core.samplers.unrepeated_sampler import * from fastNLP.core.samplers.reproducible_sampler import * @@ -39,4 +42,56 @@ def re_instantiate_sampler(sampler, new_sampler_class=None): all_attributes = vars(sampler) if new_sampler_class is not None: return new_sampler_class(**all_attributes) - return type(sampler)(**all_attributes) \ No newline at end of file + return type(sampler)(**all_attributes) + + +def create_array(length, fill_value) -> array: + """ + 根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I' + + :param length: + :param fill_value: + :return: + """ + if not isinstance(fill_value, Sequence): + fill_value = [fill_value]*length + + if length > 4294967295: + _index_lst = array("L", fill_value) + else: + _index_lst = array("I", fill_value) + return _index_lst + + +class NumConsumedSamplesArray: + def __init__(self, buffer_size=2000, num_consumed_samples=0): + """ + 保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 + ex: + array = NumConsumedSamplesArray(buffer_size=3) + for i in range(10): + array.push(i) + + array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 + array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] + + :param buffer_size: 报错多少个历史。 + :param num_consumed_samples: 第一个 num_consumed_samples 是多少。 + """ + self.count = 0 + self.deque = deque(maxlen=buffer_size) + if num_consumed_samples is not None: + self.push(num_consumed_samples) + self.buffer_size = buffer_size + + def __getitem__(self, item): + if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0 + return 0 + assert isinstance(item, int), "Only int index allowed." + assert self.count-len(self.deque)<=item num_samples: + if num_replicas*batch_size > num_samples: return num_batch_per_bucket = 10 samplers = [] lengths = [] - for i in range(num_replica): + for i in range(num_replicas): sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) - sampler.set_distributed(num_replica, rank=i, pad=pad) + sampler.set_distributed(num_replicas, rank=i, pad=pad) sampler.set_epoch(0) samplers.append(sampler) lengths.append(len(list(iter(sampler)))) assert len(set(lengths))==1 - bucket_diff = batch_size * num_batch_per_bucket * num_replica + bucket_diff = batch_size * num_batch_per_bucket * num_replicas for bs in zip(*samplers): diff = max(chain(*bs)) - min(chain(*bs)) assert diff <= bucket_diff + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replicas', [1, 2, 3]) + def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): + """ + 测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 + 偏多 + + :return: + """ + batch_size = 6 + num_batch_per_bucket = 10 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + samplers = [] + for i in range(num_replicas): + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) + + sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) + samplers.append(sampler) + count = 0 + already_seen_sets = [set()] + already_seen_set = set() + for batchs in zip(*samplers): + batch = chain(*batchs) + already_seen_set.update(batch) + already_seen_sets.append(deepcopy(already_seen_set)) + count += 1 + if count > 3: + break + states = samplers[0].state_dict() + for i in range(len(already_seen_sets)): + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][i] + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, + drop_last=drop_last) + sampler.set_epoch(0) + already_seen_set = deepcopy(already_seen_sets[i]) + for batch in sampler: + already_seen_set.update(batch) + assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( + dataset) + + # 测试保存之后再次保存 + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, + drop_last=drop_last) + sampler.set_epoch(0) + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][2] + if len(already_seen_sets)<3: + return + already_seen_set = already_seen_sets[2] + count = 0 + for batch in sampler: + already_seen_set.update(batch) + count += 1 + if count > 6: + break + + states = sampler.state_dict() + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][count] + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, + drop_last=drop_last) + sampler.load_state_dict(states) + sampler.set_epoch(0) + for batch in sampler: + already_seen_set.update(batch) + + assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) diff --git a/tests/core/samplers/test_reproducible_sampler.py b/tests/core/samplers/test_reproducible_sampler.py index 981d6a03..ddf52bcb 100644 --- a/tests/core/samplers/test_reproducible_sampler.py +++ b/tests/core/samplers/test_reproducible_sampler.py @@ -3,6 +3,7 @@ import pytest from functools import partial from itertools import chain +from copy import deepcopy from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -180,6 +181,63 @@ class TestRandomSamplerYh: assert seen <= 1 if pad else seen == 0 assert seen_in_other_rank<=1 # 因为pad可能重复 + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replicas', [1, 2, 3]) + def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): + # 测试在 sampler 多生成的时候,可以仍然可以恢复 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + samplers = [] + for i in range(num_replicas): + sampler = RandomSampler(dataset, shuffle=shuffle) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) + samplers.append(sampler) + count = 0 + already_seen_sets = [set()] + already_seen_set = set() + for idxes in zip(*samplers): + already_seen_set.update(idxes) + already_seen_sets.append(deepcopy(already_seen_set)) + count += 1 + if count > 3: + break + states = samplers[0].state_dict() + for i in range(len(already_seen_sets)): + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][i] + sampler = RandomSampler(dataset, shuffle=shuffle) + already_seen_set = deepcopy(already_seen_sets[i]) + for batch in sampler: + already_seen_set.add(batch) + assert len(already_seen_set) == len(dataset) + # 测试保存之后再次保存 + sampler = RandomSampler(dataset, shuffle=shuffle) + sampler.set_epoch(0) + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][2] + if len(already_seen_sets)<3: + return + already_seen_set = already_seen_sets[2] + count = 0 + for idx in sampler: + already_seen_set.add(idx) + count += 1 + if count > 6: + break + + states = sampler.state_dict() + if states['num_consumed_samples_array'] is not None: + states['num_consumed_samples'] = states['num_consumed_samples_array'][count] + sampler = RandomSampler(dataset, shuffle=shuffle) + sampler.load_state_dict(states) + sampler.set_epoch(0) + for idx in sampler: + already_seen_set.add(idx) + + assert len(already_seen_set)==len(dataset) + class TestRandomSampler: # 测试单卡; @@ -386,7 +444,7 @@ class TestSortedSampler: assert indexes==list(range(num_of_data-1, -1, -1)) @pytest.mark.parametrize('pad', [True, False]) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, pad, num_replica, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) @@ -540,7 +598,7 @@ class TestSequentialSampler: assert indexes==list(range(num_of_data)) @pytest.mark.parametrize('pad', [True, False]) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, pad, num_replica, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) diff --git a/tests/core/samplers/test_unrepeated_sampler.py b/tests/core/samplers/test_unrepeated_sampler.py index 09601d2c..4a271f41 100644 --- a/tests/core/samplers/test_unrepeated_sampler.py +++ b/tests/core/samplers/test_unrepeated_sampler.py @@ -25,7 +25,7 @@ class TestUnrepeatedSampler: indexes = set(sampler) assert indexes==set(range(num_of_data)) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) @pytest.mark.parametrize('shuffle', [False, True]) def test_multi(self, num_replica, num_of_data, shuffle): @@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler: indexes = list(sampler) assert indexes==list(range(num_of_data-1, -1, -1)) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, num_replica, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) @@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler: indexes = list(sampler) assert indexes==list(range(num_of_data)) - @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) def test_multi(self, num_replica, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data)