@@ -67,6 +67,7 @@ __all__ = [ | |||
"BucketSampler", | |||
"RandomSampler", | |||
"SortedSampler", | |||
"ConstantTokenNumSampler", | |||
"LossFunc", | |||
"CrossEntropyLoss", | |||
@@ -84,7 +84,8 @@ __all__ = [ | |||
"BucketSampler", | |||
"RandomSampler", | |||
"Sampler", | |||
"SortedSampler" | |||
"SortedSampler", | |||
"ConstantTokenNumSampler" | |||
] | |||
from ._logger import logger, init_logger_dist | |||
@@ -101,7 +102,7 @@ from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ | |||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | |||
ConfusionMatrixMetric | |||
from .optimizer import Optimizer, SGD, Adam, AdamW | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler, ConstantTokenNumSampler | |||
from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
@@ -6,7 +6,8 @@ __all__ = [ | |||
"BucketSampler", | |||
"SequentialSampler", | |||
"RandomSampler", | |||
"SortedSampler" | |||
"SortedSampler", | |||
"ConstantTokenNumSampler" | |||
] | |||
from itertools import chain | |||
@@ -111,6 +112,108 @@ class BucketSampler(Sampler): | |||
return list(chain(*batchs)) | |||
class ConstantTokenNumSampler: | |||
""" | |||
尽量保证每个batch的输入token数量是接近的。 | |||
使用示例 | |||
>>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量 | |||
>>> from fastNLP import DataSetIter, Trainer | |||
>>> sampler = BatchSampler(tr_data.get_field('seq_len').content, max_token=4096) | |||
>>> tr_iter = DataSetIter(tr_data, | |||
>>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, | |||
>>> drop_last=False, timeout=0, worker_init_fn=None, | |||
>>> batch_sampler=sampler) | |||
>>> | |||
>>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略 | |||
>>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(), | |||
>>> batch_size=1, sampler=None, drop_last=False, update_every=1) | |||
""" | |||
def __init__(self, seq_len, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1): | |||
""" | |||
:param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入 | |||
:param int max_token: 每个batch的最大的token数量 | |||
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定 | |||
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到 | |||
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。 | |||
""" | |||
assert (max_sentence!=-1 and max_sentence>=need_be_multiple_of) or max_sentence<1 | |||
assert len(seq_len)>num_bucket, "The number of samples should be larger than buckets." | |||
self.seq_len = seq_len | |||
self.max_token = max_token | |||
self._max_sentence = max_sentence | |||
self.need_be_multiple_of = need_be_multiple_of | |||
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)] | |||
seq_len_indice.sort(key=lambda x: x[0]) | |||
indice_in_buckets = [] | |||
if num_bucket>0: | |||
sample_per_bucket = len(seq_len_indice)//num_bucket | |||
i = 0 | |||
while len(indice_in_buckets)<len(seq_len_indice): | |||
indice_in_buckets.append(seq_len_indice[i*sample_per_bucket:(i+1)*sample_per_bucket]) | |||
i += 1 | |||
else: | |||
indice_in_buckets = [seq_len_indice] | |||
self.indice_in_buckets = indice_in_buckets | |||
self.get_new_order() | |||
@property | |||
def max_sentence(self): | |||
if self._max_sentence<1: | |||
return 100000000 | |||
return self._max_sentence | |||
@max_sentence.setter | |||
def max_sentence(self, max_sentence): | |||
self._max_sentence = max_sentence | |||
def get_new_order(self): | |||
np.random.shuffle(self.indice_in_buckets) | |||
for bucket in self.indice_in_buckets: | |||
np.random.shuffle(bucket) | |||
indices = list(chain(*self.indice_in_buckets)) | |||
batches = [] | |||
cur_max_len = 0 | |||
batch = [] | |||
for length, i in indices: | |||
max_len = max(length, cur_max_len) | |||
if max_len*(len(batch)+1)>self.max_token or len(batch)>=self.max_sentence: | |||
left_sample = len(batch) % self.need_be_multiple_of | |||
add_samples = batch.copy() | |||
cur_max_len =length | |||
if left_sample!=0: | |||
add_samples = add_samples[:-left_sample] | |||
batch = batch[-left_sample:] | |||
cur_max_len = max(cur_max_len, max(batch)) | |||
else: | |||
batch = [] | |||
if len(add_samples)==0: | |||
raise RuntimeError(f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.") | |||
batches.append(add_samples) | |||
else: | |||
cur_max_len = max_len | |||
batch.append(i) | |||
if batch: | |||
left_sample = len(batch) % self.need_be_multiple_of | |||
add_samples = batch.copy() | |||
if left_sample != 0: | |||
add_samples = add_samples[:-left_sample].copy() | |||
if add_samples: | |||
batches.append(add_samples) | |||
np.random.shuffle(batches) | |||
self.batches = batches | |||
def __iter__(self): | |||
for batch in self.batches: | |||
yield batch | |||
self.get_new_order() | |||
def __len__(self): | |||
return len(self.batches) | |||
class SortedSampler(Sampler): | |||
r""" | |||
按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding) | |||
@@ -12,7 +12,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x | |||
(n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; | |||
(n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; 如果要i->j之间不允许越迁,就把transitions中(i,j)设置为很小的 | |||
负数,例如-10000000.0 | |||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||
@@ -6,7 +6,7 @@ import torch | |||
from fastNLP import DataSetIter, TorchLoaderIter | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import SequentialSampler | |||
from fastNLP import SequentialSampler, ConstantTokenNumSampler | |||
from fastNLP import ConcatCollateFn | |||
@@ -397,6 +397,57 @@ class TestCase1(unittest.TestCase): | |||
for idx, (batch_x, batch_y) in enumerate(data_iter): | |||
self.assertEqual(num_samples[idx], len(batch_x['x'])) | |||
def test_ConstantTokenNumSampler(self): | |||
num_samples = 100 | |||
ds = generate_fake_dataset(num_samples) | |||
ds.set_input('1') | |||
ds.add_seq_len('1', 'seq_len') | |||
ds.set_input('seq_len') | |||
# 测试token数量不超过 | |||
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120) | |||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||
batch_sampler=batch_sampler) | |||
sample_count = 0 | |||
for batch_x, batch_y in data_iter: | |||
self.assertTrue(sum(batch_x['seq_len'])<120) | |||
sample_count += len(batch_x['seq_len']) | |||
self.assertEqual(sample_count, num_samples) | |||
# 测试句子数量不超过 | |||
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, max_sentence=1) | |||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||
batch_sampler=batch_sampler) | |||
sample_count = 0 | |||
for batch_x, batch_y in data_iter: | |||
sample_count += len(batch_x['seq_len']) | |||
self.assertTrue(sum(batch_x['seq_len'])<120 and len(batch_x['seq_len'])==1) | |||
self.assertEqual(sample_count, num_samples) | |||
# 测试need_be_multiple_of | |||
sample_count = 0 | |||
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, max_sentence=2, need_be_multiple_of=2) | |||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||
batch_sampler=batch_sampler) | |||
for batch_x, batch_y in data_iter: | |||
sample_count += len(batch_x['seq_len']) | |||
self.assertTrue(sum(batch_x['seq_len'])<120 and len(batch_x['seq_len'])==2) | |||
self.assertEqual(sample_count, num_samples) | |||
# 测试token数量不超过, bucket尽量接近 | |||
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, num_bucket=10) | |||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||
batch_sampler=batch_sampler) | |||
sample_count = 0 | |||
for batch_x, batch_y in data_iter: | |||
sample_count += len(batch_x['seq_len']) | |||
self.assertTrue(sum(batch_x['seq_len'])<120) | |||
self.assertEqual(sample_count, num_samples) | |||
""" | |||
def test_multi_workers_batch(self): | |||
batch_size = 32 | |||