Browse Source

BucketSampler增加一条错误检测

tags/v0.4.10
yh 5 years ago
parent
commit
5eb126dbcd
1 changed files with 1 additions and 0 deletions
  1. +1
    -0
      fastNLP/core/sampler.py

+ 1
- 0
fastNLP/core/sampler.py View File

@@ -73,6 +73,7 @@ class BucketSampler(BaseSampler):
total_sample_num = len(seq_lens)

bucket_indexes = []
assert total_sample_num>=self.num_buckets, "The number of samples is smaller than the number of buckets."
num_sample_per_bucket = total_sample_num // self.num_buckets
for i in range(self.num_buckets):
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])


Loading…
Cancel
Save