Browse Source

add lstm module and generator

tags/v0.6.0
linzehui 5 years ago
parent
commit
879cef62c6
4 changed files with 820 additions and 8 deletions
  1. +3
    -1
      fastNLP/modules/decoder/__init__.py
  2. +333
    -7
      fastNLP/modules/decoder/seq2seq_decoder.py
  3. +470
    -0
      fastNLP/modules/decoder/seq2seq_generator.py
  4. +14
    -0
      fastNLP/modules/encoder/seq2seq_encoder.py

+ 3
- 1
fastNLP/modules/decoder/__init__.py View File

@@ -6,7 +6,9 @@ __all__ = [
"MLP",
"ConditionalRandomField",
"viterbi_decode",
"allowed_transitions"
"allowed_transitions",
"seq2seq_decoder",
"seq2seq_generator"
]

from .crf import ConditionalRandomField


+ 333
- 7
fastNLP/modules/decoder/seq2seq_decoder.py View File

@@ -7,6 +7,7 @@ from fastNLP.embeddings import StaticEmbedding
import numpy as np
from typing import Union, Tuple
from fastNLP.embeddings import get_embeddings
from fastNLP.modules import LSTM
from torch.nn import LayerNorm
import math
from reproduction.Summarization.Baseline.tools.PositionEmbedding import \
@@ -24,8 +25,7 @@ class Past:

class TransformerPast(Past):
def __init__(self, encoder_outputs: torch.Tensor = None, encoder_mask: torch.Tensor = None,
encoder_key: torch.Tensor = None, encoder_value: torch.Tensor = None,
decoder_prev_key: torch.Tensor = None, decoder_prev_value: torch.Tensor = None):
num_decoder_layer: int = 6):
"""

:param encoder_outputs: (batch,src_seq_len,dim)
@@ -37,16 +37,36 @@ class TransformerPast(Past):
"""
self.encoder_outputs = encoder_outputs
self.encoder_mask = encoder_mask
self.encoder_kv = encoder_key
self.encoder_value = encoder_value
self.decoder_prev_key = decoder_prev_key
self.decoder_prev_value = decoder_prev_value
self.encoder_key = [None] * num_decoder_layer
self.encoder_value = [None] * num_decoder_layer
self.decoder_prev_key = [None] * num_decoder_layer
self.decoder_prev_value = [None] * num_decoder_layer

def num_samples(self):
if self.encoder_outputs is not None:
return self.encoder_outputs.size(0)
return None

def _reorder_state(self, state, indices):
if type(state) == torch.Tensor:
state = state.index_select(index=indices, dim=0)
elif type(state) == list:
for i in range(len(state)):
assert state[i] is not None
state[i] = state[i].index_select(index=indices, dim=0)
else:
raise ValueError('State does not support other format')

return state

def reorder_past(self, indices: torch.LongTensor):
self.encoder_outputs = self._reorder_state(self.encoder_outputs, indices)
self.encoder_mask = self._reorder_state(self.encoder_mask, indices)
self.encoder_key = self._reorder_state(self.encoder_key, indices)
self.encoder_value = self._reorder_state(self.encoder_value, indices)
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices)
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices)


class Decoder(nn.Module):
def __init__(self):
@@ -324,7 +344,7 @@ class TransformerSeq2SeqDecoder(Decoder):
pos = self.pos_embed(tokens) # bs,decode_len,embed_dim
tokens = pos + tokens
if inference:
tokens = tokens[:, -1, :]
tokens = tokens[:, -1:, :]

x = F.dropout(tokens, p=self.dropout, training=self.training)
for layer in self.layer_stacks:
@@ -349,6 +369,312 @@ class TransformerSeq2SeqDecoder(Decoder):
output = self.forward(tokens, past, inference=True) # batch,1,vocab_size
return output.squeeze(1), past

def reorder_past(self, indices: torch.LongTensor, past: TransformerPast) -> TransformerPast:
past.reorder_past(indices)
return past # todo : 其实可以不要这个的

