From 5eb126dbcd300650bd4effccc9061fa67abe2c9c Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 9 Feb 2019 13:47:13 +0800 Subject: [PATCH] =?UTF-8?q?BucketSampler=E5=A2=9E=E5=8A=A0=E4=B8=80?= =?UTF-8?q?=E6=9D=A1=E9=94=99=E8=AF=AF=E6=A3=80=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 67ec2a8d..4a523f10 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -73,6 +73,7 @@ class BucketSampler(BaseSampler): total_sample_num = len(seq_lens) bucket_indexes = [] + assert total_sample_num>=self.num_buckets, "The number of samples is smaller than the number of buckets." num_sample_per_bucket = total_sample_num // self.num_buckets for i in range(self.num_buckets): bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])