Browse Source

保存一版旧版

tags/v0.6.0
linzehui 5 years ago
parent
commit
15360e9724
12 changed files with 300 additions and 67 deletions
  1. +4
    -2
      fastNLP/models/__init__.py
  2. +6
    -11
      fastNLP/models/seq2seq_model.py
  3. +13
    -1
      fastNLP/modules/__init__.py
  4. +11
    -2
      fastNLP/modules/decoder/__init__.py
  5. +82
    -33
      fastNLP/modules/decoder/seq2seq_decoder.py
  6. +19
    -16
      fastNLP/modules/decoder/seq2seq_generator.py
  7. +5
    -0
      fastNLP/modules/encoder/__init__.py
  8. +7
    -2
      fastNLP/modules/encoder/seq2seq_encoder.py
  9. +1
    -0
      test/models/test_seq2seq_model.py
  10. +65
    -0
      test/modules/decoder/test_seq2seq_decoder.py
  11. +52
    -0
      test/modules/decoder/test_seq2seq_generator.py
  12. +35
    -0
      test/modules/encoder/test_seq2seq_encoder.py

+ 4
- 2
fastNLP/models/__init__.py View File

@@ -28,7 +28,9 @@ __all__ = [
"BertForSentenceMatching", "BertForSentenceMatching",
"BertForMultipleChoice", "BertForMultipleChoice",
"BertForTokenClassification", "BertForTokenClassification",
"BertForQuestionAnswering"
"BertForQuestionAnswering",

"TransformerSeq2SeqModel"
] ]


from .base_model import BaseModel from .base_model import BaseModel
@@ -39,7 +41,7 @@ from .cnn_text_classification import CNNText
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
from .snli import ESIM from .snli import ESIM
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel
from .seq2seq_model import TransformerSeq2SeqModel
import sys import sys
from ..doc_utils import doc_process from ..doc_utils import doc_process
doc_process(sys.modules[__name__]) doc_process(sys.modules[__name__])

fastNLP/models/seq2seq_models.py → fastNLP/models/seq2seq_model.py View File

@@ -1,29 +1,24 @@
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator
import torch.nn as nn import torch.nn as nn
import torch import torch
from typing import Union, Tuple from typing import Union, Tuple
import numpy as np import numpy as np
from fastNLP.modules import TransformerSeq2SeqDecoder, TransformerSeq2SeqEncoder, TransformerPast




class TransformerSeq2SeqModel(nn.Module):
class TransformerSeq2SeqModel(nn.Module): # todo 参考fairseq的FairseqModel的写法
def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
num_layers: int = 6,
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
num_layers: int = 6, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
bind_input_output_embed=False,
sos_id=None, eos_id=None):
bind_input_output_embed=False):
super().__init__() super().__init__()
self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout) self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout)
self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed, self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed,
bind_input_output_embed) bind_input_output_embed)
self.sos_id = sos_id
self.eos_id = eos_id

self.num_layers = num_layers self.num_layers = num_layers


def forward(self, words, target, seq_len): # todo:这里的target有sos和eos吗,参考一下lstm怎么写的
def forward(self, words, target, seq_len):
encoder_output, encoder_mask = self.encoder(words, seq_len) encoder_output, encoder_mask = self.encoder(words, seq_len)
past = TransformerPast(encoder_output, encoder_mask, self.num_layers) past = TransformerPast(encoder_output, encoder_mask, self.num_layers)
outputs = self.decoder(target, past, return_attention=False) outputs = self.decoder(target, past, return_attention=False)

+ 13
- 1
fastNLP/modules/__init__.py View File

@@ -49,7 +49,19 @@ __all__ = [


"TimestepDropout", "TimestepDropout",


'summary'
'summary',

"BiLSTMEncoder",
"TransformerSeq2SeqEncoder",

"SequenceGenerator",
"LSTMDecoder",
"LSTMPast",
"TransformerSeq2SeqDecoder",
"TransformerPast",
"Decoder",
"Past"

] ]


import sys import sys


+ 11
- 2
fastNLP/modules/decoder/__init__.py View File

@@ -7,11 +7,20 @@ __all__ = [
"ConditionalRandomField", "ConditionalRandomField",
"viterbi_decode", "viterbi_decode",
"allowed_transitions", "allowed_transitions",
"seq2seq_decoder",
"seq2seq_generator"

"SequenceGenerator",
"LSTMDecoder",
"LSTMPast",
"TransformerSeq2SeqDecoder",
"TransformerPast",
"Decoder",
"Past",

] ]


