@@ -364,16 +364,16 @@ class _MetricsWrapper: | |||
else: | |||
args.append(batch) | |||
if not isinstance(outputs, dict): | |||
raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" | |||
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||
if isinstance(metric, Metric): | |||
auto_param_call(metric.update, batch, *args) | |||
auto_param_call(metric.update, outputs, *args) | |||
elif _is_torchmetrics_metric(metric): | |||
auto_param_call(metric.update, batch, *args) | |||
auto_param_call(metric.update, outputs, *args) | |||
elif _is_allennlp_metric(metric): | |||
auto_param_call(metric.__call__, batch, *args) | |||
auto_param_call(metric.__call__, outputs, *args) | |||
elif _is_paddle_metric(metric): | |||
res = auto_param_call(metric.compute, batch, *args) | |||
res = auto_param_call(metric.compute, outputs, *args) | |||
metric.update(res) | |||
def reset(self): | |||
@@ -106,7 +106,7 @@ class Trainer(TrainerEventTrigger): | |||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||
为 True,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; | |||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | |||
:param fp16: 是否开启混合精度训练;默认为 False; | |||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | |||
@@ -326,6 +326,8 @@ class Trainer(TrainerEventTrigger): | |||
try: | |||
while self.cur_epoch_idx < self.n_epochs: | |||
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
self.driver.set_model_mode("train") | |||
self.on_train_epoch_begin() | |||
self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | |||
@@ -603,7 +605,9 @@ class Trainer(TrainerEventTrigger): | |||
# 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | |||
# 2. trainer_state; | |||
states = {"callback_states": self.on_save_checkpoint(), | |||
"trainer_state": self.trainer_state.state_dict()} | |||
"trainer_state": self.trainer_state.state_dict(), | |||
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | |||
} | |||
# 3. validate filter state; | |||
if self.evaluator is not None: | |||
@@ -680,9 +684,13 @@ class Trainer(TrainerEventTrigger): | |||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | |||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | |||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||
self.batch_idx_in_epoch | |||
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
# 5. 恢复所有 callback 的状态; | |||
self.train_stepeckpoint(states["callback_states"]) | |||
self.on_load_checkpoint(states["callback_states"]) | |||
self.driver.barrier() | |||
@@ -60,7 +60,7 @@ class TrainerState: | |||
cur_epoch_idx: 当前正在运行第几个 epoch; | |||
global_forward_batches: 当前模型总共 forward 了多少个 step; | |||
batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | |||
total_batches: 每一个 epoch 会 forward 多少个 step; | |||
num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||
total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | |||
""" | |||
n_epochs: Optional[int] = None # 无论如何重新算 | |||
@@ -202,9 +202,20 @@ class TorchDriver(Driver): | |||
sampler = dataloader_args.sampler | |||
else: | |||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
num_consumed_batches = states.pop('num_consumed_batches') | |||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
states['sampler_states'] = sampler.state_dict() | |||
sampler_states = sampler.state_dict() | |||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
# 会造成多余实际消耗的问题。 | |||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
if num_consumed_samples_array is not None: | |||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
try: | |||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
else: | |||
raise RuntimeError( | |||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
@@ -4,16 +4,18 @@ __all__ = [ | |||
] | |||
import math | |||
from array import array | |||
from copy import deepcopy | |||
from typing import Dict, Union, List | |||
from itertools import chain | |||
import os | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.log import logger | |||
from .utils import create_array, NumConsumedSamplesArray | |||
from abc import abstractmethod | |||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
class ReproducibleBatchSampler: | |||
@@ -34,6 +36,13 @@ class ReproducibleBatchSampler: | |||
@abstractmethod | |||
def state_dict(self): | |||
""" | |||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 | |||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||
:return: | |||
""" | |||
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | |||
@abstractmethod | |||
@@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
self.batch_size = batch_size | |||
self.drop_last = drop_last | |||
self.data_idx = kwargs.get("data_idx", 0) | |||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) | |||
self.index_list = kwargs.get("index_list", self._iterate_sampler()) | |||
self.need_reinitialize = kwargs.get("need_reinitialize", False) | |||
@@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | |||
else: | |||
_index_lst.append(idx) | |||
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||
if len(_index_lst) > 4294967295: | |||
# 注意 self.index_list 内存放的是全部数据的 index; | |||
# unsigned long | |||
_index_lst = array("L", _index_lst) | |||
else: | |||
# unsigned int | |||
_index_lst = array("I", _index_lst) | |||
_index_lst = create_array(len(_index_lst), _index_lst) | |||
return _index_lst | |||
def __iter__(self): | |||
if self.need_reinitialize: | |||
self.index_list = self._iterate_sampler() | |||
self.data_idx = 0 | |||
self.num_consumed_samples = 0 | |||
else: | |||
self.need_reinitialize = True | |||
batch = [] | |||
if self.data_idx: | |||
index_list = self.index_list[self.data_idx:] | |||
if self.num_consumed_samples: | |||
index_list = self.index_list[self.num_consumed_samples:] | |||
else: | |||
index_list = self.index_list | |||
# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for idx in index_list: | |||
batch.append(idx) | |||
self.data_idx += 1 | |||
if len(batch) == self.batch_size: | |||
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield batch | |||
batch = [] | |||
if len(batch) > 0 and not self.drop_last: | |||
self.num_consumed_samples += len(batch) | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield batch | |||
# 需要重置防止边界条件问题 | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
def __len__(self) -> int: | |||
if self.drop_last: | |||
@@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size | |||
def state_dict(self) -> Dict: | |||
return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||
states = { | |||
"index_list": deepcopy(self.index_list), | |||
"num_consumed_samples": self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__ | |||
} | |||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
@@ -128,7 +147,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | |||
"record and current dataset." | |||
self.index_list = _index_list | |||
self.data_idx = states["data_idx"] | |||
self.num_consumed_samples = states["num_consumed_samples"] | |||
self.need_reinitialize = False | |||
def set_distributed(self, num_replicas, rank, pad=True): | |||
@@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
@property | |||
def batch_idx_in_epoch(self): | |||
if self.drop_last: | |||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size | |||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size | |||
else: | |||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | |||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | |||
class BucketedBatchSampler(ReproducibleBatchSampler): | |||
@@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||
self.batch_size = batch_size | |||
self.num_batch_per_bucket = num_batch_per_bucket | |||
self.shuffle = shuffle | |||
@@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
self.rank = rank | |||
self.pad = pad | |||
num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||
else len(self.dataset) | |||
if self.drop_last: | |||
assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||
"than the number of replicates multiplied " \ | |||
"with batch_size when drop_last=True." | |||
# num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||
# else len(self.dataset) | |||
# | |||
# if self.drop_last: | |||
# assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||
# "than the number of replicates multiplied " \ | |||
# "with batch_size when drop_last=True." | |||
return self | |||
@@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
def __len__(self): | |||
def __len__(self)->int: | |||
""" | |||
返回当前 sampler 还会返回多少个 batch 的数据 | |||
@@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||
batches = batches[:-1] | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for batch in batches: | |||
self.num_consumed_samples += self.num_replicas * len(batch) | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield list(map(int, batch)) | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
self.old_batch_size = self.batch_size | |||
self.old_num_batch_per_bucket = self.num_batch_per_bucket | |||
self.old_num_replicas = self.num_replicas | |||
@@ -376,10 +398,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
'num_batch_per_bucket': self.num_batch_per_bucket, | |||
'num_replicas': self.num_replicas | |||
} | |||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
"during an unfinished iteration." | |||
@@ -1,9 +1,14 @@ | |||
from typing import Dict, List, Union | |||
import math | |||
import os | |||
import numpy as np | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
from .utils import NumConsumedSamplesArray | |||
__all__ = [ | |||
'ReproducibleSampler', | |||
@@ -30,6 +35,13 @@ class ReproducibleSampler: | |||
raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | |||
def state_dict(self): | |||
""" | |||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 | |||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||
:return: | |||
""" | |||
raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | |||
def load_state_dict(self, states): | |||
@@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler): | |||
indices = indices[self.num_consumed_samples:] | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
assert len(indices) == self.num_left_samples | |||
for index in indices: | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for idx, index in enumerate(indices, start=1): | |||
self.num_consumed_samples += self.num_replicas | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield index | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
def generate_indices(self) -> List[int]: | |||
""" | |||
@@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler): | |||
return indices | |||
def state_dict(self) -> Dict: | |||
states = { | |||
'seed': self.seed, | |||
'epoch': self.epoch, | |||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
'shuffle': self.shuffle | |||
} | |||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
"during an unfinished iteration." | |||
@@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler): | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
self.num_consumed_samples = states['num_consumed_samples'] | |||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
self.num_consumed_samples = 0 | |||
if self.shuffle != states['shuffle']: | |||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||
@@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler): | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
assert len(indices) == self.num_left_samples | |||
for index in indices: | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for idx, index in enumerate(indices, start=1): | |||
self.num_consumed_samples += self.num_replicas | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield index | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
def generate_indices(self) -> List[int]: | |||
""" | |||
@@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler): | |||
return list(range(len(self.dataset))) | |||
def state_dict(self) -> Dict: | |||
states = { | |||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
} | |||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
"during an unfinished iteration." | |||
@@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler): | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
assert len(indices) == self.num_left_samples | |||
for index in indices: | |||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
num_consumed_samples=self.num_consumed_samples) | |||
for idx, index in enumerate(indices, start=1): | |||
self.num_consumed_samples += self.num_replicas | |||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
yield index | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
delattr(self, 'num_consumed_samples_array') | |||
@@ -2,6 +2,9 @@ __all__ = [ | |||
're_instantiate_sampler', | |||
'conversion_between_reproducible_and_unrepeated_sampler' | |||
] | |||
from array import array | |||
from typing import Sequence | |||
from collections import deque | |||
from fastNLP.core.samplers.unrepeated_sampler import * | |||
from fastNLP.core.samplers.reproducible_sampler import * | |||
@@ -39,4 +42,56 @@ def re_instantiate_sampler(sampler, new_sampler_class=None): | |||
all_attributes = vars(sampler) | |||
if new_sampler_class is not None: | |||
return new_sampler_class(**all_attributes) | |||
return type(sampler)(**all_attributes) | |||
return type(sampler)(**all_attributes) | |||
def create_array(length, fill_value) -> array: | |||
""" | |||
根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I' | |||
:param length: | |||
:param fill_value: | |||
:return: | |||
""" | |||
if not isinstance(fill_value, Sequence): | |||
fill_value = [fill_value]*length | |||
if length > 4294967295: | |||
_index_lst = array("L", fill_value) | |||
else: | |||
_index_lst = array("I", fill_value) | |||
return _index_lst | |||
class NumConsumedSamplesArray: | |||
def __init__(self, buffer_size=2000, num_consumed_samples=0): | |||
""" | |||
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | |||
ex: | |||
array = NumConsumedSamplesArray(buffer_size=3) | |||
for i in range(10): | |||
array.push(i) | |||
array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | |||
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | |||
:param buffer_size: 报错多少个历史。 | |||
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | |||
""" | |||
self.count = 0 | |||
self.deque = deque(maxlen=buffer_size) | |||
if num_consumed_samples is not None: | |||
self.push(num_consumed_samples) | |||
self.buffer_size = buffer_size | |||
def __getitem__(self, item): | |||
if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0 | |||
return 0 | |||
assert isinstance(item, int), "Only int index allowed." | |||
assert self.count-len(self.deque)<=item<self.count, f"Only keep {len(self.deque)} history index." | |||
index = len(self.deque) - (self.count - item) | |||
return self.deque[index] | |||
def push(self, num_consumed_samples): | |||
self.deque.append(num_consumed_samples) | |||
self.count += 1 |
@@ -45,6 +45,8 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||
# todo 注释 | |||
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||
# fastNLP 中初始化deque的默认大小 | |||
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||
# todo 注释 直接使用的变量 | |||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | |||
@@ -3,6 +3,7 @@ from array import array | |||
import numpy as np | |||
import pytest | |||
from itertools import chain | |||
from copy import deepcopy | |||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
@@ -30,7 +31,7 @@ class TestReproducibleBatchSampler: | |||
_get_re_batchsampler = dataloader.batch_sampler | |||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
state = _get_re_batchsampler.state_dict() | |||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||
"sampler_type": "RandomBatchSampler"} | |||
# 2. 断点重训,重新生成一个 dataloader; | |||
@@ -413,26 +414,102 @@ class TestBucketedBatchSampler: | |||
@pytest.mark.parametrize('drop_last', [True, False]) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
batch_size = 6 | |||
if num_replica*batch_size > num_samples: | |||
if num_replicas*batch_size > num_samples: | |||
return | |||
num_batch_per_bucket = 10 | |||
samplers = [] | |||
lengths = [] | |||
for i in range(num_replica): | |||
for i in range(num_replicas): | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||
sampler.set_epoch(0) | |||
samplers.append(sampler) | |||
lengths.append(len(list(iter(sampler)))) | |||
assert len(set(lengths))==1 | |||
bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||
bucket_diff = batch_size * num_batch_per_bucket * num_replicas | |||
for bs in zip(*samplers): | |||
diff = max(chain(*bs)) - min(chain(*bs)) | |||
assert diff <= bucket_diff | |||
@pytest.mark.parametrize('shuffle', [True, False]) | |||
@pytest.mark.parametrize('drop_last', [True, False]) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||
""" | |||
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 | |||
偏多 | |||
:return: | |||
""" | |||
batch_size = 6 | |||
num_batch_per_bucket = 10 | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
samplers = [] | |||
for i in range(num_replicas): | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||
samplers.append(sampler) | |||
count = 0 | |||
already_seen_sets = [set()] | |||
already_seen_set = set() | |||
for batchs in zip(*samplers): | |||
batch = chain(*batchs) | |||
already_seen_set.update(batch) | |||
already_seen_sets.append(deepcopy(already_seen_set)) | |||
count += 1 | |||
if count > 3: | |||
break | |||
states = samplers[0].state_dict() | |||
for i in range(len(already_seen_sets)): | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
drop_last=drop_last) | |||
sampler.set_epoch(0) | |||
already_seen_set = deepcopy(already_seen_sets[i]) | |||
for batch in sampler: | |||
already_seen_set.update(batch) | |||
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( | |||
dataset) | |||
# 测试保存之后再次保存 | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
drop_last=drop_last) | |||
sampler.set_epoch(0) | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||
if len(already_seen_sets)<3: | |||
return | |||
already_seen_set = already_seen_sets[2] | |||
count = 0 | |||
for batch in sampler: | |||
already_seen_set.update(batch) | |||
count += 1 | |||
if count > 6: | |||
break | |||
states = sampler.state_dict() | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
drop_last=drop_last) | |||
sampler.load_state_dict(states) | |||
sampler.set_epoch(0) | |||
for batch in sampler: | |||
already_seen_set.update(batch) | |||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) |
@@ -3,6 +3,7 @@ import pytest | |||
from functools import partial | |||
from itertools import chain | |||
from copy import deepcopy | |||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
@@ -180,6 +181,63 @@ class TestRandomSamplerYh: | |||
assert seen <= 1 if pad else seen == 0 | |||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
@pytest.mark.parametrize('shuffle', [True, False]) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | |||
# 测试在 sampler 多生成的时候,可以仍然可以恢复 | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
samplers = [] | |||
for i in range(num_replicas): | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
sampler.set_epoch(0) | |||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||
samplers.append(sampler) | |||
count = 0 | |||
already_seen_sets = [set()] | |||
already_seen_set = set() | |||
for idxes in zip(*samplers): | |||
already_seen_set.update(idxes) | |||
already_seen_sets.append(deepcopy(already_seen_set)) | |||
count += 1 | |||
if count > 3: | |||
break | |||
states = samplers[0].state_dict() | |||
for i in range(len(already_seen_sets)): | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
already_seen_set = deepcopy(already_seen_sets[i]) | |||
for batch in sampler: | |||
already_seen_set.add(batch) | |||
assert len(already_seen_set) == len(dataset) | |||
# 测试保存之后再次保存 | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
sampler.set_epoch(0) | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||
if len(already_seen_sets)<3: | |||
return | |||
already_seen_set = already_seen_sets[2] | |||
count = 0 | |||
for idx in sampler: | |||
already_seen_set.add(idx) | |||
count += 1 | |||
if count > 6: | |||
break | |||
states = sampler.state_dict() | |||
if states['num_consumed_samples_array'] is not None: | |||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
sampler.load_state_dict(states) | |||
sampler.set_epoch(0) | |||
for idx in sampler: | |||
already_seen_set.add(idx) | |||
assert len(already_seen_set)==len(dataset) | |||
class TestRandomSampler: | |||
# 测试单卡; | |||
@@ -386,7 +444,7 @@ class TestSortedSampler: | |||
assert indexes==list(range(num_of_data-1, -1, -1)) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, pad, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
@@ -540,7 +598,7 @@ class TestSequentialSampler: | |||
assert indexes==list(range(num_of_data)) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, pad, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
@@ -25,7 +25,7 @@ class TestUnrepeatedSampler: | |||
indexes = set(sampler) | |||
assert indexes==set(range(num_of_data)) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
@pytest.mark.parametrize('shuffle', [False, True]) | |||
def test_multi(self, num_replica, num_of_data, shuffle): | |||
@@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler: | |||
indexes = list(sampler) | |||
assert indexes==list(range(num_of_data-1, -1, -1)) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
@@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler: | |||
indexes = list(sampler) | |||
assert indexes==list(range(num_of_data)) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||