Browse Source

1.新增一个sortedSampler,可以用于测试的时候加速; 2.修改BeamSearch中的bug; 3.修改Trainer中的循环代码,解决BatchSampler场景下可能出现evaluate时机不对的问题

tags/v0.6.0
yh_cc 5 years ago
parent
commit
206f7758b5
10 changed files with 143 additions and 62 deletions
  1. +1
    -0
      fastNLP/__init__.py
  2. +2
    -1
      fastNLP/core/__init__.py
  3. +3
    -2
      fastNLP/core/callback.py
  4. +7
    -1
      fastNLP/core/losses.py
  5. +25
    -3
      fastNLP/core/sampler.py
  6. +12
    -2
      fastNLP/core/tester.py
  7. +13
    -4
      fastNLP/core/trainer.py
  8. +6
    -4
      fastNLP/models/seq2seq_generator.py
  9. +2
    -2
      fastNLP/modules/decoder/seq2seq_decoder.py
  10. +72
    -43
      fastNLP/modules/generator/seq2seq_generator.py

+ 1
- 0
fastNLP/__init__.py View File

@@ -66,6 +66,7 @@ __all__ = [
"SequentialSampler", "SequentialSampler",
"BucketSampler", "BucketSampler",
"RandomSampler", "RandomSampler",
"SortedSampler",
"LossFunc", "LossFunc",
"CrossEntropyLoss", "CrossEntropyLoss",


+ 2
- 1
fastNLP/core/__init__.py View File

@@ -84,6 +84,7 @@ __all__ = [
"BucketSampler", "BucketSampler",
"RandomSampler", "RandomSampler",
"Sampler", "Sampler",
"SortedSampler"
] ]


from ._logger import logger, init_logger_dist from ._logger import logger, init_logger_dist
@@ -100,7 +101,7 @@ from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\
ConfusionMatrixMetric ConfusionMatrixMetric
from .optimizer import Optimizer, SGD, Adam, AdamW from .optimizer import Optimizer, SGD, Adam, AdamW
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler
from .tester import Tester from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len


+ 3
- 2
fastNLP/core/callback.py View File

@@ -521,7 +521,7 @@ class FitlogCallback(Callback):
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。
""" """


def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False):
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=1, log_exception=False):
r""" r"""
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要
@@ -572,7 +572,8 @@ class FitlogCallback(Callback):
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
metrics=self.trainer.metrics, metrics=self.trainer.metrics,
verbose=0, verbose=0,
use_tqdm=self.trainer.test_use_tqdm)
use_tqdm=self.trainer.test_use_tqdm,
sampler=self.trainer.kwargs.get('test_sampler', None))
self.testers[key] = tester self.testers[key] = tester
fitlog.add_progress(total_steps=self.n_steps) fitlog.add_progress(total_steps=self.n_steps)


+ 7
- 1
fastNLP/core/losses.py View File

@@ -36,7 +36,7 @@ from ..core.const import Const


class LossBase(object): class LossBase(object):
r""" r"""
所有loss的基类。如果想了解其中的原理,请查看源码。
所有loss的基类。如果需要结合到Trainer之中需要实现get_loss方法
""" """
def __init__(self): def __init__(self):
@@ -53,6 +53,12 @@ class LossBase(object):
return self._param_map return self._param_map


def get_loss(self, *args, **kwargs): def get_loss(self, *args, **kwargs):
"""

:param args:
:param kwargs:
:return: torch.Tensor
"""
raise NotImplementedError raise NotImplementedError
def _init_param_map(self, key_map=None, **kwargs): def _init_param_map(self, key_map=None, **kwargs):


+ 25
- 3
fastNLP/core/sampler.py View File

@@ -5,7 +5,8 @@ __all__ = [
"Sampler", "Sampler",
"BucketSampler", "BucketSampler",
"SequentialSampler", "SequentialSampler",
"RandomSampler"
"RandomSampler",
"SortedSampler"
] ]


