Browse Source

encoder里面的结构和文档

tags/v0.4.10
ChenXin 6 years ago
parent
commit
f3a9fc5b79
11 changed files with 161 additions and 144 deletions
  1. +11
    -3
      fastNLP/modules/__init__.py
  2. +6
    -6
      fastNLP/modules/encoder/__init__.py
  3. +26
    -28
      fastNLP/modules/encoder/attention.py
  4. +26
    -18
      fastNLP/modules/encoder/bert.py
  5. +12
    -11
      fastNLP/modules/encoder/char_encoder.py
  6. +8
    -7
      fastNLP/modules/encoder/conv_maxpool.py
  7. +3
    -2
      fastNLP/modules/encoder/lstm.py
  8. +15
    -15
      fastNLP/modules/encoder/pooling.py
  9. +25
    -25
      fastNLP/modules/encoder/star_transformer.py
  10. +5
    -5
      fastNLP/modules/encoder/transformer.py
  11. +24
    -24
      fastNLP/modules/encoder/variational_rnn.py

+ 11
- 3
fastNLP/modules/__init__.py View File

@@ -17,22 +17,30 @@

"""
__all__ = [
# "BertModel",
"BertModel",

"ConvolutionCharEncoder",
"LSTMCharEncoder",

"ConvMaxpool",

"LSTM",

"StarTransformer",

"TransformerEncoder",

"VarRNN",
"VarLSTM",
"VarGRU",
"MaxPool",
"MaxPoolWithMask",
"AvgPool",
"AvgPoolWithMask",

"MultiHeadAttention",
"MLP",
"ConditionalRandomField",
"viterbi_decode",


+ 6
- 6
fastNLP/modules/encoder/__init__.py View File

@@ -1,17 +1,17 @@
__all__ = [
"BertModel",
"ConvolutionCharEncoder",
"LSTMCharEncoder",
"ConvMaxpool",
"LSTM",
"StarTransformer",
"TransformerEncoder",
"VarRNN",
"VarLSTM",
"VarGRU",


+ 26
- 28
fastNLP/modules/encoder/attention.py View File

@@ -8,8 +8,6 @@ import torch
import torch.nn.functional as F
from torch import nn

from fastNLP.modules.dropout import TimestepDropout

from fastNLP.modules.utils import initial_parameter


@@ -18,7 +16,7 @@ class DotAttention(nn.Module):
.. todo::
补上文档
"""
def __init__(self, key_size, value_size, dropout=0.0):
super(DotAttention, self).__init__()
self.key_size = key_size
@@ -26,7 +24,7 @@ class DotAttention(nn.Module):
self.scale = math.sqrt(key_size)
self.drop = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, Q, K, V, mask_out=None):
"""

@@ -45,7 +43,7 @@ class DotAttention(nn.Module):

class MultiHeadAttention(nn.Module):
"""
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.encoder.attention.MultiHeadAttention`
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.encoder.MultiHeadAttention`

:param input_size: int, 输入维度的大小。同时也是输出维度的大小。
:param key_size: int, 每个head的维度大小。
@@ -53,14 +51,14 @@ class MultiHeadAttention(nn.Module):
:param num_head: int,head的数量。
:param dropout: float。
"""
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.input_size = input_size
self.key_size = key_size
self.value_size = value_size
self.num_head = num_head
in_size = key_size * num_head
self.q_in = nn.Linear(input_size, in_size)
self.k_in = nn.Linear(input_size, in_size)
@@ -69,14 +67,14 @@ class MultiHeadAttention(nn.Module):
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout)
self.out = nn.Linear(value_size * num_head, input_size)
self.reset_parameters()
def reset_parameters(self):
sqrt = math.sqrt
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size)))
nn.init.xavier_normal_(self.out.weight)
def forward(self, Q, K, V, atte_mask_out=None):
"""