from .crf import ConditionalRandomField from .crf import ConditionalRandomField
from .crf import allowed_transitions from .crf import allowed_transitions
from .mlp import MLP from .mlp import MLP
from .utils import viterbi_decode from .utils import viterbi_decode
from .seq2seq_generator import SequenceGenerator
from .seq2seq_decoder import *

+ 82
- 33
fastNLP/modules/decoder/seq2seq_decoder.py View File

@@ -1,17 +1,46 @@
# coding=utf-8 # coding=utf-8
__all__ = [
"TransformerPast",
"LSTMPast",
"Past",
"LSTMDecoder",
"TransformerSeq2SeqDecoder",
"Decoder"
]
import torch import torch
from torch import nn from torch import nn
import abc import abc
import torch.nn.functional as F import torch.nn.functional as F
from fastNLP.embeddings import StaticEmbedding
from ...embeddings import StaticEmbedding
import numpy as np import numpy as np
from typing import Union, Tuple from typing import Union, Tuple
from fastNLP.embeddings import get_embeddings
from fastNLP.modules import LSTM
from ...embeddings.utils import get_embeddings
from torch.nn import LayerNorm from torch.nn import LayerNorm
import math import math
from reproduction.Summarization.Baseline.tools.PositionEmbedding import \
get_sinusoid_encoding_table # todo: 应该将position embedding移到core


# from reproduction.Summarization.Baseline.tools.PositionEmbedding import \
# get_sinusoid_encoding_table # todo: 应该将position embedding移到core

def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
''' Sinusoid position encoding table '''

def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)

def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]

sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])

sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

if padding_idx is not None:
# zero vector for padding dimension
sinusoid_table[padding_idx] = 0.

return torch.FloatTensor(sinusoid_table)




class Past: class Past:
@@ -82,7 +111,7 @@ class Decoder(nn.Module):
""" """
raise NotImplemented raise NotImplemented


def decode_one(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]:
def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]:
""" """
当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了 当模型进行解码时,使用这个函数。只返回一个batch_size x vocab_size的结果。需要考虑一种特殊情况,即tokens长度不是1,即给定了
解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态
@@ -100,6 +129,7 @@ class DecoderMultiheadAttention(nn.Module):
""" """


def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None):
super(DecoderMultiheadAttention, self).__init__()
self.d_model = d_model self.d_model = d_model
self.n_head = n_head self.n_head = n_head
self.dropout = dropout self.dropout = dropout
@@ -157,11 +187,11 @@ class DecoderMultiheadAttention(nn.Module):
past.encoder_key[self.layer_idx] = k past.encoder_key[self.layer_idx] = k
past.encoder_value[self.layer_idx] = v past.encoder_value[self.layer_idx] = v
if inference and not is_encoder_attn: if inference and not is_encoder_attn:
past.decoder_prev_key[self.layer_idx] = prev_k
past.decoder_prev_value[self.layer_idx] = prev_v
past.decoder_prev_key[self.layer_idx] = prev_k if prev_k is not None else k
past.decoder_prev_value[self.layer_idx] = prev_v if prev_v is not None else v


batch_size, q_len, d_model = query.size() batch_size, q_len, d_model = query.size()
k_len, v_len = key.size(1), value.size(1)
k_len, v_len = k.size(1), v.size(1)
q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim) q = q.contiguous().view(batch_size, q_len, self.n_head, self.head_dim)
k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim) k = k.contiguous().view(batch_size, k_len, self.n_head, self.head_dim)
v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim) v = v.contiguous().view(batch_size, v_len, self.n_head, self.head_dim)
@@ -172,8 +202,8 @@ class DecoderMultiheadAttention(nn.Module):
if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len if len(mask.size()) == 2: # 是encoder mask, batch,src_len/k_len
mask = mask[:, None, :, None] mask = mask[:, None, :, None]
else: # (1, seq_len, seq_len) else: # (1, seq_len, seq_len)
mask = mask[...:None]
_mask = mask
mask = mask[..., None]
_mask = ~mask.bool()


attn_weights = attn_weights.masked_fill(_mask, float('-inf')) attn_weights = attn_weights.masked_fill(_mask, float('-inf'))


