Browse Source

新增ConstantTokenNumSampler使得每个sample中token数量一致,可以最大化利用GPU的显存

tags/v0.6.0
yh_cc 4 years ago
parent
commit
edd3022bde
5 changed files with 162 additions and 5 deletions
  1. +1
    -0
      fastNLP/__init__.py
  2. +3
    -2
      fastNLP/core/__init__.py
  3. +104
    -1
      fastNLP/core/sampler.py
  4. +2
    -1
      fastNLP/modules/decoder/utils.py
  5. +52
    -1
      test/core/test_batch.py

+ 1
- 0
fastNLP/__init__.py View File

@@ -67,6 +67,7 @@ __all__ = [
"BucketSampler", "BucketSampler",
"RandomSampler", "RandomSampler",
"SortedSampler", "SortedSampler",
"ConstantTokenNumSampler",
"LossFunc", "LossFunc",
"CrossEntropyLoss", "CrossEntropyLoss",


+ 3
- 2
fastNLP/core/__init__.py View File

@@ -84,7 +84,8 @@ __all__ = [
"BucketSampler", "BucketSampler",
"RandomSampler", "RandomSampler",
"Sampler", "Sampler",
"SortedSampler"
"SortedSampler",
"ConstantTokenNumSampler"
] ]


from ._logger import logger, init_logger_dist from ._logger import logger, init_logger_dist
@@ -101,7 +102,7 @@ from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\
ConfusionMatrixMetric ConfusionMatrixMetric
from .optimizer import Optimizer, SGD, Adam, AdamW from .optimizer import Optimizer, SGD, Adam, AdamW
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler, ConstantTokenNumSampler
from .tester import Tester from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len


+ 104
- 1
fastNLP/core/sampler.py View File

@@ -6,7 +6,8 @@ __all__ = [
"BucketSampler", "BucketSampler",
"SequentialSampler", "SequentialSampler",
"RandomSampler", "RandomSampler",
"SortedSampler"
"SortedSampler",
"ConstantTokenNumSampler"
] ]


from itertools import chain from itertools import chain
@@ -111,6 +112,108 @@ class BucketSampler(Sampler):
return list(chain(*batchs)) return list(chain(*batchs))




class ConstantTokenNumSampler:
"""
尽量保证每个batch的输入token数量是接近的。

使用示例
>>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量
>>> from fastNLP import DataSetIter, Trainer
>>> sampler = BatchSampler(tr_data.get_field('seq_len').content, max_token=4096)
>>> tr_iter = DataSetIter(tr_data,
>>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False,
>>> drop_last=False, timeout=0, worker_init_fn=None,
>>> batch_sampler=sampler)
>>>
>>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略
>>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(),
>>> batch_size=1, sampler=None, drop_last=False, update_every=1)

"""
def __init__(self, seq_len, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1):
"""

:param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入
:param int max_token: 每个batch的最大的token数量
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。
"""
assert (max_sentence!=-1 and max_sentence>=need_be_multiple_of) or max_sentence<1
assert len(seq_len)>num_bucket, "The number of samples should be larger than buckets."
self.seq_len = seq_len
self.max_token = max_token
self._max_sentence = max_sentence
self.need_be_multiple_of = need_be_multiple_of
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)]
seq_len_indice.sort(key=lambda x: x[0])
indice_in_buckets = []
if num_bucket>0:
sample_per_bucket = len(seq_len_indice)//num_bucket
i = 0
while len(indice_in_buckets)<len(seq_len_indice):
indice_in_buckets.append(seq_len_indice[i*sample_per_bucket:(i+1)*sample_per_bucket])
i += 1
else:
indice_in_buckets = [seq_len_indice]
self.indice_in_buckets = indice_in_buckets
self.get_new_order()

@property
def max_sentence(self):
if self._max_sentence<1:
return 100000000
return self._max_sentence

@max_sentence.setter
def max_sentence(self, max_sentence):
self._max_sentence = max_sentence

