Browse Source

解决Trainer在断点重训的时候无法实现准确load和保存的问题

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
9d71170bef
11 changed files with 325 additions and 74 deletions
  1. +5
    -5
      fastNLP/core/controllers/evaluator.py
  2. +12
    -4
      fastNLP/core/controllers/trainer.py
  3. +1
    -1
      fastNLP/core/controllers/utils/state.py
  4. +13
    -2
      fastNLP/core/drivers/torch_driver/torch_driver.py
  5. +52
    -28
      fastNLP/core/samplers/reproducible_batch_sampler.py
  6. +36
    -20
      fastNLP/core/samplers/reproducible_sampler.py
  7. +56
    -1
      fastNLP/core/samplers/utils.py
  8. +2
    -0
      fastNLP/envs/env.py
  9. +85
    -8
      tests/core/samplers/test_reproducible_batch_sampler.py
  10. +60
    -2
      tests/core/samplers/test_reproducible_sampler.py
  11. +3
    -3
      tests/core/samplers/test_unrepeated_sampler.py

+ 5
- 5
fastNLP/core/controllers/evaluator.py View File

@@ -364,16 +364,16 @@ class _MetricsWrapper:
else:
args.append(batch)
if not isinstance(outputs, dict):
raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly"
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly"
f" return a dict from your model or use `output_mapping` to convert it into dict type.")
if isinstance(metric, Metric):
auto_param_call(metric.update, batch, *args)
auto_param_call(metric.update, outputs, *args)
elif _is_torchmetrics_metric(metric):
auto_param_call(metric.update, batch, *args)
auto_param_call(metric.update, outputs, *args)
elif _is_allennlp_metric(metric):
auto_param_call(metric.__call__, batch, *args)
auto_param_call(metric.__call__, outputs, *args)
elif _is_paddle_metric(metric):
res = auto_param_call(metric.compute, batch, *args)
res = auto_param_call(metric.compute, outputs, *args)
metric.update(res)

def reset(self):


+ 12
- 4
fastNLP/core/controllers/trainer.py View File