@@ -92,7 +90,7 @@ class MultiHeadAttention(nn.Module):
q = self.q_in(Q).view(batch, sq, n_head, d_k)
k = self.k_in(K).view(batch, sk, n_head, d_k)
v = self.v_in(V).view(batch, sk, n_head, d_v)
# transpose q, k and v to do batch attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k)
@@ -100,7 +98,7 @@ class MultiHeadAttention(nn.Module):
if atte_mask_out is not None:
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1)
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v)
# concat all heads, do output linear
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1)
output = self.out(atte)
@@ -124,11 +122,11 @@ class BiAttention(nn.Module):
\end{array}
"""
def __init__(self):
super(BiAttention, self).__init__()
self.inf = 10e12
def forward(self, in_x1, in_x2, x1_len, x2_len):
"""
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示
@@ -139,36 +137,36 @@ class BiAttention(nn.Module):
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示
"""
assert in_x1.size()[0] == in_x2.size()[0]
assert in_x1.size()[2] == in_x2.size()[2]
# The batch size and hidden size must be equal.
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1]
# The seq len in in_x and x_len must be equal.
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0]
batch_size = in_x1.size()[0]
x1_max_len = in_x1.size()[1]
x2_max_len = in_x2.size()[1]
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len]
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len]
a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len]
a_mask = a_mask.view(batch_size, x1_max_len, -1)
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len]
b_mask = x2_len.le(0.5).float() * -self.inf
b_mask = b_mask.view(batch_size, -1, x2_max_len)
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len]
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len]
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len]
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size]
attention_b_t = torch.transpose(attention_b, 1, 2)
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size]
return out_x1, out_x2


@@ -182,10 +180,10 @@ class SelfAttention(nn.Module):
:param float drop: dropout概率,默认值为0.5
:param str initial_method: 初始化参数方法
"""
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ):
super(SelfAttention, self).__init__()
self.attention_hops = attention_hops
self.ws1 = nn.Linear(input_size, attention_unit, bias=False)
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False)
@@ -194,7 +192,7 @@ class SelfAttention(nn.Module):
self.drop = nn.Dropout(drop)
self.tanh = nn.Tanh()
initial_parameter(self, initial_method)
def _penalization(self, attention):
"""
compute the penalization term for attention module
@@ -208,7 +206,7 @@ class SelfAttention(nn.Module):
mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)]
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5
return torch.sum(ret) / size[0]
def forward(self, input, input_origin):
"""
:param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵
@@ -218,14 +216,14 @@ class SelfAttention(nn.Module):
"""
input = input.contiguous()
size = input.size() # [bsz, len, nhid]
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len]
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1, 2).contiguous()
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
attention = F.softmax(attention, 2) # [baz ,hop, len]
return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid]

+ 26
- 18
fastNLP/modules/encoder/bert.py View File

@@ -1,11 +1,11 @@



"""
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你
有用,也请引用一下他们。
"""

__all__ = [
"BertModel"
]

import collections

@@ -26,6 +26,7 @@ CONFIG_FILE = 'bert_config.json'
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""

def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
@@ -65,7 +66,7 @@ class BertConfig(object):
layer_norm_eps: The epsilon used by LayerNorm.
"""
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
@@ -150,6 +151,7 @@ class BertLayerNorm(nn.Module):
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""

def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
@@ -331,7 +333,10 @@ class BertPooler(nn.Module):


class BertModel(nn.Module):
"""BERT(Bidirectional Embedding Representations from Transformers).
"""
别名::class:`fastNLP.modules.BertModel` :class:`fastNLP.modules.encoder.BertModel`

BERT(Bidirectional Embedding Representations from Transformers).

如果你想使用预训练好的权重矩阵,请在以下网址下载.
sources::
@@ -449,9 +454,9 @@ class BertModel(nn.Module):
model = cls(config, *inputs, **kwargs)
if state_dict is None:
files = glob.glob(os.path.join(pretrained_model_dir, '*.bin'))
if len(files)==0:
if len(files) == 0:
raise FileNotFoundError(f"There is no *.bin file in {pretrained_model_dir}")
elif len(files)>1:
elif len(files) > 1:
raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}")
weights_path = files[0]
state_dict = torch.load(weights_path, map_location='cpu')
@@ -580,6 +585,7 @@ def load_vocab(vocab_file):
index += 1
return vocab


class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

@@ -765,8 +771,8 @@ class BertTokenizer(object):
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)

@@ -821,7 +827,7 @@ class BertTokenizer(object):
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
print("Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file))
" Please check that the vocabulary is not corrupted!".format(vocab_file))
index = token_index
writer.write(token + u'\n')
index += 1
@@ -841,6 +847,7 @@ class BertTokenizer(object):
tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs)
return tokenizer


