Browse Source

修改BucketedBatchSampler batch_idx_in_epoch 的计算方式,使其在分布式条件下可以正确地反应迭代次数

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
d04f49a835
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      fastNLP/core/samplers/reproducible_batch_sampler.py

+ 3
- 3
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -416,7 +416,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
@property @property
def batch_idx_in_epoch(self): def batch_idx_in_epoch(self):
if self.drop_last: if self.drop_last:
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else: else:
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size

Loading…
Cancel
Save