def _get_triangle_mask(self, max_seq_len):
tensor = torch.ones(max_seq_len, max_seq_len)
return torch.tril(tensor).byte()


class BiLSTMEncoder(nn.Module):
def __init__(self, embed, num_layers=3, hidden_size=400, dropout=0.3):
super().__init__()
self.embed = embed
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=True,
batch_first=True, dropout=dropout, num_layers=num_layers)

def forward(self, words, seq_len):
words = self.embed(words)
words, hx = self.lstm(words, seq_len)

return words, hx


class LSTMPast(Past):
def __init__(self, encode_outputs=None, encode_mask=None, decode_states=None, hx=None):
"""

:param torch.Tensor encode_outputs: batch_size x max_len x input_size
:param torch.Tensor encode_mask: batch_size x max_len, 与encode_outputs一样大,用以辅助decode的时候attention到正确的
词。为1的地方有词
:param torch.Tensor decode_states: batch_size x decode_len x hidden_size, Decoder中LSTM的输出结果
:param tuple hx: 包含LSTM所需要的h与c,h: num_layer x batch_size x hidden_size, c: num_layer x batch_size x hidden_size
"""
super().__init__()
self._encode_outputs = encode_outputs
if encode_mask is None:
if encode_outputs is not None:
self._encode_mask = encode_outputs.new_ones(encode_outputs.size(0), encode_outputs.size(1)).eq(1)
else:
self._encode_mask = None
else:
self._encode_mask = encode_mask
self._decode_states = decode_states
self._hx = hx # 包含了hidden和cell
self._attn_states = None # 当LSTM使用了Attention时会用到

def num_samples(self):
for tensor in (self.encode_outputs, self.decode_states, self.hx):
if tensor is not None:
if isinstance(tensor, torch.Tensor):
return tensor.size(0)
else:
return tensor[0].size(0)
return None

@property
def hx(self):
return self._hx

@hx.setter
def hx(self, hx):
self._hx = hx

@property
def encode_outputs(self):
return self._encode_outputs

@encode_outputs.setter
def encode_outputs(self, value):
self._encode_outputs = value

@property
def encode_mask(self):
return self._encode_mask

@encode_mask.setter
def encode_mask(self, value):
self._encode_mask = value

@property
def decode_states(self):
return self._decode_states

@decode_states.setter
def decode_states(self, value):
self._decode_states = value

@property
def attn_states(self):
"""
表示LSTMDecoder中attention模块的结果,正常情况下不需要手动设置
:return:
"""
return self._attn_states

@attn_states.setter
def attn_states(self, value):
self._attn_states = value


class AttentionLayer(nn.Module):
def __init__(self, input_size, encode_hidden_size, decode_hidden_size, bias=False):
super().__init__()

self.input_proj = nn.Linear(input_size, encode_hidden_size, bias=bias)
self.output_proj = nn.Linear(input_size + encode_hidden_size, decode_hidden_size, bias=bias)

def forward(self, input, encode_outputs, encode_mask):
"""

:param input: batch_size x input_size
:param encode_outputs: batch_size x max_len x encode_hidden_size
:param encode_mask: batch_size x max_len
:return: batch_size x decode_hidden_size, batch_size x max_len
"""

# x: bsz x encode_hidden_size
x = self.input_proj(input)

# compute attention
attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len

# don't attend over padding
if encode_mask is not None:
attn_scores = attn_scores.float().masked_fill_(
encode_mask.eq(0),
float('-inf')
).type_as(attn_scores) # FP16 support: cast to float and back

attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz

# sum weighted sources
x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size

x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1)))
return x, attn_scores