@@ -181,21 +211,22 @@ class DecoderMultiheadAttention(nn.Module):
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)


output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim
output = output.view(batch_size, q_len, -1)
output = output.reshape(batch_size, q_len, -1)
output = self.out_proj(output) # batch,q_len,dim output = self.out_proj(output) # batch,q_len,dim


return output, attn_weights return output, attn_weights


def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj)
nn.init.xavier_uniform_(self.k_proj)
nn.init.xavier_uniform_(self.v_proj)
nn.init.xavier_uniform_(self.out_proj)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)




class TransformerSeq2SeqDecoderLayer(nn.Module): class TransformerSeq2SeqDecoderLayer(nn.Module):
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1, def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
layer_idx: int = None): layer_idx: int = None):
super(TransformerSeq2SeqDecoderLayer, self).__init__()
self.d_model = d_model self.d_model = d_model
self.n_head = n_head self.n_head = n_head
self.dim_ff = dim_ff self.dim_ff = dim_ff
@@ -313,10 +344,10 @@ class TransformerSeq2SeqDecoder(Decoder):
if isinstance(self.token_embed, StaticEmbedding): if isinstance(self.token_embed, StaticEmbedding):
for i in self.token_embed.words_to_words: for i in self.token_embed.words_to_words:
assert i == self.token_embed.words_to_words[i], "The index does not match." assert i == self.token_embed.words_to_words[i], "The index does not match."
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1))
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1), requires_grad=True)
else: else:
if isinstance(output_embed, nn.Embedding): if isinstance(output_embed, nn.Embedding):
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1))
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1), requires_grad=True)
else: else:
self.output_embed = output_embed.transpose(0, 1) self.output_embed = output_embed.transpose(0, 1)
self.output_hidden_size = self.output_embed.size(0) self.output_hidden_size = self.output_embed.size(0)
@@ -326,7 +357,8 @@ class TransformerSeq2SeqDecoder(Decoder):
def forward(self, tokens, past, return_attention=False, inference=False): def forward(self, tokens, past, return_attention=False, inference=False):
""" """


:param tokens: torch.LongTensor, tokens: batch_size x decode_len
:param tokens: torch.LongTensor, tokens: batch_size , decode_len
:param self_attn_mask: 在inference的时候不需要,而在train的时候,因为训练的时候交叉熵会自动屏蔽掉padding的地方,所以也不需要
:param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算 :param past: TransformerPast: 包含encoder输出及mask,在inference阶段保存了上一时刻的key和value以减少矩阵运算
:param return_attention: :param return_attention:
:param inference: 是否在inference阶段 :param inference: 是否在inference阶段
@@ -335,13 +367,16 @@ class TransformerSeq2SeqDecoder(Decoder):
assert past is not None assert past is not None
batch_size, decode_len = tokens.size() batch_size, decode_len = tokens.size()
device = tokens.device device = tokens.device
pos_idx = torch.arange(1, decode_len + 1).unsqueeze(0).long()

if not inference: if not inference:
self_attn_mask = self._get_triangle_mask(decode_len) self_attn_mask = self._get_triangle_mask(decode_len)
self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq self_attn_mask = self_attn_mask.to(device)[None, :, :] # 1,seq,seq
else: else:
self_attn_mask = None self_attn_mask = None

tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim tokens = self.token_embed(tokens) * self.embed_scale # bs,decode_len,embed_dim
pos = self.pos_embed(tokens) # bs,decode_len,embed_dim
pos = self.pos_embed(pos_idx) # 1,decode_len,embed_dim
tokens = pos + tokens tokens = pos + tokens
if inference: if inference:
tokens = tokens[:, -1:, :] tokens = tokens[:, -1:, :]
@@ -358,7 +393,7 @@ class TransformerSeq2SeqDecoder(Decoder):
return output return output


@torch.no_grad() @torch.no_grad()
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]:
def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]:
""" """
# todo: 是否不需要return past? 因为past已经被改变了,不需要显式return? # todo: 是否不需要return past? 因为past已经被改变了,不需要显式return?
:param tokens: torch.LongTensor (batch_size,1) :param tokens: torch.LongTensor (batch_size,1)
@@ -370,7 +405,7 @@ class TransformerSeq2SeqDecoder(Decoder):


def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast: def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast:
past.reorder_past(indices) past.reorder_past(indices)
return past # todo : 其实可以不要这个的
return past


