Browse Source

BucketedBatchSampler的batch_id_in_epoch实现

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
cb01a661f1
1 changed files with 9 additions and 1 deletions
  1. +9
    -1
      fastNLP/core/samplers/reproducible_batch_sampler.py

+ 9
- 1
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -411,4 +411,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
self.old_num_replicas = states['num_replicas']

def set_epoch(self, epoch):
self.epoch = epoch
self.epoch = epoch

@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size
else:
return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \
(len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size

Loading…
Cancel
Save