VOCAB_NAME = 'vocab.txt'


@@ -849,7 +856,8 @@ class _WordPieceBertModel(nn.Module):
这个模块用于直接计算word_piece的结果.

"""
def __init__(self, model_dir:str, layers:str='-1'):

def __init__(self, model_dir: str, layers: str = '-1'):
super().__init__()

self.tokenzier = BertTokenizer.from_pretrained(model_dir)
@@ -858,11 +866,11 @@ class _WordPieceBertModel(nn.Module):
encoder_layer_number = len(self.encoder.encoder.layer)
self.layers = list(map(int, layers.split(',')))
for layer in self.layers:
if layer<0:
assert -layer<=encoder_layer_number, f"The layer index:{layer} is out of scope for " \
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."
else:
assert layer<encoder_layer_number, f"The layer index:{layer} is out of scope for " \
assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."

self._cls_index = self.tokenzier.vocab['[CLS]']
@@ -878,15 +886,16 @@ class _WordPieceBertModel(nn.Module):
:param field_name: 基于哪一列index
:return:
"""

def convert_words_to_word_pieces(words):
word_pieces = []
for word in words:
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word)
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens)
word_pieces.extend(word_piece_ids)
if word_pieces[0]!=self._cls_index:
if word_pieces[0] != self._cls_index:
word_pieces.insert(0, self._cls_index)
if word_pieces[-1]!=self._sep_index:
if word_pieces[-1] != self._sep_index:
word_pieces.insert(-1, self._sep_index)
return word_pieces

@@ -910,10 +919,9 @@ class _WordPieceBertModel(nn.Module):

attn_masks = word_pieces.ne(self._wordpiece_pad_index)
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
output_all_encoded_layers=True)
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1)))
for l_index, l in enumerate(self.layers):
outputs[l_index] = bert_outputs[l]
return outputs


+ 12
- 11
fastNLP/modules/encoder/char_encoder.py View File

@@ -11,7 +11,7 @@ from ..utils import initial_parameter
# from torch.nn.init import xavier_uniform
class ConvolutionCharEncoder(nn.Module):
"""
别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.ConvolutionCharEncoder`
别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.ConvolutionCharEncoder`

char级别的卷积编码器.
@@ -21,15 +21,16 @@ class ConvolutionCharEncoder(nn.Module):
:param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核.
:param initial_method: 初始化参数的方式, 默认为`xavier normal`
"""
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None):
super(ConvolutionCharEncoder, self).__init__()
self.convs = nn.ModuleList([
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, kernels[i]//2))
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True,
padding=(0, kernels[i] // 2))
for i in range(len(kernels))])
initial_parameter(self, initial_method)
def forward(self, x):
"""
:param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding
@@ -40,7 +41,7 @@ class ConvolutionCharEncoder(nn.Module):
x = x.transpose(2, 3)
# [batch_size*sent_length, channel, height, width]
return self._convolute(x).unsqueeze(2)
def _convolute(self, x):
feats = []
for conv in self.convs:
@@ -57,13 +58,13 @@ class ConvolutionCharEncoder(nn.Module):

class LSTMCharEncoder(nn.Module):
"""
别名::class:`fastNLP.modules.LSTMCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.LSTMCharEncoder`
别名::class:`fastNLP.modules.LSTMCharEncoder` :class:`fastNLP.modules.encoder.LSTMCharEncoder`

char级别基于LSTM的encoder.
"""
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None):
"""
:param int char_emb_size: char级别embedding的维度. Default: 50
@@ -73,14 +74,14 @@ class LSTMCharEncoder(nn.Module):
"""
super(LSTMCharEncoder, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size
self.lstm = nn.LSTM(input_size=char_emb_size,
hidden_size=self.hidden_size,
num_layers=1,
bias=True,
batch_first=True)
initial_parameter(self, initial_method)
def forward(self, x):
"""
:param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding
@@ -91,6 +92,6 @@ class LSTMCharEncoder(nn.Module):
h0 = nn.init.orthogonal_(h0)
c0 = torch.empty(1, batch_size, self.hidden_size)
c0 = nn.init.orthogonal_(c0)
_, hidden = self.lstm(x, (h0, c0))
return hidden[0].squeeze().unsqueeze(2)

+ 8
- 7
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -5,9 +5,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvMaxpool(nn.Module):
"""
别名::class:`fastNLP.modules.ConvMaxpool` :class:`fastNLP.modules.encoder.conv_maxpool.ConvMaxpool`
别名::class:`fastNLP.modules.ConvMaxpool` :class:`fastNLP.modules.encoder.ConvMaxpool`

集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x
sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len)
@@ -18,12 +19,12 @@ class ConvMaxpool(nn.Module):
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh
"""
def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"):
super(ConvMaxpool, self).__init__()

for kernel_size in kernel_sizes:
assert kernel_size%2==1, "kernel size has to be odd numbers."
assert kernel_size % 2 == 1, "kernel size has to be odd numbers."

# convolution
if isinstance(kernel_sizes, (list, tuple, int)):
@@ -36,22 +37,22 @@ class ConvMaxpool(nn.Module):
" of kernel_sizes."
else:
raise ValueError("The type of out_channels and kernel_sizes should be the same.")
self.convs = nn.ModuleList([nn.Conv1d(
in_channels=in_channels,
out_channels=oc,
kernel_size=ks,
stride=1,
padding=ks//2,
padding=ks // 2,
dilation=1,
groups=1,
bias=None)
for oc, ks in zip(out_channels, kernel_sizes)])
else:
raise Exception(
'Incorrect kernel sizes: should be list, tuple or int')
# activation function
if activation == 'relu':
self.activation = F.relu


+ 3
- 2
fastNLP/modules/encoder/lstm.py View File

@@ -10,9 +10,10 @@ import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn


class LSTM(nn.Module):
"""
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM`
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.LSTM`

LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
为1; 且可以应对DataParallel中LSTM的使用问题。
@@ -26,7 +27,7 @@ class LSTM(nn.Module):
:(batch, seq, feature). Default: ``False``
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
"""
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True):
super(LSTM, self).__init__()


