diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 67ec2a8d..4a523f10 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -73,6 +73,7 @@ class BucketSampler(BaseSampler): total_sample_num = len(seq_lens) bucket_indexes = [] + assert total_sample_num>=self.num_buckets, "The number of samples is smaller than the number of buckets." num_sample_per_bucket = total_sample_num // self.num_buckets for i in range(self.num_buckets): bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])