|
|
@@ -73,6 +73,7 @@ class BucketSampler(BaseSampler): |
|
|
|
total_sample_num = len(seq_lens) |
|
|
|
|
|
|
|
bucket_indexes = [] |
|
|
|
assert total_sample_num>=self.num_buckets, "The number of samples is smaller than the number of buckets." |
|
|
|
num_sample_per_bucket = total_sample_num // self.num_buckets |
|
|
|
for i in range(self.num_buckets): |
|
|
|
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)]) |
|
|
|