+ 15
- 15
fastNLP/modules/encoder/pooling.py View File

@@ -10,7 +10,7 @@ import torch.nn as nn

class MaxPool(nn.Module):
"""
别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.encoder.pooling.MaxPool`
别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.encoder.MaxPool`

Max-pooling模块。
@@ -21,9 +21,9 @@ class MaxPool(nn.Module):
:param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension
:param ceil_mode:
"""
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False):
super(MaxPool, self).__init__()
assert (1 <= dimension) and (dimension <= 3)
self.dimension = dimension
@@ -32,7 +32,7 @@ class MaxPool(nn.Module):
self.dilation = dilation
self.kernel_size = kernel_size
self.ceil_mode = ceil_mode
def forward(self, x):
if self.dimension == 1:
pooling = nn.MaxPool1d(
@@ -59,15 +59,15 @@ class MaxPool(nn.Module):

class MaxPoolWithMask(nn.Module):
"""
别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.encoder.pooling.MaxPoolWithMask`
别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.encoder.MaxPoolWithMask`

带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。
"""
def __init__(self):
super(MaxPoolWithMask, self).__init__()
self.inf = 10e12
def forward(self, tensor, mask, dim=1):
"""
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor
@@ -82,11 +82,11 @@ class MaxPoolWithMask(nn.Module):

class KMaxPool(nn.Module):
"""K max-pooling module."""
def __init__(self, k=1):
super(KMaxPool, self).__init__()
self.k = k
def forward(self, x):
"""
:param torch.Tensor x: [N, C, L] 初始tensor
@@ -99,16 +99,16 @@ class KMaxPool(nn.Module):

class AvgPool(nn.Module):
"""
别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.encoder.pooling.AvgPool`
别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.encoder.AvgPool`

给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size]
"""
def __init__(self, stride=None, padding=0):
super(AvgPool, self).__init__()
self.stride = stride
self.padding = padding
def forward(self, x):
"""
:param torch.Tensor x: [N, C, L] 初始tensor
@@ -126,16 +126,16 @@ class AvgPool(nn.Module):

class AvgPoolWithMask(nn.Module):
"""
别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.encoder.pooling.AvgPoolWithMask`
别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.encoder.AvgPoolWithMask`

给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling
的时候只会考虑mask为1的位置
"""
def __init__(self):
super(AvgPoolWithMask, self).__init__()
self.inf = 10e12
def forward(self, tensor, mask, dim=1):
"""
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor


+ 25
- 25
fastNLP/modules/encoder/star_transformer.py View File

@@ -13,7 +13,7 @@ from torch.nn import functional as F

class StarTransformer(nn.Module):
"""
别名::class:`fastNLP.modules.StarTransformer` :class:`fastNLP.modules.encoder.star_transformer.StarTransformer`
别名::class:`fastNLP.modules.StarTransformer` :class:`fastNLP.modules.encoder.StarTransformer`


Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码
@@ -29,11 +29,11 @@ class StarTransformer(nn.Module):
模型会为输入序列加上position embedding。
若为`None`,忽略加上position embedding的步骤. Default: `None`
"""
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None):
super(StarTransformer, self).__init__()
self.iters = num_layers
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)])
# self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
self.emb_drop = nn.Dropout(dropout)
@@ -43,12 +43,12 @@ class StarTransformer(nn.Module):
self.star_att = nn.ModuleList(
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)])
if max_len is not None:
self.pos_emb = nn.Embedding(max_len, hidden_size)
else:
self.pos_emb = None
def forward(self, data, mask):
"""
:param FloatTensor data: [batch, length, hidden] 输入的序列
@@ -58,15 +58,15 @@ class StarTransformer(nn.Module):

[batch, hidden] 全局 relay 节点, 详见论文
"""
def norm_func(f, x):
# B, H, L, 1
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
B, L, H = data.size()
mask = (mask == 0) # flip the mask for masked_fill_
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
if self.pos_emb and False:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
@@ -80,13 +80,13 @@ class StarTransformer(nn.Module):
for i in range(self.iters):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
#nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
# nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))
nodes = nodes.masked_fill_(ex_mask, 0)
nodes = nodes.view(B, H, L).permute(0, 2, 1)
return nodes, relay.view(B, H)


