Browse Source

保存一版旧版

tags/v0.6.0
linzehui 5 years ago
parent
commit
15360e9724
12 changed files with 300 additions and 67 deletions
  1. +4
    -2
      fastNLP/models/__init__.py
  2. +6
    -11
      fastNLP/models/seq2seq_model.py
  3. +13
    -1
      fastNLP/modules/__init__.py
  4. +11
    -2
      fastNLP/modules/decoder/__init__.py
  5. +82
    -33
      fastNLP/modules/decoder/seq2seq_decoder.py
  6. +19
    -16
      fastNLP/modules/decoder/seq2seq_generator.py
  7. +5
    -0
      fastNLP/modules/encoder/__init__.py
  8. +7
    -2
      fastNLP/modules/encoder/seq2seq_encoder.py
  9. +1
    -0
      test/models/test_seq2seq_model.py
  10. +65
    -0
      test/modules/decoder/test_seq2seq_decoder.py
  11. +52
    -0
      test/modules/decoder/test_seq2seq_generator.py
  12. +35
    -0
      test/modules/encoder/test_seq2seq_encoder.py

+ 4
- 2
fastNLP/models/__init__.py View File

@@ -28,7 +28,9 @@ __all__ = [
"BertForSentenceMatching",
"BertForMultipleChoice",
"BertForTokenClassification",
"BertForQuestionAnswering"
"BertForQuestionAnswering",

"TransformerSeq2SeqModel"
]

from .base_model import BaseModel
@@ -39,7 +41,7 @@ from .cnn_text_classification import CNNText
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
from .snli import ESIM
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel
from .seq2seq_model import TransformerSeq2SeqModel
import sys
from ..doc_utils import doc_process
doc_process(sys.modules[__name__])

fastNLP/models/seq2seq_models.py → fastNLP/models/seq2seq_model.py View File

@@ -1,29 +1,24 @@
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator
import torch.nn as nn
import torch
from typing import Union, Tuple
import numpy as np
from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast


class TransformerSeq2SeqModel(nn.Module):
class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法
def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
num_layers: int = 6,
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
bind_input_output_embed=False,
sos_id=None, eos_id=None):
bind_input_output_embed=False):
super().__init__()
self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout)
self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed,
bind_input_output_embed)
self.sos_id = sos_id
self.eos_id = eos_id

self.num_layers = num_layers

def forward(self, words, target, seq_len): # todo:这里的target有sos和eos吗,参考一下lstm怎么写的
def forward(self, words, target, seq_len):
encoder_output, encoder_mask = self.encoder(words, seq_len)
past = TransformerPast(encoder_output, encoder_mask, self.num_layers)
outputs = self.decoder(target, past, return_attention=False)

+ 13
- 1
fastNLP/modules/__init__.py View File

@@ -49,7 +49,19 @@ __all__ = [

"TimestepDropout",

'summary'
'summary',

"BiLSTMEncoder",
"TransformerSeq2SeqEncoder",

"SequenceGenerator",
"LSTMDecoder",
"LSTMPast",
"TransformerSeq2SeqDecoder",
"TransformerPast",
"Decoder",
"Past"

]

import sys


+ 11
- 2
fastNLP/modules/decoder/__init__.py View File

@@ -7,11 +7,20 @@ __all__ = [
"ConditionalRandomField",
"viterbi_decode",
"allowed_transitions",
"seq2seq_decoder",
"seq2seq_generator"

"SequenceGenerator",
"LSTMDecoder",
"LSTMPast",
"TransformerSeq2SeqDecoder",
"TransformerPast",
"Decoder",
"Past",

]

from .crf import ConditionalRandomField
from .crf import allowed_transitions
from .mlp import MLP
from .utils import viterbi_decode
from .seq2seq_generator import SequenceGenerator
from .seq2seq_decoder import *

+ 82
- 33
fastNLP/modules/decoder/seq2seq_decoder.py View File

@@ -1,17 +1,46 @@
# coding=utf-8
__all__ = [
"TransformerPast",
"LSTMPast",
"Past",
"LSTMDecoder",
"TransformerSeq2SeqDecoder",
"Decoder"
]
import torch
from torch import nn
import abc
import torch.nn.functional as F
from fastNLP.embeddings import StaticEmbedding
from ...embeddings import StaticEmbedding
import numpy as np
from typing import Union, Tuple
from fastNLP.embeddings import get_embeddings
from fastNLP.modules import LSTM
from ...embeddings.utils import get_embeddings
from torch.nn import LayerNorm
import math
from reproduction.Summarization.Baseline.tools.PositionEmbedding import \
get_sinusoid_encoding_table # todo: 应该将position embedding移到core


# from reproduction.Summarization.Baseline.tools.PositionEmbedding import \
# get_sinusoid_encoding_table # todo: 应该将position embedding移到core

