| @@ -28,7 +28,9 @@ __all__ = [ | |||
| "BertForSentenceMatching", | |||
| "BertForMultipleChoice", | |||
| "BertForTokenClassification", | |||
| "BertForQuestionAnswering" | |||
| "BertForQuestionAnswering", | |||
| "TransformerSeq2SeqModel" | |||
| ] | |||
| from .base_model import BaseModel | |||
| @@ -39,7 +41,7 @@ from .cnn_text_classification import CNNText | |||
| from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
| from .snli import ESIM | |||
| from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
| from .seq2seq_model import TransformerSeq2SeqModel | |||
| import sys | |||
| from ..doc_utils import doc_process | |||
| doc_process(sys.modules[__name__]) | |||
| @@ -1,29 +1,24 @@ | |||
| from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder | |||
| from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast | |||
| from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator | |||
| import torch.nn as nn | |||
| import torch | |||
| from typing import Union, Tuple | |||
| import numpy as np | |||
| from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast | |||
| class TransformerSeq2SeqModel(nn.Module): | |||
| class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法 | |||
| def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], | |||
| tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], | |||
| num_layers: int = 6, | |||
| d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | |||
| num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | |||
| output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, | |||
| bind_input_output_embed=False, | |||
| sos_id=None, eos_id=None): | |||
| bind_input_output_embed=False): | |||
| super().__init__() | |||
| self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout) | |||
| self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed, | |||
| bind_input_output_embed) | |||
| self.sos_id = sos_id | |||
| self.eos_id = eos_id | |||
| self.num_layers = num_layers | |||
| def forward(self, words, target, seq_len): # todo:这里的target有sos和eos吗,参考一下lstm怎么写的 | |||
| def forward(self, words, target, seq_len): | |||
| encoder_output, encoder_mask = self.encoder(words, seq_len) | |||
| past = TransformerPast(encoder_output, encoder_mask, self.num_layers) | |||
| outputs = self.decoder(target, past, return_attention=False) | |||
| @@ -49,7 +49,19 @@ __all__ = [ | |||
| "TimestepDropout", | |||
| 'summary' | |||
| 'summary', | |||
| "BiLSTMEncoder", | |||
| "TransformerSeq2SeqEncoder", | |||
| "SequenceGenerator", | |||
| "LSTMDecoder", | |||
| "LSTMPast", | |||
| "TransformerSeq2SeqDecoder", | |||
| "TransformerPast", | |||
| "Decoder", | |||
| "Past" | |||
| ] | |||
| import sys | |||
| @@ -7,11 +7,20 @@ __all__ = [ | |||
| "ConditionalRandomField", | |||
| "viterbi_decode", | |||
| "allowed_transitions", | |||
| "seq2seq_decoder", | |||
| "seq2seq_generator" | |||
| "SequenceGenerator", | |||
| "LSTMDecoder", | |||
| "LSTMPast", | |||
| "TransformerSeq2SeqDecoder", | |||
| "TransformerPast", | |||
| "Decoder", | |||
| "Past", | |||
| ] | |||
| from .crf import ConditionalRandomField | |||
| from .crf import allowed_transitions | |||
| from .mlp import MLP | |||
| from .utils import viterbi_decode | |||
| from .seq2seq_generator import SequenceGenerator | |||
| from .seq2seq_decoder import * | |||
| @@ -1,17 +1,46 @@ | |||
| # coding=utf-8 | |||
| __all__ = [ | |||
| "TransformerPast", | |||
| "LSTMPast", | |||
| "Past", | |||
| "LSTMDecoder", | |||
| "TransformerSeq2SeqDecoder", | |||
| "Decoder" | |||
| ] | |||
| import torch | |||
| from torch import nn | |||
| import abc | |||
| import torch.nn.functional as F | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| from ...embeddings import StaticEmbedding | |||
| import numpy as np | |||
| from typing import Union, Tuple | |||
| from fastNLP.embeddings import get_embeddings | |||
| from fastNLP.modules import LSTM | |||
| from ...embeddings.utils import get_embeddings | |||
| from torch.nn import LayerNorm | |||
| import math | |||
| from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ | |||
| get_sinusoid_encoding_table # todo: 应该将position embedding移到core | |||
| # from reproduction.Summarization.Baseline.tools.PositionEmbedding import \ | |||
| # get_sinusoid_encoding_table # todo: 应该将position embedding移到core | |||
| def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
| ''' Sinusoid position encoding table ''' | |||
| def cal_angle(position, hid_idx): | |||
| return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||
| def get_posi_angle_vec(position): | |||
| return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||
| sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||
| sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||
| sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||
| if padding_idx is not None: | |||
| # zero vector for padding dimension | |||
| sinusoid_table[padding_idx] = 0. | |||
| return torch.FloatTensor(sinusoid_table) | |||
| class Past: | |||
| @@ -82,7 +111,7 @@ class Decoder(nn.Module): | |||
| """ | |||
| raise NotImplemented | |||
| def decode_one(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||
| def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||
| """ | |||
| 当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了 | |||
| 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态 | |||
| @@ -100,6 +129,7 @@ class DecoderMultiheadAttention(nn.Module): | |||
| """ | |||
| def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||
| super(DecoderMultiheadAttention, self).__init__() | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dropout = dropout | |||
| @@ -157,11 +187,11 @@ class DecoderMultiheadAttention(nn.Module): | |||
| past.encoder_key[self.layer_idx] = k | |||
| past.encoder_value[self.layer_idx] = v | |||
| if inference and not is_encoder_attn: | |||
| past.decoder_prev_key[self.layer_idx] = prev_k | |||
| past.decoder_prev_value[self.layer_idx] = prev_v | |||
| past.decoder_prev_key[self.layer_idx] = prev_k if prev_k is not None else k | |||
| past.decoder_prev_value[self.layer_idx] = prev_v if prev_v is not None else v | |||
| batch_size, q_len, d_model = query.size() | |||
| k_len, v_len = key.size(1), value.size(1) | |||
| k_len, v_len = k.size(1), v.size(1) | |||
| q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim) | |||
| k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim) | |||
| v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim) | |||
| @@ -172,8 +202,8 @@ class DecoderMultiheadAttention(nn.Module): | |||
| if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len | |||
| mask = mask[:, None, :, None] | |||
| else: # (1, seq_len, seq_len) | |||
| mask = mask[...:None] | |||
| _mask = mask | |||
| mask = mask[..., None] | |||
| _mask = ~mask.bool() | |||
| attn_weights = attn_weights.masked_fill(_mask, float('-inf')) | |||
| @@ -181,21 +211,22 @@ class DecoderMultiheadAttention(nn.Module): | |||
| attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |||
| output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | |||
| output = output.view(batch_size, q_len, -1) | |||
| output = output.reshape(batch_size, q_len, -1) | |||
| output = self.out_proj(output) # batch,q_len,dim | |||
| return output, attn_weights | |||
| def reset_parameters(self): | |||
| nn.init.xavier_uniform_(self.q_proj) | |||
| nn.init.xavier_uniform_(self.k_proj) | |||
| nn.init.xavier_uniform_(self.v_proj) | |||
| nn.init.xavier_uniform_(self.out_proj) | |||
| nn.init.xavier_uniform_(self.q_proj.weight) | |||
| nn.init.xavier_uniform_(self.k_proj.weight) | |||
| nn.init.xavier_uniform_(self.v_proj.weight) | |||
| nn.init.xavier_uniform_(self.out_proj.weight) | |||
| class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
| def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, | |||
| layer_idx: int = None): | |||
| super(TransformerSeq2SeqDecoderLayer, self).__init__() | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| @@ -313,10 +344,10 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
| if isinstance(self.token_embed, StaticEmbedding): | |||
| for i in self.token_embed.words_to_words: | |||
| assert i == self.token_embed.words_to_words[i], "The index does not match." | |||
| self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1)) | |||
| self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1), requires_grad=True) | |||
| else: | |||
| if isinstance(output_embed, nn.Embedding): | |||
| self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1)) | |||
| self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1), requires_grad=True) | |||
| else: | |||
| self.output_embed = output_embed.transpose(0, 1) | |||
| self.output_hidden_size = self.output_embed.size(0) | |||
| @@ -326,7 +357,8 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
| def forward(self, tokens, past, return_attention=False, inference=False): | |||
| """ | |||
| :param tokens: torch.LongTensor, tokens: batch_size x decode_len | |||
| :param tokens: torch.LongTensor, tokens: batch_size , decode_len | |||
| :param self_attn_mask: 在inference的时候不需要,而在train的时候,因为训练的时候交叉熵会自动屏蔽掉padding的地方,所以也不需要 | |||
| :param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算 | |||
| :param return_attention: | |||
| :param inference: 是否在inference阶段 | |||
| @@ -335,13 +367,16 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
| assert past is not None | |||
| batch_size, decode_len = tokens.size() | |||
| device = tokens.device | |||
| pos_idx = torch.arange(1, decode_len + 1).unsqueeze(0).long() | |||
| if not inference: | |||
| self_attn_mask = self._get_triangle_mask(decode_len) | |||
| self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq | |||
| else: | |||
| self_attn_mask = None | |||
| tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim | |||
| pos = self.pos_embed(tokens) # bs,decode_len,embed_dim | |||
| pos = self.pos_embed(pos_idx) # 1,decode_len,embed_dim | |||
| tokens = pos + tokens | |||
| if inference: | |||
| tokens = tokens[:, -1:, :] | |||
| @@ -358,7 +393,7 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
| return output | |||
| @torch.no_grad() | |||
| def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||
| def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||
| """ | |||
| # todo: 是否不需要return past? 因为past已经被改变了,不需要显式return? | |||
| :param tokens: torch.LongTensor (batch_size,1) | |||
| @@ -370,7 +405,7 @@ class TransformerSeq2SeqDecoder(Decoder): | |||
| def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast: | |||
| past.reorder_past(indices) | |||
| return past # todo : 其实可以不要这个的 | |||
| return past | |||
| def _get_triangle_mask(self, max_seq_len): | |||
| tensor = torch.ones(max_seq_len, max_seq_len) | |||
| @@ -409,6 +444,27 @@ class LSTMPast(Past): | |||
| return tensor[0].size(0) | |||
| return None | |||
| def _reorder_past(self, state, indices, dim=0): | |||
| if type(state) == torch.Tensor: | |||
| state = state.index_select(index=indices, dim=dim) | |||
| elif type(state) == tuple: | |||
| tmp_list = [] | |||
| for i in range(len(state)): | |||
| assert state[i] is not None | |||
| tmp_list.append(state[i].index_select(index=indices, dim=dim)) | |||
| state = tuple(tmp_list) | |||
| else: | |||
| raise ValueError('State does not support other format') | |||
| return state | |||
| def reorder_past(self, indices: torch.LongTensor): | |||
| self.encode_outputs = self._reorder_past(self.encode_outputs, indices) | |||
| self.encode_mask = self._reorder_past(self.encode_mask, indices) | |||
| self.hx = self._reorder_past(self.hx, indices, 1) | |||
| if self.attn_states is not None: | |||
| self.attn_states = self._reorder_past(self.attn_states, indices) | |||
| @property | |||
| def hx(self): | |||
| return self._hx | |||
| @@ -493,7 +549,7 @@ class AttentionLayer(nn.Module): | |||
| class LSTMDecoder(Decoder): | |||
| def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers, input_size, | |||
| def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers=3, input_size=400, | |||
| hidden_size=None, dropout=0, | |||
| output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, | |||
| bind_input_output_embed=False, | |||
| @@ -612,6 +668,7 @@ class LSTMDecoder(Decoder): | |||
| for i in range(tokens.size(1)): | |||
| input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h' | |||
| # bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size) | |||
| _, (hidden, cell) = self.lstm(input, hx=past.hx) | |||
| past.hx = (hidden, cell) | |||
| if self.attention_layer is not None: | |||
| @@ -633,7 +690,7 @@ class LSTMDecoder(Decoder): | |||
| return feats | |||
| @torch.no_grad() | |||
| def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||
| def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]: | |||
| """ | |||
| 给定上一个位置的输出,决定当前位置的输出。 | |||
| :param torch.LongTensor tokens: batch_size x seq_len | |||
| @@ -653,13 +710,5 @@ class LSTMDecoder(Decoder): | |||
| :param LSTMPast past: 保存的过去的状态 | |||
| :return: | |||
| """ | |||
| encode_outputs = past.encode_outputs.index_select(index=indices, dim=0) | |||
| encoder_mask = past.encode_mask.index_select(index=indices, dim=0) | |||
| hx = (past.hx[0].index_select(index=indices, dim=1), | |||
| past.hx[1].index_select(index=indices, dim=1)) | |||
| if past.attn_states is not None: | |||
| past.attn_states = past.attn_states.index_select(index=indices, dim=0) | |||
| past.encode_mask = encoder_mask | |||
| past.encode_outputs = encode_outputs | |||
| past.hx = hx | |||
| past.reorder_past(indices) | |||
| return past | |||
| @@ -1,7 +1,10 @@ | |||
| __all__ = [ | |||
| "SequenceGenerator" | |||
| ] | |||
| import torch | |||
| from .seq2seq_decoder import Decoder | |||
| import torch.nn.functional as F | |||
| from fastNLP.core.utils import _get_model_device | |||
| from ...core.utils import _get_model_device | |||
| from functools import partial | |||
| @@ -130,8 +133,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
| else: | |||
| _eos_token_id = eos_token_id | |||
| for i in range(tokens.size(1) - 1): | |||
| scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||
| for i in range(tokens.size(1)): | |||
| scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||
| token_ids = tokens.clone() | |||
| cur_len = token_ids.size(1) | |||
| @@ -139,7 +142,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
| # tokens = tokens[:, -1:] | |||
| while cur_len < max_length: | |||
| scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past | |||
| scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past | |||
| if repetition_penalty != 1.0: | |||
| token_scores = scores.gather(dim=1, index=token_ids) | |||
| @@ -153,7 +156,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
| eos_mask = scores.new_ones(scores.size(1)) | |||
| eos_mask[eos_token_id] = 0 | |||
| eos_mask = eos_mask.unsqueeze(0).eq(1) | |||
| scores = scores.masked_scatter(eos_mask, token_scores) | |||
| scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 | |||
| if do_sample: | |||
| if temperature > 0 and temperature != 1: | |||
| @@ -167,7 +170,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
| else: | |||
| next_tokens = torch.argmax(scores, dim=-1) # batch_size | |||
| next_tokens = next_tokens.masked_fill(dones, 0) | |||
| next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding | |||
| tokens = next_tokens.unsqueeze(1) | |||
| token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len | |||
| @@ -181,7 +184,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
| if eos_token_id is not None: | |||
| if cur_len == max_length: | |||
| token_ids[:, -1].masked_fill_(dones, eos_token_id) | |||
| token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos | |||
| return token_ids | |||
| @@ -206,9 +209,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||
| for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode | |||
| scores, past = decoder.decode_one(tokens[:, :i + 1], | |||
| past) # (batch_size, vocab_size), Past | |||
| scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 | |||
| scores, past = decoder.decode(tokens[:, :i + 1], | |||
| past) # (batch_size, vocab_size), Past | |||
| scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 | |||
| vocab_size = scores.size(1) | |||
| assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | |||
| @@ -224,7 +227,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| indices = torch.arange(batch_size, dtype=torch.long).to(device) | |||
| indices = indices.repeat_interleave(num_beams) | |||
| past = decoder.reorder_past(indices, past) | |||
| decoder.reorder_past(indices, past) | |||
| tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | |||
| # 记录生成好的token (batch_size', cur_len) | |||
| @@ -240,11 +243,11 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| hypos = [ | |||
| BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) | |||
| ] | |||
| # 0,num_beams, 2*num_beams | |||
| # 0,num_beams, 2*num_beams, ... | |||
| batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | |||
| while cur_len < max_length: | |||
| scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past | |||
| scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past | |||
| if repetition_penalty != 1.0: | |||
| token_scores = scores.gather(dim=1, index=token_ids) | |||
| @@ -298,8 +301,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| beam_scores = _next_scores.view(-1) | |||
| # 更改past状态, 重组token_ids | |||
| reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) | |||
| past = decoder.reorder_past(reorder_inds, past) | |||
| reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||
| decoder.reorder_past(reorder_inds, past) | |||
| flag = True | |||
| if cur_len + 1 == max_length: | |||
| @@ -445,7 +448,7 @@ if __name__ == '__main__': | |||
| super().__init__() | |||
| self.num_words = num_words | |||
| def decode_one(self, tokens, past): | |||
| def decode(self, tokens, past): | |||
| batch_size = tokens.size(0) | |||
| return torch.randn(batch_size, self.num_words), past | |||
| @@ -30,6 +30,9 @@ __all__ = [ | |||
| "MultiHeadAttention", | |||
| "BiAttention", | |||
| "SelfAttention", | |||
| "BiLSTMEncoder", | |||
| "TransformerSeq2SeqEncoder" | |||
| ] | |||
| from .attention import MultiHeadAttention, BiAttention, SelfAttention | |||
| @@ -41,3 +44,5 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo | |||
| from .star_transformer import StarTransformer | |||
| from .transformer import TransformerEncoder | |||
| from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
| from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder | |||
| @@ -1,7 +1,12 @@ | |||
| __all__ = [ | |||
| "TransformerSeq2SeqEncoder", | |||
| "BiLSTMEncoder" | |||
| ] | |||
| from torch import nn | |||
| import torch | |||
| from fastNLP.modules import LSTM | |||
| from fastNLP import seq_len_to_mask | |||
| from ...modules.encoder import LSTM | |||
| from ...core.utils import seq_len_to_mask | |||
| from torch.nn import TransformerEncoder | |||
| from typing import Union, Tuple | |||
| import numpy as np | |||
| @@ -0,0 +1 @@ | |||
| @@ -0,0 +1,65 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||
| from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| from fastNLP.core.utils import seq_len_to_mask | |||
| class TestTransformerSeq2SeqDecoder(unittest.TestCase): | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=512) | |||
| encoder = TransformerSeq2SeqEncoder(embed) | |||
| decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True) | |||
| src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
| tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| encoder_outputs, mask = encoder(src_words_idx, src_seq_len) | |||
| past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask) | |||
| decoder_outputs = decoder(tgt_words_idx, past) | |||
| print(decoder_outputs) | |||
| print(mask) | |||
| self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) | |||
| def test_decode(self): | |||
| pass # todo | |||
| class TestLSTMDecoder(unittest.TestCase): | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=512) | |||
| encoder = BiLSTMEncoder(embed) | |||
| decoder = LSTMDecoder(embed, bind_input_output_embed=True) | |||
| src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
| tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| words, hx = encoder(src_words_idx, src_seq_len) | |||
| encode_mask = seq_len_to_mask(src_seq_len) | |||
| hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||
| cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||
| past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell)) | |||
| decoder_outputs = decoder(tgt_words_idx, past) | |||
| print(decoder_outputs) | |||
| print(encode_mask) | |||
| self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab))) | |||
| def test_decode(self): | |||
| pass # todo | |||
| @@ -0,0 +1,52 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||
| from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| from fastNLP.core.utils import seq_len_to_mask | |||
| from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator | |||
| class TestSequenceGenerator(unittest.TestCase): | |||
| def test_case_for_transformer(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=512) | |||
| encoder = TransformerSeq2SeqEncoder(embed, num_layers=6) | |||
| decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True, num_layers=6) | |||
| src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
| tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| encoder_outputs, mask = encoder(src_words_idx, src_seq_len) | |||
| past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask, num_decoder_layer=6) | |||
| generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2) | |||
| tokens_ids = generator.generate(past=past) | |||
| print(tokens_ids) | |||
| def test_case_for_lstm(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=512) | |||
| encoder = BiLSTMEncoder(embed) | |||
| decoder = LSTMDecoder(embed, bind_input_output_embed=True) | |||
| src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
| tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| words, hx = encoder(src_words_idx, src_seq_len) | |||
| encode_mask = seq_len_to_mask(src_seq_len) | |||
| hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||
| cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1) | |||
| past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell)) | |||
| generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2) | |||
| tokens_ids = generator.generate(past=past) | |||
| print(tokens_ids) | |||
| @@ -0,0 +1,35 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| class TestTransformerSeq2SeqEncoder(unittest.TestCase): | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=512) | |||
| encoder = TransformerSeq2SeqEncoder(embed) | |||
| words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
| seq_len = torch.LongTensor([3]) | |||
| outputs, mask = encoder(words_idx, seq_len) | |||
| print(outputs) | |||
| print(mask) | |||
| self.assertEqual(outputs.size(), (1, 3, 512)) | |||
| class TestBiLSTMEncoder(unittest.TestCase): | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=300) | |||
| encoder = BiLSTMEncoder(embed, hidden_size=300) | |||
| words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
| seq_len = torch.LongTensor([3]) | |||
| outputs, hx = encoder(words_idx, seq_len) | |||
| # print(outputs) | |||
| # print(hx) | |||
| self.assertEqual(outputs.size(), (1, 3, 300)) | |||