def _get_triangle_mask(self, max_seq_len): def _get_triangle_mask(self, max_seq_len):
tensor = torch.ones(max_seq_len, max_seq_len) tensor = torch.ones(max_seq_len, max_seq_len)
@@ -409,6 +444,27 @@ class LSTMPast(Past):
return tensor[0].size(0) return tensor[0].size(0)
return None return None


def _reorder_past(self, state, indices, dim=0):
if type(state) == torch.Tensor:
state = state.index_select(index=indices, dim=dim)
elif type(state) == tuple:
tmp_list = []
for i in range(len(state)):
assert state[i] is not None
tmp_list.append(state[i].index_select(index=indices, dim=dim))
state = tuple(tmp_list)
else:
raise ValueError('State does not support other format')

return state

def reorder_past(self, indices: torch.LongTensor):
self.encode_outputs = self._reorder_past(self.encode_outputs, indices)
self.encode_mask = self._reorder_past(self.encode_mask, indices)
self.hx = self._reorder_past(self.hx, indices, 1)
if self.attn_states is not None:
self.attn_states = self._reorder_past(self.attn_states, indices)

@property @property
def hx(self): def hx(self):
return self._hx return self._hx
@@ -493,7 +549,7 @@ class AttentionLayer(nn.Module):




class LSTMDecoder(Decoder): class LSTMDecoder(Decoder):
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers, input_size,
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers=3, input_size=400,
hidden_size=None, dropout=0, hidden_size=None, dropout=0,
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None, output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
bind_input_output_embed=False, bind_input_output_embed=False,
@@ -612,6 +668,7 @@ class LSTMDecoder(Decoder):
for i in range(tokens.size(1)): for i in range(tokens.size(1)):
input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h' input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h'
# bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size) # bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size)

_, (hidden, cell) = self.lstm(input, hx=past.hx) _, (hidden, cell) = self.lstm(input, hx=past.hx)
past.hx = (hidden, cell) past.hx = (hidden, cell)
if self.attention_layer is not None: if self.attention_layer is not None:
@@ -633,7 +690,7 @@ class LSTMDecoder(Decoder):
return feats return feats


@torch.no_grad() @torch.no_grad()
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]:
def decode(self, tokens, past) -> Tuple[torch.Tensor, Past]:
""" """
给定上一个位置的输出,决定当前位置的输出。 给定上一个位置的输出,决定当前位置的输出。
:param torch.LongTensor tokens: batch_size x seq_len :param torch.LongTensor tokens: batch_size x seq_len
@@ -653,13 +710,5 @@ class LSTMDecoder(Decoder):
:param LSTMPast past: 保存的过去的状态 :param LSTMPast past: 保存的过去的状态
:return: :return:
""" """
encode_outputs = past.encode_outputs.index_select(index=indices, dim=0)
encoder_mask = past.encode_mask.index_select(index=indices, dim=0)
hx = (past.hx[0].index_select(index=indices, dim=1),
past.hx[1].index_select(index=indices, dim=1))
if past.attn_states is not None:
past.attn_states = past.attn_states.index_select(index=indices, dim=0)
past.encode_mask = encoder_mask
past.encode_outputs = encode_outputs
past.hx = hx
past.reorder_past(indices)
return past return past

+ 19
- 16
fastNLP/modules/decoder/seq2seq_generator.py View File

@@ -1,7 +1,10 @@
__all__ = [
"SequenceGenerator"
]
import torch import torch
from .seq2seq_decoder import Decoder from .seq2seq_decoder import Decoder
import torch.nn.functional as F import torch.nn.functional as F
from fastNLP.core.utils import _get_model_device
from ...core.utils import _get_model_device
from functools import partial from functools import partial




@@ -130,8 +133,8 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
else: else:
_eos_token_id = eos_token_id _eos_token_id = eos_token_id


for i in range(tokens.size(1) - 1):
scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past
for i in range(tokens.size(1)):
scores, past = decoder.decode(tokens[:, :i + 1], past) # batch_size x vocab_size, Past


token_ids = tokens.clone() token_ids = tokens.clone()
cur_len = token_ids.size(1) cur_len = token_ids.size(1)
@@ -139,7 +142,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
# tokens = tokens[:, -1:] # tokens = tokens[:, -1:]


while cur_len < max_length: while cur_len < max_length:
scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past


