| @@ -364,16 +364,16 @@ class _MetricsWrapper: | |||
| else: | |||
| args.append(batch) | |||
| if not isinstance(outputs, dict): | |||
| raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" | |||
| raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||
| f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||
| if isinstance(metric, Metric): | |||
| auto_param_call(metric.update, batch, *args) | |||
| auto_param_call(metric.update, outputs, *args) | |||
| elif _is_torchmetrics_metric(metric): | |||
| auto_param_call(metric.update, batch, *args) | |||
| auto_param_call(metric.update, outputs, *args) | |||
| elif _is_allennlp_metric(metric): | |||
| auto_param_call(metric.__call__, batch, *args) | |||
| auto_param_call(metric.__call__, outputs, *args) | |||
| elif _is_paddle_metric(metric): | |||
| res = auto_param_call(metric.compute, batch, *args) | |||
| res = auto_param_call(metric.compute, outputs, *args) | |||
| metric.update(res) | |||
| def reset(self): | |||
| @@ -106,7 +106,7 @@ class Trainer(TrainerEventTrigger): | |||
| 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | |||
| :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
| 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||
| 为 True,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||
| 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; | |||
| :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | |||
| :param fp16: 是否开启混合精度训练;默认为 False; | |||
| :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||
| @@ -326,6 +326,8 @@ class Trainer(TrainerEventTrigger): | |||
| try: | |||
| while self.cur_epoch_idx < self.n_epochs: | |||
| # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
| self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
| self.driver.set_model_mode("train") | |||
| self.on_train_epoch_begin() | |||
| self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | |||
| @@ -603,7 +605,9 @@ class Trainer(TrainerEventTrigger): | |||
| # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | |||
| # 2. trainer_state; | |||
| states = {"callback_states": self.on_save_checkpoint(), | |||
| "trainer_state": self.trainer_state.state_dict()} | |||
| "trainer_state": self.trainer_state.state_dict(), | |||
| 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | |||
| } | |||
| # 3. validate filter state; | |||
| if self.evaluator is not None: | |||
| @@ -680,9 +684,13 @@ class Trainer(TrainerEventTrigger): | |||
| # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||
| # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||
| self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
| self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||
| self.batch_idx_in_epoch | |||
| # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
| self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
| # 5. 恢复所有 callback 的状态; | |||
| self.train_stepeckpoint(states["callback_states"]) | |||
| self.on_load_checkpoint(states["callback_states"]) | |||
| self.driver.barrier() | |||
| @@ -60,7 +60,7 @@ class TrainerState: | |||
| cur_epoch_idx: 当前正在运行第几个 epoch; | |||
| global_forward_batches: 当前模型总共 forward 了多少个 step; | |||
| batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | |||
| total_batches: 每一个 epoch 会 forward 多少个 step; | |||
| num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||
| total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | |||
| """ | |||
| n_epochs: Optional[int] = None # 无论如何重新算 | |||
| @@ -202,9 +202,20 @@ class TorchDriver(Driver): | |||
| sampler = dataloader_args.sampler | |||
| else: | |||
| raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
| num_consumed_batches = states.pop('num_consumed_batches') | |||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
| states['sampler_states'] = sampler.state_dict() | |||
| sampler_states = sampler.state_dict() | |||
| # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
| # 会造成多余实际消耗的问题。 | |||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
| if num_consumed_samples_array is not None: | |||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
| try: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| except: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
| else: | |||
| raise RuntimeError( | |||
| 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
| @@ -4,16 +4,18 @@ __all__ = [ | |||
| ] | |||
| import math | |||
| from array import array | |||
| from copy import deepcopy | |||
| from typing import Dict, Union, List | |||
| from itertools import chain | |||
| import os | |||
| import numpy as np | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.log import logger | |||
| from .utils import create_array, NumConsumedSamplesArray | |||
| from abc import abstractmethod | |||
| from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
| class ReproducibleBatchSampler: | |||
| @@ -34,6 +36,13 @@ class ReproducibleBatchSampler: | |||
| @abstractmethod | |||
| def state_dict(self): | |||
| """ | |||
| 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 | |||
| 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||
| 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||
| :return: | |||
| """ | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | |||
| @abstractmethod | |||
| @@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| self.batch_size = batch_size | |||
| self.drop_last = drop_last | |||
| self.data_idx = kwargs.get("data_idx", 0) | |||
| self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) | |||
| self.index_list = kwargs.get("index_list", self._iterate_sampler()) | |||
| self.need_reinitialize = kwargs.get("need_reinitialize", False) | |||
| @@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | |||
| else: | |||
| _index_lst.append(idx) | |||
| # 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||
| if len(_index_lst) > 4294967295: | |||
| # 注意 self.index_list 内存放的是全部数据的 index; | |||
| # unsigned long | |||
| _index_lst = array("L", _index_lst) | |||
| else: | |||
| # unsigned int | |||
| _index_lst = array("I", _index_lst) | |||
| _index_lst = create_array(len(_index_lst), _index_lst) | |||
| return _index_lst | |||
| def __iter__(self): | |||
| if self.need_reinitialize: | |||
| self.index_list = self._iterate_sampler() | |||
| self.data_idx = 0 | |||
| self.num_consumed_samples = 0 | |||
| else: | |||
| self.need_reinitialize = True | |||
| batch = [] | |||
| if self.data_idx: | |||
| index_list = self.index_list[self.data_idx:] | |||
| if self.num_consumed_samples: | |||
| index_list = self.index_list[self.num_consumed_samples:] | |||
| else: | |||
| index_list = self.index_list | |||
| # 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||
| # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| for idx in index_list: | |||
| batch.append(idx) | |||
| self.data_idx += 1 | |||
| if len(batch) == self.batch_size: | |||
| self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield batch | |||
| batch = [] | |||
| if len(batch) > 0 and not self.drop_last: | |||
| self.num_consumed_samples += len(batch) | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield batch | |||
| # 需要重置防止边界条件问题 | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| def __len__(self) -> int: | |||
| if self.drop_last: | |||
| @@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| return (len(self.index_list) + self.batch_size - 1) // self.batch_size | |||
| def state_dict(self) -> Dict: | |||
| return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||
| states = { | |||
| "index_list": deepcopy(self.index_list), | |||
| "num_consumed_samples": self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__ | |||
| } | |||
| states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
| @@ -128,7 +147,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | |||
| "record and current dataset." | |||
| self.index_list = _index_list | |||
| self.data_idx = states["data_idx"] | |||
| self.num_consumed_samples = states["num_consumed_samples"] | |||
| self.need_reinitialize = False | |||
| def set_distributed(self, num_replicas, rank, pad=True): | |||
| @@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| @property | |||
| def batch_idx_in_epoch(self): | |||
| if self.drop_last: | |||
| return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size | |||
| return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size | |||
| else: | |||
| return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | |||
| (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||
| (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | |||
| class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| @@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
| self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||
| self.batch_size = batch_size | |||
| self.num_batch_per_bucket = num_batch_per_bucket | |||
| self.shuffle = shuffle | |||
| @@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| self.rank = rank | |||
| self.pad = pad | |||
| num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||
| else len(self.dataset) | |||
| if self.drop_last: | |||
| assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||
| "than the number of replicates multiplied " \ | |||
| "with batch_size when drop_last=True." | |||
| # num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||
| # else len(self.dataset) | |||
| # | |||
| # if self.drop_last: | |||
| # assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||
| # "than the number of replicates multiplied " \ | |||
| # "with batch_size when drop_last=True." | |||
| return self | |||
| @@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||
| self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
| def __len__(self): | |||
| def __len__(self)->int: | |||
| """ | |||
| 返回当前 sampler 还会返回多少个 batch 的数据 | |||
| @@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||
| batches = batches[:-1] | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| for batch in batches: | |||
| self.num_consumed_samples += self.num_replicas * len(batch) | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield list(map(int, batch)) | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| self.old_batch_size = self.batch_size | |||
| self.old_num_batch_per_bucket = self.num_batch_per_bucket | |||
| self.old_num_replicas = self.num_replicas | |||
| @@ -376,10 +398,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
| 'num_replicas': self.num_replicas | |||
| } | |||
| states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
| # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
| "during an unfinished iteration." | |||
| @@ -1,9 +1,14 @@ | |||
| from typing import Dict, List, Union | |||
| import math | |||
| import os | |||
| import numpy as np | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
| from .utils import NumConsumedSamplesArray | |||
| __all__ = [ | |||
| 'ReproducibleSampler', | |||
| @@ -30,6 +35,13 @@ class ReproducibleSampler: | |||
| raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | |||
| def state_dict(self): | |||
| """ | |||
| 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 | |||
| 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||
| 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||
| :return: | |||
| """ | |||
| raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | |||
| def load_state_dict(self, states): | |||
| @@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler): | |||
| indices = indices[self.num_consumed_samples:] | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == self.num_left_samples | |||
| for index in indices: | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| for idx, index in enumerate(indices, start=1): | |||
| self.num_consumed_samples += self.num_replicas | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield index | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| def generate_indices(self) -> List[int]: | |||
| """ | |||
| @@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler): | |||
| return indices | |||
| def state_dict(self) -> Dict: | |||
| states = { | |||
| 'seed': self.seed, | |||
| 'epoch': self.epoch, | |||
| 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
| 'sampler_type': self.__class__.__name__, | |||
| 'length': len(self.dataset), | |||
| 'shuffle': self.shuffle | |||
| } | |||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
| 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
| # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
| "during an unfinished iteration." | |||
| @@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler): | |||
| self.seed = states['seed'] | |||
| self.epoch = states['epoch'] | |||
| self.num_consumed_samples = states['num_consumed_samples'] | |||
| if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
| if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
| self.num_consumed_samples = 0 | |||
| if self.shuffle != states['shuffle']: | |||
| logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||
| @@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler): | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == self.num_left_samples | |||
| for index in indices: | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| for idx, index in enumerate(indices, start=1): | |||
| self.num_consumed_samples += self.num_replicas | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield index | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| def generate_indices(self) -> List[int]: | |||
| """ | |||
| @@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler): | |||
| return list(range(len(self.dataset))) | |||
| def state_dict(self) -> Dict: | |||
| states = { | |||
| 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
| 'sampler_type': self.__class__.__name__, | |||
| 'length': len(self.dataset), | |||
| } | |||
| states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||
| 'length': len(self.dataset), | |||
| 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
| # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
| "during an unfinished iteration." | |||
| @@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler): | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == self.num_left_samples | |||
| for index in indices: | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| for idx, index in enumerate(indices, start=1): | |||
| self.num_consumed_samples += self.num_replicas | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield index | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| @@ -2,6 +2,9 @@ __all__ = [ | |||
| 're_instantiate_sampler', | |||
| 'conversion_between_reproducible_and_unrepeated_sampler' | |||
| ] | |||
| from array import array | |||
| from typing import Sequence | |||
| from collections import deque | |||
| from fastNLP.core.samplers.unrepeated_sampler import * | |||
| from fastNLP.core.samplers.reproducible_sampler import * | |||
| @@ -39,4 +42,56 @@ def re_instantiate_sampler(sampler, new_sampler_class=None): | |||
| all_attributes = vars(sampler) | |||
| if new_sampler_class is not None: | |||
| return new_sampler_class(**all_attributes) | |||
| return type(sampler)(**all_attributes) | |||
| return type(sampler)(**all_attributes) | |||
| def create_array(length, fill_value) -> array: | |||
| """ | |||
| 根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I' | |||
| :param length: | |||
| :param fill_value: | |||
| :return: | |||
| """ | |||
| if not isinstance(fill_value, Sequence): | |||
| fill_value = [fill_value]*length | |||
| if length > 4294967295: | |||
| _index_lst = array("L", fill_value) | |||
| else: | |||
| _index_lst = array("I", fill_value) | |||
| return _index_lst | |||
| class NumConsumedSamplesArray: | |||
| def __init__(self, buffer_size=2000, num_consumed_samples=0): | |||
| """ | |||
| 保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | |||
| ex: | |||
| array = NumConsumedSamplesArray(buffer_size=3) | |||
| for i in range(10): | |||
| array.push(i) | |||
| array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | |||
| array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | |||
| :param buffer_size: 报错多少个历史。 | |||
| :param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | |||
| """ | |||
| self.count = 0 | |||
| self.deque = deque(maxlen=buffer_size) | |||
| if num_consumed_samples is not None: | |||
| self.push(num_consumed_samples) | |||
| self.buffer_size = buffer_size | |||
| def __getitem__(self, item): | |||
| if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0 | |||
| return 0 | |||
| assert isinstance(item, int), "Only int index allowed." | |||
| assert self.count-len(self.deque)<=item<self.count, f"Only keep {len(self.deque)} history index." | |||
| index = len(self.deque) - (self.count - item) | |||
| return self.deque[index] | |||
| def push(self, num_consumed_samples): | |||
| self.deque.append(num_consumed_samples) | |||
| self.count += 1 | |||
| @@ -45,6 +45,8 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||
| # todo 注释 | |||
| FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||
| # fastNLP 中初始化deque的默认大小 | |||
| FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||
| # todo 注释 直接使用的变量 | |||
| FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | |||
| @@ -3,6 +3,7 @@ from array import array | |||
| import numpy as np | |||
| import pytest | |||
| from itertools import chain | |||
| from copy import deepcopy | |||
| from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
| @@ -30,7 +31,7 @@ class TestReproducibleBatchSampler: | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||
| assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||
| "sampler_type": "RandomBatchSampler"} | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| @@ -413,26 +414,102 @@ class TestBucketedBatchSampler: | |||
| @pytest.mark.parametrize('drop_last', [True, False]) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||
| # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||
| def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||
| # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| batch_size = 6 | |||
| if num_replica*batch_size > num_samples: | |||
| if num_replicas*batch_size > num_samples: | |||
| return | |||
| num_batch_per_bucket = 10 | |||
| samplers = [] | |||
| lengths = [] | |||
| for i in range(num_replica): | |||
| for i in range(num_replicas): | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||
| sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
| sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||
| sampler.set_epoch(0) | |||
| samplers.append(sampler) | |||
| lengths.append(len(list(iter(sampler)))) | |||
| assert len(set(lengths))==1 | |||
| bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||
| bucket_diff = batch_size * num_batch_per_bucket * num_replicas | |||
| for bs in zip(*samplers): | |||
| diff = max(chain(*bs)) - min(chain(*bs)) | |||
| assert diff <= bucket_diff | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| @pytest.mark.parametrize('drop_last', [True, False]) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
| @pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||
| def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||
| """ | |||
| 测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 | |||
| 偏多 | |||
| :return: | |||
| """ | |||
| batch_size = 6 | |||
| num_batch_per_bucket = 10 | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| samplers = [] | |||
| for i in range(num_replicas): | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||
| sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||
| samplers.append(sampler) | |||
| count = 0 | |||
| already_seen_sets = [set()] | |||
| already_seen_set = set() | |||
| for batchs in zip(*samplers): | |||
| batch = chain(*batchs) | |||
| already_seen_set.update(batch) | |||
| already_seen_sets.append(deepcopy(already_seen_set)) | |||
| count += 1 | |||
| if count > 3: | |||
| break | |||
| states = samplers[0].state_dict() | |||
| for i in range(len(already_seen_sets)): | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
| drop_last=drop_last) | |||
| sampler.set_epoch(0) | |||
| already_seen_set = deepcopy(already_seen_sets[i]) | |||
| for batch in sampler: | |||
| already_seen_set.update(batch) | |||
| assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( | |||
| dataset) | |||
| # 测试保存之后再次保存 | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
| drop_last=drop_last) | |||
| sampler.set_epoch(0) | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||
| if len(already_seen_sets)<3: | |||
| return | |||
| already_seen_set = already_seen_sets[2] | |||
| count = 0 | |||
| for batch in sampler: | |||
| already_seen_set.update(batch) | |||
| count += 1 | |||
| if count > 6: | |||
| break | |||
| states = sampler.state_dict() | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
| drop_last=drop_last) | |||
| sampler.load_state_dict(states) | |||
| sampler.set_epoch(0) | |||
| for batch in sampler: | |||
| already_seen_set.update(batch) | |||
| assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) | |||
| @@ -3,6 +3,7 @@ import pytest | |||
| from functools import partial | |||
| from itertools import chain | |||
| from copy import deepcopy | |||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| @@ -180,6 +181,63 @@ class TestRandomSamplerYh: | |||
| assert seen <= 1 if pad else seen == 0 | |||
| assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
| @pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||
| def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | |||
| # 测试在 sampler 多生成的时候,可以仍然可以恢复 | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| samplers = [] | |||
| for i in range(num_replicas): | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| sampler.set_epoch(0) | |||
| sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||
| samplers.append(sampler) | |||
| count = 0 | |||
| already_seen_sets = [set()] | |||
| already_seen_set = set() | |||
| for idxes in zip(*samplers): | |||
| already_seen_set.update(idxes) | |||
| already_seen_sets.append(deepcopy(already_seen_set)) | |||
| count += 1 | |||
| if count > 3: | |||
| break | |||
| states = samplers[0].state_dict() | |||
| for i in range(len(already_seen_sets)): | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| already_seen_set = deepcopy(already_seen_sets[i]) | |||
| for batch in sampler: | |||
| already_seen_set.add(batch) | |||
| assert len(already_seen_set) == len(dataset) | |||
| # 测试保存之后再次保存 | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| sampler.set_epoch(0) | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||
| if len(already_seen_sets)<3: | |||
| return | |||
| already_seen_set = already_seen_sets[2] | |||
| count = 0 | |||
| for idx in sampler: | |||
| already_seen_set.add(idx) | |||
| count += 1 | |||
| if count > 6: | |||
| break | |||
| states = sampler.state_dict() | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| sampler.load_state_dict(states) | |||
| sampler.set_epoch(0) | |||
| for idx in sampler: | |||
| already_seen_set.add(idx) | |||
| assert len(already_seen_set)==len(dataset) | |||
| class TestRandomSampler: | |||
| # 测试单卡; | |||
| @@ -386,7 +444,7 @@ class TestSortedSampler: | |||
| assert indexes==list(range(num_of_data-1, -1, -1)) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, pad, num_replica, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| @@ -540,7 +598,7 @@ class TestSequentialSampler: | |||
| assert indexes==list(range(num_of_data)) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, pad, num_replica, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| @@ -25,7 +25,7 @@ class TestUnrepeatedSampler: | |||
| indexes = set(sampler) | |||
| assert indexes==set(range(num_of_data)) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| @pytest.mark.parametrize('shuffle', [False, True]) | |||
| def test_multi(self, num_replica, num_of_data, shuffle): | |||
| @@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler: | |||
| indexes = list(sampler) | |||
| assert indexes==list(range(num_of_data-1, -1, -1)) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, num_replica, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| @@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler: | |||
| indexes = list(sampler) | |||
| assert indexes==list(range(num_of_data)) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, num_replica, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||