From edd3022bde518a0025c1e4461e7e53f26fe6dd8f Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 28 Sep 2020 10:21:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9EConstantTokenNumSampler?= =?UTF-8?q?=E4=BD=BF=E5=BE=97=E6=AF=8F=E4=B8=AAsample=E4=B8=ADtoken?= =?UTF-8?q?=E6=95=B0=E9=87=8F=E4=B8=80=E8=87=B4=EF=BC=8C=E5=8F=AF=E4=BB=A5?= =?UTF-8?q?=E6=9C=80=E5=A4=A7=E5=8C=96=E5=88=A9=E7=94=A8GPU=E7=9A=84?= =?UTF-8?q?=E6=98=BE=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/__init__.py | 1 + fastNLP/core/__init__.py | 5 +- fastNLP/core/sampler.py | 105 ++++++++++++++++++++++++++++++- fastNLP/modules/decoder/utils.py | 3 +- test/core/test_batch.py | 53 +++++++++++++++- 5 files changed, 162 insertions(+), 5 deletions(-) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 59e4b67c..be6b744e 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -67,6 +67,7 @@ __all__ = [ "BucketSampler", "RandomSampler", "SortedSampler", + "ConstantTokenNumSampler", "LossFunc", "CrossEntropyLoss", diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 629dd786..56808bff 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -84,7 +84,8 @@ __all__ = [ "BucketSampler", "RandomSampler", "Sampler", - "SortedSampler" + "SortedSampler", + "ConstantTokenNumSampler" ] from ._logger import logger, init_logger_dist @@ -101,7 +102,7 @@ from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ ConfusionMatrixMetric from .optimizer import Optimizer, SGD, Adam, AdamW -from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler +from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler, ConstantTokenNumSampler from .tester import Tester from .trainer import Trainer from .utils import cache_results, seq_len_to_mask, get_seq_len diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 61f47315..f5d60ebb 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -6,7 +6,8 @@ __all__ = [ "BucketSampler", "SequentialSampler", "RandomSampler", - "SortedSampler" + "SortedSampler", + "ConstantTokenNumSampler" ] from itertools import chain @@ -111,6 +112,108 @@ class BucketSampler(Sampler): return list(chain(*batchs)) +class ConstantTokenNumSampler: + """ + 尽量保证每个batch的输入token数量是接近的。 + + 使用示例 + >>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量 + >>> from fastNLP import DataSetIter, Trainer + >>> sampler = BatchSampler(tr_data.get_field('seq_len').content, max_token=4096) + >>> tr_iter = DataSetIter(tr_data, + >>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, + >>> drop_last=False, timeout=0, worker_init_fn=None, + >>> batch_sampler=sampler) + >>> + >>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略 + >>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(), + >>> batch_size=1, sampler=None, drop_last=False, update_every=1) + + """ + def __init__(self, seq_len, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1): + """ + + :param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入 + :param int max_token: 每个batch的最大的token数量 + :param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定 + :param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到 + :param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。 + """ + assert (max_sentence!=-1 and max_sentence>=need_be_multiple_of) or max_sentence<1 + assert len(seq_len)>num_bucket, "The number of samples should be larger than buckets." + self.seq_len = seq_len + self.max_token = max_token + self._max_sentence = max_sentence + self.need_be_multiple_of = need_be_multiple_of + seq_len_indice = [(length, i) for i, length in enumerate(seq_len)] + seq_len_indice.sort(key=lambda x: x[0]) + indice_in_buckets = [] + if num_bucket>0: + sample_per_bucket = len(seq_len_indice)//num_bucket + i = 0 + while len(indice_in_buckets)self.max_token or len(batch)>=self.max_sentence: + left_sample = len(batch) % self.need_be_multiple_of + add_samples = batch.copy() + cur_max_len =length + if left_sample!=0: + add_samples = add_samples[:-left_sample] + batch = batch[-left_sample:] + cur_max_len = max(cur_max_len, max(batch)) + else: + batch = [] + if len(add_samples)==0: + raise RuntimeError(f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.") + batches.append(add_samples) + else: + cur_max_len = max_len + batch.append(i) + if batch: + left_sample = len(batch) % self.need_be_multiple_of + add_samples = batch.copy() + if left_sample != 0: + add_samples = add_samples[:-left_sample].copy() + if add_samples: + batches.append(add_samples) + np.random.shuffle(batches) + self.batches = batches + + def __iter__(self): + for batch in self.batches: + yield batch + self.get_new_order() + + def __len__(self): + return len(self.batches) + + class SortedSampler(Sampler): r""" 按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding) diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py index 0d804a7e..2600bee0 100644 --- a/fastNLP/modules/decoder/utils.py +++ b/fastNLP/modules/decoder/utils.py @@ -12,7 +12,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 :param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x - (n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; + (n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; 如果要i->j之间不允许越迁,就把transitions中(i,j)设置为很小的 + 负数,例如-10000000.0 :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 6a340d36..7bdcbda7 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -6,7 +6,7 @@ import torch from fastNLP import DataSetIter, TorchLoaderIter from fastNLP import DataSet from fastNLP import Instance -from fastNLP import SequentialSampler +from fastNLP import SequentialSampler, ConstantTokenNumSampler from fastNLP import ConcatCollateFn @@ -397,6 +397,57 @@ class TestCase1(unittest.TestCase): for idx, (batch_x, batch_y) in enumerate(data_iter): self.assertEqual(num_samples[idx], len(batch_x['x'])) + def test_ConstantTokenNumSampler(self): + num_samples = 100 + ds = generate_fake_dataset(num_samples) + ds.set_input('1') + ds.add_seq_len('1', 'seq_len') + ds.set_input('seq_len') + + # 测试token数量不超过 + batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120) + data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, + pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, + batch_sampler=batch_sampler) + sample_count = 0 + for batch_x, batch_y in data_iter: + self.assertTrue(sum(batch_x['seq_len'])<120) + sample_count += len(batch_x['seq_len']) + self.assertEqual(sample_count, num_samples) + + # 测试句子数量不超过 + batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, max_sentence=1) + data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, + pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, + batch_sampler=batch_sampler) + sample_count = 0 + for batch_x, batch_y in data_iter: + sample_count += len(batch_x['seq_len']) + self.assertTrue(sum(batch_x['seq_len'])<120 and len(batch_x['seq_len'])==1) + self.assertEqual(sample_count, num_samples) + + # 测试need_be_multiple_of + sample_count = 0 + batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, max_sentence=2, need_be_multiple_of=2) + data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, + pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, + batch_sampler=batch_sampler) + for batch_x, batch_y in data_iter: + sample_count += len(batch_x['seq_len']) + self.assertTrue(sum(batch_x['seq_len'])<120 and len(batch_x['seq_len'])==2) + self.assertEqual(sample_count, num_samples) + + # 测试token数量不超过, bucket尽量接近 + batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, num_bucket=10) + data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, + pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, + batch_sampler=batch_sampler) + sample_count = 0 + for batch_x, batch_y in data_iter: + sample_count += len(batch_x['seq_len']) + self.assertTrue(sum(batch_x['seq_len'])<120) + self.assertEqual(sample_count, num_samples) + """ def test_multi_workers_batch(self): batch_size = 32