@@ -28,7 +28,9 @@ __all__ = [ | |||
"BertForSentenceMatching", | |||
"BertForMultipleChoice", | |||
"BertForTokenClassification", | |||
"BertForQuestionAnswering" | |||
"BertForQuestionAnswering", | |||
"TransformerSeq2SeqModel" | |||
] | |||
from .base_model import BaseModel | |||
@@ -39,7 +41,7 @@ from .cnn_text_classification import CNNText | |||
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
from .snli import ESIM | |||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
from .seq2seq_model import TransformerSeq2SeqModel | |||
import sys | |||
from ..doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |
@@ -1,29 +1,24 @@ | |||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder | |||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast | |||
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator | |||
import torch.nn as nn | |||
import torch | |||
from typing import Union, Tuple | |||
import numpy as np | |||
from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast | |||
class TransformerSeq2SeqModel(nn.Module): | |||
class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法 | |||
def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], | |||
tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], | |||
num_layers: int = 6, | |||
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | |||
num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | |||
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, | |||
bind_input_output_embed=False, | |||
sos_id=None, eos_id=None): | |||
bind_input_output_embed=False): | |||
super().__init__() | |||
self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout) | |||
self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed, | |||
bind_input_output_embed) | |||
self.sos_id = sos_id | |||
self.eos_id = eos_id | |||
self.num_layers = num_layers | |||
def forward(self, words, target, seq_len): # todo:这里的target有sos和eos吗,参考一下lstm怎么写的 | |||
def forward(self, words, target, seq_len): | |||
encoder_output, encoder_mask = self.encoder(words, seq_len) | |||
past = TransformerPast(encoder_output, encoder_mask, self.num_layers) | |||
outputs = self.decoder(target, past, return_attention=False) |
@@ -49,7 +49,19 @@ __all__ = [ | |||
"TimestepDropout", | |||
'summary' | |||
'summary', | |||
"BiLSTMEncoder", | |||
"TransformerSeq2SeqEncoder", | |||
"SequenceGenerator", | |||
"LSTMDecoder", | |||
"LSTMPast", | |||
"TransformerSeq2SeqDecoder", | |||
"TransformerPast", | |||
"Decoder", | |||
"Past" | |||
] | |||
import sys | |||
@@ -7,11 +7,20 @@ __all__ = [ | |||
"ConditionalRandomField", | |||
"viterbi_decode", | |||
"allowed_transitions", | |||
"seq2seq_decoder", | |||
"seq2seq_generator" | |||
"SequenceGenerator", | |||
"LSTMDecoder", | |||
"LSTMPast", | |||
"TransformerSeq2SeqDecoder", | |||
"TransformerPast", | |||
"Decoder", | |||
"Past", | |||
] | |||
from .crf import ConditionalRandomField | |||
from .crf import allowed_transitions | |||
from .mlp import MLP | |||
from .utils import viterbi_decode | |||
from .seq2seq_generator import SequenceGenerator | |||
from .seq2seq_decoder import * |
@@ -1,17 +1,46 @@ | |||
# coding=utf-8 | |||
__all__ = [ | |||
"TransformerPast", | |||
"LSTMPast", | |||
"Past", | |||
"LSTMDecoder", | |||
"TransformerSeq2SeqDecoder", | |||
"Decoder" | |||
] | |||
import torch | |||
from torch import nn | |||
import abc | |||
import torch.nn.functional as F | |||
from fastNLP.embeddings import StaticEmbedding | |||
from ...embeddings import StaticEmbedding | |||
import numpy as np | |||
from typing import Union, Tuple | |||
from fastNLP.embeddings import get_embeddings | |||
from fastNLP.modules import LSTM | |||
from ...embeddings.utils import get_embeddings | |||
from torch.nn import LayerNorm | |||
import math | |||
from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ | |||
get_sinusoid_encoding_table # todo: 应该将position embedding移到core | |||
# from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ | |||
# get_sinusoid_encoding_table # todo: 应该将position embedding移到core | |||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
''' Sinusoid position encoding table ''' | |||
def cal_angle(position, hid_idx): | |||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||
def get_posi_angle_vec(position): | |||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||
if padding_idx is not None: | |||
# zero vector for padding dimension | |||
sinusoid_table[padding_idx] = 0. | |||
return torch.FloatTensor(sinusoid_table) | |||
class Past: | |||
@@ -82,7 +111,7 @@ class Decoder(nn.Module): | |||
""" | |||
raise NotImplemented | |||
def decode_one(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||
def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||
""" | |||
当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了 | |||
解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态 | |||
@@ -100,6 +129,7 @@ class DecoderMultiheadAttention(nn.Module): | |||
""" | |||
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||
super(DecoderMultiheadAttention, self).__init__() | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dropout = dropout | |||
@@ -157,11 +187,11 @@ class DecoderMultiheadAttention(nn.Module): | |||
past.encoder_key[self.layer_idx] = k | |||
past.encoder_value[self.layer_idx] = v | |||
if inference and not is_encoder_attn: | |||
past.decoder_prev_key[self.layer_idx] = prev_k | |||
past.decoder_prev_value[self.layer_idx] = prev_v | |||
past.decoder_prev_key[self.layer_idx] = prev_k if prev_k is not None else k | |||
past.decoder_prev_value[self.layer_idx] = prev_v if prev_v is not None else v | |||
batch_size, q_len, d_model = query.size() | |||
k_len, v_len = key.size(1), value.size(1) | |||
k_len, v_len = k.size(1), v.size(1) | |||
q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim) | |||
k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim) | |||
v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim) | |||
@@ -172,8 +202,8 @@ class DecoderMultiheadAttention(nn.Module): | |||
if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len | |||
mask = mask[:, None, :, None] | |||
else: # (1, seq_len, seq_len) | |||
mask = mask[...:None] | |||
_mask = mask | |||
mask = mask[..., None] | |||
_mask = ~mask.bool() | |||
attn_weights = attn_weights.masked_fill(_mask, float('-inf')) | |||
@@ -181,21 +211,22 @@ class DecoderMultiheadAttention(nn.Module): | |||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |||
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | |||
output = output.view(batch_size, q_len, -1) | |||
output = output.reshape(batch_size, q_len, -1) | |||
output = self.out_proj(output) # batch,q_len,dim | |||
return output, attn_weights | |||
def reset_parameters(self): | |||
nn.init.xavier_uniform_(self.q_proj) | |||
nn.init.xavier_uniform_(self.k_proj) | |||
nn.init.xavier_uniform_(self.v_proj) | |||
nn.init.xavier_uniform_(self.out_proj) | |||
nn.init.xavier_uniform_(self.q_proj.weight) | |||
nn.init.xavier_uniform_(self.k_proj.weight) | |||
nn.init.xavier_uniform_(self.v_proj.weight) | |||
nn.init.xavier_uniform_(self.out_proj.weight) | |||
class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | |||
layer_idx: int = None): | |||
super(TransformerSeq2SeqDecoderLayer, self).__init__() | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
@@ -313,10 +344,10 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
if isinstance(self.token_embed, StaticEmbedding): | |||
for i in self.token_embed.words_to_words: | |||
assert i == self.token_embed.words_to_words[i], "The index does not match." | |||
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1)) | |||
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1), requires_grad=True) | |||
else: | |||
if isinstance(output_embed, nn.Embedding): | |||
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1)) | |||
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1), requires_grad=True) | |||
else: | |||
self.output_embed = output_embed.transpose(0, 1) | |||
self.output_hidden_size = self.output_embed.size(0) | |||
@@ -326,7 +357,8 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
def forward(self, tokens, past, return_attention=False, inference=False): | |||
""" | |||
:param tokens: torch.LongTensor, tokens: batch_size x decode_len | |||
:param tokens: torch.LongTensor, tokens: batch_size , decode_len | |||
:param self_attn_mask: 在inference的时候不需要,而在train的时候,因为训练的时候交叉熵会自动屏蔽掉padding的地方,所以也不需要 | |||
:param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算 | |||
:param return_attention: | |||
:param inference: 是否在inference阶段 | |||
@@ -335,13 +367,16 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
assert past is not None | |||
batch_size, decode_len = tokens.size() | |||
device = tokens.device | |||
pos_idx = torch.arange(1, decode_len + 1).unsqueeze(0).long() | |||
if not inference: | |||
self_attn_mask = self._get_triangle_mask(decode_len) | |||
self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq | |||
else: | |||
self_attn_mask = None | |||
tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim | |||
pos = self.pos_embed(tokens) # bs,decode_len,embed_dim | |||
pos = self.pos_embed(pos_idx) # 1,decode_len,embed_dim | |||
tokens = pos + tokens | |||
if inference: | |||
tokens = tokens[:, -1:, :] | |||
@@ -358,7 +393,7 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
return output | |||
@torch.no_grad() | |||
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||
def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||
""" | |||
# todo: 是否不需要return past? 因为past已经被改变了,不需要显式return? | |||
:param tokens: torch.LongTensor (batch_size,1) | |||
@@ -370,7 +405,7 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast: | |||
past.reorder_past(indices) | |||
return past # todo : 其实可以不要这个的 | |||
return past | |||
def _get_triangle_mask(self, max_seq_len): | |||
tensor = torch.ones(max_seq_len, max_seq_len) | |||
@@ -409,6 +444,27 @@ class LSTMPast(Past): | |||
return tensor[0].size(0) | |||
return None | |||
def _reorder_past(self, state, indices, dim=0): | |||
if type(state) == torch.Tensor: | |||
state = state.index_select(index=indices, dim=dim) | |||
elif type(state) == tuple: | |||
tmp_list = [] | |||
for i in range(len(state)): | |||
assert state[i] is not None | |||
tmp_list.append(state[i].index_select(index=indices, dim=dim)) | |||
state = tuple(tmp_list) | |||
else: | |||
raise ValueError('State does not support other format') | |||
return state | |||
def reorder_past(self, indices: torch.LongTensor): | |||
self.encode_outputs = self._reorder_past(self.encode_outputs, indices) | |||
self.encode_mask = self._reorder_past(self.encode_mask, indices) | |||
self.hx = self._reorder_past(self.hx, indices, 1) | |||
if self.attn_states is not None: | |||
self.attn_states = self._reorder_past(self.attn_states, indices) | |||
@property | |||
def hx(self): | |||
return self._hx | |||
@@ -493,7 +549,7 @@ class AttentionLayer(nn.Module): | |||
class LSTMDecoder(Decoder): | |||
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers, input_size, | |||
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers=3, input_size=400, | |||
hidden_size=None, dropout=0, | |||
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, | |||
bind_input_output_embed=False, | |||
@@ -612,6 +668,7 @@ class LSTMDecoder(Decoder): | |||
for i in range(tokens.size(1)): | |||
input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h' | |||
# bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size) | |||
_, (hidden, cell) = self.lstm(input, hx=past.hx) | |||
past.hx = (hidden, cell) | |||
if self.attention_layer is not None: | |||
@@ -633,7 +690,7 @@ class LSTMDecoder(Decoder): | |||
return feats | |||
@torch.no_grad() | |||
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||
def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||
""" | |||
给定上一个位置的输出,决定当前位置的输出。 | |||
:param torch.LongTensor tokens: batch_size x seq_len | |||
@@ -653,13 +710,5 @@ class LSTMDecoder(Decoder): | |||
:param LSTMPast past: 保存的过去的状态 | |||
:return: | |||
""" | |||
encode_outputs = past.encode_outputs.index_select(index=indices, dim=0) | |||
encoder_mask = past.encode_mask.index_select(index=indices, dim=0) | |||
hx = (past.hx[0].index_select(index=indices, dim=1), | |||
past.hx[1].index_select(index=indices, dim=1)) | |||
if past.attn_states is not None: | |||
past.attn_states = past.attn_states.index_select(index=indices, dim=0) | |||
past.encode_mask = encoder_mask | |||
past.encode_outputs = encode_outputs | |||
past.hx = hx | |||
past.reorder_past(indices) | |||
return past |
@@ -1,7 +1,10 @@ | |||
__all__ = [ | |||
"SequenceGenerator" | |||
] | |||
import torch | |||
from .seq2seq_decoder import Decoder | |||
import torch.nn.functional as F | |||
from fastNLP.core.utils import _get_model_device | |||
from ...core.utils import _get_model_device | |||
from functools import partial | |||
@@ -130,8 +133,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
else: | |||
_eos_token_id = eos_token_id | |||
for i in range(tokens.size(1) - 1): | |||
scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||
for i in range(tokens.size(1)): | |||
scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||
token_ids = tokens.clone() | |||
cur_len = token_ids.size(1) | |||
@@ -139,7 +142,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
# tokens = tokens[:, -1:] | |||
while cur_len < max_length: | |||
scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past | |||
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past | |||
if repetition_penalty != 1.0: | |||
token_scores = scores.gather(dim=1, index=token_ids) | |||
@@ -153,7 +156,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
eos_mask = scores.new_ones(scores.size(1)) | |||
eos_mask[eos_token_id] = 0 | |||
eos_mask = eos_mask.unsqueeze(0).eq(1) | |||
scores = scores.masked_scatter(eos_mask, token_scores) | |||
scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 | |||
if do_sample: | |||
if temperature > 0 and temperature != 1: | |||
@@ -167,7 +170,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
else: | |||
next_tokens = torch.argmax(scores, dim=-1) # batch_size | |||
next_tokens = next_tokens.masked_fill(dones, 0) | |||
next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding | |||
tokens = next_tokens.unsqueeze(1) | |||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len | |||
@@ -181,7 +184,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
if eos_token_id is not None: | |||
if cur_len == max_length: | |||
token_ids[:, -1].masked_fill_(dones, eos_token_id) | |||
token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos | |||
return token_ids | |||
@@ -206,9 +209,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||
for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode | |||
scores, past = decoder.decode_one(tokens[:, :i + 1], | |||
past) # (batch_size, vocab_size), Past | |||
scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 | |||
scores, past = decoder.decode(tokens[:, :i + 1], | |||
past) # (batch_size, vocab_size), Past | |||
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 | |||
vocab_size = scores.size(1) | |||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | |||
@@ -224,7 +227,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
indices = torch.arange(batch_size, dtype=torch.long).to(device) | |||
indices = indices.repeat_interleave(num_beams) | |||
past = decoder.reorder_past(indices, past) | |||
decoder.reorder_past(indices, past) | |||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | |||
# 记录生成好的token (batch_size', cur_len) | |||
@@ -240,11 +243,11 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
hypos = [ | |||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) | |||
] | |||
# 0,num_beams, 2*num_beams | |||
# 0,num_beams, 2*num_beams, ... | |||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | |||
while cur_len < max_length: | |||
scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past | |||
scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past | |||
if repetition_penalty != 1.0: | |||
token_scores = scores.gather(dim=1, index=token_ids) | |||
@@ -298,8 +301,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
beam_scores = _next_scores.view(-1) | |||
# 更改past状态, 重组token_ids | |||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) | |||
past = decoder.reorder_past(reorder_inds, past) | |||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||
decoder.reorder_past(reorder_inds, past) | |||
flag = True | |||
if cur_len + 1 == max_length: | |||
@@ -445,7 +448,7 @@ if __name__ == '__main__': | |||
super().__init__() | |||
self.num_words = num_words | |||
def decode_one(self, tokens, past): | |||
def decode(self, tokens, past): | |||
batch_size = tokens.size(0) | |||
return torch.randn(batch_size, self.num_words), past | |||
@@ -30,6 +30,9 @@ __all__ = [ | |||
"MultiHeadAttention", | |||
"BiAttention", | |||
"SelfAttention", | |||
"BiLSTMEncoder", | |||
"TransformerSeq2SeqEncoder" | |||
] | |||
from .attention import MultiHeadAttention, BiAttention, SelfAttention | |||
@@ -41,3 +44,5 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo | |||
from .star_transformer import StarTransformer | |||
from .transformer import TransformerEncoder | |||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder |
@@ -1,7 +1,12 @@ | |||
__all__ = [ | |||
"TransformerSeq2SeqEncoder", | |||
"BiLSTMEncoder" | |||
] | |||
from torch import nn | |||
import torch | |||
from fastNLP.modules import LSTM | |||
from fastNLP import seq_len_to_mask | |||
from ...modules.encoder import LSTM | |||
from ...core.utils import seq_len_to_mask | |||
from torch.nn import TransformerEncoder | |||
from typing import Union, Tuple | |||
import numpy as np | |||
@@ -0,0 +1 @@ | |||
@@ -0,0 +1,65 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.core.utils import seq_len_to_mask | |||
class TestTransformerSeq2SeqDecoder(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||
encoder = TransformerSeq2SeqEncoder(embed) | |||
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True) | |||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
encoder_outputs, mask = encoder(src_words_idx, src_seq_len) | |||
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask) | |||
decoder_outputs = decoder(tgt_words_idx, past) | |||
print(decoder_outputs) | |||
print(mask) | |||
self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) | |||
def test_decode(self): | |||
pass # todo | |||
class TestLSTMDecoder(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||
encoder = BiLSTMEncoder(embed) | |||
decoder = LSTMDecoder(embed, bind_input_output_embed=True) | |||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
words, hx = encoder(src_words_idx, src_seq_len) | |||
encode_mask = seq_len_to_mask(src_seq_len) | |||
hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||
cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||
past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell)) | |||
decoder_outputs = decoder(tgt_words_idx, past) | |||
print(decoder_outputs) | |||
print(encode_mask) | |||
self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) | |||
def test_decode(self): | |||
pass # todo |
@@ -0,0 +1,52 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.core.utils import seq_len_to_mask | |||
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator | |||
class TestSequenceGenerator(unittest.TestCase): | |||
def test_case_for_transformer(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||
encoder = TransformerSeq2SeqEncoder(embed, num_layers=6) | |||
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True, num_layers=6) | |||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
encoder_outputs, mask = encoder(src_words_idx, src_seq_len) | |||
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask, num_decoder_layer=6) | |||
generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2) | |||
tokens_ids = generator.generate(past=past) | |||
print(tokens_ids) | |||
def test_case_for_lstm(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||
encoder = BiLSTMEncoder(embed) | |||
decoder = LSTMDecoder(embed, bind_input_output_embed=True) | |||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
words, hx = encoder(src_words_idx, src_seq_len) | |||
encode_mask = seq_len_to_mask(src_seq_len) | |||
hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||
cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||
past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell)) | |||
generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2) | |||
tokens_ids = generator.generate(past=past) | |||
print(tokens_ids) |
@@ -0,0 +1,35 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings import StaticEmbedding | |||
class TestTransformerSeq2SeqEncoder(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=512) | |||
encoder = TransformerSeq2SeqEncoder(embed) | |||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
seq_len = torch.LongTensor([3]) | |||
outputs, mask = encoder(words_idx, seq_len) | |||
print(outputs) | |||
print(mask) | |||
self.assertEqual(outputs.size(), (1, 3, 512)) | |||
class TestBiLSTMEncoder(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=300) | |||
encoder = BiLSTMEncoder(embed, hidden_size=300) | |||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
seq_len = torch.LongTensor([3]) | |||
outputs, hx = encoder(words_idx, seq_len) | |||
# print(outputs) | |||
# print(hx) | |||
self.assertEqual(outputs.size(), (1, 3, 300)) |