Browse Source

1.更新权重下载url; 2.更新seq2seq,方式第一个位置预测eos

tags/v1.0.0alpha
yh_cc 4 years ago
parent
commit
0a2f546b70
5 changed files with 33 additions and 19 deletions
  1. +2
    -2
      fastNLP/io/file_utils.py
  2. +4
    -3
      fastNLP/models/seq2seq_generator.py
  3. +10
    -4
      fastNLP/models/seq2seq_model.py
  4. +2
    -2
      fastNLP/modules/decoder/seq2seq_decoder.py
  5. +15
    -8
      fastNLP/modules/generator/seq2seq_generator.py

+ 2
- 2
fastNLP/io/file_utils.py View File

@@ -259,8 +259,8 @@ def _get_base_url(name):
return url + '/' return url + '/'
else: else:
URLS = { URLS = {
'embedding': "http://212.129.155.247/embedding/",
"dataset": "http://212.129.155.247/dataset/"
'embedding': "http://download.fastnlp.top/embedding/",
"dataset": "http://download.fastnlp.top/dataset/"
} }
if name.lower() not in URLS: if name.lower() not in URLS:
raise KeyError(f"{name} is not recognized.") raise KeyError(f"{name} is not recognized.")


+ 4
- 3
fastNLP/models/seq2seq_generator.py View File

@@ -11,7 +11,8 @@ __all__ = ['SequenceGeneratorModel']


class SequenceGeneratorModel(nn.Module): class SequenceGeneratorModel(nn.Module):
""" """
用于封装Seq2SeqModel使其可以做生成任务
通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict
函数会被调用。


""" """


@@ -46,7 +47,7 @@ class SequenceGeneratorModel(nn.Module):


def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None):
""" """
透传调用seq2seq_model的forward
透传调用seq2seq_model的forward


:param torch.LongTensor src_tokens: bsz x max_len :param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor tgt_tokens: bsz x max_len' :param torch.LongTensor tgt_tokens: bsz x max_len'
@@ -58,7 +59,7 @@ class SequenceGeneratorModel(nn.Module):


def predict(self, src_tokens, src_seq_len=None): def predict(self, src_tokens, src_seq_len=None):
""" """
给定source的内容,输出generate的内容
给定source的内容,输出generate的内容


:param torch.LongTensor src_tokens: bsz x max_len :param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor src_seq_len: bsz :param torch.LongTensor src_seq_len: bsz


+ 10
- 4
fastNLP/models/seq2seq_model.py View File

@@ -18,10 +18,16 @@ __all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel']
class Seq2SeqModel(nn.Module): class Seq2SeqModel(nn.Module):
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder):
""" """
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。

:param encoder: Encoder
:param decoder: Decoder
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。如果需要使用该模型
进行生成,需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 中。在本模型中,forward()会把encoder后的
结果传入到decoder中,并将decoder的输出output出来。

:param encoder: Seq2SeqEncoder 对象,需要实现对应的forward()函数,接受两个参数,第一个为bsz x max_len的source tokens, 第二个为
bsz的source的长度;需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len
为1的地方需要被attend。如果encoder的输出或者输入有变化,可以重载本模型的prepare_state()函数或者forward()函数
:param decoder: Seq2SeqDecoder 对象,需要实现init_state()函数,输出为两个参数,第一个为bsz x max_len x hidden_size是
encoder的输出; 第二个为bsz x max_len,为encoder输出的mask,为0的地方为pad。若decoder需要更多输入,请重载当前模型的
prepare_state()或forward()函数
""" """
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder


+ 2
- 2
fastNLP/modules/decoder/seq2seq_decoder.py View File

@@ -16,7 +16,7 @@ __all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder']


class Seq2SeqDecoder(nn.Module): class Seq2SeqDecoder(nn.Module):
""" """
Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象
Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。


""" """
@@ -61,7 +61,7 @@ class Seq2SeqDecoder(nn.Module):
""" """
根据states中的内容,以及tokens中的内容进行之后的生成。 根据states中的内容,以及tokens中的内容进行之后的生成。


:param torch.LongTensor tokens: bsz x max_len, 上一个时刻的token输出。
:param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。
:param State state: 记录了encoder输出与decoder过去状态 :param State state: 记录了encoder输出与decoder过去状态
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 :return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布
""" """


+ 15
- 8
fastNLP/modules/generator/seq2seq_generator.py View File

@@ -12,9 +12,11 @@ import torch.nn.functional as F
from ...core.utils import _get_model_device from ...core.utils import _get_model_device
from functools import partial from functools import partial



class SequenceGenerator: class SequenceGenerator:
""" """
给定一个Seq2SeqDecoder,decode出句子
给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出,
第二个参数为state。SequenceGenerator不会对state进行任何操作。


""" """
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1,
@@ -65,7 +67,8 @@ class SequenceGenerator:
""" """


:param State state: encoder结果的State, 是与Decoder配套是用的 :param State state: encoder结果的State, 是与Decoder配套是用的
:param torch.LongTensor,None tokens: batch_size x length, 开始的token
:param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token
进行生成。
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id :return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id
""" """


@@ -168,6 +171,8 @@ def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_le
_eos_token_id = eos_token_id _eos_token_id = eos_token_id


scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state
if _eos_token_id!=-1: # 防止第一个位置为结束
scores[:, _eos_token_id] = -1e12
next_tokens = scores.argmax(dim=-1, keepdim=True) next_tokens = scores.argmax(dim=-1, keepdim=True)
token_ids = torch.cat([tokens, next_tokens], dim=1) token_ids = torch.cat([tokens, next_tokens], dim=1)
cur_len = token_ids.size(1) cur_len = token_ids.size(1)
@@ -261,6 +266,8 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
_eos_token_id = eos_token_id _eos_token_id = eos_token_id


scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度
if _eos_token_id!=-1: # 防止第一个位置为结束
scores[:, _eos_token_id] = -1e12
vocab_size = scores.size(1) vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."


@@ -322,7 +329,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
max_len_eos_mask = max_lengths.eq(cur_len+1) max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id] eos_scores = scores[:, _eos_token_id]
# 如果已经达到最大长度,就把eos的分数加大 # 如果已经达到最大长度,就把eos的分数加大
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores)
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+float('inf'), eos_scores)


if do_sample: if do_sample:
if temperature > 0 and temperature != 1: if temperature > 0 and temperature != 1:
@@ -356,9 +363,9 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_


# 接下来需要组装下一个batch的结果。 # 接下来需要组装下一个batch的结果。
# 需要选定哪些留下来 # 需要选定哪些留下来
next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)


not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留
@@ -413,7 +420,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
break break


# select the best hypotheses # select the best hypotheses
tgt_len = token_ids.new(batch_size)
tgt_len = token_ids.new_zeros(batch_size)
best = [] best = []


for i, hypotheses in enumerate(hypos): for i, hypotheses in enumerate(hypos):
@@ -425,7 +432,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
best.append(best_hyp) best.append(best_hyp)


# generate target batch # generate target batch
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
decoded[i, :tgt_len[i]] = hypo decoded[i, :tgt_len[i]] = hypo




Loading…
Cancel
Save