def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
''' Sinusoid position encoding table '''

def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)

def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]

sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])

sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

if padding_idx is not None:
# zero vector for padding dimension
sinusoid_table[padding_idx] = 0.

return torch.FloatTensor(sinusoid_table)


class Past:
@@ -82,7 +111,7 @@ class Decoder(nn.Module):
"""
raise NotImplemented

def decode_one(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]:
def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]:
"""
当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了
解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态
@@ -100,6 +129,7 @@ class DecoderMultiheadAttention(nn.Module):
"""

def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None):
super(DecoderMultiheadAttention, self).__init__()
self.d_model = d_model
self.n_head = n_head
self.dropout = dropout
@@ -157,11 +187,11 @@ class DecoderMultiheadAttention(nn.Module):
past.encoder_key[self.layer_idx] = k
past.encoder_value[self.layer_idx] = v
if inference and not is_encoder_attn:
past.decoder_prev_key[self.layer_idx] = prev_k
past.decoder_prev_value[self.layer_idx] = prev_v
past.decoder_prev_key[self.layer_idx] = prev_k if prev_k is not None else k
past.decoder_prev_value[self.layer_idx] = prev_v if prev_v is not None else v

batch_size, q_len, d_model = query.size()
k_len, v_len = key.size(1), value.size(1)
k_len, v_len = k.size(1), v.size(1)
q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim)
k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim)
v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim)
@@ -172,8 +202,8 @@ class DecoderMultiheadAttention(nn.Module):
if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len
mask = mask[:, None, :, None]
else: # (1, seq_len, seq_len)
mask = mask[...:None]
_mask = mask
mask = mask[..., None]
_mask = ~mask.bool()

attn_weights = attn_weights.masked_fill(_mask, float('-inf'))

@@ -181,21 +211,22 @@ class DecoderMultiheadAttention(nn.Module):
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim
output = output.view(batch_size, q_len, -1)
output = output.reshape(batch_size, q_len, -1)
output = self.out_proj(output) # batch,q_len,dim

return output, attn_weights

def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj)
nn.init.xavier_uniform_(self.k_proj)
nn.init.xavier_uniform_(self.v_proj)
nn.init.xavier_uniform_(self.out_proj)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)


class TransformerSeq2SeqDecoderLayer(nn.Module):
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
layer_idx: int = None):
super(TransformerSeq2SeqDecoderLayer, self).__init__()
self.d_model = d_model
self.n_head = n_head
self.dim_ff = dim_ff
@@ -313,10 +344,10 @@ class TransformerSeq2SeqDecoder(Decoder):
if isinstance(self.token_embed, StaticEmbedding):
for i in self.token_embed.words_to_words:
assert i == self.token_embed.words_to_words[i], "The index does not match."
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1))
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1), requires_grad=True)
else:
if isinstance(output_embed, nn.Embedding):
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1))
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1), requires_grad=True)
else:
self.output_embed = output_embed.transpose(0, 1)
self.output_hidden_size = self.output_embed.size(0)
@@ -326,7 +357,8 @@ class TransformerSeq2SeqDecoder(Decoder):
def forward(self, tokens, past, return_attention=False, inference=False):
"""

:param tokens: torch.LongTensor, tokens: batch_size x decode_len
:param tokens: torch.LongTensor, tokens: batch_size , decode_len
:param self_attn_mask: 在inference的时候不需要,而在train的时候,因为训练的时候交叉熵会自动屏蔽掉padding的地方,所以也不需要
:param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算
:param return_attention:
:param inference: 是否在inference阶段
@@ -335,13 +367,16 @@ class TransformerSeq2SeqDecoder(Decoder):
assert past is not None
batch_size, decode_len = tokens.size()
device = tokens.device
pos_idx = torch.arange(1, decode_len + 1).unsqueeze(0).long()

if not inference:
self_attn_mask = self._get_triangle_mask(decode_len)
self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq
else:
self_attn_mask = None

tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim
pos = self.pos_embed(tokens) # bs,decode_len,embed_dim
pos = self.pos_embed(pos_idx) # 1,decode_len,embed_dim
tokens = pos + tokens
if inference:
tokens = tokens[:, -1:, :]
@@ -358,7 +393,7 @@ class TransformerSeq2SeqDecoder(Decoder):
return output

