@@ -364,16 +364,16 @@ class _MetricsWrapper: | |||||
else: | else: | ||||
args.append(batch) | args.append(batch) | ||||
if not isinstance(outputs, dict): | if not isinstance(outputs, dict): | ||||
raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" | |||||
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | f" return a dict from your model or use `output_mapping` to convert it into dict type.") | ||||
if isinstance(metric, Metric): | if isinstance(metric, Metric): | ||||
auto_param_call(metric.update, batch, *args) | |||||
auto_param_call(metric.update, outputs, *args) | |||||
elif _is_torchmetrics_metric(metric): | elif _is_torchmetrics_metric(metric): | ||||
auto_param_call(metric.update, batch, *args) | |||||
auto_param_call(metric.update, outputs, *args) | |||||
elif _is_allennlp_metric(metric): | elif _is_allennlp_metric(metric): | ||||
auto_param_call(metric.__call__, batch, *args) | |||||
auto_param_call(metric.__call__, outputs, *args) | |||||
elif _is_paddle_metric(metric): | elif _is_paddle_metric(metric): | ||||
res = auto_param_call(metric.compute, batch, *args) | |||||
res = auto_param_call(metric.compute, outputs, *args) | |||||
metric.update(res) | metric.update(res) | ||||
def reset(self): | def reset(self): | ||||
@@ -105,8 +105,8 @@ class Trainer(TrainerEventTrigger): | |||||
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | ||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | ||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | ||||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | ||||
:param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | ||||
@@ -325,6 +325,8 @@ class Trainer(TrainerEventTrigger): | |||||
try: | try: | ||||
while self.cur_epoch_idx < self.n_epochs: | while self.cur_epoch_idx < self.n_epochs: | ||||
# 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||||
self.driver.set_model_mode("train") | self.driver.set_model_mode("train") | ||||
self.on_train_epoch_begin() | self.on_train_epoch_begin() | ||||
self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | ||||
@@ -598,7 +600,9 @@ class Trainer(TrainerEventTrigger): | |||||
# 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | ||||
# 2. trainer_state; | # 2. trainer_state; | ||||
states = {"callback_states": self.on_save_checkpoint(), | states = {"callback_states": self.on_save_checkpoint(), | ||||
"trainer_state": self.trainer_state.state_dict()} | |||||
"trainer_state": self.trainer_state.state_dict(), | |||||
'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | |||||
} | |||||
# 3. validate filter state; | # 3. validate filter state; | ||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
@@ -675,9 +679,13 @@ class Trainer(TrainerEventTrigger): | |||||
# 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | ||||
# '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | ||||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | ||||
self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||||
self.batch_idx_in_epoch | |||||
# 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||||
# 5. 恢复所有 callback 的状态; | # 5. 恢复所有 callback 的状态; | ||||
self.train_stepeckpoint(states["callback_states"]) | |||||
self.on_load_checkpoint(states["callback_states"]) | |||||
self.driver.barrier() | self.driver.barrier() | ||||
@@ -60,7 +60,7 @@ class TrainerState: | |||||
cur_epoch_idx: 当前正在运行第几个 epoch; | cur_epoch_idx: 当前正在运行第几个 epoch; | ||||
global_forward_batches: 当前模型总共 forward 了多少个 step; | global_forward_batches: 当前模型总共 forward 了多少个 step; | ||||
batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | ||||
total_batches: 每一个 epoch 会 forward 多少个 step; | |||||
num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||||
total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | ||||
""" | """ | ||||
n_epochs: Optional[int] = None # 无论如何重新算 | n_epochs: Optional[int] = None # 无论如何重新算 | ||||
@@ -194,9 +194,20 @@ class TorchDriver(Driver): | |||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
else: | else: | ||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | ||||
num_consumed_batches = states.pop('num_consumed_batches') | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
states['sampler_states'] = sampler.state_dict() | |||||
sampler_states = sampler.state_dict() | |||||
# 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
# 会造成多余实际消耗的问题。 | |||||
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
if num_consumed_samples_array is not None: | |||||
if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
try: | |||||
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | ||||
@@ -4,16 +4,18 @@ __all__ = [ | |||||
] | ] | ||||
import math | import math | ||||
from array import array | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from typing import Dict, Union, List | from typing import Dict, Union, List | ||||
from itertools import chain | from itertools import chain | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .utils import create_array, NumConsumedSamplesArray | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
class ReproducibleBatchSampler: | class ReproducibleBatchSampler: | ||||
@@ -34,6 +36,13 @@ class ReproducibleBatchSampler: | |||||
@abstractmethod | @abstractmethod | ||||
def state_dict(self): | def state_dict(self): | ||||
""" | |||||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 | |||||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||||
:return: | |||||
""" | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | ||||
@abstractmethod | @abstractmethod | ||||
@@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
self.data_idx = kwargs.get("data_idx", 0) | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) | |||||
self.index_list = kwargs.get("index_list", self._iterate_sampler()) | self.index_list = kwargs.get("index_list", self._iterate_sampler()) | ||||
self.need_reinitialize = kwargs.get("need_reinitialize", False) | self.need_reinitialize = kwargs.get("need_reinitialize", False) | ||||
@@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
# 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | ||||
else: | else: | ||||
_index_lst.append(idx) | _index_lst.append(idx) | ||||
# 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||||
if len(_index_lst) > 4294967295: | |||||
# 注意 self.index_list 内存放的是全部数据的 index; | |||||
# unsigned long | |||||
_index_lst = array("L", _index_lst) | |||||
else: | |||||
# unsigned int | |||||
_index_lst = array("I", _index_lst) | |||||
_index_lst = create_array(len(_index_lst), _index_lst) | |||||
return _index_lst | return _index_lst | ||||
def __iter__(self): | def __iter__(self): | ||||
if self.need_reinitialize: | if self.need_reinitialize: | ||||
self.index_list = self._iterate_sampler() | self.index_list = self._iterate_sampler() | ||||
self.data_idx = 0 | |||||
self.num_consumed_samples = 0 | |||||
else: | else: | ||||
self.need_reinitialize = True | self.need_reinitialize = True | ||||
batch = [] | batch = [] | ||||
if self.data_idx: | |||||
index_list = self.index_list[self.data_idx:] | |||||
if self.num_consumed_samples: | |||||
index_list = self.index_list[self.num_consumed_samples:] | |||||
else: | else: | ||||
index_list = self.index_list | index_list = self.index_list | ||||
# 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||||
# batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | |||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx in index_list: | for idx in index_list: | ||||
batch.append(idx) | batch.append(idx) | ||||
self.data_idx += 1 | |||||
if len(batch) == self.batch_size: | if len(batch) == self.batch_size: | ||||
self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | |||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield batch | yield batch | ||||
batch = [] | batch = [] | ||||
if len(batch) > 0 and not self.drop_last: | if len(batch) > 0 and not self.drop_last: | ||||
self.num_consumed_samples += len(batch) | |||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield batch | yield batch | ||||
# 需要重置防止边界条件问题 | |||||
self.num_consumed_samples = 0 | |||||
delattr(self, 'num_consumed_samples_array') | |||||
def __len__(self) -> int: | def __len__(self) -> int: | ||||
if self.drop_last: | if self.drop_last: | ||||
@@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size | return (len(self.index_list) + self.batch_size - 1) // self.batch_size | ||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||||
states = { | |||||
"index_list": deepcopy(self.index_list), | |||||
"num_consumed_samples": self.num_consumed_samples, | |||||
'sampler_type': self.__class__.__name__ | |||||
} | |||||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||||
return states | |||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | ||||
@@ -128,7 +147,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | ||||
"record and current dataset." | "record and current dataset." | ||||
self.index_list = _index_list | self.index_list = _index_list | ||||
self.data_idx = states["data_idx"] | |||||
self.num_consumed_samples = states["num_consumed_samples"] | |||||
self.need_reinitialize = False | self.need_reinitialize = False | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
@@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
@property | @property | ||||
def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
if self.drop_last: | if self.drop_last: | ||||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size | |||||
return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size | |||||
else: | else: | ||||
return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | ||||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||||
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | |||||
class BucketedBatchSampler(ReproducibleBatchSampler): | class BucketedBatchSampler(ReproducibleBatchSampler): | ||||
@@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.num_batch_per_bucket = num_batch_per_bucket | self.num_batch_per_bucket = num_batch_per_bucket | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
@@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
self.rank = rank | self.rank = rank | ||||
self.pad = pad | self.pad = pad | ||||
num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||||
else len(self.dataset) | |||||
if self.drop_last: | |||||
assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||||
"than the number of replicates multiplied " \ | |||||
"with batch_size when drop_last=True." | |||||
# num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||||
# else len(self.dataset) | |||||
# | |||||
# if self.drop_last: | |||||
# assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||||
# "than the number of replicates multiplied " \ | |||||
# "with batch_size when drop_last=True." | |||||
return self | return self | ||||
@@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | ||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | ||||
def __len__(self): | |||||
def __len__(self)->int: | |||||
""" | """ | ||||
返回当前 sampler 还会返回多少个 batch 的数据 | 返回当前 sampler 还会返回多少个 batch 的数据 | ||||
@@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | ||||
batches = batches[:-1] | batches = batches[:-1] | ||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for batch in batches: | for batch in batches: | ||||
self.num_consumed_samples += self.num_replicas * len(batch) | self.num_consumed_samples += self.num_replicas * len(batch) | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield list(map(int, batch)) | yield list(map(int, batch)) | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
self.old_batch_size = self.batch_size | self.old_batch_size = self.batch_size | ||||
self.old_num_batch_per_bucket = self.num_batch_per_bucket | self.old_num_batch_per_bucket = self.num_batch_per_bucket | ||||
self.old_num_replicas = self.num_replicas | self.old_num_replicas = self.num_replicas | ||||
@@ -376,10 +398,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
'num_batch_per_bucket': self.num_batch_per_bucket, | 'num_batch_per_bucket': self.num_batch_per_bucket, | ||||
'num_replicas': self.num_replicas | 'num_replicas': self.num_replicas | ||||
} | } | ||||
states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
@@ -1,9 +1,14 @@ | |||||
from typing import Dict, List, Union | from typing import Dict, List, Union | ||||
import math | import math | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
from .utils import NumConsumedSamplesArray | |||||
__all__ = [ | __all__ = [ | ||||
'ReproducibleSampler', | 'ReproducibleSampler', | ||||
@@ -30,6 +35,13 @@ class ReproducibleSampler: | |||||
raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | ||||
def state_dict(self): | def state_dict(self): | ||||
""" | |||||
由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 | |||||
正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||||
了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||||
:return: | |||||
""" | |||||
raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | ||||
def load_state_dict(self, states): | def load_state_dict(self, states): | ||||
@@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler): | |||||
indices = indices[self.num_consumed_samples:] | indices = indices[self.num_consumed_samples:] | ||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
for index in indices: | |||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | |||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
""" | """ | ||||
@@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler): | |||||
return indices | return indices | ||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = { | |||||
'seed': self.seed, | |||||
'epoch': self.epoch, | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
'shuffle': self.shuffle | |||||
} | |||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
@@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler): | |||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
if self.shuffle != states['shuffle']: | if self.shuffle != states['shuffle']: | ||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | ||||
@@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler): | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
for index in indices: | |||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | |||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
""" | """ | ||||
@@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler): | |||||
return list(range(len(self.dataset))) | return list(range(len(self.dataset))) | ||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
states = { | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
} | |||||
states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
return states | return states | ||||
def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
@@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler): | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
for index in indices: | |||||
self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
num_consumed_samples=self.num_consumed_samples) | |||||
for idx, index in enumerate(indices, start=1): | |||||
self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
yield index | yield index | ||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
delattr(self, 'num_consumed_samples_array') | |||||
@@ -2,6 +2,9 @@ __all__ = [ | |||||
're_instantiate_sampler', | 're_instantiate_sampler', | ||||
'conversion_between_reproducible_and_unrepeated_sampler' | 'conversion_between_reproducible_and_unrepeated_sampler' | ||||
] | ] | ||||
from array import array | |||||
from typing import Sequence | |||||
from collections import deque | |||||
from fastNLP.core.samplers.unrepeated_sampler import * | from fastNLP.core.samplers.unrepeated_sampler import * | ||||
from fastNLP.core.samplers.reproducible_sampler import * | from fastNLP.core.samplers.reproducible_sampler import * | ||||
@@ -39,4 +42,56 @@ def re_instantiate_sampler(sampler, new_sampler_class=None): | |||||
all_attributes = vars(sampler) | all_attributes = vars(sampler) | ||||
if new_sampler_class is not None: | if new_sampler_class is not None: | ||||
return new_sampler_class(**all_attributes) | return new_sampler_class(**all_attributes) | ||||
return type(sampler)(**all_attributes) | |||||
return type(sampler)(**all_attributes) | |||||
def create_array(length, fill_value) -> array: | |||||
""" | |||||
根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I' | |||||
:param length: | |||||
:param fill_value: | |||||
:return: | |||||
""" | |||||
if not isinstance(fill_value, Sequence): | |||||
fill_value = [fill_value]*length | |||||
if length > 4294967295: | |||||
_index_lst = array("L", fill_value) | |||||
else: | |||||
_index_lst = array("I", fill_value) | |||||
return _index_lst | |||||
class NumConsumedSamplesArray: | |||||
def __init__(self, buffer_size=2000, num_consumed_samples=0): | |||||
""" | |||||
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | |||||
ex: | |||||
array = NumConsumedSamplesArray(buffer_size=3) | |||||
for i in range(10): | |||||
array.push(i) | |||||
array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | |||||
array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | |||||
:param buffer_size: 报错多少个历史。 | |||||
:param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | |||||
""" | |||||
self.count = 0 | |||||
self.deque = deque(maxlen=buffer_size) | |||||
if num_consumed_samples is not None: | |||||
self.push(num_consumed_samples) | |||||
self.buffer_size = buffer_size | |||||
def __getitem__(self, item): | |||||
if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0 | |||||
return 0 | |||||
assert isinstance(item, int), "Only int index allowed." | |||||
assert self.count-len(self.deque)<=item<self.count, f"Only keep {len(self.deque)} history index." | |||||
index = len(self.deque) - (self.count - item) | |||||
return self.deque[index] | |||||
def push(self, num_consumed_samples): | |||||
self.deque.append(num_consumed_samples) | |||||
self.count += 1 |
@@ -45,6 +45,8 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||||
# todo 注释 | # todo 注释 | ||||
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | ||||
# fastNLP 中初始化deque的默认大小 | |||||
FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||||
# todo 注释 直接使用的变量 | # todo 注释 直接使用的变量 | ||||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | ||||
@@ -3,6 +3,7 @@ from array import array | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from itertools import chain | from itertools import chain | ||||
from copy import deepcopy | |||||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
@@ -30,7 +31,7 @@ class TestReproducibleBatchSampler: | |||||
_get_re_batchsampler = dataloader.batch_sampler | _get_re_batchsampler = dataloader.batch_sampler | ||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | assert isinstance(_get_re_batchsampler, RandomBatchSampler) | ||||
state = _get_re_batchsampler.state_dict() | state = _get_re_batchsampler.state_dict() | ||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||||
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
"sampler_type": "RandomBatchSampler"} | "sampler_type": "RandomBatchSampler"} | ||||
# 2. 断点重训,重新生成一个 dataloader; | # 2. 断点重训,重新生成一个 dataloader; | ||||
@@ -413,26 +414,102 @@ class TestBucketedBatchSampler: | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | @pytest.mark.parametrize('drop_last', [True, False]) | ||||
@pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | dataset = DatasetWithVaryLength(num_of_data=num_samples) | ||||
batch_size = 6 | batch_size = 6 | ||||
if num_replica*batch_size > num_samples: | |||||
if num_replicas*batch_size > num_samples: | |||||
return | return | ||||
num_batch_per_bucket = 10 | num_batch_per_bucket = 10 | ||||
samplers = [] | samplers = [] | ||||
lengths = [] | lengths = [] | ||||
for i in range(num_replica): | |||||
for i in range(num_replicas): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | ||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | ||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||||
sampler.set_epoch(0) | sampler.set_epoch(0) | ||||
samplers.append(sampler) | samplers.append(sampler) | ||||
lengths.append(len(list(iter(sampler)))) | lengths.append(len(list(iter(sampler)))) | ||||
assert len(set(lengths))==1 | assert len(set(lengths))==1 | ||||
bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||||
bucket_diff = batch_size * num_batch_per_bucket * num_replicas | |||||
for bs in zip(*samplers): | for bs in zip(*samplers): | ||||
diff = max(chain(*bs)) - min(chain(*bs)) | diff = max(chain(*bs)) - min(chain(*bs)) | ||||
assert diff <= bucket_diff | assert diff <= bucket_diff | ||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||||
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
""" | |||||
测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 | |||||
偏多 | |||||
:return: | |||||
""" | |||||
batch_size = 6 | |||||
num_batch_per_bucket = 10 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
samplers = [] | |||||
for i in range(num_replicas): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
count = 0 | |||||
already_seen_sets = [set()] | |||||
already_seen_set = set() | |||||
for batchs in zip(*samplers): | |||||
batch = chain(*batchs) | |||||
already_seen_set.update(batch) | |||||
already_seen_sets.append(deepcopy(already_seen_set)) | |||||
count += 1 | |||||
if count > 3: | |||||
break | |||||
states = samplers[0].state_dict() | |||||
for i in range(len(already_seen_sets)): | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
already_seen_set = deepcopy(already_seen_sets[i]) | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( | |||||
dataset) | |||||
# 测试保存之后再次保存 | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||||
if len(already_seen_sets)<3: | |||||
return | |||||
already_seen_set = already_seen_sets[2] | |||||
count = 0 | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
count += 1 | |||||
if count > 6: | |||||
break | |||||
states = sampler.state_dict() | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.load_state_dict(states) | |||||
sampler.set_epoch(0) | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) |
@@ -3,6 +3,7 @@ import pytest | |||||
from functools import partial | from functools import partial | ||||
from itertools import chain | from itertools import chain | ||||
from copy import deepcopy | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -180,6 +181,63 @@ class TestRandomSamplerYh: | |||||
assert seen <= 1 if pad else seen == 0 | assert seen <= 1 if pad else seen == 0 | ||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | assert seen_in_other_rank<=1 # 因为pad可能重复 | ||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||||
def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | |||||
# 测试在 sampler 多生成的时候,可以仍然可以恢复 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
samplers = [] | |||||
for i in range(num_replicas): | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
count = 0 | |||||
already_seen_sets = [set()] | |||||
already_seen_set = set() | |||||
for idxes in zip(*samplers): | |||||
already_seen_set.update(idxes) | |||||
already_seen_sets.append(deepcopy(already_seen_set)) | |||||
count += 1 | |||||
if count > 3: | |||||
break | |||||
states = samplers[0].state_dict() | |||||
for i in range(len(already_seen_sets)): | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
already_seen_set = deepcopy(already_seen_sets[i]) | |||||
for batch in sampler: | |||||
already_seen_set.add(batch) | |||||
assert len(already_seen_set) == len(dataset) | |||||
# 测试保存之后再次保存 | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
sampler.set_epoch(0) | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||||
if len(already_seen_sets)<3: | |||||
return | |||||
already_seen_set = already_seen_sets[2] | |||||
count = 0 | |||||
for idx in sampler: | |||||
already_seen_set.add(idx) | |||||
count += 1 | |||||
if count > 6: | |||||
break | |||||
states = sampler.state_dict() | |||||
if states['num_consumed_samples_array'] is not None: | |||||
states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
sampler.load_state_dict(states) | |||||
sampler.set_epoch(0) | |||||
for idx in sampler: | |||||
already_seen_set.add(idx) | |||||
assert len(already_seen_set)==len(dataset) | |||||
class TestRandomSampler: | class TestRandomSampler: | ||||
# 测试单卡; | # 测试单卡; | ||||
@@ -386,7 +444,7 @@ class TestSortedSampler: | |||||
assert indexes==list(range(num_of_data-1, -1, -1)) | assert indexes==list(range(num_of_data-1, -1, -1)) | ||||
@pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, pad, num_replica, num_of_data): | def test_multi(self, pad, num_replica, num_of_data): | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
@@ -540,7 +598,7 @@ class TestSequentialSampler: | |||||
assert indexes==list(range(num_of_data)) | assert indexes==list(range(num_of_data)) | ||||
@pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, pad, num_replica, num_of_data): | def test_multi(self, pad, num_replica, num_of_data): | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
@@ -25,7 +25,7 @@ class TestUnrepeatedSampler: | |||||
indexes = set(sampler) | indexes = set(sampler) | ||||
assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
@pytest.mark.parametrize('shuffle', [False, True]) | @pytest.mark.parametrize('shuffle', [False, True]) | ||||
def test_multi(self, num_replica, num_of_data, shuffle): | def test_multi(self, num_replica, num_of_data, shuffle): | ||||
@@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler: | |||||
indexes = list(sampler) | indexes = list(sampler) | ||||
assert indexes==list(range(num_of_data-1, -1, -1)) | assert indexes==list(range(num_of_data-1, -1, -1)) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replica, num_of_data): | def test_multi(self, num_replica, num_of_data): | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
@@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler: | |||||
indexes = list(sampler) | indexes = list(sampler) | ||||
assert indexes==list(range(num_of_data)) | assert indexes==list(range(num_of_data)) | ||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replica, num_of_data): | def test_multi(self, num_replica, num_of_data): | ||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||