Browse Source

add t2t model

tags/v0.6.0
linzehui 5 years ago
parent
commit
7000cdebc4
3 changed files with 32 additions and 16 deletions
  1. +31
    -0
      fastNLP/models/seq2seq_models.py
  2. +0
    -15
      fastNLP/modules/decoder/seq2seq_decoder.py
  3. +1
    -1
      fastNLP/modules/encoder/seq2seq_encoder.py

+ 31
- 0
fastNLP/models/seq2seq_models.py View File

@@ -0,0 +1,31 @@
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder
from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast
from fastNLP.modules.decoder.seq2seq_generator import SequenceGenerator
import torch.nn as nn
import torch
from typing import Union, Tuple
import numpy as np


class TransformerSeq2SeqModel(nn.Module):
def __init__(self, src_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
tgt_embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray],
num_layers: int = 6,
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1,
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
bind_input_output_embed=False,
sos_id=None, eos_id=None):
super().__init__()
self.encoder = TransformerSeq2SeqEncoder(src_embed, num_layers, d_model, n_head, dim_ff, dropout)
self.decoder = TransformerSeq2SeqDecoder(tgt_embed, num_layers, d_model, n_head, dim_ff, dropout, output_embed,
bind_input_output_embed)
self.sos_id = sos_id
self.eos_id = eos_id
self.num_layers = num_layers

def forward(self, words, target, seq_len): # todo:这里的target有sos和eos吗,参考一下lstm怎么写的
encoder_output, encoder_mask = self.encoder(words, seq_len)
past = TransformerPast(encoder_output, encoder_mask, self.num_layers)
outputs = self.decoder(target, past, return_attention=False)

return outputs

+ 0
- 15
fastNLP/modules/decoder/seq2seq_decoder.py View File

@@ -360,7 +360,6 @@ class TransformerSeq2SeqDecoder(Decoder):
@torch.no_grad()
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]:
"""
# todo: 对于transformer而言,因为position的原因,需要输入整个prefix序列,因此lstm的decode one和beam search需要改一下,以统一接口
# todo: 是否不需要return past? 因为past已经被改变了,不需要显式return?
:param tokens: torch.LongTensor (batch_size,1)
:param past: TransformerPast
@@ -378,20 +377,6 @@ class TransformerSeq2SeqDecoder(Decoder):
return torch.tril(tensor).byte()


class BiLSTMEncoder(nn.Module):
def __init__(self, embed, num_layers=3, hidden_size=400, dropout=0.3):
super().__init__()
self.embed = embed
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=True,
batch_first=True, dropout=dropout, num_layers=num_layers)

def forward(self, words, seq_len):
words = self.embed(words)
words, hx = self.lstm(words, seq_len)

return words, hx


class LSTMPast(Past):
def __init__(self, encode_outputs=None, encode_mask=None, decode_states=None, hx=None):
"""


+ 1
- 1
fastNLP/modules/encoder/seq2seq_encoder.py View File

@@ -12,7 +12,7 @@ class TransformerSeq2SeqEncoder(nn.Module):
d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
super(TransformerSeq2SeqEncoder, self).__init__()
self.embed = embed
self.transformer = TransformerEncoder(nn.TransformerEncoderLayer(d_model, n_head), num_layers)
self.transformer = TransformerEncoder(nn.TransformerEncoderLayer(d_model, n_head,dim_ff,dropout), num_layers)

def forward(self, words, seq_len):
"""


Loading…
Cancel
Save