diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index dfb66d71..59e4b67c 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -66,6 +66,7 @@ __all__ = [ "SequentialSampler", "BucketSampler", "RandomSampler", + "SortedSampler", "LossFunc", "CrossEntropyLoss", diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 6eb3e424..629dd786 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -84,6 +84,7 @@ __all__ = [ "BucketSampler", "RandomSampler", "Sampler", + "SortedSampler" ] from ._logger import logger, init_logger_dist @@ -100,7 +101,7 @@ from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ ConfusionMatrixMetric from .optimizer import Optimizer, SGD, Adam, AdamW -from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler +from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler from .tester import Tester from .trainer import Trainer from .utils import cache_results, seq_len_to_mask, get_seq_len diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 9ab4c3ff..9830ff1e 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -521,7 +521,7 @@ class FitlogCallback(Callback): fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 """ - def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): + def __init__(self, data=None, tester=None, log_loss_every=0, verbose=1, log_exception=False): r""" :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 @@ -572,7 +572,8 @@ class FitlogCallback(Callback): batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), metrics=self.trainer.metrics, verbose=0, - use_tqdm=self.trainer.test_use_tqdm) + use_tqdm=self.trainer.test_use_tqdm, + sampler=self.trainer.kwargs.get('test_sampler', None)) self.testers[key] = tester fitlog.add_progress(total_steps=self.n_steps) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 574738bb..afd2d083 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -36,7 +36,7 @@ from ..core.const import Const class LossBase(object): r""" - 所有loss的基类。如果想了解其中的原理,请查看源码。 + 所有loss的基类。如果需要结合到Trainer之中需要实现get_loss方法 """ def __init__(self): @@ -53,6 +53,12 @@ class LossBase(object): return self._param_map def get_loss(self, *args, **kwargs): + """ + + :param args: + :param kwargs: + :return: torch.Tensor + """ raise NotImplementedError def _init_param_map(self, key_map=None, **kwargs): diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 230a921a..61f47315 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -5,7 +5,8 @@ __all__ = [ "Sampler", "BucketSampler", "SequentialSampler", - "RandomSampler" + "RandomSampler", + "SortedSampler" ] from itertools import chain @@ -57,8 +58,8 @@ class BucketSampler(Sampler): r""" :param int num_buckets: bucket的数量 - :param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 - 要显示传递该值 + :param int batch_size: batch的大小. 默认为None,Trainer/Tester在调用BucketSampler时,会将该值正确设置,如果是非 + Trainer/Tester场景使用,需要显示传递该值 :param str seq_len_field_name: 对应序列长度的 `field` 的名字 """ self.num_buckets = num_buckets @@ -110,6 +111,27 @@ class BucketSampler(Sampler): return list(chain(*batchs)) +class SortedSampler(Sampler): + r""" + 按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding) + """ + def __init__(self, seq_len_field_name='seq_len', descending=True): + """ + + :param str seq_len_field_name: 对应序列长度的 `field` 的名字 + :param bool descending: 是否降序排列 + """ + self.seq_len_field_name = seq_len_field_name + self.descending = descending + + def __call__(self, data_set): + seq_lens = data_set.get_field(self.seq_len_field_name).content + orders = np.argsort(seq_lens).tolist() # 从小到大的顺序 + if self.descending: + orders = orders[::-1] + return orders + + def simple_sort_bucketing(lengths): r""" diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 680782b1..abb39c56 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -57,6 +57,7 @@ from ._parallel_utils import _data_parallel_wrapper from ._parallel_utils import _model_contains_inner_module from functools import partial from ._logger import logger +from .sampler import Sampler __all__ = [ "Tester" @@ -68,7 +69,8 @@ class Tester(object): Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。 """ - def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): + def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True, + **kwargs): r""" :param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 @@ -91,6 +93,7 @@ class Tester(object): 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 + :param kwargs: 支持传入sampler控制测试顺序 """ super(Tester, self).__init__() @@ -107,7 +110,14 @@ class Tester(object): self.logger = logger if isinstance(data, DataSet): - self.data_iterator = DataSetIter(dataset=data, batch_size=batch_size, sampler=SequentialSampler(), + sampler = kwargs.get('sampler', None) + if sampler is None: + sampler = SequentialSampler() + elif not isinstance(sampler, (Sampler, torch.utils.data.Sampler)): + raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}") + if hasattr(sampler, 'set_batch_size'): + sampler.set_batch_size(batch_size) + self.data_iterator = DataSetIter(dataset=data, batch_size=batch_size, sampler=sampler, num_workers=num_workers) elif isinstance(data, BatchIter): self.data_iterator = data diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index c9f16b9a..e183632e 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -422,6 +422,9 @@ class Trainer(object): 报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; (2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 + :param kwargs: 支持配置可选参数 + bool test_use_tqdm: 在dev上验证的时候是否开启tqdm + Sampler test_sampler: 在evaluate的时候使用的sampler """ super(Trainer, self).__init__() if not isinstance(model, nn.Module): @@ -561,7 +564,8 @@ class Trainer(object): batch_size=kwargs.get("dev_batch_size", self.batch_size), device=None, # 由上面的部分处理device verbose=0, - use_tqdm=self.test_use_tqdm) + use_tqdm=self.test_use_tqdm, + sampler=kwargs.get('test_sampler', None)) self.start_time = None # start timestamp @@ -691,8 +695,7 @@ class Trainer(object): avg_loss = 0 self.callback_manager.on_batch_end() - if ((self.validate_every > 0 and self.step % self.validate_every == 0) or - (self.validate_every < 0 and self.step % len(self.data_iterator) == 0)) \ + if (self.validate_every > 0 and self.step % self.validate_every == 0) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, @@ -701,7 +704,13 @@ class Trainer(object): self.logger.info(eval_str) self.logger.info(self.tester._format_eval_results(eval_res)+'\n') # ================= mini-batch end ==================== # - + if self.validate_every<0 and self.dev_data is not None: # 在epoch结束之后的evaluate + eval_res = self._do_validation(epoch=epoch, step=self.step) + eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, + self.n_steps) + # pbar.write(eval_str + '\n') + self.logger.info(eval_str) + self.logger.info(self.tester._format_eval_results(eval_res) + '\n') # lr decay; early stopping self.callback_manager.on_epoch_end() # =============== epochs end =================== # diff --git a/fastNLP/models/seq2seq_generator.py b/fastNLP/models/seq2seq_generator.py index 81eb344d..33da67bd 100644 --- a/fastNLP/models/seq2seq_generator.py +++ b/fastNLP/models/seq2seq_generator.py @@ -12,15 +12,16 @@ class SequenceGeneratorModel(nn.Module): """ - def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, num_beams=1, - do_sample=True, temperature=1.0, top_k=50, top_p=1.0, + def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, + num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1, length_penalty=1.0, pad_token_id=0): """ :param Seq2SeqModel seq2seq_model: 序列到序列模型 :param int,None bos_token_id: 句子开头的token id :param int,None eos_token_id: 句子结束的token id - :param int max_length: 句子的最大长度 + :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len + :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask :param int num_beams: beam search的大小 :param bool do_sample: 是否通过采样的方式生成 :param float temperature: 只有在do_sample为True才有意义 @@ -32,7 +33,8 @@ class SequenceGeneratorModel(nn.Module): """ super().__init__() self.seq2seq_model = seq2seq_model - self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, num_beams=num_beams, + self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, diff --git a/fastNLP/modules/decoder/seq2seq_decoder.py b/fastNLP/modules/decoder/seq2seq_decoder.py index 41f255b6..2c223dea 100644 --- a/fastNLP/modules/decoder/seq2seq_decoder.py +++ b/fastNLP/modules/decoder/seq2seq_decoder.py @@ -263,8 +263,8 @@ class TransformerSeq2SeqDecoderLayer(nn.Module): """ :param x: (batch, seq_len, dim), decoder端的输入 - :param encoder_output: (batch,src_seq_len,dim) - :param encoder_mask: batch,src_seq_len + :param encoder_output: (batch,src_seq_len,dim), encoder的输出 + :param encoder_mask: batch,src_seq_len, 为1的地方需要attend :param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 :param TransformerState state: 只在inference阶段传入 :return: diff --git a/fastNLP/modules/generator/seq2seq_generator.py b/fastNLP/modules/generator/seq2seq_generator.py index e6115407..8aa9eddc 100644 --- a/fastNLP/modules/generator/seq2seq_generator.py +++ b/fastNLP/modules/generator/seq2seq_generator.py @@ -12,19 +12,19 @@ import torch.nn.functional as F from ...core.utils import _get_model_device from functools import partial - class SequenceGenerator: """ 给定一个Seq2SeqDecoder,decode出句子 """ - def __init__(self, decoder: Seq2SeqDecoder, max_length=20, num_beams=1, + def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1, length_penalty=1.0, pad_token_id=0): """ :param Seq2SeqDecoder decoder: Decoder对象 - :param int max_length: 句子的最大长度 + :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len + :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask :param int num_beams: beam search的大小 :param bool do_sample: 是否通过采样的方式生成 :param float temperature: 只有在do_sample为True才有意义 @@ -37,12 +37,14 @@ class SequenceGenerator: :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 """ if do_sample: - self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, + self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) else: - self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, + self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, bos_token_id=bos_token_id, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) @@ -71,7 +73,7 @@ class SequenceGenerator: @torch.no_grad() -def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, +def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1, length_penalty=1.0): """ @@ -80,7 +82,8 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 :param Decoder decoder: Decoder对象 :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 :param State state: 应该包含encoder的一些输出。 - :param int max_length: 生成句子的最大长度。 + :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len + :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask :param int num_beams: 使用多大的beam进行解码。 :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 @@ -90,13 +93,14 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 :return: """ if num_beams == 1: - token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=1, top_k=50, top_p=1, + token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, + temperature=1, top_k=50, top_p=1, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) else: - token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, - temperature=1, top_k=50, top_p=1, + token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, temperature=1, top_k=50, top_p=1, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) @@ -105,7 +109,7 @@ def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 @torch.no_grad() -def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, +def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, length_penalty=1.0): """ @@ -114,7 +118,8 @@ def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 :param Decoder decoder: Decoder对象 :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 :param State state: 应该包含encoder的一些输出。 - :param int max_length: 生成句子的最大长度。 + :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len + :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask :param int num_beam: 使用多大的beam进行解码。 :param float temperature: 采样时的退火大小 :param int top_k: 只在top_k的sample里面采样 @@ -128,21 +133,21 @@ def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1 """ # 每个位置在生成的时候会sample生成 if num_beams == 1: - token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=temperature, - top_k=top_k, top_p=top_p, + token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, + temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) else: - token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, - temperature=temperature, top_k=top_k, top_p=top_p, + token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, + num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id) return token_ids -def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, temperature=1.0, top_k=50, +def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, max_len_a=0.0, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): device = _get_model_device(decoder) @@ -169,7 +174,14 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le dones = token_ids.new_zeros(batch_size).eq(1) # tokens = tokens[:, -1:] - while cur_len < max_length: + if max_len_a!=0: + max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length + real_max_length = max_lengths.max() + else: + real_max_length = max_length + max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length + + while cur_len < real_max_length: scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size if repetition_penalty != 1.0: @@ -194,11 +206,12 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le # 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 probs = F.softmax(scores, dim=-1) + 1e-12 - # 保证至少有一个不是eos的值 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size else: next_tokens = torch.argmax(scores, dim=-1) # batch_size + # 如果已经达到对应的sequence长度了,就直接填为eos了 + next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding tokens = next_tokens.unsqueeze(1) @@ -211,14 +224,14 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le if dones.min() == 1: break - if eos_token_id is not None: - if cur_len == max_length: - token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos - + # if eos_token_id is not None: + # tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # 将最大长度位置设置为eos + # if cur_len == max_length: + # token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos return token_ids -def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, num_beams=4, temperature=1.0, +def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: # 进行beam search @@ -268,14 +281,22 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ # 用来记录已经生成好的token的长度 cur_len = token_ids.size(1) + if max_len_a!=0: + # (bsz x num_beams, ) + max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length + real_max_length = max_lengths.max().item() + else: + real_max_length = max_length + max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length + hypos = [ - BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) + BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) ] - # 0,num_beams, 2*num_beams, ... + # 0, num_beams, 2*num_beams, ... batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) - while cur_len < max_length: - scores = decoder.decode(token_ids, state) + while cur_len < real_max_length: + scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size) if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) lt_zero_mask = token_scores.lt(0).float() @@ -283,6 +304,12 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores scores.scatter_(dim=1, index=token_ids, src=token_scores) + if _eos_token_id!=-1: + max_len_eos_mask = max_lengths.eq(cur_len+1) + eos_scores = scores[:, _eos_token_id] + # 如果已经达到最大长度,就把eos的分数加大 + scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+100, eos_scores) + if do_sample: if temperature > 0 and temperature != 1: scores = scores / temperature @@ -309,7 +336,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) - next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) + next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) next_tokens = ids % vocab_size # (batch_size, 2*num_beams) @@ -328,12 +355,8 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) beam_scores = _next_scores.view(-1) - # 更改state状态, 重组token_ids - reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 - state.reorder_state(reorder_inds) - flag = True - if cur_len+1 == max_length: + if cur_len+1 == real_max_length: eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 @@ -348,19 +371,24 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ else: flag = False - # 重新组织token_ids的状态 - tokens = _next_tokens - token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) - if flag: for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()): if not dones[batch_idx]: score = next_scores[batch_idx, beam_ind].item() - hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len+1].clone(), score) + # 之后需要在结尾新增一个eos + hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) for batch_idx in range(batch_size): - dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) + dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ + max_lengths[batch_idx*num_beams]==cur_len+1 + + # 更改state状态, 重组token_ids + reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 + state.reorder_state(reorder_inds) + # 重新组织token_ids的状态 + tokens = _next_tokens + token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) cur_len += 1 @@ -373,15 +401,16 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ for i, hypotheses in enumerate(hypos): best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] - tgt_len[i] = len(best_hyp) # +1 for the symbol + # 把上面替换为非eos的词替换回eos + if _eos_token_id!=-1: + best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id]) + tgt_len[i] = len(best_hyp) best.append(best_hyp) # generate target batch decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) for i, hypo in enumerate(best): decoded[i, :tgt_len[i]] = hypo - if eos_token_id is not None: - decoded[i, tgt_len[i] - 1] = _eos_token_id return decoded