from itertools import chain from itertools import chain
@@ -57,8 +58,8 @@ class BucketSampler(Sampler):
r""" r"""
:param int num_buckets: bucket的数量 :param int num_buckets: bucket的数量
:param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需
要显示传递该值
:param int batch_size: batch的大小. 默认为None,Trainer/Tester在调用BucketSampler时,会将该值正确设置,如果是非
Trainer/Tester场景使用,需要显示传递该值
:param str seq_len_field_name: 对应序列长度的 `field` 的名字 :param str seq_len_field_name: 对应序列长度的 `field` 的名字
""" """
self.num_buckets = num_buckets self.num_buckets = num_buckets
@@ -110,6 +111,27 @@ class BucketSampler(Sampler):
return list(chain(*batchs)) return list(chain(*batchs))




class SortedSampler(Sampler):
r"""
按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding)
"""
def __init__(self, seq_len_field_name='seq_len', descending=True):
"""

:param str seq_len_field_name: 对应序列长度的 `field` 的名字
:param bool descending: 是否降序排列
"""
self.seq_len_field_name = seq_len_field_name
self.descending = descending

def __call__(self, data_set):
seq_lens = data_set.get_field(self.seq_len_field_name).content
orders = np.argsort(seq_lens).tolist() # 从小到大的顺序
if self.descending:
orders = orders[::-1]
return orders


def simple_sort_bucketing(lengths): def simple_sort_bucketing(lengths):
r""" r"""




+ 12
- 2
fastNLP/core/tester.py View File

@@ -57,6 +57,7 @@ from ._parallel_utils import _data_parallel_wrapper
from ._parallel_utils import _model_contains_inner_module from ._parallel_utils import _model_contains_inner_module
from functools import partial from functools import partial
from ._logger import logger from ._logger import logger
from .sampler import Sampler


__all__ = [ __all__ = [
"Tester" "Tester"
@@ -68,7 +69,8 @@ class Tester(object):
Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。 Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。
""" """
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True):
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True,
**kwargs):
r""" r"""
:param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 :param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集
@@ -91,6 +93,7 @@ class Tester(object):
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。
:param kwargs: 支持传入sampler控制测试顺序
""" """
super(Tester, self).__init__() super(Tester, self).__init__()


@@ -107,7 +110,14 @@ class Tester(object):
self.logger = logger self.logger = logger


if isinstance(data, DataSet): if isinstance(data, DataSet):
self.data_iterator = DataSetIter(dataset=data, batch_size=batch_size, sampler=SequentialSampler(),
sampler = kwargs.get('sampler', None)
if sampler is None:
sampler = SequentialSampler()
elif not isinstance(sampler, (Sampler, torch.utils.data.Sampler)):
raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}")
if hasattr(sampler, 'set_batch_size'):
sampler.set_batch_size(batch_size)
self.data_iterator = DataSetIter(dataset=data, batch_size=batch_size, sampler=sampler,
num_workers=num_workers) num_workers=num_workers)
elif isinstance(data, BatchIter): elif isinstance(data, BatchIter):
self.data_iterator = data self.data_iterator = data


+ 13
- 4
fastNLP/core/trainer.py View File

@@ -422,6 +422,9 @@ class Trainer(object):
报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是
这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; 这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况;
(2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 (2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。
:param kwargs: 支持配置可选参数
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
Sampler test_sampler: 在evaluate的时候使用的sampler
""" """
super(Trainer, self).__init__() super(Trainer, self).__init__()
if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
@@ -561,7 +564,8 @@ class Trainer(object):
batch_size=kwargs.get("dev_batch_size", self.batch_size), batch_size=kwargs.get("dev_batch_size", self.batch_size),
device=None, # 由上面的部分处理device device=None, # 由上面的部分处理device
verbose=0, verbose=0,
use_tqdm=self.test_use_tqdm)
use_tqdm=self.test_use_tqdm,
sampler=kwargs.get('test_sampler', None))


self.start_time = None # start timestamp self.start_time = None # start timestamp