class LSTMDecoder(Decoder):
def __init__(self, embed: Union[Tuple[int, int], nn.Module, torch.Tensor, np.ndarray], num_layers, input_size,
hidden_size=None, dropout=0,
output_embed: Union[Tuple[int, int], int, nn.Module, torch.Tensor, np.ndarray] = None,
bind_input_output_embed=False,
attention=True):
"""
# embed假设是TokenEmbedding, 则没有对应关系(因为可能一个token会对应多个word)?vocab出来的结果是不对的

:param embed: 输入的embedding
:param int num_layers: 使用多少层LSTM
:param int input_size: 输入被encode后的维度
:param int hidden_size: LSTM中的隐藏层维度
:param float dropout: 多层LSTM的dropout
:param int output_embed: 输出的词表如何初始化,如果bind_input_output_embed为True,则改值无效
:param bool bind_input_output_embed: 是否将输入输出的embedding权重使用同一个
:param bool attention: 是否使用attention对encode之后的内容进行计算
"""

super().__init__()
self.token_embed = get_embeddings(embed)
if hidden_size is None:
hidden_size = input_size
self.hidden_size = hidden_size
self.input_size = input_size
if num_layers == 1:
self.lstm = nn.LSTM(self.token_embed.embedding_dim + hidden_size, hidden_size, num_layers=num_layers,
bidirectional=False, batch_first=True)
else:
self.lstm = nn.LSTM(self.token_embed.embedding_dim + hidden_size, hidden_size, num_layers=num_layers,
bidirectional=False, batch_first=True, dropout=dropout)
if input_size != hidden_size:
self.encode_hidden_proj = nn.Linear(input_size, hidden_size)
self.encode_cell_proj = nn.Linear(input_size, hidden_size)
self.dropout_layer = nn.Dropout(p=dropout)

if isinstance(output_embed, int):
output_embed = (output_embed, hidden_size)
output_embed = get_embeddings(output_embed)
elif output_embed is not None:
assert not bind_input_output_embed, "When `output_embed` is not None, `bind_input_output_embed` must " \
"be False."
if isinstance(output_embed, StaticEmbedding):
for i in self.token_embed.words_to_words:
assert i == self.token_embed.words_to_words[i], "The index does not match."
output_embed = self.token_embed.embedding.weight
else:
output_embed = get_embeddings(output_embed)
else:
if not bind_input_output_embed:
raise RuntimeError("You have to specify output embedding.")

if bind_input_output_embed:
assert output_embed is None, "When `bind_input_output_embed=True`, `output_embed` must be None"
if isinstance(self.token_embed, StaticEmbedding):
for i in self.token_embed.words_to_words:
assert i == self.token_embed.words_to_words[i], "The index does not match."
self.output_embed = nn.Parameter(self.token_embed.weight.transpose(0, 1))
self.output_hidden_size = self.token_embed.embedding_dim
else:
if isinstance(output_embed, nn.Embedding):
self.output_embed = nn.Parameter(output_embed.weight.transpose(0, 1))
else:
self.output_embed = output_embed.transpose(0, 1)
self.output_hidden_size = self.output_embed.size(0)

self.ffn = nn.Sequential(nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, self.output_hidden_size))
self.num_layers = num_layers

if attention:
self.attention_layer = AttentionLayer(hidden_size, input_size, hidden_size, bias=False)
else:
self.attention_layer = None

def _init_hx(self, past, tokens):
batch_size = tokens.size(0)
if past.hx is None:
zeros = tokens.new_zeros((self.num_layers, batch_size, self.hidden_size)).float()
past.hx = (zeros, zeros)
else:
assert past.hx[0].size(-1) == self.input_size
if self.attention_layer is not None:
if past.attn_states is None:
past.attn_states = past.hx[0].new_zeros(batch_size, self.hidden_size)
else:
assert past.attn_states.size(-1) == self.hidden_size, "The attention states dimension mismatch."
if self.hidden_size != past.hx[0].size(-1):
hidden, cell = past.hx
hidden = self.encode_hidden_proj(hidden)
cell = self.encode_cell_proj(cell)
past.hx = (hidden, cell)
return past

def forward(self, tokens, past=None, return_attention=False):
"""

:param torch.LongTensor, tokens: batch_size x decode_len, 应该输入整个句子
:param LSTMPast past: 应该包含了encode的输出
:param bool return_attention: 是否返回各处attention的值
:return:
"""
batch_size, decode_len = tokens.size()
tokens = self.token_embed(tokens) # b x decode_len x embed_size