@@ -99,19 +99,19 @@ class _MSA1(nn.Module):
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)
self.drop = nn.Dropout(dropout)
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
def forward(self, x, ax=None):
# x: B, H, L, 1, ax : B, H, X, L append features
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = x.shape
q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1)
if ax is not None:
aL = ax.shape[2]
ak = self.WK(ax).view(B, nhead, head_dim, aL, L)
@@ -124,12 +124,12 @@ class _MSA1(nn.Module):
if ax is not None:
k = torch.cat([k, ak], 3)
v = torch.cat([v, av], 3)
alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1)
ret = self.WO(att)
return ret


@@ -141,19 +141,19 @@ class _MSA2(nn.Module):
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)
self.drop = nn.Dropout(dropout)
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
def forward(self, x, y, mask=None):
# x: B, H, 1, 1, 1 y: B H L 1
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = y.shape
q, k, v = self.WQ(x), self.WK(y), self.WV(y)
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h


+ 5
- 5
fastNLP/modules/encoder/transformer.py View File

@@ -9,7 +9,7 @@ from ..dropout import TimestepDropout

class TransformerEncoder(nn.Module):
"""
别名::class:`fastNLP.modules.TransformerEncoder` :class:`fastNLP.modules.encoder.transformer.TransformerEncoder`
别名::class:`fastNLP.modules.TransformerEncoder` :class:`fastNLP.modules.encoder.TransformerEncoder`


transformer的encoder模块,不包含embedding层
@@ -22,7 +22,7 @@ class TransformerEncoder(nn.Module):
:param int num_head: head的数量。
:param float dropout: dropout概率. Default: 0.1
"""
class SubLayer(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
super(TransformerEncoder.SubLayer, self).__init__()
@@ -33,7 +33,7 @@ class TransformerEncoder(nn.Module):
nn.Linear(inner_size, model_size),
TimestepDropout(dropout), )
self.norm2 = nn.LayerNorm(model_size)
def forward(self, input, seq_mask=None, atte_mask_out=None):
"""

@@ -48,11 +48,11 @@ class TransformerEncoder(nn.Module):
output = self.norm2(output + norm_atte)
output *= seq_mask
return output
def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])
def forward(self, x, seq_mask=None):
"""
:param x: [batch, seq_len, model_size] 输入序列


+ 24
- 24
fastNLP/modules/encoder/variational_rnn.py View File