@@ -691,8 +695,7 @@ class Trainer(object):
avg_loss = 0 avg_loss = 0
self.callback_manager.on_batch_end() self.callback_manager.on_batch_end()


if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(self.data_iterator) == 0)) \
if (self.validate_every > 0 and self.step % self.validate_every == 0) \
and self.dev_data is not None: and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,
@@ -701,7 +704,13 @@ class Trainer(object):
self.logger.info(eval_str) self.logger.info(eval_str)
self.logger.info(self.tester._format_eval_results(eval_res)+'\n') self.logger.info(self.tester._format_eval_results(eval_res)+'\n')
# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #

if self.validate_every<0 and self.dev_data is not None: # 在epoch结束之后的evaluate
eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,
self.n_steps)
# pbar.write(eval_str + '\n')
self.logger.info(eval_str)
self.logger.info(self.tester._format_eval_results(eval_res) + '\n')
# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end() self.callback_manager.on_epoch_end()
# =============== epochs end =================== # # =============== epochs end =================== #


+ 6
- 4
fastNLP/models/seq2seq_generator.py View File

@@ -12,15 +12,16 @@ class SequenceGeneratorModel(nn.Module):


""" """


def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, num_beams=1,
do_sample=True, temperature=1.0, top_k=50, top_p=1.0,
def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0,
num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0,
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): repetition_penalty=1, length_penalty=1.0, pad_token_id=0):
""" """


:param Seq2SeqModel seq2seq_model: 序列到序列模型 :param Seq2SeqModel seq2seq_model: 序列到序列模型
:param int,None bos_token_id: 句子开头的token id :param int,None bos_token_id: 句子开头的token id
:param int,None eos_token_id: 句子结束的token id :param int,None eos_token_id: 句子结束的token id
:param int max_length: 句子的最大长度
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
:param int num_beams: beam search的大小 :param int num_beams: beam search的大小
:param bool do_sample: 是否通过采样的方式生成 :param bool do_sample: 是否通过采样的方式生成
:param float temperature: 只有在do_sample为True才有意义 :param float temperature: 只有在do_sample为True才有意义
@@ -32,7 +33,8 @@ class SequenceGeneratorModel(nn.Module):
""" """
super().__init__() super().__init__()
self.seq2seq_model = seq2seq_model self.seq2seq_model = seq2seq_model
self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, num_beams=num_beams,
self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,


+ 2
- 2
fastNLP/modules/decoder/seq2seq_decoder.py View File

@@ -263,8 +263,8 @@ class TransformerSeq2SeqDecoderLayer(nn.Module):
""" """


:param x: (batch, seq_len, dim), decoder端的输入 :param x: (batch, seq_len, dim), decoder端的输入
:param encoder_output: (batch,src_seq_len,dim)
:param encoder_mask: batch,src_seq_len
:param encoder_output: (batch,src_seq_len,dim), encoder的输出
:param encoder_mask: batch,src_seq_len, 为1的地方需要attend
:param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 :param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入
:param TransformerState state: 只在inference阶段传入 :param TransformerState state: 只在inference阶段传入
:return: :return:


+ 72
- 43
fastNLP/modules/generator/seq2seq_generator.py View File

@@ -12,19 +12,19 @@ import torch.nn.functional as F
from ...core.utils import _get_model_device from ...core.utils import _get_model_device
from functools import partial from functools import partial



class SequenceGenerator: class SequenceGenerator:
""" """
给定一个Seq2SeqDecoder,decode出句子 给定一个Seq2SeqDecoder,decode出句子


""" """
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, num_beams=1,
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1,
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): repetition_penalty=1, length_penalty=1.0, pad_token_id=0):
""" """


