From 3dbb3677f0d839f918001f528a8410f88c150401 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 08:39:07 +0000 Subject: [PATCH] =?UTF-8?q?=E5=BE=AE=E8=B0=83=20reproducible=20sampler=20?= =?UTF-8?q?=E7=9A=84=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/samplers/reproducible_batch_sampler.py | 12 ++---------- fastNLP/core/samplers/reproducible_sampler.py | 10 ++-------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index be43bc74..171a784b 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -19,7 +19,7 @@ from abc import abstractmethod class ReproducibleBatchSampler: def __init__(self, **kwargs): - pass + self.num_replicas = 1 @abstractmethod def set_distributed(self, num_replicas, rank, pad=True): @@ -53,14 +53,6 @@ class ReproducibleBatchSampler: def batch_idx_in_epoch(self): raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") - @property - def num_replicas(self): - return self._num_replicas - - @num_replicas.setter - def num_replicas(self, value): - self._num_replicas = value - class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; @@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): if len(batches[-1])==0: batches.pop(-1) - assert len(list(chain(*batches))) == self.num_left_samples + assert sum(map(len, batches)) == self.num_left_samples if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: batches = batches[:-1] diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index c3facbb9..c8425dc7 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -20,6 +20,8 @@ class ReproducibleSampler: 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 """ + def __init__(self, **kwargs): + self.num_replicas = 1 def set_distributed(self, num_replicas, rank, pad=True): raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.") @@ -47,14 +49,6 @@ class ReproducibleSampler: def set_epoch(self, epoch): pass - @property - def num_repliacs(self): - return self._num_replicas - - @num_repliacs.setter - def num_repliacs(self, value): - self._num_replicas = value - class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):