@torch.no_grad()
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]:
def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]:
"""
# todo: 是否不需要return past? 因为past已经被改变了,不需要显式return?
:param tokens: torch.LongTensor (batch_size,1)
@@ -370,7 +405,7 @@ class TransformerSeq2SeqDecoder(Decoder):

def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast:
past.reorder_past(indices)
return past # todo : 其实可以不要这个的
return past

def _get_triangle_mask(self, max_seq_len):
tensor = torch.ones(max_seq_len, max_seq_len)
@@ -409,6 +444,27 @@ class LSTMPast(Past):
return tensor[0].size(0)
return None

def _reorder_past(self, state, indices, dim=0):
if type(state) == torch.Tensor:
state = state.index_select(index=indices, dim=dim)
elif type(state) == tuple:
tmp_list = []
for i in range(len(state)):
assert state[i] is not None
tmp_list.append(state[i].index_select(index=indices, dim=dim))
state = tuple(tmp_list)
else:
raise ValueError('State does not support other format')

return state

def reorder_past(self, indices: torch.LongTensor):
self.encode_outputs = self._reorder_past(self.encode_outputs, indices)
self.encode_mask = self._reorder_past(self.encode_mask, indices)
self.hx = self._reorder_past(self.hx, indices, 1)
if self.attn_states is not None:
self.attn_states = self._reorder_past(self.attn_states, indices)

@property
def hx(self):
return self._hx
@@ -493,7 +549,7 @@ class AttentionLayer(nn.Module):


class LSTMDecoder(Decoder):
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers, input_size,
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers=3, input_size=400,
hidden_size=None, dropout=0,
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
bind_input_output_embed=False,
@@ -612,6 +668,7 @@ class LSTMDecoder(Decoder):
for i in range(tokens.size(1)):
input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h'
# bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size)

_, (hidden, cell) = self.lstm(input, hx=past.hx)
past.hx = (hidden, cell)
if self.attention_layer is not None:
@@ -633,7 +690,7 @@ class LSTMDecoder(Decoder):
return feats

@torch.no_grad()
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]:
def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]:
"""
给定上一个位置的输出,决定当前位置的输出。
:param torch.LongTensor tokens: batch_size x seq_len
@@ -653,13 +710,5 @@ class LSTMDecoder(Decoder):
:param LSTMPast past: 保存的过去的状态
:return:
"""
encode_outputs = past.encode_outputs.index_select(index=indices, dim=0)
encoder_mask = past.encode_mask.index_select(index=indices, dim=0)
hx = (past.hx[0].index_select(index=indices, dim=1),
past.hx[1].index_select(index=indices, dim=1))
if past.attn_states is not None:
past.attn_states = past.attn_states.index_select(index=indices, dim=0)
past.encode_mask = encoder_mask
past.encode_outputs = encode_outputs
past.hx = hx
past.reorder_past(indices)
return past

+ 19
- 16
fastNLP/modules/decoder/seq2seq_generator.py View File

@@ -1,7 +1,10 @@
__all__ = [
"SequenceGenerator"
]
import torch
from .seq2seq_decoder import Decoder
import torch.nn.functional as F
from fastNLP.core.utils import _get_model_device
from ...core.utils import _get_model_device
from functools import partial


@@ -130,8 +133,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
else:
_eos_token_id = eos_token_id

for i in range(tokens.size(1) - 1):
scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past
for i in range(tokens.size(1)):
scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past

token_ids = tokens.clone()
cur_len = token_ids.size(1)
@@ -139,7 +142,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
# tokens = tokens[:, -1:]

while cur_len < max_length:
scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past

if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
@@ -153,7 +156,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
eos_mask = scores.new_ones(scores.size(1))
eos_mask[eos_token_id] = 0
eos_mask = eos_mask.unsqueeze(0).eq(1)
scores = scores.masked_scatter(eos_mask, token_scores)
scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小

if do_sample:
if temperature > 0 and temperature != 1:
@@ -167,7 +170,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
else:
next_tokens = torch.argmax(scores, dim=-1) # batch_size

next_tokens = next_tokens.masked_fill(dones, 0)
next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding
tokens = next_tokens.unsqueeze(1)

token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
@@ -181,7 +184,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt

if eos_token_id is not None:
if cur_len == max_length:
token_ids[:, -1].masked_fill_(dones, eos_token_id)
token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos

return token_ids

@@ -206,9 +209,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."

for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode
scores, past = decoder.decode_one(tokens[:, :i + 1],
past) # (batch_size, vocab_size), Past
scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度
scores, past = decoder.decode(tokens[:, :i + 1],
past) # (batch_size, vocab_size), Past
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."

@@ -224,7 +227,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2

indices = torch.arange(batch_size, dtype=torch.long).to(device)
indices = indices.repeat_interleave(num_beams)
past = decoder.reorder_past(indices, past)
decoder.reorder_past(indices, past)

tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
# 记录生成好的token (batch_size', cur_len)
@@ -240,11 +243,11 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
hypos = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
# 0,num_beams, 2*num_beams
# 0,num_beams, 2*num_beams, ...
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)

while cur_len < max_length:
scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past
scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past

if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
@@ -298,8 +301,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
beam_scores = _next_scores.view(-1)

# 更改past状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1)
past = decoder.reorder_past(reorder_inds, past)
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
decoder.reorder_past(reorder_inds, past)

flag = True
if cur_len + 1 == max_length:
@@ -445,7 +448,7 @@ if __name__ == '__main__':
super().__init__()
self.num_words = num_words

def decode_one(self, tokens, past):
def decode(self, tokens, past):
batch_size = tokens.size(0)
return torch.randn(batch_size, self.num_words), past



+ 5
- 0
fastNLP/modules/encoder/__init__.py View File

@@ -30,6 +30,9 @@ __all__ = [
"MultiHeadAttention",
"BiAttention",
"SelfAttention",

"BiLSTMEncoder",
"TransformerSeq2SeqEncoder"
]

from .attention import MultiHeadAttention, BiAttention, SelfAttention
@@ -41,3 +44,5 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo
from .star_transformer import StarTransformer
from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU

from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder

+ 7
- 2
fastNLP/modules/encoder/seq2seq_encoder.py View File

@@ -1,7 +1,12 @@
__all__ = [
"TransformerSeq2SeqEncoder",
"BiLSTMEncoder"
]

from torch import nn
import torch
from fastNLP.modules import LSTM
from fastNLP import seq_len_to_mask
from ...modules.encoder import LSTM
from ...core.utils import seq_len_to_mask
from torch.nn import TransformerEncoder
from typing import Union, Tuple
import numpy as np


+ 1
- 0
test/models/test_seq2seq_model.py View File

@@ -0,0 +1 @@


+ 65
- 0
test/modules/decoder/test_seq2seq_decoder.py View File

@@ -0,0 +1,65 @@
import unittest

import torch

from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder
from fastNLP import Vocabulary
from fastNLP.embeddings import StaticEmbedding
from fastNLP.core.utils import seq_len_to_mask


class TestTransformerSeq2SeqDecoder(unittest.TestCase):
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)

encoder = TransformerSeq2SeqEncoder(embed)
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True)

src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])

encoder_outputs, mask = encoder(src_words_idx, src_seq_len)
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask)

decoder_outputs = decoder(tgt_words_idx, past)

print(decoder_outputs)
print(mask)

self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))

def test_decode(self):
pass # todo


class TestLSTMDecoder(unittest.TestCase):
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)

encoder = BiLSTMEncoder(embed)
decoder = LSTMDecoder(embed, bind_input_output_embed=True)

src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])

words, hx = encoder(src_words_idx, src_seq_len)
encode_mask = seq_len_to_mask(src_seq_len)
hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell))
decoder_outputs = decoder(tgt_words_idx, past)

print(decoder_outputs)
print(encode_mask)

self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))

def test_decode(self):
pass # todo

+ 52
- 0
test/modules/decoder/test_seq2seq_generator.py View File

@@ -0,0 +1,52 @@
import unittest

import torch

from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder
from fastNLP import Vocabulary
from fastNLP.embeddings import StaticEmbedding
from fastNLP.core.utils import seq_len_to_mask
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator


class TestSequenceGenerator(unittest.TestCase):
def test_case_for_transformer(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)
encoder = TransformerSeq2SeqEncoder(embed, num_layers=6)
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True, num_layers=6)

src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])

encoder_outputs, mask = encoder(src_words_idx, src_seq_len)
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask, num_decoder_layer=6)

generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2)
tokens_ids = generator.generate(past=past)

print(tokens_ids)

def test_case_for_lstm(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)
encoder = BiLSTMEncoder(embed)
decoder = LSTMDecoder(embed, bind_input_output_embed=True)
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])

words, hx = encoder(src_words_idx, src_seq_len)
encode_mask = seq_len_to_mask(src_seq_len)
hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell))

generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2)
tokens_ids = generator.generate(past=past)

print(tokens_ids)

+ 35
- 0
test/modules/encoder/test_seq2seq_encoder.py View File

@@ -0,0 +1,35 @@
import unittest

import torch

from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
from fastNLP import Vocabulary
from fastNLP.embeddings import StaticEmbedding

class TestTransformerSeq2SeqEncoder(unittest.TestCase):
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = StaticEmbedding(vocab, embedding_dim=512)
encoder = TransformerSeq2SeqEncoder(embed)
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
seq_len = torch.LongTensor([3])
outputs, mask = encoder(words_idx, seq_len)

print(outputs)
print(mask)
self.assertEqual(outputs.size(), (1, 3, 512))


class TestBiLSTMEncoder(unittest.TestCase):
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = StaticEmbedding(vocab, embedding_dim=300)
encoder = BiLSTMEncoder(embed, hidden_size=300)
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
seq_len = torch.LongTensor([3])

outputs, hx = encoder(words_idx, seq_len)

# print(outputs)
# print(hx)
self.assertEqual(outputs.size(), (1, 3, 300))

Loading…
Cancel
Save