@@ -105,8 +105,8 @@ class Trainer(TrainerEventTrigger):
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`;
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
:param fp16: 是否开启混合精度训练;默认为 False;
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
@@ -325,6 +325,8 @@ class Trainer(TrainerEventTrigger):

try:
while self.cur_epoch_idx < self.n_epochs:
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
self.driver.set_model_mode("train")
self.on_train_epoch_begin()
self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx)
@@ -598,7 +600,9 @@ class Trainer(TrainerEventTrigger):
# 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态;
# 2. trainer_state;
states = {"callback_states": self.on_save_checkpoint(),
"trainer_state": self.trainer_state.state_dict()}
"trainer_state": self.trainer_state.state_dict(),
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0)
}

# 3. validate filter state;
if self.evaluator is not None:
@@ -675,9 +679,13 @@ class Trainer(TrainerEventTrigger):
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \
self.batch_idx_in_epoch
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch

# 5. 恢复所有 callback 的状态;
self.train_stepeckpoint(states["callback_states"])
self.on_load_checkpoint(states["callback_states"])

self.driver.barrier()



+ 1
- 1
fastNLP/core/controllers/utils/state.py View File

@@ -60,7 +60,7 @@ class TrainerState:
cur_epoch_idx: 当前正在运行第几个 epoch;
global_forward_batches: 当前模型总共 forward 了多少个 step;
batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step;
total_batches: 每一个 epoch 会 forward 多少个 step;
num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step;
total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs;
"""
n_epochs: Optional[int] = None # 无论如何重新算


+ 13
- 2
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -194,9 +194,20 @@ class TorchDriver(Driver):
sampler = dataloader_args.sampler
else:
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.")
num_consumed_batches = states.pop('num_consumed_batches')
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict):
states['sampler_states'] = sampler.state_dict()
sampler_states = sampler.state_dict()
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples
# 会造成多余实际消耗的问题。
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)
if num_consumed_samples_array is not None:
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。
try:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
except: # 有可能 batch_size 为 None,就只有损失精度了
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report."
else:
raise RuntimeError(
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.')


+ 52
- 28
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -4,16 +4,18 @@ __all__ = [
]

import math
from array import array
from copy import deepcopy
from typing import Dict, Union, List
from itertools import chain
import os

import numpy as np

from fastNLP.core.dataset import DataSet
from fastNLP.core.log import logger
from .utils import create_array, NumConsumedSamplesArray
from abc import abstractmethod
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE


class ReproducibleBatchSampler:
@@ -34,6 +36,13 @@ class ReproducibleBatchSampler:

@abstractmethod
def state_dict(self):
"""
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。

:return:
"""
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.")

@abstractmethod
@@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.batch_size = batch_size
self.drop_last = drop_last

self.data_idx = kwargs.get("data_idx", 0)
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0)

self.index_list = kwargs.get("index_list", self._iterate_sampler())
self.need_reinitialize = kwargs.get("need_reinitialize", False)
@@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler):
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况;
else:
_index_lst.append(idx)
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295;
if len(_index_lst) > 4294967295:
# 注意 self.index_list 内存放的是全部数据的 index;
# unsigned long
_index_lst = array("L", _index_lst)
else:
# unsigned int
_index_lst = array("I", _index_lst)
_index_lst = create_array(len(_index_lst), _index_lst)
return _index_lst

def __iter__(self):
if self.need_reinitialize:
self.index_list = self._iterate_sampler()
self.data_idx = 0
self.num_consumed_samples = 0
else:
self.need_reinitialize = True

batch = []
if self.data_idx:
index_list = self.index_list[self.data_idx:]
if self.num_consumed_samples:
index_list = self.index_list[self.num_consumed_samples:]
else:
index_list = self.index_list

# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
num_consumed_samples=self.num_consumed_samples)
for idx in index_list:
batch.append(idx)
self.data_idx += 1
if len(batch) == self.batch_size:
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ]
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
self.num_consumed_samples += len(batch)
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield batch
# 需要重置防止边界条件问题
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def __len__(self) -> int:
if self.drop_last:
@@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler):
return (len(self.index_list) + self.batch_size - 1) // self.batch_size

def state_dict(self) -> Dict:
return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__}
states = {
"index_list": deepcopy(self.index_list),
"num_consumed_samples": self.num_consumed_samples,
'sampler_type': self.__class__.__name__
}
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None)
return states

def load_state_dict(self, states: Dict):
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
@@ -128,7 +147,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \
"record and current dataset."
self.index_list = _index_list
self.data_idx = states["data_idx"]
self.num_consumed_samples = states["num_consumed_samples"]
self.need_reinitialize = False

def set_distributed(self, num_replicas, rank, pad=True):
@@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler):
@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size
else:
return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size


class BucketedBatchSampler(ReproducibleBatchSampler):
@@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的


self.batch_size = batch_size
self.num_batch_per_bucket = num_batch_per_bucket
self.shuffle = shuffle
@@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.rank = rank
self.pad = pad

num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \
else len(self.dataset)
if self.drop_last:
assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \
"than the number of replicates multiplied " \
"with batch_size when drop_last=True."
# num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \
# else len(self.dataset)
#
# if self.drop_last:
# assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \
# "than the number of replicates multiplied " \
# "with batch_size when drop_last=True."

return self

@@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))

def __len__(self):
def __len__(self)->int:
"""
返回当前 sampler 还会返回多少个 batch 的数据

@@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]

self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30),
num_consumed_samples=self.num_consumed_samples)
for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')
self.old_batch_size = self.batch_size
self.old_num_batch_per_bucket = self.num_batch_per_bucket
self.old_num_replicas = self.num_replicas
@@ -376,10 +398,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas
}

states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None)
return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0;
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."



+ 36
- 20
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -1,9 +1,14 @@
from typing import Dict, List, Union
import math
import os

import numpy as np

from fastNLP.core.log import logger
from fastNLP.core.dataset import DataSet
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE
from .utils import NumConsumedSamplesArray


__all__ = [
'ReproducibleSampler',
@@ -30,6 +35,13 @@ class ReproducibleSampler:
raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.")

def state_dict(self):
"""
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。

:return:
"""
raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.")

def load_state_dict(self, states):
@@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler):
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

for index in indices:
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def generate_indices(self) -> List[int]:
"""
@@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler):
return indices

def state_dict(self) -> Dict:
states = {
'seed': self.seed,
'epoch': self.epoch,
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'shuffle': self.shuffle
}
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0;
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."

@@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler):
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
@@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler):
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

for index in indices:
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')

def generate_indices(self) -> List[int]:
"""
@@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler):
return list(range(len(self.dataset)))

def state_dict(self) -> Dict:
states = {
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
}
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}
return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0;
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."

@@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler):
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

for index in indices:
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000),
num_consumed_samples=self.num_consumed_samples)
for idx, index in enumerate(indices, start=1):
self.num_consumed_samples += self.num_replicas
self.num_consumed_samples_array.push(self.num_consumed_samples)
yield index
self.during_iter = False
self.num_consumed_samples = 0
delattr(self, 'num_consumed_samples_array')


+ 56
- 1
fastNLP/core/samplers/utils.py View File

@@ -2,6 +2,9 @@ __all__ = [
're_instantiate_sampler',
'conversion_between_reproducible_and_unrepeated_sampler'
]
from array import array
from typing import Sequence
from collections import deque

from fastNLP.core.samplers.unrepeated_sampler import *
from fastNLP.core.samplers.reproducible_sampler import *
@@ -39,4 +42,56 @@ def re_instantiate_sampler(sampler, new_sampler_class=None):
all_attributes = vars(sampler)
if new_sampler_class is not None:
return new_sampler_class(**all_attributes)
return type(sampler)(**all_attributes)
return type(sampler)(**all_attributes)


def create_array(length, fill_value) -> array:
"""
根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I'

:param length:
:param fill_value:
:return:
"""
if not isinstance(fill_value, Sequence):
fill_value = [fill_value]*length

if length > 4294967295:
_index_lst = array("L", fill_value)
else:
_index_lst = array("I", fill_value)
return _index_lst


class NumConsumedSamplesArray:
def __init__(self, buffer_size=2000, num_consumed_samples=0):
"""
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少
ex:
array = NumConsumedSamplesArray(buffer_size=3)
for i in range(10):
array.push(i)

array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9]

:param buffer_size: 报错多少个历史。
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。
"""
self.count = 0
self.deque = deque(maxlen=buffer_size)
if num_consumed_samples is not None:
self.push(num_consumed_samples)
self.buffer_size = buffer_size

def __getitem__(self, item):
if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0
return 0
assert isinstance(item, int), "Only int index allowed."
assert self.count-len(self.deque)<=item<self.count, f"Only keep {len(self.deque)} history index."
index = len(self.deque) - (self.count - item)
return self.deque[index]

def push(self, num_consumed_samples):
self.deque.append(num_consumed_samples)
self.count += 1

+ 2
- 0
fastNLP/envs/env.py View File

@@ -45,6 +45,8 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK'
# todo 注释
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH"

# fastNLP 中初始化deque的默认大小
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE'

# todo 注释 直接使用的变量
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar"


+ 85
- 8
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -3,6 +3,7 @@ from array import array
import numpy as np
import pytest
from itertools import chain
from copy import deepcopy

from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
@@ -30,7 +31,7 @@ class TestReproducibleBatchSampler:
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size,
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
"sampler_type": "RandomBatchSampler"}

# 2. 断点重训,重新生成一个 dataloader;
@@ -413,26 +414,102 @@ class TestBucketedBatchSampler:
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replica', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2):
@pytest.mark.parametrize('num_replicas', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2):
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replica*batch_size > num_samples:
if num_replicas*batch_size > num_samples:
return
num_batch_per_bucket = 10
samplers = []
lengths = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replica, rank=i, pad=pad)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
sampler.set_epoch(0)
samplers.append(sampler)
lengths.append(len(list(iter(sampler))))
assert len(set(lengths))==1
bucket_diff = batch_size * num_batch_per_bucket * num_replica
bucket_diff = batch_size * num_batch_per_bucket * num_replicas

for bs in zip(*samplers):
diff = max(chain(*bs)) - min(chain(*bs))
assert diff <= bucket_diff

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
"""
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能
偏多

:return:
"""
batch_size = 6
num_batch_per_bucket = 10
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
for i in range(num_replicas):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last)

sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for batchs in zip(*samplers):
batch = chain(*batchs)
already_seen_set.update(batch)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len(
dataset)

# 测试保存之后再次保存
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for batch in sampler:
already_seen_set.update(batch)
count += 1
if count > 6:
break

states = sampler.state_dict()
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle,
drop_last=drop_last)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for batch in sampler:
already_seen_set.update(batch)

assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)

+ 60
- 2
tests/core/samplers/test_reproducible_sampler.py View File

@@ -3,6 +3,7 @@ import pytest

from functools import partial
from itertools import chain
from copy import deepcopy

from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler
from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -180,6 +181,63 @@ class TestRandomSamplerYh:
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas):
# 测试在 sampler 多生成的时候,可以仍然可以恢复
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
for i in range(num_replicas):
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.set_epoch(0)
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for idxes in zip(*samplers):
already_seen_set.update(idxes)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
sampler = RandomSampler(dataset, shuffle=shuffle)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.add(batch)
assert len(already_seen_set) == len(dataset)
# 测试保存之后再次保存
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.set_epoch(0)
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for idx in sampler:
already_seen_set.add(idx)
count += 1
if count > 6:
break

states = sampler.state_dict()
if states['num_consumed_samples_array'] is not None:
states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
sampler = RandomSampler(dataset, shuffle=shuffle)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for idx in sampler:
already_seen_set.add(idx)

assert len(already_seen_set)==len(dataset)


class TestRandomSampler:
# 测试单卡;
@@ -386,7 +444,7 @@ class TestSortedSampler:
assert indexes==list(range(num_of_data-1, -1, -1))

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
@@ -540,7 +598,7 @@ class TestSequentialSampler:
assert indexes==list(range(num_of_data))

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)


+ 3
- 3
tests/core/samplers/test_unrepeated_sampler.py View File

@@ -25,7 +25,7 @@ class TestUnrepeatedSampler:
indexes = set(sampler)
assert indexes==set(range(num_of_data))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
@pytest.mark.parametrize('shuffle', [False, True])
def test_multi(self, num_replica, num_of_data, shuffle):
@@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler:
indexes = list(sampler)
assert indexes==list(range(num_of_data-1, -1, -1))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
@@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler:
indexes = list(sampler)
assert indexes==list(range(num_of_data))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)


Loading…
Cancel
Save