@@ -67,6 +67,7 @@ __all__ = [ | |||||
"BucketSampler", | "BucketSampler", | ||||
"RandomSampler", | "RandomSampler", | ||||
"SortedSampler", | "SortedSampler", | ||||
"ConstantTokenNumSampler", | |||||
"LossFunc", | "LossFunc", | ||||
"CrossEntropyLoss", | "CrossEntropyLoss", | ||||
@@ -84,7 +84,8 @@ __all__ = [ | |||||
"BucketSampler", | "BucketSampler", | ||||
"RandomSampler", | "RandomSampler", | ||||
"Sampler", | "Sampler", | ||||
"SortedSampler" | |||||
"SortedSampler", | |||||
"ConstantTokenNumSampler" | |||||
] | ] | ||||
from ._logger import logger, init_logger_dist | from ._logger import logger, init_logger_dist | ||||
@@ -101,7 +102,7 @@ from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | ||||
ConfusionMatrixMetric | ConfusionMatrixMetric | ||||
from .optimizer import Optimizer, SGD, Adam, AdamW | from .optimizer import Optimizer, SGD, Adam, AdamW | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler | |||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler, ConstantTokenNumSampler | |||||
from .tester import Tester | from .tester import Tester | ||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .utils import cache_results, seq_len_to_mask, get_seq_len | from .utils import cache_results, seq_len_to_mask, get_seq_len | ||||
@@ -6,7 +6,8 @@ __all__ = [ | |||||
"BucketSampler", | "BucketSampler", | ||||
"SequentialSampler", | "SequentialSampler", | ||||
"RandomSampler", | "RandomSampler", | ||||
"SortedSampler" | |||||
"SortedSampler", | |||||
"ConstantTokenNumSampler" | |||||
] | ] | ||||
from itertools import chain | from itertools import chain | ||||
@@ -111,6 +112,108 @@ class BucketSampler(Sampler): | |||||
return list(chain(*batchs)) | return list(chain(*batchs)) | ||||
class ConstantTokenNumSampler: | |||||
""" | |||||
尽量保证每个batch的输入token数量是接近的。 | |||||
使用示例 | |||||
>>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量 | |||||
>>> from fastNLP import DataSetIter, Trainer | |||||
>>> sampler = BatchSampler(tr_data.get_field('seq_len').content, max_token=4096) | |||||
>>> tr_iter = DataSetIter(tr_data, | |||||
>>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, | |||||
>>> drop_last=False, timeout=0, worker_init_fn=None, | |||||
>>> batch_sampler=sampler) | |||||
>>> | |||||
>>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略 | |||||
>>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(), | |||||
>>> batch_size=1, sampler=None, drop_last=False, update_every=1) | |||||
""" | |||||
def __init__(self, seq_len, max_token=4096, max_sentence=-1, need_be_multiple_of=1, num_bucket=-1): | |||||
""" | |||||
:param List[int] seq_len: list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入 | |||||
:param int max_token: 每个batch的最大的token数量 | |||||
:param int max_sentence: 每个batch最多多少个instance, -1表示根据max_token决定 | |||||
:param int need_be_multiple_of: 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到 | |||||
:param int num_bucket: 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。 | |||||
""" | |||||
assert (max_sentence!=-1 and max_sentence>=need_be_multiple_of) or max_sentence<1 | |||||
assert len(seq_len)>num_bucket, "The number of samples should be larger than buckets." | |||||
self.seq_len = seq_len | |||||
self.max_token = max_token | |||||
self._max_sentence = max_sentence | |||||
self.need_be_multiple_of = need_be_multiple_of | |||||
seq_len_indice = [(length, i) for i, length in enumerate(seq_len)] | |||||
seq_len_indice.sort(key=lambda x: x[0]) | |||||
indice_in_buckets = [] | |||||
if num_bucket>0: | |||||
sample_per_bucket = len(seq_len_indice)//num_bucket | |||||
i = 0 | |||||
while len(indice_in_buckets)<len(seq_len_indice): | |||||
indice_in_buckets.append(seq_len_indice[i*sample_per_bucket:(i+1)*sample_per_bucket]) | |||||
i += 1 | |||||
else: | |||||
indice_in_buckets = [seq_len_indice] | |||||
self.indice_in_buckets = indice_in_buckets | |||||
self.get_new_order() | |||||
@property | |||||
def max_sentence(self): | |||||
if self._max_sentence<1: | |||||
return 100000000 | |||||
return self._max_sentence | |||||
@max_sentence.setter | |||||
def max_sentence(self, max_sentence): | |||||
self._max_sentence = max_sentence | |||||
def get_new_order(self): | |||||
np.random.shuffle(self.indice_in_buckets) | |||||
for bucket in self.indice_in_buckets: | |||||
np.random.shuffle(bucket) | |||||
indices = list(chain(*self.indice_in_buckets)) | |||||
batches = [] | |||||
cur_max_len = 0 | |||||
batch = [] | |||||
for length, i in indices: | |||||
max_len = max(length, cur_max_len) | |||||
if max_len*(len(batch)+1)>self.max_token or len(batch)>=self.max_sentence: | |||||
left_sample = len(batch) % self.need_be_multiple_of | |||||
add_samples = batch.copy() | |||||
cur_max_len =length | |||||
if left_sample!=0: | |||||
add_samples = add_samples[:-left_sample] | |||||
batch = batch[-left_sample:] | |||||
cur_max_len = max(cur_max_len, max(batch)) | |||||
else: | |||||
batch = [] | |||||
if len(add_samples)==0: | |||||
raise RuntimeError(f"The sample `{i}` is too long to make a batch with {self.need_be_multiple_of} samples.") | |||||
batches.append(add_samples) | |||||
else: | |||||
cur_max_len = max_len | |||||
batch.append(i) | |||||
if batch: | |||||
left_sample = len(batch) % self.need_be_multiple_of | |||||
add_samples = batch.copy() | |||||
if left_sample != 0: | |||||
add_samples = add_samples[:-left_sample].copy() | |||||
if add_samples: | |||||
batches.append(add_samples) | |||||
np.random.shuffle(batches) | |||||
self.batches = batches | |||||
def __iter__(self): | |||||
for batch in self.batches: | |||||
yield batch | |||||
self.get_new_order() | |||||
def __len__(self): | |||||
return len(self.batches) | |||||
class SortedSampler(Sampler): | class SortedSampler(Sampler): | ||||
r""" | r""" | ||||
按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding) | 按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding) | ||||
@@ -12,7 +12,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | ||||
:param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x | :param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x | ||||
(n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; | |||||
(n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; 如果要i->j之间不允许越迁,就把transitions中(i,j)设置为很小的 | |||||
负数,例如-10000000.0 | |||||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | ||||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | ||||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | ||||
@@ -6,7 +6,7 @@ import torch | |||||
from fastNLP import DataSetIter, TorchLoaderIter | from fastNLP import DataSetIter, TorchLoaderIter | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import SequentialSampler | |||||
from fastNLP import SequentialSampler, ConstantTokenNumSampler | |||||
from fastNLP import ConcatCollateFn | from fastNLP import ConcatCollateFn | ||||
@@ -397,6 +397,57 @@ class TestCase1(unittest.TestCase): | |||||
for idx, (batch_x, batch_y) in enumerate(data_iter): | for idx, (batch_x, batch_y) in enumerate(data_iter): | ||||
self.assertEqual(num_samples[idx], len(batch_x['x'])) | self.assertEqual(num_samples[idx], len(batch_x['x'])) | ||||
def test_ConstantTokenNumSampler(self): | |||||
num_samples = 100 | |||||
ds = generate_fake_dataset(num_samples) | |||||
ds.set_input('1') | |||||
ds.add_seq_len('1', 'seq_len') | |||||
ds.set_input('seq_len') | |||||
# 测试token数量不超过 | |||||
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120) | |||||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||||
batch_sampler=batch_sampler) | |||||
sample_count = 0 | |||||
for batch_x, batch_y in data_iter: | |||||
self.assertTrue(sum(batch_x['seq_len'])<120) | |||||
sample_count += len(batch_x['seq_len']) | |||||
self.assertEqual(sample_count, num_samples) | |||||
# 测试句子数量不超过 | |||||
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, max_sentence=1) | |||||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||||
batch_sampler=batch_sampler) | |||||
sample_count = 0 | |||||
for batch_x, batch_y in data_iter: | |||||
sample_count += len(batch_x['seq_len']) | |||||
self.assertTrue(sum(batch_x['seq_len'])<120 and len(batch_x['seq_len'])==1) | |||||
self.assertEqual(sample_count, num_samples) | |||||
# 测试need_be_multiple_of | |||||
sample_count = 0 | |||||
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, max_sentence=2, need_be_multiple_of=2) | |||||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||||
batch_sampler=batch_sampler) | |||||
for batch_x, batch_y in data_iter: | |||||
sample_count += len(batch_x['seq_len']) | |||||
self.assertTrue(sum(batch_x['seq_len'])<120 and len(batch_x['seq_len'])==2) | |||||
self.assertEqual(sample_count, num_samples) | |||||
# 测试token数量不超过, bucket尽量接近 | |||||
batch_sampler = ConstantTokenNumSampler(ds.get_field('seq_len'), max_token=120, num_bucket=10) | |||||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, num_workers=0, | |||||
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, | |||||
batch_sampler=batch_sampler) | |||||
sample_count = 0 | |||||
for batch_x, batch_y in data_iter: | |||||
sample_count += len(batch_x['seq_len']) | |||||
self.assertTrue(sum(batch_x['seq_len'])<120) | |||||
self.assertEqual(sample_count, num_samples) | |||||
""" | """ | ||||
def test_multi_workers_batch(self): | def test_multi_workers_batch(self): | ||||
batch_size = 32 | batch_size = 32 | ||||