|
|
@@ -19,7 +19,7 @@ from abc import abstractmethod |
|
|
|
|
|
|
|
class ReproducibleBatchSampler: |
|
|
|
def __init__(self, **kwargs): |
|
|
|
pass |
|
|
|
self.num_replicas = 1 |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def set_distributed(self, num_replicas, rank, pad=True): |
|
|
@@ -53,14 +53,6 @@ class ReproducibleBatchSampler: |
|
|
|
def batch_idx_in_epoch(self): |
|
|
|
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") |
|
|
|
|
|
|
|
@property |
|
|
|
def num_replicas(self): |
|
|
|
return self._num_replicas |
|
|
|
|
|
|
|
@num_replicas.setter |
|
|
|
def num_replicas(self, value): |
|
|
|
self._num_replicas = value |
|
|
|
|
|
|
|
|
|
|
|
class RandomBatchSampler(ReproducibleBatchSampler): |
|
|
|
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; |
|
|
@@ -322,7 +314,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): |
|
|
|
if len(batches[-1])==0: |
|
|
|
batches.pop(-1) |
|
|
|
|
|
|
|
assert len(list(chain(*batches))) == self.num_left_samples |
|
|
|
assert sum(map(len, batches)) == self.num_left_samples |
|
|
|
|
|
|
|
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: |
|
|
|
batches = batches[:-1] |
|
|
|