past = self._init_hx(past, tokens)

tokens = self.dropout_layer(tokens)

decode_states = tokens.new_zeros((batch_size, decode_len, self.hidden_size))
if self.attention_layer is not None:
attn_scores = tokens.new_zeros((tokens.size(0), tokens.size(1), past.encode_outputs.size(1)))
if self.attention_layer is not None:
input_feed = past.attn_states
else:
input_feed = past.hx[0][-1]
for i in range(tokens.size(1)):
input = torch.cat([tokens[:, i:i + 1], input_feed.unsqueeze(1)], dim=2) # batch_size x 1 x h'
# bsz x 1 x hidden_size, (n_layer x bsz x hidden_size, n_layer x bsz x hidden_size)
_, (hidden, cell) = self.lstm(input, hx=past.hx)
past.hx = (hidden, cell)
if self.attention_layer is not None:
input_feed, attn_score = self.attention_layer(hidden[-1], past.encode_outputs, past.encode_mask)
attn_scores[:, i] = attn_score
past.attn_states = input_feed
else:
input_feed = hidden[-1]
decode_states[:, i] = input_feed

decode_states = self.dropout_layer(decode_states)

outputs = self.ffn(decode_states) # batch_size x decode_len x output_hidden_size

feats = torch.matmul(outputs, self.output_embed) # bsz x decode_len x vocab_size
if return_attention:
return feats, attn_scores
else:
return feats

@torch.no_grad()
def decode_one(self, tokens, past) -> Tuple[torch.Tensor, Past]:
"""
给定上一个位置的输出,决定当前位置的输出。
:param torch.LongTensor tokens: batch_size x seq_len
:param LSTMPast past:
:return:
"""
# past = self._init_hx(past, tokens)
tokens = tokens[:, -1:]
feats = self.forward(tokens, past, return_attention=False)
return feats.squeeze(1), past

def reorder_past(self, indices: torch.LongTensor, past: LSTMPast) -> LSTMPast:
"""
将LSTMPast中的状态重置一下

:param torch.LongTensor indices: 在batch维度的index
:param LSTMPast past: 保存的过去的状态
:return:
"""
encode_outputs = past.encode_outputs.index_select(index=indices, dim=0)
encoder_mask = past.encode_mask.index_select(index=indices, dim=0)
hx = (past.hx[0].index_select(index=indices, dim=1),
past.hx[1].index_select(index=indices, dim=1))
if past.attn_states is not None:
past.attn_states = past.attn_states.index_select(index=indices, dim=0)
past.encode_mask = encoder_mask
past.encode_outputs = encode_outputs
past.hx = hx
return past

+ 470
- 0
fastNLP/modules/decoder/seq2seq_generator.py View File

@@ -0,0 +1,470 @@
import torch
from .seq2seq_decoder import Decoder
import torch.nn.functional as F
from fastNLP.core.utils import _get_model_device
from functools import partial


class SequenceGenerator:
def __init__(self, decoder: Decoder, max_length=20, num_beams=1,
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0):
if do_sample:
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty,
length_penalty=length_penalty)
else:
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, num_beams=num_beams,
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty)
self.do_sample = do_sample
self.max_length = max_length
self.num_beams = num_beams
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.repetition_penalty = repetition_penalty
self.length_penalty = length_penalty
self.decoder = decoder

@torch.no_grad()
def generate(self, tokens=None, past=None):
"""

:param torch.LongTensor tokens: batch_size x length, 开始的token
:param past:
:return:
"""
# TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode?
return self.generate_func(tokens=tokens, past=past)


@torch.no_grad()
def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1,
bos_token_id=None, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0):
"""
贪婪地搜索句子

:param Decoder decoder: Decoder对象
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
:param Past past: 应该包好encoder的一些输出。
:param int max_length: 生成句子的最大长度。
:param int num_beams: 使用多大的beam进行解码。
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。
:param float repetition_penalty: 对重复出现的token多大的惩罚。
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。
:return:
"""
if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
else:
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
temperature=1, top_k=50, top_p=1,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)

