diff --git a/fastNLP/models/__init__.py b/fastNLP/models/__init__.py index ba499ac2..c6930b9a 100644 --- a/fastNLP/models/__init__.py +++ b/fastNLP/models/__init__.py @@ -28,7 +28,9 @@ __all__ = [ "BertForSentenceMatching", "BertForMultipleChoice", "BertForTokenClassification", - "BertForQuestionAnswering" + "BertForQuestionAnswering", + + "TransformerSeq2SeqModel" ] from .base_model import BaseModel @@ -39,7 +41,7 @@ from .cnn_text_classification import CNNText from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF from .snli import ESIM from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel - +from .seq2seq_model import TransformerSeq2SeqModel import sys from ..doc_utils import doc_process doc_process(sys.modules[__name__]) \ No newline at end of file diff --git a/fastNLP/models/seq2seq_models.py b/fastNLP/models/seq2seq_model.py similarity index 58% rename from fastNLP/models/seq2seq_models.py rename to fastNLP/models/seq2seq_model.py index d1126177..94b96198 100644 --- a/fastNLP/models/seq2seq_models.py +++ b/fastNLP/models/seq2seq_model.py @@ -1,29 +1,24 @@ -from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder -from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast -from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator import torch.nn as nn import torch from typing import Union, Tuple import numpy as np +from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast -class TransformerSeq2SeqModel(nn.Module): +class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法 def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], - num_layers: int = 6, - d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, + num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, - bind_input_output_embed=False, - sos_id=None, eos_id=None): + bind_input_output_embed=False): super().__init__() self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout) self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed, bind_input_output_embed) - self.sos_id = sos_id - self.eos_id = eos_id + self.num_layers = num_layers - def forward(self, words, target, seq_len): # todo:这里的target有sos和eos吗,参考一下lstm怎么写的 + def forward(self, words, target, seq_len): encoder_output, encoder_mask = self.encoder(words, seq_len) past = TransformerPast(encoder_output, encoder_mask, self.num_layers) outputs = self.decoder(target, past, return_attention=False) diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py index 77283806..e973379d 100644 --- a/fastNLP/modules/__init__.py +++ b/fastNLP/modules/__init__.py @@ -49,7 +49,19 @@ __all__ = [ "TimestepDropout", - 'summary' + 'summary', + + "BiLSTMEncoder", + "TransformerSeq2SeqEncoder", + + "SequenceGenerator", + "LSTMDecoder", + "LSTMPast", + "TransformerSeq2SeqDecoder", + "TransformerPast", + "Decoder", + "Past" + ] import sys diff --git a/fastNLP/modules/decoder/__init__.py b/fastNLP/modules/decoder/__init__.py index 2b418927..e3bceff0 100644 --- a/fastNLP/modules/decoder/__init__.py +++ b/fastNLP/modules/decoder/__init__.py @@ -7,11 +7,20 @@ __all__ = [ "ConditionalRandomField", "viterbi_decode", "allowed_transitions", - "seq2seq_decoder", - "seq2seq_generator" + + "SequenceGenerator", + "LSTMDecoder", + "LSTMPast", + "TransformerSeq2SeqDecoder", + "TransformerPast", + "Decoder", + "Past", + ] from .crf import ConditionalRandomField from .crf import allowed_transitions from .mlp import MLP from .utils import viterbi_decode +from .seq2seq_generator import SequenceGenerator +from .seq2seq_decoder import * diff --git a/fastNLP/modules/decoder/seq2seq_decoder.py b/fastNLP/modules/decoder/seq2seq_decoder.py index fae46c20..9a4e27de 100644 --- a/fastNLP/modules/decoder/seq2seq_decoder.py +++ b/fastNLP/modules/decoder/seq2seq_decoder.py @@ -1,17 +1,46 @@ # coding=utf-8 +__all__ = [ + "TransformerPast", + "LSTMPast", + "Past", + "LSTMDecoder", + "TransformerSeq2SeqDecoder", + "Decoder" +] import torch from torch import nn import abc import torch.nn.functional as F -from fastNLP.embeddings import StaticEmbedding +from ...embeddings import StaticEmbedding import numpy as np from typing import Union, Tuple -from fastNLP.embeddings import get_embeddings -from fastNLP.modules import LSTM +from ...embeddings.utils import get_embeddings from torch.nn import LayerNorm import math -from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ - get_sinusoid_encoding_table # todo: 应该将position embedding移到core + + +# from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ +# get_sinusoid_encoding_table # todo: 应该将position embedding移到core + +def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + ''' Sinusoid position encoding table ''' + + def cal_angle(position, hid_idx): + return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) + + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) + + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0. + + return torch.FloatTensor(sinusoid_table) class Past: @@ -82,7 +111,7 @@ class Decoder(nn.Module): """ raise NotImplemented - def decode_one(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: + def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: """ 当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态 @@ -100,6 +129,7 @@ class DecoderMultiheadAttention(nn.Module): """ def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): + super(DecoderMultiheadAttention, self).__init__() self.d_model = d_model self.n_head = n_head self.dropout = dropout @@ -157,11 +187,11 @@ class DecoderMultiheadAttention(nn.Module): past.encoder_key[self.layer_idx] = k past.encoder_value[self.layer_idx] = v if inference and not is_encoder_attn: - past.decoder_prev_key[self.layer_idx] = prev_k - past.decoder_prev_value[self.layer_idx] = prev_v + past.decoder_prev_key[self.layer_idx] = prev_k if prev_k is not None else k + past.decoder_prev_value[self.layer_idx] = prev_v if prev_v is not None else v batch_size, q_len, d_model = query.size() - k_len, v_len = key.size(1), value.size(1) + k_len, v_len = k.size(1), v.size(1) q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim) k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim) v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim) @@ -172,8 +202,8 @@ class DecoderMultiheadAttention(nn.Module): if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len mask = mask[:, None, :, None] else: # (1, seq_len, seq_len) - mask = mask[...:None] - _mask = mask + mask = mask[..., None] + _mask = ~mask.bool() attn_weights = attn_weights.masked_fill(_mask, float('-inf')) @@ -181,21 +211,22 @@ class DecoderMultiheadAttention(nn.Module): attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim - output = output.view(batch_size, q_len, -1) + output = output.reshape(batch_size, q_len, -1) output = self.out_proj(output) # batch,q_len,dim return output, attn_weights def reset_parameters(self): - nn.init.xavier_uniform_(self.q_proj) - nn.init.xavier_uniform_(self.k_proj) - nn.init.xavier_uniform_(self.v_proj) - nn.init.xavier_uniform_(self.out_proj) + nn.init.xavier_uniform_(self.q_proj.weight) + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.out_proj.weight) class TransformerSeq2SeqDecoderLayer(nn.Module): def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, layer_idx: int = None): + super(TransformerSeq2SeqDecoderLayer, self).__init__() self.d_model = d_model self.n_head = n_head self.dim_ff = dim_ff @@ -313,10 +344,10 @@ class TransformerSeq2SeqDecoder(Decoder): if isinstance(self.token_embed, StaticEmbedding): for i in self.token_embed.words_to_words: assert i == self.token_embed.words_to_words[i], "The index does not match." - self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1)) + self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1), requires_grad=True) else: if isinstance(output_embed, nn.Embedding): - self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1)) + self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1), requires_grad=True) else: self.output_embed = output_embed.transpose(0, 1) self.output_hidden_size = self.output_embed.size(0) @@ -326,7 +357,8 @@ class TransformerSeq2SeqDecoder(Decoder): def forward(self, tokens, past, return_attention=False, inference=False): """ - :param tokens: torch.LongTensor, tokens: batch_size x decode_len + :param tokens: torch.LongTensor, tokens: batch_size , decode_len + :param self_attn_mask: 在inference的时候不需要,而在train的时候,因为训练的时候交叉熵会自动屏蔽掉padding的地方,所以也不需要 :param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算 :param return_attention: :param inference: 是否在inference阶段 @@ -335,13 +367,16 @@ class TransformerSeq2SeqDecoder(Decoder): assert past is not None batch_size, decode_len = tokens.size() device = tokens.device + pos_idx = torch.arange(1, decode_len + 1).unsqueeze(0).long() + if not inference: self_attn_mask = self._get_triangle_mask(decode_len) self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq else: self_attn_mask = None + tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim - pos = self.pos_embed(tokens) # bs,decode_len,embed_dim + pos = self.pos_embed(pos_idx) # 1,decode_len,embed_dim tokens = pos + tokens if inference: tokens = tokens[:, -1:, :] @@ -358,7 +393,7 @@ class TransformerSeq2SeqDecoder(Decoder): return output @torch.no_grad() - def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]: + def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: """ # todo: 是否不需要return past? 因为past已经被改变了,不需要显式return? :param tokens: torch.LongTensor (batch_size,1) @@ -370,7 +405,7 @@ class TransformerSeq2SeqDecoder(Decoder): def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast: past.reorder_past(indices) - return past # todo : 其实可以不要这个的 + return past def _get_triangle_mask(self, max_seq_len): tensor = torch.ones(max_seq_len, max_seq_len) @@ -409,6 +444,27 @@ class LSTMPast(Past): return tensor[0].size(0) return None + def _reorder_past(self, state, indices, dim=0): + if type(state) == torch.Tensor: + state = state.index_select(index=indices, dim=dim) + elif type(state) == tuple: + tmp_list = [] + for i in range(len(state)): + assert state[i] is not None + tmp_list.append(state[i].index_select(index=indices, dim=dim)) + state = tuple(tmp_list) + else: + raise ValueError('State does not support other format') + + return state + + def reorder_past(self, indices: torch.LongTensor): + self.encode_outputs = self._reorder_past(self.encode_outputs, indices) + self.encode_mask = self._reorder_past(self.encode_mask, indices) + self.hx = self._reorder_past(self.hx, indices, 1) + if self.attn_states is not None: + self.attn_states = self._reorder_past(self.attn_states, indices) + @property def hx(self): return self._hx @@ -493,7 +549,7 @@ class AttentionLayer(nn.Module): class LSTMDecoder(Decoder): - def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers, input_size, + def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers=3, input_size=400, hidden_size=None, dropout=0, output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, bind_input_output_embed=False, @@ -612,6 +668,7 @@ class LSTMDecoder(Decoder): for i in range(tokens.size(1)): input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h' # bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size) + _, (hidden, cell) = self.lstm(input, hx=past.hx) past.hx = (hidden, cell) if self.attention_layer is not None: @@ -633,7 +690,7 @@ class LSTMDecoder(Decoder): return feats @torch.no_grad() - def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]: + def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: """ 给定上一个位置的输出,决定当前位置的输出。 :param torch.LongTensor tokens: batch_size x seq_len @@ -653,13 +710,5 @@ class LSTMDecoder(Decoder): :param LSTMPast past: 保存的过去的状态 :return: """ - encode_outputs = past.encode_outputs.index_select(index=indices, dim=0) - encoder_mask = past.encode_mask.index_select(index=indices, dim=0) - hx = (past.hx[0].index_select(index=indices, dim=1), - past.hx[1].index_select(index=indices, dim=1)) - if past.attn_states is not None: - past.attn_states = past.attn_states.index_select(index=indices, dim=0) - past.encode_mask = encoder_mask - past.encode_outputs = encode_outputs - past.hx = hx + past.reorder_past(indices) return past diff --git a/fastNLP/modules/decoder/seq2seq_generator.py b/fastNLP/modules/decoder/seq2seq_generator.py index 6d41e717..4ee2c787 100644 --- a/fastNLP/modules/decoder/seq2seq_generator.py +++ b/fastNLP/modules/decoder/seq2seq_generator.py @@ -1,7 +1,10 @@ +__all__ = [ + "SequenceGenerator" +] import torch from .seq2seq_decoder import Decoder import torch.nn.functional as F -from fastNLP.core.utils import _get_model_device +from ...core.utils import _get_model_device from functools import partial @@ -130,8 +133,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt else: _eos_token_id = eos_token_id - for i in range(tokens.size(1) - 1): - scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past + for i in range(tokens.size(1)): + scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past token_ids = tokens.clone() cur_len = token_ids.size(1) @@ -139,7 +142,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt # tokens = tokens[:, -1:] while cur_len < max_length: - scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past + scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) @@ -153,7 +156,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt eos_mask = scores.new_ones(scores.size(1)) eos_mask[eos_token_id] = 0 eos_mask = eos_mask.unsqueeze(0).eq(1) - scores = scores.masked_scatter(eos_mask, token_scores) + scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 if do_sample: if temperature > 0 and temperature != 1: @@ -167,7 +170,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt else: next_tokens = torch.argmax(scores, dim=-1) # batch_size - next_tokens = next_tokens.masked_fill(dones, 0) + next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding tokens = next_tokens.unsqueeze(1) token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len @@ -181,7 +184,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt if eos_token_id is not None: if cur_len == max_length: - token_ids[:, -1].masked_fill_(dones, eos_token_id) + token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos return token_ids @@ -206,9 +209,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode - scores, past = decoder.decode_one(tokens[:, :i + 1], - past) # (batch_size, vocab_size), Past - scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 + scores, past = decoder.decode(tokens[:, :i + 1], + past) # (batch_size, vocab_size), Past + scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 vocab_size = scores.size(1) assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." @@ -224,7 +227,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 indices = torch.arange(batch_size, dtype=torch.long).to(device) indices = indices.repeat_interleave(num_beams) - past = decoder.reorder_past(indices, past) + decoder.reorder_past(indices, past) tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length # 记录生成好的token (batch_size', cur_len) @@ -240,11 +243,11 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 hypos = [ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) ] - # 0,num_beams, 2*num_beams + # 0,num_beams, 2*num_beams, ... batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) while cur_len < max_length: - scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past + scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) @@ -298,8 +301,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 beam_scores = _next_scores.view(-1) # 更改past状态, 重组token_ids - reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) - past = decoder.reorder_past(reorder_inds, past) + reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 + decoder.reorder_past(reorder_inds, past) flag = True if cur_len + 1 == max_length: @@ -445,7 +448,7 @@ if __name__ == '__main__': super().__init__() self.num_words = num_words - def decode_one(self, tokens, past): + def decode(self, tokens, past): batch_size = tokens.size(0) return torch.randn(batch_size, self.num_words), past diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index cbb42d7e..579dddd4 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -30,6 +30,9 @@ __all__ = [ "MultiHeadAttention", "BiAttention", "SelfAttention", + + "BiLSTMEncoder", + "TransformerSeq2SeqEncoder" ] from .attention import MultiHeadAttention, BiAttention, SelfAttention @@ -41,3 +44,5 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo from .star_transformer import StarTransformer from .transformer import TransformerEncoder from .variational_rnn import VarRNN, VarLSTM, VarGRU + +from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder diff --git a/fastNLP/modules/encoder/seq2seq_encoder.py b/fastNLP/modules/encoder/seq2seq_encoder.py index 91d4b499..1474c864 100644 --- a/fastNLP/modules/encoder/seq2seq_encoder.py +++ b/fastNLP/modules/encoder/seq2seq_encoder.py @@ -1,7 +1,12 @@ +__all__ = [ + "TransformerSeq2SeqEncoder", + "BiLSTMEncoder" +] + from torch import nn import torch -from fastNLP.modules import LSTM -from fastNLP import seq_len_to_mask +from ...modules.encoder import LSTM +from ...core.utils import seq_len_to_mask from torch.nn import TransformerEncoder from typing import Union, Tuple import numpy as np diff --git a/test/models/test_seq2seq_model.py b/test/models/test_seq2seq_model.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/test/models/test_seq2seq_model.py @@ -0,0 +1 @@ + diff --git a/test/modules/decoder/test_seq2seq_decoder.py b/test/modules/decoder/test_seq2seq_decoder.py new file mode 100644 index 00000000..6c74d527 --- /dev/null +++ b/test/modules/decoder/test_seq2seq_decoder.py @@ -0,0 +1,65 @@ +import unittest + +import torch + +from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder +from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder +from fastNLP import Vocabulary +from fastNLP.embeddings import StaticEmbedding +from fastNLP.core.utils import seq_len_to_mask + + +class TestTransformerSeq2SeqDecoder(unittest.TestCase): + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, embedding_dim=512) + + encoder = TransformerSeq2SeqEncoder(embed) + decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True) + + src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) + tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) + src_seq_len = torch.LongTensor([3, 2]) + + encoder_outputs, mask = encoder(src_words_idx, src_seq_len) + past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask) + + decoder_outputs = decoder(tgt_words_idx, past) + + print(decoder_outputs) + print(mask) + + self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) + + def test_decode(self): + pass # todo + + +class TestLSTMDecoder(unittest.TestCase): + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, embedding_dim=512) + + encoder = BiLSTMEncoder(embed) + decoder = LSTMDecoder(embed, bind_input_output_embed=True) + + src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) + tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) + src_seq_len = torch.LongTensor([3, 2]) + + words, hx = encoder(src_words_idx, src_seq_len) + encode_mask = seq_len_to_mask(src_seq_len) + hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) + cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) + past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell)) + decoder_outputs = decoder(tgt_words_idx, past) + + print(decoder_outputs) + print(encode_mask) + + self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) + + def test_decode(self): + pass # todo diff --git a/test/modules/decoder/test_seq2seq_generator.py b/test/modules/decoder/test_seq2seq_generator.py new file mode 100644 index 00000000..23c20fa1 --- /dev/null +++ b/test/modules/decoder/test_seq2seq_generator.py @@ -0,0 +1,52 @@ +import unittest + +import torch + +from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder +from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder +from fastNLP import Vocabulary +from fastNLP.embeddings import StaticEmbedding +from fastNLP.core.utils import seq_len_to_mask +from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator + + +class TestSequenceGenerator(unittest.TestCase): + def test_case_for_transformer(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, embedding_dim=512) + encoder = TransformerSeq2SeqEncoder(embed, num_layers=6) + decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True, num_layers=6) + + src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) + tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) + src_seq_len = torch.LongTensor([3, 2]) + + encoder_outputs, mask = encoder(src_words_idx, src_seq_len) + past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask, num_decoder_layer=6) + + generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2) + tokens_ids = generator.generate(past=past) + + print(tokens_ids) + + def test_case_for_lstm(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + vocab.add_word_lst("Another test !".split()) + embed = StaticEmbedding(vocab, embedding_dim=512) + encoder = BiLSTMEncoder(embed) + decoder = LSTMDecoder(embed, bind_input_output_embed=True) + src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) + tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) + src_seq_len = torch.LongTensor([3, 2]) + + words, hx = encoder(src_words_idx, src_seq_len) + encode_mask = seq_len_to_mask(src_seq_len) + hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) + cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) + past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell)) + + generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2) + tokens_ids = generator.generate(past=past) + + print(tokens_ids) diff --git a/test/modules/encoder/test_seq2seq_encoder.py b/test/modules/encoder/test_seq2seq_encoder.py new file mode 100644 index 00000000..a9491f84 --- /dev/null +++ b/test/modules/encoder/test_seq2seq_encoder.py @@ -0,0 +1,35 @@ +import unittest + +import torch + +from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder +from fastNLP import Vocabulary +from fastNLP.embeddings import StaticEmbedding + +class TestTransformerSeq2SeqEncoder(unittest.TestCase): + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + embed = StaticEmbedding(vocab, embedding_dim=512) + encoder = TransformerSeq2SeqEncoder(embed) + words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) + seq_len = torch.LongTensor([3]) + outputs, mask = encoder(words_idx, seq_len) + + print(outputs) + print(mask) + self.assertEqual(outputs.size(), (1, 3, 512)) + + +class TestBiLSTMEncoder(unittest.TestCase): + def test_case(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + embed = StaticEmbedding(vocab, embedding_dim=300) + encoder = BiLSTMEncoder(embed, hidden_size=300) + words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) + seq_len = torch.LongTensor([3]) + + outputs, hx = encoder(words_idx, seq_len) + + # print(outputs) + # print(hx) + self.assertEqual(outputs.size(), (1, 3, 300))