if repetition_penalty != 1.0: if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids) token_scores = scores.gather(dim=1, index=token_ids)
@@ -153,7 +156,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
eos_mask = scores.new_ones(scores.size(1)) eos_mask = scores.new_ones(scores.size(1))
eos_mask[eos_token_id] = 0 eos_mask[eos_token_id] = 0
eos_mask = eos_mask.unsqueeze(0).eq(1) eos_mask = eos_mask.unsqueeze(0).eq(1)
scores = scores.masked_scatter(eos_mask, token_scores)
scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小


if do_sample: if do_sample:
if temperature > 0 and temperature != 1: if temperature > 0 and temperature != 1:
@@ -167,7 +170,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt
else: else:
next_tokens = torch.argmax(scores, dim=-1) # batch_size next_tokens = torch.argmax(scores, dim=-1) # batch_size


next_tokens = next_tokens.masked_fill(dones, 0)
next_tokens = next_tokens.masked_fill(dones, 0) # 对已经搜索完成的sample做padding
tokens = next_tokens.unsqueeze(1) tokens = next_tokens.unsqueeze(1)


token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
@@ -181,7 +184,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt


if eos_token_id is not None: if eos_token_id is not None:
if cur_len == max_length: if cur_len == max_length:
token_ids[:, -1].masked_fill_(dones, eos_token_id)
token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos


return token_ids return token_ids


@@ -206,9 +209,9 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."


for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode
scores, past = decoder.decode_one(tokens[:, :i + 1],
past) # (batch_size, vocab_size), Past
scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度
scores, past = decoder.decode(tokens[:, :i + 1],
past) # (batch_size, vocab_size), Past
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度
vocab_size = scores.size(1) vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."


@@ -224,7 +227,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2


indices = torch.arange(batch_size, dtype=torch.long).to(device) indices = torch.arange(batch_size, dtype=torch.long).to(device)
indices = indices.repeat_interleave(num_beams) indices = indices.repeat_interleave(num_beams)
past = decoder.reorder_past(indices, past)
decoder.reorder_past(indices, past)


tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
# 记录生成好的token (batch_size', cur_len) # 记录生成好的token (batch_size', cur_len)
@@ -240,11 +243,11 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
hypos = [ hypos = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size) BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
] ]
# 0,num_beams, 2*num_beams
# 0,num_beams, 2*num_beams, ...
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)


while cur_len < max_length: while cur_len < max_length:
scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past
scores, past = decoder.decode(tokens, past) # batch_size * num_beams x vocab_size, Past


if repetition_penalty != 1.0: if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids) token_scores = scores.gather(dim=1, index=token_ids)
@@ -298,8 +301,8 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2
beam_scores = _next_scores.view(-1) beam_scores = _next_scores.view(-1)


# 更改past状态, 重组token_ids # 更改past状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1)
past = decoder.reorder_past(reorder_inds, past)
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维
decoder.reorder_past(reorder_inds, past)


flag = True flag = True
if cur_len + 1 == max_length: if cur_len + 1 == max_length:
@@ -445,7 +448,7 @@ if __name__ == '__main__':
super().__init__() super().__init__()
self.num_words = num_words self.num_words = num_words


def decode_one(self, tokens, past):
def decode(self, tokens, past):
batch_size = tokens.size(0) batch_size = tokens.size(0)
return torch.randn(batch_size, self.num_words), past return torch.randn(batch_size, self.num_words), past




+ 5
- 0
fastNLP/modules/encoder/__init__.py View File

@@ -30,6 +30,9 @@ __all__ = [
"MultiHeadAttention", "MultiHeadAttention",
"BiAttention", "BiAttention",
"SelfAttention", "SelfAttention",

"BiLSTMEncoder",
"TransformerSeq2SeqEncoder"
] ]


from .attention import MultiHeadAttention, BiAttention, SelfAttention from .attention import MultiHeadAttention, BiAttention, SelfAttention
@@ -41,3 +44,5 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo
from .star_transformer import StarTransformer from .star_transformer import StarTransformer
from .transformer import TransformerEncoder from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU from .variational_rnn import VarRNN, VarLSTM, VarGRU

from .seq2seq_encoder import BiLSTMEncoder, TransformerSeq2SeqEncoder

+ 7
- 2
fastNLP/modules/encoder/seq2seq_encoder.py View File

@@ -1,7 +1,12 @@
__all__ = [
"TransformerSeq2SeqEncoder",
"BiLSTMEncoder"
]

