From d813f31f9ab7b1ea93ab6867ced1288ded05207c Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 16 Apr 2022 15:22:44 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=88=A0=E9=99=A4sampler=E4=B8=ADnum=5Fconsu?= =?UTF-8?q?med=5Fsamples=5Farray=EF=BC=8C=E8=BD=AC=E4=B8=BA=E9=80=9A?= =?UTF-8?q?=E8=BF=87batch=5Fsize=E7=AD=89=E8=BF=9B=E8=A1=8C=E6=8D=A2?= =?UTF-8?q?=E7=AE=97=E3=80=822.=E5=9C=A8Metric=E4=B8=AD=E6=96=B0=E5=A2=9Ea?= =?UTF-8?q?ll=5Fgather=5Fobject=E6=8E=A5=E5=8F=A3=EF=BC=8C=E6=96=B9?= =?UTF-8?q?=E4=BE=BF=E8=AF=84=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/drivers/torch_driver/torch_driver.py | 18 +++++++--- .../metrics/classify_f1_pre_rec_metric.py | 15 ++++---- fastNLP/core/metrics/metric.py | 14 +++++++- .../core/metrics/span_f1_pre_rec_metric.py | 15 ++++---- .../samplers/reproducible_batch_sampler.py | 35 +++++++++---------- fastNLP/core/samplers/reproducible_sampler.py | 34 ++++++------------ fastNLP/core/samplers/utils.py | 2 ++ fastNLP/envs/env.py | 3 -- .../test_reproducible_batch_sampler.py | 14 ++++---- .../samplers/test_reproducible_sampler.py | 23 ++++++------ 10 files changed, 86 insertions(+), 87 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 2a04e62f..5638b4c6 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -188,17 +188,27 @@ class TorchDriver(Driver): num_consumed_batches = states.pop('num_consumed_batches') if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): sampler_states = sampler.state_dict() - # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples - # 会造成多余实际消耗的问题。 + # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples + # 会造成多余实际消耗的问题。因为 num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) if num_consumed_samples_array is not None: if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 - try: + if dataloader_args.batch_size is not None: num_consumed_batches = num_consumed_batches * dataloader_args.batch_size - except: # 有可能 batch_size 为 None,就只有损失精度了 + else: # 有可能 batch_size 为 None,就只有损失精度了 + logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") num_consumed_batches = sampler_states['num_consumed_samples'] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + else: + if dataloader_args.batch_size is not None: + sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ + * num_consumed_batches + else: + logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " + "it may cause missing some samples when reload.") + states['sampler_states'] = sampler_states else: raise RuntimeError( diff --git a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py index 87b022c9..2c71602d 100644 --- a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py @@ -63,15 +63,12 @@ class ClassifyFPreRecMetric(Metric): evaluate_result = {} # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 - if self.aggregate_when_get_metric: - ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) - tps, fps, fns = zip(*ls) - _tp, _fp, _fn = Counter(), Counter(), Counter() - for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): - for _c in cs: - c.update(_c) - else: - _tp, _fp, _fn = self._tp, self._fp, self._tp + ls = self.all_gather_object([self._tp, self._fp, self._fn]) + tps, fps, fns = zip(*ls) + _tp, _fp, _fn = Counter(), Counter(), Counter() + for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): + for _c in cs: + c.update(_c) if not self.only_gross or self.f_type == 'macro': tags = set(_fn.keys()) diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index ef4839df..b5fc44dd 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -4,7 +4,7 @@ __all__ = [ from abc import abstractmethod -from typing import Union +from typing import Union, List import functools from contextlib import contextmanager import numpy as np @@ -180,3 +180,15 @@ class Metric: """ for element in self.elements.values(): element.to(device) + + def all_gather_object(self, obj, group=None)->List: + """ + 给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 + + :param obj: 需要汇总的对象,必须是个 pickable 的对象。 + :param group: + :return: -> List[obj0, obj1, ...] 其中 obj0 是rank 0 上的 obj;obj1 是 rank 1 上的 obj... + """ + if self.aggregate_when_get_metric: + return self.backend.all_gather_object(obj, group=group) + return [obj] \ No newline at end of file diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py index d847da41..12d86a31 100644 --- a/fastNLP/core/metrics/span_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py @@ -264,15 +264,12 @@ class SpanFPreRecMetric(Metric): evaluate_result = {} # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 - if self.aggregate_when_get_metric: - ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) - tps, fps, fns = zip(*ls) - _tp, _fp, _fn = Counter(), Counter(), Counter() - for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): - for _c in cs: - c.update(_c) - else: - _tp, _fp, _fn = self._tp, self._fp, self._tp + ls = self.all_gather_object([self._tp, self._fp, self._fn]) + tps, fps, fns = zip(*ls) + _tp, _fp, _fn = Counter(), Counter(), Counter() + for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): + for _c in cs: + c.update(_c) if not self.only_gross or self.f_type == 'macro': tags = set(_fn.keys()) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 1bb0a628..be43bc74 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -13,9 +13,8 @@ import numpy as np from fastNLP.core.dataset import DataSet from fastNLP.core.log import logger -from .utils import create_array, NumConsumedSamplesArray +from .utils import create_array from abc import abstractmethod -from fastNLP.envs.env import FASTNLP_DEQUE_SIZE class ReproducibleBatchSampler: @@ -37,9 +36,6 @@ class ReproducibleBatchSampler: @abstractmethod def state_dict(self): """ - 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 - 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward - 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 :return: """ @@ -57,6 +53,14 @@ class ReproducibleBatchSampler: def batch_idx_in_epoch(self): raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") + @property + def num_replicas(self): + return self._num_replicas + + @num_replicas.setter + def num_replicas(self, value): + self._num_replicas = value + class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; @@ -105,24 +109,21 @@ class RandomBatchSampler(ReproducibleBatchSampler): else: index_list = self.index_list - # 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 + # 暂时弃用。记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 - self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), - num_consumed_samples=self.num_consumed_samples) + # self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), + # num_consumed_samples=self.num_consumed_samples) for idx in index_list: batch.append(idx) if len(batch) == self.batch_size: self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] - self.num_consumed_samples_array.push(self.num_consumed_samples) yield batch batch = [] if len(batch) > 0 and not self.drop_last: self.num_consumed_samples += len(batch) - self.num_consumed_samples_array.push(self.num_consumed_samples) yield batch # 需要重置防止边界条件问题 self.num_consumed_samples = 0 - delattr(self, 'num_consumed_samples_array') def __len__(self) -> int: if self.drop_last: @@ -136,7 +137,6 @@ class RandomBatchSampler(ReproducibleBatchSampler): "num_consumed_samples": self.num_consumed_samples, 'sampler_type': self.__class__.__name__ } - states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) return states def load_state_dict(self, states: Dict): @@ -327,15 +327,14 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: batches = batches[:-1] - self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), - num_consumed_samples=self.num_consumed_samples) + # 暂时弃用 + # self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), + # num_consumed_samples=self.num_consumed_samples) for batch in batches: self.num_consumed_samples += self.num_replicas * len(batch) - self.num_consumed_samples_array.push(self.num_consumed_samples) yield list(map(int, batch)) self.during_iter = False self.num_consumed_samples = 0 - delattr(self, 'num_consumed_samples_array') self.old_batch_size = self.batch_size self.old_num_batch_per_bucket = self.num_batch_per_bucket self.old_num_replicas = self.num_replicas @@ -390,8 +389,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, 'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket, - 'num_replicas': self.num_replicas, - 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} + 'num_replicas': self.num_replicas + } return states diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 43017098..c3facbb9 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -7,15 +7,11 @@ __all__ = [ from typing import Dict, List, Union import math -import os import numpy as np from fastNLP.core.log import logger from fastNLP.core.dataset import DataSet -from fastNLP.envs.env import FASTNLP_DEQUE_SIZE -from .utils import NumConsumedSamplesArray - class ReproducibleSampler: @@ -36,9 +32,6 @@ class ReproducibleSampler: def state_dict(self): """ - 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 - 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward - 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 :return: """ @@ -54,6 +47,14 @@ class ReproducibleSampler: def set_epoch(self, epoch): pass + @property + def num_repliacs(self): + return self._num_replicas + + @num_repliacs.setter + def num_repliacs(self, value): + self._num_replicas = value + class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): @@ -121,15 +122,11 @@ class RandomSampler(ReproducibleSampler): indices = indices[self.num_consumed_samples:] indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), - num_consumed_samples=self.num_consumed_samples) for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas - self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 - delattr(self, 'num_consumed_samples_array') def generate_indices(self) -> List[int]: """ @@ -150,8 +147,7 @@ class RandomSampler(ReproducibleSampler): def state_dict(self) -> Dict: states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, - 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, - 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} + 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle} return states def load_state_dict(self, states: Dict): @@ -255,15 +251,11 @@ class SequentialSampler(RandomSampler): indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), - num_consumed_samples=self.num_consumed_samples) for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas - self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 - delattr(self, 'num_consumed_samples_array') def generate_indices(self) -> List[int]: """ @@ -275,8 +267,8 @@ class SequentialSampler(RandomSampler): def state_dict(self) -> Dict: states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, - 'length': len(self.dataset), - 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} + 'length': len(self.dataset) + } return states def load_state_dict(self, states: Dict): @@ -346,13 +338,9 @@ class SortedSampler(SequentialSampler): indices = indices[self.rank:len(indices):self.num_replicas] assert len(indices) == self.num_left_samples - self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), - num_consumed_samples=self.num_consumed_samples) for idx, index in enumerate(indices, start=1): self.num_consumed_samples += self.num_replicas - self.num_consumed_samples_array.push(self.num_consumed_samples) yield index self.during_iter = False self.num_consumed_samples = 0 - delattr(self, 'num_consumed_samples_array') diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py index ddcff37f..92514c3a 100644 --- a/fastNLP/core/samplers/utils.py +++ b/fastNLP/core/samplers/utils.py @@ -43,6 +43,8 @@ class NumConsumedSamplesArray: array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] + 暂时由于 sampler 的 batch 都是规整的,先保留 + :param buffer_size: 报错多少个历史。 :param num_consumed_samples: 第一个 num_consumed_samples 是多少。 """ diff --git a/fastNLP/envs/env.py b/fastNLP/envs/env.py index a943de1f..3d1fd738 100644 --- a/fastNLP/envs/env.py +++ b/fastNLP/envs/env.py @@ -45,9 +45,6 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' # todo 注释 FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" -# fastNLP 中初始化deque的默认大小 -FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' - # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; # 为 '2' 表示 barrier 与 gather/broadcast 都关闭。 FASTNLP_NO_SYNC = 'FASTNLP_NO_SYNC' diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index 5af971a0..3514c331 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -445,8 +445,7 @@ class TestBucketedBatchSampler: @pytest.mark.parametrize('num_replicas', [1, 2, 3]) def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): """ - 测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 - 偏多 + 测试是否能够正确地恢复使用过的(forward)数据 :return: """ @@ -454,6 +453,7 @@ class TestBucketedBatchSampler: num_batch_per_bucket = 10 dataset = DatasetWithVaryLength(num_of_data=num_samples) samplers = [] + num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas)) for i in range(num_replicas): sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) @@ -472,8 +472,7 @@ class TestBucketedBatchSampler: break states = samplers[0].state_dict() for i in range(len(already_seen_sets)): - if states['num_consumed_samples_array'] is not None: - states['num_consumed_samples'] = states['num_consumed_samples_array'][i] + states['num_consumed_samples'] = num_consumed_samples_array[i] sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) @@ -489,8 +488,7 @@ class TestBucketedBatchSampler: num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) sampler.set_epoch(0) - if states['num_consumed_samples_array'] is not None: - states['num_consumed_samples'] = states['num_consumed_samples_array'][2] + states['num_consumed_samples'] = num_consumed_samples_array[2] if len(already_seen_sets)<3: return already_seen_set = already_seen_sets[2] @@ -502,8 +500,8 @@ class TestBucketedBatchSampler: break states = sampler.state_dict() - if states['num_consumed_samples_array'] is not None: - states['num_consumed_samples'] = states['num_consumed_samples_array'][count] + num_consumed_samples_array = list(range(len(dataset))) + states['num_consumed_samples'] = num_consumed_samples_array[count] sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) diff --git a/tests/core/samplers/test_reproducible_sampler.py b/tests/core/samplers/test_reproducible_sampler.py index ddf52bcb..6c0ba7a8 100644 --- a/tests/core/samplers/test_reproducible_sampler.py +++ b/tests/core/samplers/test_reproducible_sampler.py @@ -186,9 +186,11 @@ class TestRandomSamplerYh: @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) @pytest.mark.parametrize('num_replicas', [1, 2, 3]) def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): + # def test_num_consumed_samples_array(self, shuffle=True, pad=True, num_samples=100, num_replicas=2): # 测试在 sampler 多生成的时候,可以仍然可以恢复 dataset = DatasetWithVaryLength(num_of_data=num_samples) samplers = [] + num_consumed_samples_array = list(range(0, len(dataset)+num_replicas, num_replicas)) for i in range(num_replicas): sampler = RandomSampler(dataset, shuffle=shuffle) sampler.set_epoch(0) @@ -205,8 +207,7 @@ class TestRandomSamplerYh: break states = samplers[0].state_dict() for i in range(len(already_seen_sets)): - if states['num_consumed_samples_array'] is not None: - states['num_consumed_samples'] = states['num_consumed_samples_array'][i] + states['num_consumed_samples'] = num_consumed_samples_array[i] sampler = RandomSampler(dataset, shuffle=shuffle) already_seen_set = deepcopy(already_seen_sets[i]) for batch in sampler: @@ -215,12 +216,11 @@ class TestRandomSamplerYh: # 测试保存之后再次保存 sampler = RandomSampler(dataset, shuffle=shuffle) sampler.set_epoch(0) - if states['num_consumed_samples_array'] is not None: - states['num_consumed_samples'] = states['num_consumed_samples_array'][2] if len(already_seen_sets)<3: return already_seen_set = already_seen_sets[2] count = 0 + num_consumed_samples_array = list(range(0, num_samples)) for idx in sampler: already_seen_set.add(idx) count += 1 @@ -228,8 +228,7 @@ class TestRandomSamplerYh: break states = sampler.state_dict() - if states['num_consumed_samples_array'] is not None: - states['num_consumed_samples'] = states['num_consumed_samples_array'][count] + states['num_consumed_samples'] = num_consumed_samples_array[count] sampler = RandomSampler(dataset, shuffle=shuffle) sampler.load_state_dict(states) sampler.set_epoch(0) @@ -446,12 +445,12 @@ class TestSortedSampler: @pytest.mark.parametrize('pad', [True, False]) @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) - def test_multi(self, pad, num_replica, num_of_data): + def test_multi(self, pad, num_replicas, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] - for i in range(num_replica): + for i in range(num_replicas): sampler = SortedSampler(dataset=data, length=data.data) - sampler.set_distributed(num_replica, rank=i, pad=pad) + sampler.set_distributed(num_replicas, rank=i, pad=pad) samplers.append(sampler) # 保证顺序是没乱的 @@ -600,12 +599,12 @@ class TestSequentialSampler: @pytest.mark.parametrize('pad', [True, False]) @pytest.mark.parametrize('num_replicas', [2, 3]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) - def test_multi(self, pad, num_replica, num_of_data): + def test_multi(self, pad, num_replicas, num_of_data): data = DatasetWithVaryLength(num_of_data=num_of_data) samplers = [] - for i in range(num_replica): + for i in range(num_replicas): sampler = SequentialSampler(dataset=data) - sampler.set_distributed(num_replica, rank=i, pad=pad) + sampler.set_distributed(num_replicas, rank=i, pad=pad) samplers.append(sampler) # 保证顺序是没乱的