def get_new_order(self):
np.random.shuffle(self.indice_in_buckets)
for bucket in self.indice_in_buckets:
np.random.shuffle(bucket)
indices = list(chain(*self.indice_in_buckets))
batches = []
cur_max_len = 0
batch = []
for length, i in indices:
max_len = max(length, cur_max_len)
if max_len*(len(batch)+1)>self.max_token or len(batch)>=self.max_sentence:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
cur_max_len =length
if left_sample!=0:
add_samples = add_samples[:-left_sample]
batch = batch[-left_sample:]
cur_max_len = max(cur_max_len, max(batch))
else:
batch = []
if len(add_samples)==0:
raise RuntimeError(f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.")
batches.append(add_samples)
else:
cur_max_len = max_len
batch.append(i)
if batch:
left_sample = len(batch) % self.need_be_multiple_of
add_samples = batch.copy()
if left_sample != 0:
add_samples = add_samples[:-left_sample].copy()
if add_samples:
batches.append(add_samples)
np.random.shuffle(batches)
self.batches = batches

def __iter__(self):
for batch in self.batches:
yield batch
self.get_new_order()

def __len__(self):
return len(self.batches)


class SortedSampler(Sampler): class SortedSampler(Sampler):
r""" r"""
按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding) 按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding)


+ 2
- 1
fastNLP/modules/decoder/utils.py View File

@@ -12,7 +12,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):


:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。
:param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x :param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x
(n_tags+2), 其中n_tag是start的index, n_tags+1是end的index;
(n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; 如果要i->j之间不允许越迁,就把transitions中(i,j)设置为很小的
负数,例如-10000000.0
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这


+ 52
- 1
test/core/test_batch.py View File

@@ -6,7 +6,7 @@ import torch
from fastNLP import DataSetIter, TorchLoaderIter from fastNLP import DataSetIter, TorchLoaderIter
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
from fastNLP import SequentialSampler
from fastNLP import SequentialSampler, ConstantTokenNumSampler
from fastNLP import ConcatCollateFn from fastNLP import ConcatCollateFn




@@ -397,6 +397,57 @@ class TestCase1(unittest.TestCase):
for idx, (batch_x, batch_y) in enumerate(data_iter): for idx, (batch_x, batch_y) in enumerate(data_iter):
self.assertEqual(num_samples[idx], len(batch_x['x'])) self.assertEqual(num_samples[idx], len(batch_x['x']))


def test_ConstantTokenNumSampler(self):
num_samples = 100
ds = generate_fake_dataset(num_samples)
ds.set_input('1')
ds.add_seq_len('1', 'seq_len')
ds.set_input('seq_len')

# 测试token数量不超过
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120)
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
batch_sampler=batch_sampler)
sample_count = 0
for batch_x, batch_y in data_iter:
self.assertTrue(sum(batch_x['seq_len'])<120)
sample_count += len(batch_x['seq_len'])
self.assertEqual(sample_count, num_samples)

# 测试句子数量不超过
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, max_sentence=1)
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
batch_sampler=batch_sampler)
sample_count = 0
for batch_x, batch_y in data_iter:
sample_count += len(batch_x['seq_len'])
self.assertTrue(sum(batch_x['seq_len'])<120 and len(batch_x['seq_len'])==1)
self.assertEqual(sample_count, num_samples)

# 测试need_be_multiple_of
sample_count = 0
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, max_sentence=2, need_be_multiple_of=2)
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
batch_sampler=batch_sampler)
for batch_x, batch_y in data_iter:
sample_count += len(batch_x['seq_len'])
self.assertTrue(sum(batch_x['seq_len'])<120 and len(batch_x['seq_len'])==2)
self.assertEqual(sample_count, num_samples)

# 测试token数量不超过, bucket尽量接近
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, num_bucket=10)
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None,
batch_sampler=batch_sampler)
sample_count = 0
for batch_x, batch_y in data_iter:
sample_count += len(batch_x['seq_len'])
self.assertTrue(sum(batch_x['seq_len'])<120)
self.assertEqual(sample_count, num_samples)

""" """
def test_multi_workers_batch(self): def test_multi_workers_batch(self):
batch_size = 32 batch_size = 32


Loading…
Cancel
Save