return token_ids


@torch.no_grad()
def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1.0, length_penalty=1.0):
"""
使用采样的方法生成句子

:param Decoder decoder: Decoder对象
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成
:param Past past: 应该包好encoder的一些输出。
:param int max_length: 生成句子的最大长度。
:param int num_beam: 使用多大的beam进行解码。
:param float temperature: 采样时的退火大小
:param int top_k: 只在top_k的sample里面采样
:param float top_p: 介于0,1的值。
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。
:param float repetition_penalty: 对重复出现的token多大的惩罚。
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。
:return:
"""
# 每个位置在生成的时候会sample生成
if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature,
top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
else:
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams,
temperature=temperature, top_k=top_k, top_p=top_p,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
return token_ids


def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
repetition_penalty=1.0, length_penalty=1.0):
device = _get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
if past is None:
raise RuntimeError("You have to specify either `past` or `tokens`.")
batch_size = past.num_samples()
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `past`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if past is not None:
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."

if eos_token_id is None:
_eos_token_id = float('nan')
else:
_eos_token_id = eos_token_id

for i in range(tokens.size(1) - 1):
scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past

token_ids = tokens.clone()
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1)
# tokens = tokens[:, -1:]

while cur_len < max_length:
scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past

if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores)

if eos_token_id is not None and length_penalty != 1.0:
token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size
eos_mask = scores.new_ones(scores.size(1))
eos_mask[eos_token_id] = 0
eos_mask = eos_mask.unsqueeze(0).eq(1)
scores = scores.masked_scatter(eos_mask, token_scores)

if do_sample:
if temperature > 0 and temperature != 1:
scores = scores / temperature

scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
probs = F.softmax(scores, dim=-1)

# 保证至少有一个不是eos的值
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
else:
next_tokens = torch.argmax(scores, dim=-1) # batch_size

next_tokens = next_tokens.masked_fill(dones, 0)
tokens = next_tokens.unsqueeze(1)

token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len

end_mask = next_tokens.eq(_eos_token_id)
dones = dones.__or__(end_mask)
cur_len += 1

if dones.min() == 1:
break

if eos_token_id is not None:
if cur_len == max_length:
token_ids[:, -1].masked_fill_(dones, eos_token_id)

return token_ids


def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0,
top_k=50,
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True,
repetition_penalty=1.0, length_penalty=None) -> torch.LongTensor:
# 进行beam search
device = _get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
if past is None:
raise RuntimeError("You have to specify either `past` or `tokens`.")
batch_size = past.num_samples()
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `past`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if past is not None:
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match."

for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode
scores, past = decoder.decode_one(tokens[:, :i + 1],
past) # (batch_size, vocab_size), Past
scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."

if do_sample:
probs = F.softmax(scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams)
logits = probs.log()
next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams)
else:
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
# 得到(batch_size, num_beams), (batch_size, num_beams)
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)

indices = torch.arange(batch_size, dtype=torch.long).to(device)
indices = indices.repeat_interleave(num_beams)
past = decoder.reorder_past(indices, past)

tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
# 记录生成好的token (batch_size', cur_len)
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
dones = [False] * batch_size
tokens = next_tokens.view(-1, 1)

beam_scores = next_scores.view(-1) # batch_size * num_beams

# 用来记录已经生成好的token的长度
cur_len = token_ids.size(1)

hypos = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
# 0,num_beams, 2*num_beams
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)

while cur_len < max_length:
scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past

if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores)

if do_sample:
if temperature > 0 and temperature != 1:
scores = scores / temperature

# 多召回一个防止eos
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1)
probs = F.softmax(scores, dim=-1)

# 保证至少有一个不是eos的值
_tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1)