:param Seq2SeqDecoder decoder: Decoder对象 :param Seq2SeqDecoder decoder: Decoder对象
:param int max_length: 句子的最大长度
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
:param int num_beams: beam search的大小 :param int num_beams: beam search的大小
:param bool do_sample: 是否通过采样的方式生成 :param bool do_sample: 是否通过采样的方式生成
:param float temperature: 只有在do_sample为True才有意义 :param float temperature: 只有在do_sample为True才有意义
@@ -37,12 +37,14 @@ class SequenceGenerator:
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充
""" """
if do_sample: if do_sample:
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty,
length_penalty=length_penalty, pad_token_id=pad_token_id) length_penalty=length_penalty, pad_token_id=pad_token_id)
else: else:
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
length_penalty=length_penalty, pad_token_id=pad_token_id) length_penalty=length_penalty, pad_token_id=pad_token_id)
@@ -71,7 +73,7 @@ class SequenceGenerator:




@torch.no_grad() @torch.no_grad()
def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1,
def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1,
bos_token_id=None, eos_token_id=None, pad_token_id=0, bos_token_id=None, eos_token_id=None, pad_token_id=0,
repetition_penalty=1, length_penalty=1.0): repetition_penalty=1, length_penalty=1.0):
""" """
@@ -80,7 +82,8 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1
:param Decoder decoder: Decoder对象 :param Decoder decoder: Decoder对象
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
:param State state: 应该包含encoder的一些输出。 :param State state: 应该包含encoder的一些输出。
:param int max_length: 生成句子的最大长度。
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
:param int num_beams: 使用多大的beam进行解码。 :param int num_beams: 使用多大的beam进行解码。
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。
@@ -90,13 +93,14 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1
:return: :return:
""" """
if num_beams == 1: if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=1, top_k=50, top_p=1,
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
temperature=1, top_k=50, top_p=1,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty, repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id) pad_token_id=pad_token_id)
else: else:
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams,
temperature=1, top_k=50, top_p=1,
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams, temperature=1, top_k=50, top_p=1,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty, repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id) pad_token_id=pad_token_id)
@@ -105,7 +109,7 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1




@torch.no_grad() @torch.no_grad()
def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, temperature=1.0, top_k=50,
def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0,
length_penalty=1.0): length_penalty=1.0):
""" """
@@ -114,7 +118,8 @@ def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1
:param Decoder decoder: Decoder对象 :param Decoder decoder: Decoder对象
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
:param State state: 应该包含encoder的一些输出。 :param State state: 应该包含encoder的一些输出。
:param int max_length: 生成句子的最大长度。
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
:param int num_beam: 使用多大的beam进行解码。 :param int num_beam: 使用多大的beam进行解码。
:param float temperature: 采样时的退火大小 :param float temperature: 采样时的退火大小
:param int top_k: 只在top_k的sample里面采样 :param int top_k: 只在top_k的sample里面采样
@@ -128,21 +133,21 @@ def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1
""" """
# 每个位置在生成的时候会sample生成 # 每个位置在生成的时候会sample生成
if num_beams == 1: if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=temperature,
top_k=top_k, top_p=top_p,
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
temperature=temperature, top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
repetition_penalty=repetition_penalty, length_penalty=length_penalty, repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id) pad_token_id=pad_token_id)
else: else:
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams,
temperature=temperature, top_k=top_k, top_p=top_p,
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
repetition_penalty=repetition_penalty, length_penalty=length_penalty, repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id) pad_token_id=pad_token_id)
return token_ids return token_ids




def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, temperature=1.0, top_k=50,
def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, max_len_a=0.0, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0):
device = _get_model_device(decoder) device = _get_model_device(decoder)
@@ -169,7 +174,14 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le
dones = token_ids.new_zeros(batch_size).eq(1) dones = token_ids.new_zeros(batch_size).eq(1)
# tokens = tokens[:, -1:] # tokens = tokens[:, -1:]


while cur_len < max_length:
if max_len_a!=0:
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
real_max_length = max_lengths.max()
else:
real_max_length = max_length
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length

while cur_len < real_max_length:
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size


if repetition_penalty != 1.0: if repetition_penalty != 1.0:
@@ -194,11 +206,12 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 # 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523
probs = F.softmax(scores, dim=-1) + 1e-12 probs = F.softmax(scores, dim=-1) + 1e-12