@@ -28,14 +28,14 @@ class VarRnnCellWrapper(nn.Module):
"""
Wrapper for normal RNN Cells, make it support variational dropout
"""
def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
self.hidden_size = hidden_size
self.input_p = input_p
self.hidden_p = hidden_p
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False):
"""
:param PackedSequence input_x: [seq_len, batch_size, input_size]
@@ -47,13 +47,13 @@ class VarRnnCellWrapper(nn.Module):
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
"""
def get_hi(hi, h0, size):
h0_size = size - hi.size(0)
if h0_size > 0:
return torch.cat([hi, h0[:h0_size]], dim=0)
return hi[:size]
is_lstm = isinstance(hidden, tuple)
input, batch_sizes = input_x.data, input_x.batch_sizes
output = []
@@ -64,7 +64,7 @@ class VarRnnCellWrapper(nn.Module):
else:
batch_iter = batch_sizes
idx = 0
if is_lstm:
hn = (hidden[0].clone(), hidden[1].clone())
else:
@@ -91,7 +91,7 @@ class VarRnnCellWrapper(nn.Module):
hi = cell(input_i, hi)
hn[:size] = hi
output.append(hi)
if is_reversed:
output = list(reversed(output))
output = torch.cat(output, dim=0)
@@ -117,7 +117,7 @@ class VarRNNBase(nn.Module):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
"""
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
@@ -141,7 +141,7 @@ class VarRNNBase(nn.Module):
cell, self.hidden_size, input_dropout, hidden_dropout))
initial_parameter(self)
self.is_lstm = (self.mode == "LSTM")
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h):
is_lstm = self.is_lstm
idx = self.num_directions * n_layer + n_direction
@@ -150,7 +150,7 @@ class VarRNNBase(nn.Module):
output_x, hidden_x = cell(
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1))
return output_x, hidden_x
def forward(self, x, hx=None):
"""

@@ -170,13 +170,13 @@ class VarRNNBase(nn.Module):
else:
max_batch_size = int(x.batch_sizes[0])
x, batch_sizes = x.data, x.batch_sizes
if hx is None:
hx = x.new_zeros(self.num_layers * self.num_directions,
max_batch_size, self.hidden_size, requires_grad=True)
if is_lstm:
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True))
mask_x = x.new_ones((max_batch_size, self.input_size))
mask_out = x.new_ones(
(max_batch_size, self.hidden_size * self.num_directions))
@@ -185,7 +185,7 @@ class VarRNNBase(nn.Module):
training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout,
training=self.training, inplace=True)
hidden = x.new_zeros(
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
if is_lstm:
@@ -207,22 +207,22 @@ class VarRNNBase(nn.Module):
else:
hidden[idx] = hidden_x
x = torch.cat(output_list, dim=-1)
if is_lstm:
hidden = (hidden, cellstate)
if is_packed:
output = PackedSequence(x, batch_sizes)
else:
x = PackedSequence(x, batch_sizes)
output, _ = pad_packed_sequence(x, batch_first=self.batch_first)
return output, hidden


class VarLSTM(VarRNNBase):
"""
别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.variational_rnn.VarLSTM`
别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.VarLSTM`

Variational Dropout LSTM.

@@ -236,18 +236,18 @@ class VarLSTM(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False``
"""
def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)
def forward(self, x, hx=None):
return super(VarLSTM, self).forward(x, hx)


class VarRNN(VarRNNBase):
"""
别名::class:`fastNLP.modules.VarRNN` :class:`fastNLP.modules.encoder.variational_rnn.VarRNN`
别名::class:`fastNLP.modules.VarRNN` :class:`fastNLP.modules.encoder.VarRNN`

Variational Dropout RNN.

@@ -261,18 +261,18 @@ class VarRNN(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
"""
def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(
mode="RNN", Cell=nn.RNNCell, *args, **kwargs)
def forward(self, x, hx=None):
return super(VarRNN, self).forward(x, hx)


class VarGRU(VarRNNBase):
"""
别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.variational_rnn.VarGRU`
别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.VarGRU`

Variational Dropout GRU.

@@ -286,10 +286,10 @@ class VarGRU(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False``
"""
def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(
mode="GRU", Cell=nn.GRUCell, *args, **kwargs)
def forward(self, x, hx=None):
return super(VarGRU, self).forward(x, hx)

Loading…
Cancel
Save