logits = probs.log()
# 防止全是这个beam的被选中了,且需要考虑eos被选择的情况
_scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1)
_scores = _scores + beam_scores[:, None] # batch_size' x (num_beams+1)
# 从这里面再选择top的2*num_beam个
_scores = _scores.view(batch_size, num_beams * (num_beams + 1))
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1))
next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams)
from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams)
else:
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size)
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams)
next_tokens = ids % vocab_size # (batch_size, 2*num_beams)

# 接下来需要组装下一个batch的结果。
# 需要选定哪些留下来
next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)

not_eos_mask = next_tokens.ne(eos_token_id) # 为1的地方不是eos
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留
keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的

_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) # 上面的token是来自哪个beam
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
beam_scores = _next_scores.view(-1)

# 更改past状态, 重组token_ids
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1)
past = decoder.reorder_past(reorder_inds, past)

flag = True
if cur_len + 1 == max_length:
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的
else:
# 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了
effective_eos_mask = next_tokens[:, :num_beams].eq(eos_token_id) # batch_size x num_beams
if effective_eos_mask.sum().gt(0):
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
# 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos
else:
flag = False
if flag:
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(),
eos_beam_idx.tolist()):
if not dones[batch_idx]:
score = next_scores[batch_idx, beam_ind].item()
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)

# 重新组织token_ids的状态
tokens = _next_tokens
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1)

for batch_idx in range(batch_size):
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())

cur_len += 1

if all(dones):
break

# select the best hypotheses
tgt_len = token_ids.new(batch_size)
best = []

for i, hypotheses in enumerate(hypos):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
best.append(best_hyp)

# generate target batch
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(0)
for i, hypo in enumerate(best):
decoded[i, :tgt_len[i] - 1] = hypo
if eos_token_id is not None:
decoded[i, tgt_len[i] - 1] = eos_token_id

return decoded


class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.hyp = []
self.worst_score = 1e9

def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)

def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)

def is_done(self, best_sum_logprobs):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty


def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
"""
根据top_k, top_p的值,将不满足的值置为filter_value的值

:param torch.Tensor logits: bsz x vocab_size
:param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value
:param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式
:param float filter_value:
:param int min_tokens_to_keep: 每个sample返回的分布中有概率的词不会低于这个值
:return:
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value

if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits


if __name__ == '__main__':
# TODO 需要检查一下greedy_generate和sample_generate是否正常工作。
from torch import nn


class DummyDecoder(nn.Module):
def __init__(self, num_words):
super().__init__()
self.num_words = num_words

def decode_one(self, tokens, past):
batch_size = tokens.size(0)
return torch.randn(batch_size, self.num_words), past

def reorder_past(self, indices, past):
return past


num_words = 10
batch_size = 3
decoder = DummyDecoder(num_words)

tokens = greedy_generate(decoder=decoder, tokens=torch.zeros(batch_size, 1).long(), past=None, max_length=20,
num_beams=2,
bos_token_id=0, eos_token_id=num_words - 1,
repetition_penalty=1, length_penalty=1.0)
print(tokens)

tokens = sample_generate(decoder, tokens=torch.zeros(batch_size, 1).long(),
past=None, max_length=20, num_beams=2, temperature=1.0, top_k=50,
top_p=1.0, bos_token_id=0, eos_token_id=num_words - 1, repetition_penalty=1.0,
length_penalty=1.0)
print(tokens)

+ 14
- 0
fastNLP/modules/encoder/seq2seq_encoder.py View File

@@ -27,3 +27,17 @@ class TransformerSeq2SeqEncoder(nn.Module):
words = self.transformer(words, src_key_padding_mask=~encoder_mask) # seq_len,batch,dim

return words.transpose(0, 1), encoder_mask


class BiLSTMEncoder(nn.Module):
def __init__(self, embed, num_layers=3, hidden_size=400, dropout=0.3):
super().__init__()
self.embed = embed
self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_size // 2, bidirectional=True,
batch_first=True, dropout=dropout, num_layers=num_layers)

def forward(self, words, seq_len):
words = self.embed(words)
words, hx = self.lstm(words, seq_len)

return words, hx

Loading…
Cancel
Save