from torch import nn from torch import nn
import torch import torch
from fastNLP.modules import LSTM
from fastNLP import seq_len_to_mask
from ...modules.encoder import LSTM
from ...core.utils import seq_len_to_mask
from torch.nn import TransformerEncoder from torch.nn import TransformerEncoder
from typing import Union, Tuple from typing import Union, Tuple
import numpy as np import numpy as np


+ 1
- 0
test/models/test_seq2seq_model.py View File

@@ -0,0 +1 @@


+ 65
- 0
test/modules/decoder/test_seq2seq_decoder.py View File

@@ -0,0 +1,65 @@
import unittest

import torch

from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder
from fastNLP import Vocabulary
from fastNLP.embeddings import StaticEmbedding
from fastNLP.core.utils import seq_len_to_mask


class TestTransformerSeq2SeqDecoder(unittest.TestCase):
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)

encoder = TransformerSeq2SeqEncoder(embed)
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True)

src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])

encoder_outputs, mask = encoder(src_words_idx, src_seq_len)
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask)

decoder_outputs = decoder(tgt_words_idx, past)

print(decoder_outputs)
print(mask)

self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))

def test_decode(self):
pass # todo


class TestLSTMDecoder(unittest.TestCase):
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)

encoder = BiLSTMEncoder(embed)
decoder = LSTMDecoder(embed, bind_input_output_embed=True)

src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])

words, hx = encoder(src_words_idx, src_seq_len)
encode_mask = seq_len_to_mask(src_seq_len)
hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell))
decoder_outputs = decoder(tgt_words_idx, past)

print(decoder_outputs)
print(encode_mask)

self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))

def test_decode(self):
pass # todo

+ 52
- 0
test/modules/decoder/test_seq2seq_generator.py View File

@@ -0,0 +1,52 @@
import unittest

import torch

from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, LSTMDecoder
from fastNLP import Vocabulary
from fastNLP.embeddings import StaticEmbedding
from fastNLP.core.utils import seq_len_to_mask
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator


class TestSequenceGenerator(unittest.TestCase):
def test_case_for_transformer(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)
encoder = TransformerSeq2SeqEncoder(embed, num_layers=6)
decoder = TransformerSeq2SeqDecoder(embed=embed, bind_input_output_embed=True, num_layers=6)

src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])

encoder_outputs, mask = encoder(src_words_idx, src_seq_len)
past = TransformerPast(encoder_outputs=encoder_outputs, encoder_mask=mask, num_decoder_layer=6)

generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2)
tokens_ids = generator.generate(past=past)

print(tokens_ids)

def test_case_for_lstm(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
vocab.add_word_lst("Another test !".split())
embed = StaticEmbedding(vocab, embedding_dim=512)
encoder = BiLSTMEncoder(embed)
decoder = LSTMDecoder(embed, bind_input_output_embed=True)
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
src_seq_len = torch.LongTensor([3, 2])

words, hx = encoder(src_words_idx, src_seq_len)
encode_mask = seq_len_to_mask(src_seq_len)
hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell))

generator = SequenceGenerator(decoder, bos_token_id=1, eos_token_id=2, num_beams=2)
tokens_ids = generator.generate(past=past)

print(tokens_ids)

+ 35
- 0
test/modules/encoder/test_seq2seq_encoder.py View File

@@ -0,0 +1,35 @@
import unittest

import torch

from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, BiLSTMEncoder
from fastNLP import Vocabulary
from fastNLP.embeddings import StaticEmbedding

class TestTransformerSeq2SeqEncoder(unittest.TestCase):
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = StaticEmbedding(vocab, embedding_dim=512)
encoder = TransformerSeq2SeqEncoder(embed)
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
seq_len = torch.LongTensor([3])
outputs, mask = encoder(words_idx, seq_len)

print(outputs)
print(mask)
self.assertEqual(outputs.size(), (1, 3, 512))


class TestBiLSTMEncoder(unittest.TestCase):
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = StaticEmbedding(vocab, embedding_dim=300)
encoder = BiLSTMEncoder(embed, hidden_size=300)
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
seq_len = torch.LongTensor([3])

outputs, hx = encoder(words_idx, seq_len)

# print(outputs)
# print(hx)
self.assertEqual(outputs.size(), (1, 3, 300))

Loading…
Cancel
Save