# 保证至少有一个不是eos的值
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
else: else:
next_tokens = torch.argmax(scores, dim=-1) # batch_size next_tokens = torch.argmax(scores, dim=-1) # batch_size


# 如果已经达到对应的sequence长度了,就直接填为eos了
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id)
next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding
tokens = next_tokens.unsqueeze(1) tokens = next_tokens.unsqueeze(1)


@@ -211,14 +224,14 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le
if dones.min() == 1: if dones.min() == 1:
break break


if eos_token_id is not None:
if cur_len == max_length:
token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos
# if eos_token_id is not None:
# tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # 将最大长度位置设置为eos
# if cur_len == max_length:
# token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos
return token_ids return token_ids




def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, num_beams=4, temperature=1.0,
def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, temperature=1.0,
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor:
# 进行beam search # 进行beam search
@@ -268,14 +281,22 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
# 用来记录已经生成好的token的长度 # 用来记录已经生成好的token的长度
cur_len = token_ids.size(1) cur_len = token_ids.size(1)


if max_len_a!=0:
# (bsz x num_beams, )
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
real_max_length = max_lengths.max().item()
else:
real_max_length = max_length
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length

hypos = [ hypos = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
] ]
# 0,num_beams, 2*num_beams, ...
# 0, num_beams, 2*num_beams, ...
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)


while cur_len < max_length:
scores = decoder.decode(token_ids, state)
while cur_len < real_max_length:
scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids) token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float() lt_zero_mask = token_scores.lt(0).float()
@@ -283,6 +304,12 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores) scores.scatter_(dim=1, index=token_ids, src=token_scores)


if _eos_token_id!=-1:
max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id]
# 如果已经达到最大长度,就把eos的分数加大
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+100, eos_scores)

if do_sample: if do_sample:
if temperature > 0 and temperature != 1: if temperature > 0 and temperature != 1:
scores = scores / temperature scores = scores / temperature
@@ -309,7 +336,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size)
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams)
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) from_which_beam = ids // vocab_size # (batch_size, 2*num_beams)
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) next_tokens = ids % vocab_size # (batch_size, 2*num_beams)


@@ -328,12 +355,8 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
beam_scores = _next_scores.view(-1) beam_scores = _next_scores.view(-1)


# 更改state状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
state.reorder_state(reorder_inds)

flag = True flag = True
if cur_len+1 == max_length:
if cur_len+1 == real_max_length:
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的
@@ -348,19 +371,24 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
else: else:
flag = False flag = False


# 重新组织token_ids的状态
tokens = _next_tokens
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1)

if flag: if flag:
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(),
eos_beam_idx.tolist()): eos_beam_idx.tolist()):
if not dones[batch_idx]: if not dones[batch_idx]:
score = next_scores[batch_idx, beam_ind].item() score = next_scores[batch_idx, beam_ind].item()
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len+1].clone(), score)
# 之后需要在结尾新增一个eos
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)


for batch_idx in range(batch_size): for batch_idx in range(batch_size):
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \
max_lengths[batch_idx*num_beams]==cur_len+1

# 更改state状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
state.reorder_state(reorder_inds)
# 重新组织token_ids的状态
tokens = _next_tokens
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1)


cur_len += 1 cur_len += 1


@@ -373,15 +401,16 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_


for i, hypotheses in enumerate(hypos): for i, hypotheses in enumerate(hypos):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
tgt_len[i] = len(best_hyp) # +1 for the <EOS> symbol
# 把上面替换为非eos的词替换回eos
if _eos_token_id!=-1:
best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id])
tgt_len[i] = len(best_hyp)
best.append(best_hyp) best.append(best_hyp)


# generate target batch # generate target batch
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
decoded[i, :tgt_len[i]] = hypo decoded[i, :tgt_len[i]] = hypo
if eos_token_id is not None:
decoded[i, tgt_len[i] - 1] = _eos_token_id


return decoded return decoded




Loading…
Cancel
Save