Browse Source

微调 reproducible sampler 的初始化

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
3dbb3677f0
2 changed files with 4 additions and 18 deletions
  1. +2
    -10
      fastNLP/core/samplers/reproducible_batch_sampler.py
  2. +2
    -8
      fastNLP/core/samplers/reproducible_sampler.py

+ 2
- 10
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -19,7 +19,7 @@ from abc import abstractmethod

class ReproducibleBatchSampler:
def __init__(self, **kwargs):
pass
self.num_replicas = 1

@abstractmethod
def set_distributed(self, num_replicas, rank, pad=True):
@@ -53,14 +53,6 @@ class ReproducibleBatchSampler:
def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")

@property
def num_replicas(self):
return self._num_replicas

@num_replicas.setter
def num_replicas(self, value):
self._num_replicas = value


class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
@@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if len(batches[-1])==0:
batches.pop(-1)

assert len(list(chain(*batches))) == self.num_left_samples
assert sum(map(len, batches)) == self.num_left_samples

if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]


+ 2
- 8
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -20,6 +20,8 @@ class ReproducibleSampler:
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。

"""
def __init__(self, **kwargs):
self.num_replicas = 1

def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific sampler should implement its own `set_distributed` method.")
@@ -47,14 +49,6 @@ class ReproducibleSampler:
def set_epoch(self, epoch):
pass

@property
def num_repliacs(self):
return self._num_replicas

@num_repliacs.setter
def num_repliacs(self, value):
self._num_replicas = value


class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):


Loading…
Cancel
Save