@@ -28,7 +28,9 @@ __all__ = [ | |||||
"BertForSentenceMatching", | "BertForSentenceMatching", | ||||
"BertForMultipleChoice", | "BertForMultipleChoice", | ||||
"BertForTokenClassification", | "BertForTokenClassification", | ||||
"BertForQuestionAnswering" | |||||
"BertForQuestionAnswering", | |||||
"TransformerSeq2SeqModel" | |||||
] | ] | ||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
@@ -39,7 +41,7 @@ from .cnn_text_classification import CNNText | |||||
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | ||||
from .snli import ESIM | from .snli import ESIM | ||||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | ||||
from .seq2seq_model import TransformerSeq2SeqModel | |||||
import sys | import sys | ||||
from ..doc_utils import doc_process | from ..doc_utils import doc_process | ||||
doc_process(sys.modules[__name__]) | doc_process(sys.modules[__name__]) |
@@ -1,29 +1,24 @@ | |||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder | |||||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast | |||||
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch | import torch | ||||
from typing import Union, Tuple | from typing import Union, Tuple | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast | |||||
class TransformerSeq2SeqModel(nn.Module): | |||||
class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法 | |||||
def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], | def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], | ||||
tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], | tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], | ||||
num_layers: int = 6, | |||||
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | |||||
num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | |||||
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, | output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, | ||||
bind_input_output_embed=False, | |||||
sos_id=None, eos_id=None): | |||||
bind_input_output_embed=False): | |||||
super().__init__() | super().__init__() | ||||
self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout) | self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout) | ||||
self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed, | self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed, | ||||
bind_input_output_embed) | bind_input_output_embed) | ||||
self.sos_id = sos_id | |||||
self.eos_id = eos_id | |||||
self.num_layers = num_layers | self.num_layers = num_layers | ||||
def forward(self, words, target, seq_len): # todo:这里的target有sos和eos吗,参考一下lstm怎么写的 | |||||
def forward(self, words, target, seq_len): | |||||
encoder_output, encoder_mask = self.encoder(words, seq_len) | encoder_output, encoder_mask = self.encoder(words, seq_len) | ||||
past = TransformerPast(encoder_output, encoder_mask, self.num_layers) | past = TransformerPast(encoder_output, encoder_mask, self.num_layers) | ||||
outputs = self.decoder(target, past, return_attention=False) | outputs = self.decoder(target, past, return_attention=False) |
@@ -49,7 +49,19 @@ __all__ = [ | |||||
"TimestepDropout", | "TimestepDropout", | ||||
'summary' | |||||
'summary', | |||||
"BiLSTMEncoder", | |||||
"TransformerSeq2SeqEncoder", | |||||
"SequenceGenerator", | |||||
"LSTMDecoder", | |||||
"LSTMPast", | |||||
"TransformerSeq2SeqDecoder", | |||||
"TransformerPast", | |||||
"Decoder", | |||||
"Past" | |||||
] | ] | ||||
import sys | import sys | ||||
@@ -7,11 +7,20 @@ __all__ = [ | |||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
"viterbi_decode", | "viterbi_decode", | ||||
"allowed_transitions", | "allowed_transitions", | ||||
"seq2seq_decoder", | |||||
"seq2seq_generator" | |||||
"SequenceGenerator", | |||||
"LSTMDecoder", | |||||
"LSTMPast", | |||||
"TransformerSeq2SeqDecoder", | |||||
"TransformerPast", | |||||
"Decoder", | |||||
"Past", | |||||
] | ] | ||||
from .crf import ConditionalRandomField | from .crf import ConditionalRandomField | ||||
from .crf import allowed_transitions | from .crf import allowed_transitions | ||||
from .mlp import MLP | from .mlp import MLP | ||||
from .utils import viterbi_decode | from .utils import viterbi_decode | ||||
from .seq2seq_generator import SequenceGenerator | |||||
from .seq2seq_decoder import * |
@@ -1,17 +1,46 @@ | |||||
# coding=utf-8 | # coding=utf-8 | ||||
__all__ = [ | |||||
"TransformerPast", | |||||
"LSTMPast", | |||||
"Past", | |||||
"LSTMDecoder", | |||||
"TransformerSeq2SeqDecoder", | |||||
"Decoder" | |||||
] | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
import abc | import abc | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.embeddings import StaticEmbedding | |||||
from ...embeddings import StaticEmbedding | |||||
import numpy as np | import numpy as np | ||||
from typing import Union, Tuple | from typing import Union, Tuple | ||||
from fastNLP.embeddings import get_embeddings | |||||
from fastNLP.modules import LSTM | |||||
from ...embeddings.utils import get_embeddings | |||||
from torch.nn import LayerNorm | from torch.nn import LayerNorm | ||||
import math | import math | ||||
from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ | |||||
get_sinusoid_encoding_table # todo: 应该将position embedding移到core | |||||
# from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ | |||||
# get_sinusoid_encoding_table # todo: 应该将position embedding移到core | |||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||||
''' Sinusoid position encoding table ''' | |||||
def cal_angle(position, hid_idx): | |||||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||||
def get_posi_angle_vec(position): | |||||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||||
if padding_idx is not None: | |||||
# zero vector for padding dimension | |||||
sinusoid_table[padding_idx] = 0. | |||||
return torch.FloatTensor(sinusoid_table) | |||||
class Past: | class Past: | ||||
@@ -82,7 +111,7 @@ class Decoder(nn.Module): | |||||
""" | """ | ||||
raise NotImplemented | raise NotImplemented | ||||
def decode_one(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||||
def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||||
""" | """ | ||||
当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了 | 当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了 | ||||
解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态 | 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态 | ||||
@@ -100,6 +129,7 @@ class DecoderMultiheadAttention(nn.Module): | |||||
""" | """ | ||||
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | ||||
super(DecoderMultiheadAttention, self).__init__() | |||||
self.d_model = d_model | self.d_model = d_model | ||||
self.n_head = n_head | self.n_head = n_head | ||||
self.dropout = dropout | self.dropout = dropout | ||||
@@ -157,11 +187,11 @@ class DecoderMultiheadAttention(nn.Module): | |||||
past.encoder_key[self.layer_idx] = k | past.encoder_key[self.layer_idx] = k | ||||
past.encoder_value[self.layer_idx] = v | past.encoder_value[self.layer_idx] = v | ||||
if inference and not is_encoder_attn: | if inference and not is_encoder_attn: | ||||
past.decoder_prev_key[self.layer_idx] = prev_k | |||||
past.decoder_prev_value[self.layer_idx] = prev_v | |||||
past.decoder_prev_key[self.layer_idx] = prev_k if prev_k is not None else k | |||||
past.decoder_prev_value[self.layer_idx] = prev_v if prev_v is not None else v | |||||
batch_size, q_len, d_model = query.size() | batch_size, q_len, d_model = query.size() | ||||
k_len, v_len = key.size(1), value.size(1) | |||||
k_len, v_len = k.size(1), v.size(1) | |||||
q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim) | q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim) | ||||
k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim) | k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim) | ||||
v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim) | v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim) | ||||
@@ -172,8 +202,8 @@ class DecoderMultiheadAttention(nn.Module): | |||||
if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len | if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len | ||||
mask = mask[:, None, :, None] | mask = mask[:, None, :, None] | ||||
else: # (1, seq_len, seq_len) | else: # (1, seq_len, seq_len) | ||||
mask = mask[...:None] | |||||
_mask = mask | |||||
mask = mask[..., None] | |||||
_mask = ~mask.bool() | |||||
attn_weights = attn_weights.masked_fill(_mask, float('-inf')) | attn_weights = attn_weights.masked_fill(_mask, float('-inf')) | ||||
@@ -181,21 +211,22 @@ class DecoderMultiheadAttention(nn.Module): | |||||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | ||||
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | ||||
output = output.view(batch_size, q_len, -1) | |||||
output = output.reshape(batch_size, q_len, -1) | |||||
output = self.out_proj(output) # batch,q_len,dim | output = self.out_proj(output) # batch,q_len,dim | ||||
return output, attn_weights | return output, attn_weights | ||||
def reset_parameters(self): | def reset_parameters(self): | ||||
nn.init.xavier_uniform_(self.q_proj) | |||||
nn.init.xavier_uniform_(self.k_proj) | |||||
nn.init.xavier_uniform_(self.v_proj) | |||||
nn.init.xavier_uniform_(self.out_proj) | |||||
nn.init.xavier_uniform_(self.q_proj.weight) | |||||
nn.init.xavier_uniform_(self.k_proj.weight) | |||||
nn.init.xavier_uniform_(self.v_proj.weight) | |||||
nn.init.xavier_uniform_(self.out_proj.weight) | |||||
class TransformerSeq2SeqDecoderLayer(nn.Module): | class TransformerSeq2SeqDecoderLayer(nn.Module): | ||||
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | ||||
layer_idx: int = None): | layer_idx: int = None): | ||||
super(TransformerSeq2SeqDecoderLayer, self).__init__() | |||||
self.d_model = d_model | self.d_model = d_model | ||||
self.n_head = n_head | self.n_head = n_head | ||||
self.dim_ff = dim_ff | self.dim_ff = dim_ff | ||||
@@ -313,10 +344,10 @@ class TransformerSeq2SeqDecoder(Decoder): | |||||
if isinstance(self.token_embed, StaticEmbedding): | if isinstance(self.token_embed, StaticEmbedding): | ||||
for i in self.token_embed.words_to_words: | for i in self.token_embed.words_to_words: | ||||
assert i == self.token_embed.words_to_words[i], "The index does not match." | assert i == self.token_embed.words_to_words[i], "The index does not match." | ||||
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1)) | |||||
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1), requires_grad=True) | |||||
else: | else: | ||||
if isinstance(output_embed, nn.Embedding): | if isinstance(output_embed, nn.Embedding): | ||||
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1)) | |||||
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1), requires_grad=True) | |||||
else: | else: | ||||
self.output_embed = output_embed.transpose(0, 1) | self.output_embed = output_embed.transpose(0, 1) | ||||
self.output_hidden_size = self.output_embed.size(0) | self.output_hidden_size = self.output_embed.size(0) | ||||
@@ -326,7 +357,8 @@ class TransformerSeq2SeqDecoder(Decoder): | |||||
def forward(self, tokens, past, return_attention=False, inference=False): | def forward(self, tokens, past, return_attention=False, inference=False): | ||||
""" | """ | ||||
:param tokens: torch.LongTensor, tokens: batch_size x decode_len | |||||
:param tokens: torch.LongTensor, tokens: batch_size , decode_len | |||||
:param self_attn_mask: 在inference的时候不需要,而在train的时候,因为训练的时候交叉熵会自动屏蔽掉padding的地方,所以也不需要 | |||||
:param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算 | :param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算 | ||||
:param return_attention: | :param return_attention: | ||||
:param inference: 是否在inference阶段 | :param inference: 是否在inference阶段 | ||||
@@ -335,13 +367,16 @@ class TransformerSeq2SeqDecoder(Decoder): | |||||
assert past is not None | assert past is not None | ||||
batch_size, decode_len = tokens.size() | batch_size, decode_len = tokens.size() | ||||
device = tokens.device | device = tokens.device | ||||
pos_idx = torch.arange(1, decode_len + 1).unsqueeze(0).long() | |||||
if not inference: | if not inference: | ||||
self_attn_mask = self._get_triangle_mask(decode_len) | self_attn_mask = self._get_triangle_mask(decode_len) | ||||
self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq | self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq | ||||
else: | else: | ||||
self_attn_mask = None | self_attn_mask = None | ||||
tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim | tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim | ||||
pos = self.pos_embed(tokens) # bs,decode_len,embed_dim | |||||
pos = self.pos_embed(pos_idx) # 1,decode_len,embed_dim | |||||
tokens = pos + tokens | tokens = pos + tokens | ||||
if inference: | if inference: | ||||
tokens = tokens[:, -1:, :] | tokens = tokens[:, -1:, :] | ||||
@@ -358,7 +393,7 @@ class TransformerSeq2SeqDecoder(Decoder): | |||||
return output | return output | ||||
@torch.no_grad() | @torch.no_grad() | ||||
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||||
def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||||
""" | """ | ||||
# todo: 是否不需要return past? 因为past已经被改变了,不需要显式return? | # todo: 是否不需要return past? 因为past已经被改变了,不需要显式return? | ||||
:param tokens: torch.LongTensor (batch_size,1) | :param tokens: torch.LongTensor (batch_size,1) | ||||
@@ -370,7 +405,7 @@ class TransformerSeq2SeqDecoder(Decoder): | |||||
def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast: | def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast: | ||||
past.reorder_past(indices) | past.reorder_past(indices) | ||||
return past # todo : 其实可以不要这个的 | |||||
return past | |||||
def _get_triangle_mask(self, max_seq_len): | def _get_triangle_mask(self, max_seq_len): | ||||
tensor = torch.ones(max_seq_len, max_seq_len) | tensor = torch.ones(max_seq_len, max_seq_len) | ||||
@@ -409,6 +444,27 @@ class LSTMPast(Past): | |||||
return tensor[0].size(0) | return tensor[0].size(0) | ||||
return None | return None | ||||
def _reorder_past(self, state, indices, dim=0): | |||||
if type(state) == torch.Tensor: | |||||
state = state.index_select(index=indices, dim=dim) | |||||
elif type(state) == tuple: | |||||
tmp_list = [] | |||||
for i in range(len(state)): | |||||
assert state[i] is not None | |||||
tmp_list.append(state[i].index_select(index=indices, dim=dim)) | |||||
state = tuple(tmp_list) | |||||
else: | |||||
raise ValueError('State does not support other format') | |||||
return state | |||||
def reorder_past(self, indices: torch.LongTensor): | |||||
self.encode_outputs = self._reorder_past(self.encode_outputs, indices) | |||||
self.encode_mask = self._reorder_past(self.encode_mask, indices) | |||||
self.hx = self._reorder_past(self.hx, indices, 1) | |||||
if self.attn_states is not None: | |||||
self.attn_states = self._reorder_past(self.attn_states, indices) | |||||
@property | @property | ||||
def hx(self): | def hx(self): | ||||
return self._hx | return self._hx | ||||
@@ -493,7 +549,7 @@ class AttentionLayer(nn.Module): | |||||
class LSTMDecoder(Decoder): | class LSTMDecoder(Decoder): | ||||
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers, input_size, | |||||
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers=3, input_size=400, | |||||
hidden_size=None, dropout=0, | hidden_size=None, dropout=0, | ||||
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, | output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, | ||||
bind_input_output_embed=False, | bind_input_output_embed=False, | ||||
@@ -612,6 +668,7 @@ class LSTMDecoder(Decoder): | |||||
for i in range(tokens.size(1)): | for i in range(tokens.size(1)): | ||||
input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h' | input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h' | ||||
# bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size) | # bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size) | ||||
_, (hidden, cell) = self.lstm(input, hx=past.hx) | _, (hidden, cell) = self.lstm(input, hx=past.hx) | ||||
past.hx = (hidden, cell) | past.hx = (hidden, cell) | ||||
if self.attention_layer is not None: | if self.attention_layer is not None: | ||||
@@ -633,7 +690,7 @@ class LSTMDecoder(Decoder): | |||||
return feats | return feats | ||||
@torch.no_grad() | @torch.no_grad() | ||||
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||||
def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||||
""" | """ | ||||
给定上一个位置的输出,决定当前位置的输出。 | 给定上一个位置的输出,决定当前位置的输出。 | ||||
:param torch.LongTensor tokens: batch_size x seq_len | :param torch.LongTensor tokens: batch_size x seq_len | ||||
@@ -653,13 +710,5 @@ class LSTMDecoder(Decoder): | |||||
:param LSTMPast past: 保存的过去的状态 | :param LSTMPast past: 保存的过去的状态 | ||||
:return: | :return: | ||||
""" | """ | ||||
encode_outputs = past.encode_outputs.index_select(index=indices, dim=0) | |||||
encoder_mask = past.encode_mask.index_select(index=indices, dim=0) | |||||
hx = (past.hx[0].index_select(index=indices, dim=1), | |||||
past.hx[1].index_select(index=indices, dim=1)) | |||||
if past.attn_states is not None: | |||||
past.attn_states = past.attn_states.index_select(index=indices, dim=0) | |||||
past.encode_mask = encoder_mask | |||||
past.encode_outputs = encode_outputs | |||||
past.hx = hx | |||||
past.reorder_past(indices) | |||||
return past | return past |
@@ -1,7 +1,10 @@ | |||||
__all__ = [ | |||||
"SequenceGenerator" | |||||
] | |||||
import torch | import torch | ||||
from .seq2seq_decoder import Decoder | from .seq2seq_decoder import Decoder | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.core.utils import _get_model_device | |||||
from ...core.utils import _get_model_device | |||||
from functools import partial | from functools import partial | ||||
@@ -130,8 +133,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||||
else: | else: | ||||
_eos_token_id = eos_token_id | _eos_token_id = eos_token_id | ||||
for i in range(tokens.size(1) - 1): | |||||
scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||||
for i in range(tokens.size(1)): | |||||
scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||||
token_ids = tokens.clone() | token_ids = tokens.clone() | ||||
cur_len = token_ids.size(1) | cur_len = token_ids.size(1) | ||||
@@ -139,7 +142,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||||
# tokens = tokens[:, -1:] | # tokens = tokens[:, -1:] | ||||
while cur_len < max_length: | while cur_len < max_length: | ||||
scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past | |||||
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past | |||||
if repetition_penalty != 1.0: | if repetition_penalty != 1.0: | ||||
token_scores = scores.gather(dim=1, index=token_ids) | token_scores = scores.gather(dim=1, index=token_ids) | ||||
@@ -153,7 +156,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||||
eos_mask = scores.new_ones(scores.size(1)) | eos_mask = scores.new_ones(scores.size(1)) | ||||
eos_mask[eos_token_id] = 0 | eos_mask[eos_token_id] = 0 | ||||
eos_mask = eos_mask.unsqueeze(0).eq(1) | eos_mask = eos_mask.unsqueeze(0).eq(1) | ||||
scores = scores.masked_scatter(eos_mask, token_scores) | |||||
scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 | |||||
if do_sample: | if do_sample: | ||||
if temperature > 0 and temperature != 1: | if temperature > 0 and temperature != 1: | ||||
@@ -167,7 +170,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||||
else: | else: | ||||
next_tokens = torch.argmax(scores, dim=-1) # batch_size | next_tokens = torch.argmax(scores, dim=-1) # batch_size | ||||
next_tokens = next_tokens.masked_fill(dones, 0) | |||||
next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding | |||||
tokens = next_tokens.unsqueeze(1) | tokens = next_tokens.unsqueeze(1) | ||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len | token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len | ||||
@@ -181,7 +184,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||||
if eos_token_id is not None: | if eos_token_id is not None: | ||||
if cur_len == max_length: | if cur_len == max_length: | ||||
token_ids[:, -1].masked_fill_(dones, eos_token_id) | |||||
token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos | |||||
return token_ids | return token_ids | ||||
@@ -206,9 +209,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | ||||
for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode | for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode | ||||
scores, past = decoder.decode_one(tokens[:, :i + 1], | |||||
past) # (batch_size, vocab_size), Past | |||||
scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 | |||||
scores, past = decoder.decode(tokens[:, :i + 1], | |||||
past) # (batch_size, vocab_size), Past | |||||
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 | |||||
vocab_size = scores.size(1) | vocab_size = scores.size(1) | ||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | ||||
@@ -224,7 +227,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
indices = torch.arange(batch_size, dtype=torch.long).to(device) | indices = torch.arange(batch_size, dtype=torch.long).to(device) | ||||
indices = indices.repeat_interleave(num_beams) | indices = indices.repeat_interleave(num_beams) | ||||
past = decoder.reorder_past(indices, past) | |||||
decoder.reorder_past(indices, past) | |||||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | ||||
# 记录生成好的token (batch_size', cur_len) | # 记录生成好的token (batch_size', cur_len) | ||||
@@ -240,11 +243,11 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
hypos = [ | hypos = [ | ||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) | BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) | ||||
] | ] | ||||
# 0,num_beams, 2*num_beams | |||||
# 0,num_beams, 2*num_beams, ... | |||||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | ||||
while cur_len < max_length: | while cur_len < max_length: | ||||
scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past | |||||
scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past | |||||
if repetition_penalty != 1.0: | if repetition_penalty != 1.0: | ||||
token_scores = scores.gather(dim=1, index=token_ids) | token_scores = scores.gather(dim=1, index=token_ids) | ||||
@@ -298,8 +301,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
beam_scores = _next_scores.view(-1) | beam_scores = _next_scores.view(-1) | ||||
# 更改past状态, 重组token_ids | # 更改past状态, 重组token_ids | ||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) | |||||
past = decoder.reorder_past(reorder_inds, past) | |||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||||
decoder.reorder_past(reorder_inds, past) | |||||
flag = True | flag = True | ||||
if cur_len + 1 == max_length: | if cur_len + 1 == max_length: | ||||
@@ -445,7 +448,7 @@ if __name__ == '__main__': | |||||
super().__init__() | super().__init__() | ||||
self.num_words = num_words | self.num_words = num_words | ||||
def decode_one(self, tokens, past): | |||||
def decode(self, tokens, past): | |||||
batch_size = tokens.size(0) | batch_size = tokens.size(0) | ||||
return torch.randn(batch_size, self.num_words), past | return torch.randn(batch_size, self.num_words), past | ||||
@@ -30,6 +30,9 @@ __all__ = [ | |||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
"BiAttention", | "BiAttention", | ||||
"SelfAttention", | "SelfAttention", | ||||
"BiLSTMEncoder", | |||||
"TransformerSeq2SeqEncoder" | |||||
] | ] | ||||
from .attention import MultiHeadAttention, BiAttention, SelfAttention | from .attention import MultiHeadAttention, BiAttention, SelfAttention | ||||
@@ -41,3 +44,5 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo | |||||
from .star_transformer import StarTransformer | from .star_transformer import StarTransformer | ||||
from .transformer import TransformerEncoder | from .transformer import TransformerEncoder | ||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | from .variational_rnn import VarRNN, VarLSTM, VarGRU | ||||
from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder |
@@ -1,7 +1,12 @@ | |||||
__all__ = [ | |||||
"TransformerSeq2SeqEncoder", | |||||
"BiLSTMEncoder" | |||||
] | |||||
from torch import nn | from torch import nn | ||||
import torch | import torch | ||||
from fastNLP.modules import LSTM | |||||
from fastNLP import seq_len_to_mask | |||||
from ...modules.encoder import LSTM | |||||
from ...core.utils import seq_len_to_mask | |||||
from torch.nn import TransformerEncoder | from torch.nn import TransformerEncoder | ||||
from typing import Union, Tuple | from typing import Union, Tuple | ||||
import numpy as np | import numpy as np | ||||
@@ -0,0 +1 @@ | |||||
@@ -0,0 +1,65 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP.core.utils import seq_len_to_mask | |||||
class TestTransformerSeq2SeqDecoder(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||||
encoder = TransformerSeq2SeqEncoder(embed) | |||||
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True) | |||||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
encoder_outputs, mask = encoder(src_words_idx, src_seq_len) | |||||
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask) | |||||
decoder_outputs = decoder(tgt_words_idx, past) | |||||
print(decoder_outputs) | |||||
print(mask) | |||||
self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) | |||||
def test_decode(self): | |||||
pass # todo | |||||
class TestLSTMDecoder(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||||
encoder = BiLSTMEncoder(embed) | |||||
decoder = LSTMDecoder(embed, bind_input_output_embed=True) | |||||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
words, hx = encoder(src_words_idx, src_seq_len) | |||||
encode_mask = seq_len_to_mask(src_seq_len) | |||||
hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||||
cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||||
past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell)) | |||||
decoder_outputs = decoder(tgt_words_idx, past) | |||||
print(decoder_outputs) | |||||
print(encode_mask) | |||||
self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) | |||||
def test_decode(self): | |||||
pass # todo |
@@ -0,0 +1,52 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP.core.utils import seq_len_to_mask | |||||
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator | |||||
class TestSequenceGenerator(unittest.TestCase): | |||||
def test_case_for_transformer(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||||
encoder = TransformerSeq2SeqEncoder(embed, num_layers=6) | |||||
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True, num_layers=6) | |||||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
encoder_outputs, mask = encoder(src_words_idx, src_seq_len) | |||||
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask, num_decoder_layer=6) | |||||
generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2) | |||||
tokens_ids = generator.generate(past=past) | |||||
print(tokens_ids) | |||||
def test_case_for_lstm(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||||
encoder = BiLSTMEncoder(embed) | |||||
decoder = LSTMDecoder(embed, bind_input_output_embed=True) | |||||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
words, hx = encoder(src_words_idx, src_seq_len) | |||||
encode_mask = seq_len_to_mask(src_seq_len) | |||||
hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||||
cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||||
past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell)) | |||||
generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2) | |||||
tokens_ids = generator.generate(past=past) | |||||
print(tokens_ids) |
@@ -0,0 +1,35 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
class TestTransformerSeq2SeqEncoder(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||||
encoder = TransformerSeq2SeqEncoder(embed) | |||||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||||
seq_len = torch.LongTensor([3]) | |||||
outputs, mask = encoder(words_idx, seq_len) | |||||
print(outputs) | |||||
print(mask) | |||||
self.assertEqual(outputs.size(), (1, 3, 512)) | |||||
class TestBiLSTMEncoder(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=300) | |||||
encoder = BiLSTMEncoder(embed, hidden_size=300) | |||||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||||
seq_len = torch.LongTensor([3]) | |||||
outputs, hx = encoder(words_idx, seq_len) | |||||
# print(outputs) | |||||
# print(hx) | |||||
self.assertEqual(outputs.size(), (1, 3, 300)) |