| @@ -188,17 +188,27 @@ class TorchDriver(Driver): | |||
| num_consumed_batches = states.pop('num_consumed_batches') | |||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
| sampler_states = sampler.state_dict() | |||
| # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
| # 会造成多余实际消耗的问题。 | |||
| # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
| # 会造成多余实际消耗的问题。因为 | |||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
| if num_consumed_samples_array is not None: | |||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
| try: | |||
| if dataloader_args.batch_size is not None: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| except: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
| else: | |||
| if dataloader_args.batch_size is not None: | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| states['sampler_states'] = sampler_states | |||
| else: | |||
| raise RuntimeError( | |||
| @@ -63,15 +63,12 @@ class ClassifyFPreRecMetric(Metric): | |||
| evaluate_result = {} | |||
| # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||
| if self.aggregate_when_get_metric: | |||
| ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||
| tps, fps, fns = zip(*ls) | |||
| _tp, _fp, _fn = Counter(), Counter(), Counter() | |||
| for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||
| for _c in cs: | |||
| c.update(_c) | |||
| else: | |||
| _tp, _fp, _fn = self._tp, self._fp, self._tp | |||
| ls = self.all_gather_object([self._tp, self._fp, self._fn]) | |||
| tps, fps, fns = zip(*ls) | |||
| _tp, _fp, _fn = Counter(), Counter(), Counter() | |||
| for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||
| for _c in cs: | |||
| c.update(_c) | |||
| if not self.only_gross or self.f_type == 'macro': | |||
| tags = set(_fn.keys()) | |||
| @@ -4,7 +4,7 @@ __all__ = [ | |||
| from abc import abstractmethod | |||
| from typing import Union | |||
| from typing import Union, List | |||
| import functools | |||
| from contextlib import contextmanager | |||
| import numpy as np | |||
| @@ -180,3 +180,15 @@ class Metric: | |||
| """ | |||
| for element in self.elements.values(): | |||
| element.to(device) | |||
| def all_gather_object(self, obj, group=None)->List: | |||
| """ | |||
| 给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||
| :param obj: 需要汇总的对象,必须是个 pickable 的对象。 | |||
| :param group: | |||
| :return: -> List[obj0, obj1, ...] 其中 obj0 是rank 0 上的 obj;obj1 是 rank 1 上的 obj... | |||
| """ | |||
| if self.aggregate_when_get_metric: | |||
| return self.backend.all_gather_object(obj, group=group) | |||
| return [obj] | |||
| @@ -264,15 +264,12 @@ class SpanFPreRecMetric(Metric): | |||
| evaluate_result = {} | |||
| # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||
| if self.aggregate_when_get_metric: | |||
| ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||
| tps, fps, fns = zip(*ls) | |||
| _tp, _fp, _fn = Counter(), Counter(), Counter() | |||
| for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||
| for _c in cs: | |||
| c.update(_c) | |||
| else: | |||
| _tp, _fp, _fn = self._tp, self._fp, self._tp | |||
| ls = self.all_gather_object([self._tp, self._fp, self._fn]) | |||
| tps, fps, fns = zip(*ls) | |||
| _tp, _fp, _fn = Counter(), Counter(), Counter() | |||
| for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||
| for _c in cs: | |||
| c.update(_c) | |||
| if not self.only_gross or self.f_type == 'macro': | |||
| tags = set(_fn.keys()) | |||
| @@ -13,9 +13,8 @@ import numpy as np | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.log import logger | |||
| from .utils import create_array, NumConsumedSamplesArray | |||
| from .utils import create_array | |||
| from abc import abstractmethod | |||
| from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
| class ReproducibleBatchSampler: | |||
| @@ -37,9 +36,6 @@ class ReproducibleBatchSampler: | |||
| @abstractmethod | |||
| def state_dict(self): | |||
| """ | |||
| 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 | |||
| 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||
| 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||
| :return: | |||
| """ | |||
| @@ -57,6 +53,14 @@ class ReproducibleBatchSampler: | |||
| def batch_idx_in_epoch(self): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | |||
| @property | |||
| def num_replicas(self): | |||
| return self._num_replicas | |||
| @num_replicas.setter | |||
| def num_replicas(self, value): | |||
| self._num_replicas = value | |||
| class RandomBatchSampler(ReproducibleBatchSampler): | |||
| # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||
| @@ -105,24 +109,21 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| else: | |||
| index_list = self.index_list | |||
| # 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||
| # 暂时弃用。记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||
| # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| # self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
| # num_consumed_samples=self.num_consumed_samples) | |||
| for idx in index_list: | |||
| batch.append(idx) | |||
| if len(batch) == self.batch_size: | |||
| self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield batch | |||
| batch = [] | |||
| if len(batch) > 0 and not self.drop_last: | |||
| self.num_consumed_samples += len(batch) | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield batch | |||
| # 需要重置防止边界条件问题 | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| def __len__(self) -> int: | |||
| if self.drop_last: | |||
| @@ -136,7 +137,6 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| "num_consumed_samples": self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__ | |||
| } | |||
| states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| @@ -327,15 +327,14 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||
| batches = batches[:-1] | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| # 暂时弃用 | |||
| # self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||
| # num_consumed_samples=self.num_consumed_samples) | |||
| for batch in batches: | |||
| self.num_consumed_samples += self.num_replicas * len(batch) | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield list(map(int, batch)) | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| self.old_batch_size = self.batch_size | |||
| self.old_num_batch_per_bucket = self.num_batch_per_bucket | |||
| self.old_num_replicas = self.num_replicas | |||
| @@ -390,8 +389,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
| 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
| 'num_replicas': self.num_replicas, | |||
| 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
| 'num_replicas': self.num_replicas | |||
| } | |||
| return states | |||
| @@ -7,15 +7,11 @@ __all__ = [ | |||
| from typing import Dict, List, Union | |||
| import math | |||
| import os | |||
| import numpy as np | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||
| from .utils import NumConsumedSamplesArray | |||
| class ReproducibleSampler: | |||
| @@ -36,9 +32,6 @@ class ReproducibleSampler: | |||
| def state_dict(self): | |||
| """ | |||
| 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 | |||
| 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||
| 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||
| :return: | |||
| """ | |||
| @@ -54,6 +47,14 @@ class ReproducibleSampler: | |||
| def set_epoch(self, epoch): | |||
| pass | |||
| @property | |||
| def num_repliacs(self): | |||
| return self._num_replicas | |||
| @num_repliacs.setter | |||
| def num_repliacs(self, value): | |||
| self._num_replicas = value | |||
| class RandomSampler(ReproducibleSampler): | |||
| def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||
| @@ -121,15 +122,11 @@ class RandomSampler(ReproducibleSampler): | |||
| indices = indices[self.num_consumed_samples:] | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == self.num_left_samples | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| for idx, index in enumerate(indices, start=1): | |||
| self.num_consumed_samples += self.num_replicas | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield index | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| def generate_indices(self) -> List[int]: | |||
| """ | |||
| @@ -150,8 +147,7 @@ class RandomSampler(ReproducibleSampler): | |||
| def state_dict(self) -> Dict: | |||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||
| 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle} | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| @@ -255,15 +251,11 @@ class SequentialSampler(RandomSampler): | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == self.num_left_samples | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| for idx, index in enumerate(indices, start=1): | |||
| self.num_consumed_samples += self.num_replicas | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield index | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| def generate_indices(self) -> List[int]: | |||
| """ | |||
| @@ -275,8 +267,8 @@ class SequentialSampler(RandomSampler): | |||
| def state_dict(self) -> Dict: | |||
| states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||
| 'length': len(self.dataset), | |||
| 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||
| 'length': len(self.dataset) | |||
| } | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| @@ -346,13 +338,9 @@ class SortedSampler(SequentialSampler): | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == self.num_left_samples | |||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||
| num_consumed_samples=self.num_consumed_samples) | |||
| for idx, index in enumerate(indices, start=1): | |||
| self.num_consumed_samples += self.num_replicas | |||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||
| yield index | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| delattr(self, 'num_consumed_samples_array') | |||
| @@ -43,6 +43,8 @@ class NumConsumedSamplesArray: | |||
| array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | |||
| array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | |||
| 暂时由于 sampler 的 batch 都是规整的,先保留 | |||
| :param buffer_size: 报错多少个历史。 | |||
| :param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | |||
| """ | |||
| @@ -45,9 +45,6 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||
| # todo 注释 | |||
| FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||
| # fastNLP 中初始化deque的默认大小 | |||
| FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||
| # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | |||
| # 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 | |||
| FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' | |||
| @@ -445,8 +445,7 @@ class TestBucketedBatchSampler: | |||
| @pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||
| def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||
| """ | |||
| 测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 | |||
| 偏多 | |||
| 测试是否能够正确地恢复使用过的(forward)数据 | |||
| :return: | |||
| """ | |||
| @@ -454,6 +453,7 @@ class TestBucketedBatchSampler: | |||
| num_batch_per_bucket = 10 | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| samplers = [] | |||
| num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas)) | |||
| for i in range(num_replicas): | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||
| @@ -472,8 +472,7 @@ class TestBucketedBatchSampler: | |||
| break | |||
| states = samplers[0].state_dict() | |||
| for i in range(len(already_seen_sets)): | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||
| states['num_consumed_samples'] = num_consumed_samples_array[i] | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
| drop_last=drop_last) | |||
| @@ -489,8 +488,7 @@ class TestBucketedBatchSampler: | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
| drop_last=drop_last) | |||
| sampler.set_epoch(0) | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||
| states['num_consumed_samples'] = num_consumed_samples_array[2] | |||
| if len(already_seen_sets)<3: | |||
| return | |||
| already_seen_set = already_seen_sets[2] | |||
| @@ -502,8 +500,8 @@ class TestBucketedBatchSampler: | |||
| break | |||
| states = sampler.state_dict() | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||
| num_consumed_samples_array = list(range(len(dataset))) | |||
| states['num_consumed_samples'] = num_consumed_samples_array[count] | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||
| drop_last=drop_last) | |||
| @@ -186,9 +186,11 @@ class TestRandomSamplerYh: | |||
| @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
| @pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||
| def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | |||
| # def test_num_consumed_samples_array(self, shuffle=True, pad=True, num_samples=100, num_replicas=2): | |||
| # 测试在 sampler 多生成的时候,可以仍然可以恢复 | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| samplers = [] | |||
| num_consumed_samples_array = list(range(0, len(dataset)+num_replicas, num_replicas)) | |||
| for i in range(num_replicas): | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| sampler.set_epoch(0) | |||
| @@ -205,8 +207,7 @@ class TestRandomSamplerYh: | |||
| break | |||
| states = samplers[0].state_dict() | |||
| for i in range(len(already_seen_sets)): | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||
| states['num_consumed_samples'] = num_consumed_samples_array[i] | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| already_seen_set = deepcopy(already_seen_sets[i]) | |||
| for batch in sampler: | |||
| @@ -215,12 +216,11 @@ class TestRandomSamplerYh: | |||
| # 测试保存之后再次保存 | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| sampler.set_epoch(0) | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||
| if len(already_seen_sets)<3: | |||
| return | |||
| already_seen_set = already_seen_sets[2] | |||
| count = 0 | |||
| num_consumed_samples_array = list(range(0, num_samples)) | |||
| for idx in sampler: | |||
| already_seen_set.add(idx) | |||
| count += 1 | |||
| @@ -228,8 +228,7 @@ class TestRandomSamplerYh: | |||
| break | |||
| states = sampler.state_dict() | |||
| if states['num_consumed_samples_array'] is not None: | |||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||
| states['num_consumed_samples'] = num_consumed_samples_array[count] | |||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||
| sampler.load_state_dict(states) | |||
| sampler.set_epoch(0) | |||
| @@ -446,12 +445,12 @@ class TestSortedSampler: | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, pad, num_replica, num_of_data): | |||
| def test_multi(self, pad, num_replicas, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| for i in range(num_replicas): | |||
| sampler = SortedSampler(dataset=data, length=data.data) | |||
| sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
| sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||
| samplers.append(sampler) | |||
| # 保证顺序是没乱的 | |||
| @@ -600,12 +599,12 @@ class TestSequentialSampler: | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, pad, num_replica, num_of_data): | |||
| def test_multi(self, pad, num_replicas, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| for i in range(num_replicas): | |||
| sampler = SequentialSampler(dataset=data) | |||
| sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
| sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||
| samplers.append(sampler) | |||
| # 保证顺序是没乱的 | |||