Browse Source

Merge branch 'master' of https://github.com/fastnlp/fastNLP into current branch

# Conflicts solved:
#	fastNLP/io/dataset_loader.py
#	fastNLP/io/embed_loader.py
#	fastNLP/modules/aggregator/attention.py
#	fastNLP/modules/decoder/CRF.py
tags/v0.3.0^2
FengZiYjun 6 years ago
parent
commit
751fe2768e
17 changed files with 196 additions and 188 deletions
  1. +2
    -2
      fastNLP/io/dataset_loader.py
  2. +2
    -4
      fastNLP/modules/aggregator/attention.py
  3. +10
    -9
      fastNLP/modules/aggregator/self_attention.py
  4. +2
    -3
      fastNLP/modules/decoder/CRF.py
  5. +11
    -9
      fastNLP/modules/decoder/MLP.py
  6. +2
    -2
      fastNLP/modules/dropout.py
  7. +16
    -18
      fastNLP/modules/encoder/char_embedding.py
  8. +17
    -7
      fastNLP/modules/encoder/conv.py
  9. +15
    -5
      fastNLP/modules/encoder/conv_maxpool.py
  10. +6
    -9
      fastNLP/modules/encoder/embedding.py
  11. +5
    -8
      fastNLP/modules/encoder/linear.py
  12. +9
    -7
      fastNLP/modules/encoder/lstm.py
  13. +59
    -57
      fastNLP/modules/encoder/masked_rnn.py
  14. +3
    -6
      fastNLP/modules/encoder/transformer.py
  15. +12
    -6
      fastNLP/modules/encoder/variational_rnn.py
  16. +23
    -34
      fastNLP/modules/other_modules.py
  17. +2
    -2
      fastNLP/modules/utils.py

+ 2
- 2
fastNLP/io/dataset_loader.py View File

@@ -254,7 +254,7 @@ class TokenizeDataSetLoader(DataSetLoader):




class ClassDataSetLoader(DataSetLoader): class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""
"""Loader for a dummy classification data set"""


def __init__(self): def __init__(self):
super(ClassDataSetLoader, self).__init__() super(ClassDataSetLoader, self).__init__()
@@ -304,7 +304,7 @@ class ConllLoader(DataSetLoader):
@staticmethod @staticmethod
def parse(lines): def parse(lines):
""" """
:param list lines:a list containing all lines in a conll file.
:param list lines: a list containing all lines in a conll file.
:return: a 3D list :return: a 3D list
""" """
sentences = list() sentences = list()


+ 2
- 4
fastNLP/modules/aggregator/attention.py View File

@@ -6,7 +6,6 @@ from fastNLP.modules.utils import mask_softmax




class Attention(torch.nn.Module): class Attention(torch.nn.Module):

def __init__(self, normalize=False): def __init__(self, normalize=False):
super(Attention, self).__init__() super(Attention, self).__init__()
self.normalize = normalize self.normalize = normalize
@@ -20,9 +19,9 @@ class Attention(torch.nn.Module):
def _atten_forward(self, query, memory): def _atten_forward(self, query, memory):
raise NotImplementedError raise NotImplementedError



class DotAtte(nn.Module): class DotAtte(nn.Module):
def __init__(self, key_size, value_size): def __init__(self, key_size, value_size):
# TODO never test
super(DotAtte, self).__init__() super(DotAtte, self).__init__()
self.key_size = key_size self.key_size = key_size
self.value_size = value_size self.value_size = value_size
@@ -42,10 +41,9 @@ class DotAtte(nn.Module):
output = nn.functional.softmax(output, dim=2) output = nn.functional.softmax(output, dim=2)
return torch.matmul(output, V) return torch.matmul(output, V)



class MultiHeadAtte(nn.Module): class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte): def __init__(self, input_size, output_size, key_size, value_size, num_atte):
raise NotImplementedError
# TODO never test
super(MultiHeadAtte, self).__init__() super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList() self.in_linear = nn.ModuleList()
for i in range(num_atte * 3): for i in range(num_atte * 3):


+ 10
- 9
fastNLP/modules/aggregator/self_attention.py View File

@@ -7,13 +7,14 @@ from fastNLP.modules.utils import initial_parameter




class SelfAttention(nn.Module): class SelfAttention(nn.Module):
"""
Self Attention Module.
"""Self Attention Module.


Args:
input_size: int, the size for the input vector
dim: int, the width of weight matrix.
num_vec: int, the number of encoded vectors
:param int input_size:
:param int attention_unit:
:param int attention_hops:
:param float drop:
:param str initial_method:
:param bool use_cuda:
""" """


def __init__(self, input_size, attention_unit=350, attention_hops=10, drop=0.5, initial_method=None, def __init__(self, input_size, attention_unit=350, attention_hops=10, drop=0.5, initial_method=None,
@@ -48,7 +49,7 @@ class SelfAttention(nn.Module):
def forward(self, input, input_origin): def forward(self, input, input_origin):
""" """
:param input: the matrix to do attention. [baz, senLen, h_dim] :param input: the matrix to do attention. [baz, senLen, h_dim]
:param inp: then token index include pad token( 0 ) [baz , senLen]
:param inp: then token index include pad token( 0 ) [baz , senLen]
:return output1: the input matrix after attention operation [baz, multi-head , h_dim] :return output1: the input matrix after attention operation [baz, multi-head , h_dim]
:return output2: the attention penalty term, a scalar [1] :return output2: the attention penalty term, a scalar [1]
""" """
@@ -59,8 +60,8 @@ class SelfAttention(nn.Module):
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]


y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1,
2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
attention = self.ws2(y1).transpose(1, 2).contiguous()
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]


attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
attention = F.softmax(attention, 2) # [baz ,hop, len] attention = F.softmax(attention, 2) # [baz ,hop, len]


+ 2
- 3
fastNLP/modules/decoder/CRF.py View File

@@ -161,7 +161,6 @@ class ConditionalRandomField(nn.Module):


# self.reset_parameter() # self.reset_parameter()
initial_parameter(self, initial_method) initial_parameter(self, initial_method)

def reset_parameter(self): def reset_parameter(self):
nn.init.xavier_normal_(self.trans_m) nn.init.xavier_normal_(self.trans_m)
if self.include_start_end_trans: if self.include_start_end_trans:
@@ -169,9 +168,9 @@ class ConditionalRandomField(nn.Module):
nn.init.normal_(self.end_scores) nn.init.normal_(self.end_scores)


def _normalizer_likelihood(self, logits, mask): def _normalizer_likelihood(self, logits, mask):
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences. sum of the likelihoods across all possible state sequences.

:param logits:FloatTensor, max_len x batch_size x num_tags :param logits:FloatTensor, max_len x batch_size x num_tags
:param mask:ByteTensor, max_len x batch_size :param mask:ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size :return:FloatTensor, batch_size


+ 11
- 9
fastNLP/modules/decoder/MLP.py View File

@@ -1,21 +1,23 @@
import torch import torch
import torch.nn as nn import torch.nn as nn

from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter




class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0):
"""Multilayer Perceptrons as a decoder
"""Multilayer Perceptrons as a decoder


:param size_layer: list of int, define the size of MLP layers.
:param activation: str or function, the activation function for hidden layers.
:param initial_method: str, the name of init method.
:param dropout: float, the probability of dropout.
:param list size_layer: list of int, define the size of MLP layers.
:param str activation: str or function, the activation function for hidden layers.
:param str initial_method: the name of initialization method.
:param float dropout: the probability of dropout.


.. note::
There is no activation function applying on output layer.
.. note::
There is no activation function applying on output layer.


"""
"""

def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0):
super(MLP, self).__init__() super(MLP, self).__init__()
self.hiddens = nn.ModuleList() self.hiddens = nn.ModuleList()
self.output = None self.output = None


+ 2
- 2
fastNLP/modules/dropout.py View File

@@ -2,8 +2,8 @@ import torch




class TimestepDropout(torch.nn.Dropout): class TimestepDropout(torch.nn.Dropout):
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step.
"""This module accepts a ``[batch_size, num_timesteps, embedding_dim)]`` and use a single
dropout mask of shape ``(batch_size, embedding_dim)`` to apply on every time step.
""" """


def forward(self, x): def forward(self, x):


+ 16
- 18
fastNLP/modules/encoder/char_embedding.py View File

@@ -1,5 +1,4 @@
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn


from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter
@@ -7,17 +6,17 @@ from fastNLP.modules.utils import initial_parameter


# from torch.nn.init import xavier_uniform # from torch.nn.init import xavier_uniform
class ConvCharEmbedding(nn.Module): class ConvCharEmbedding(nn.Module):
"""Character-level Embedding with CNN.

:param int char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param tuple feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param tuple kernels: tuple of int. The width of each kernel.
"""


def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None):
"""
Character Level Word Embedding
:param char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param kernels: tuple of int. The width of each kernel.
"""
super(ConvCharEmbedding, self).__init__() super(ConvCharEmbedding, self).__init__()
self.convs = nn.ModuleList([ self.convs = nn.ModuleList([
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4))
@@ -27,8 +26,8 @@ class ConvCharEmbedding(nn.Module):


def forward(self, x): def forward(self, x):
""" """
:param x: [batch_size * sent_length, word_length, char_emb_size]
:return: [batch_size * sent_length, sum(feature_maps), 1]
:param x: ``[batch_size * sent_length, word_length, char_emb_size]``
:return: feature map of shape [batch_size * sent_length, sum(feature_maps), 1]
""" """
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2))
# [batch_size*sent_length, channel, width, height] # [batch_size*sent_length, channel, width, height]
@@ -51,13 +50,12 @@ class ConvCharEmbedding(nn.Module):




class LSTMCharEmbedding(nn.Module): class LSTMCharEmbedding(nn.Module):
"""
Character Level Word Embedding with LSTM with a single layer.
:param char_emb_size: int, the size of character level embedding. Default: 50
"""Character-level Embedding with LSTM.
:param int char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50. say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size.
:param int hidden_size: the number of hidden units. Default: equal to char_emb_size.
""" """

def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None):
super(LSTMCharEmbedding, self).__init__() super(LSTMCharEmbedding, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size self.hidden_size = char_emb_size if hidden_size is None else hidden_size
@@ -71,7 +69,7 @@ class LSTMCharEmbedding(nn.Module):


def forward(self, x): def forward(self, x):
""" """
:param x:[ n_batch*n_word, word_length, char_emb_size]
:param x: ``[ n_batch*n_word, word_length, char_emb_size]``
:return: [ n_batch*n_word, char_emb_size] :return: [ n_batch*n_word, char_emb_size]
""" """
batch_size = x.shape[0] batch_size = x.shape[0]


+ 17
- 7
fastNLP/modules/encoder/conv.py View File

@@ -3,20 +3,30 @@


import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.init import xavier_uniform_
# import torch.nn.functional as F


from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter



# import torch.nn.functional as F


class Conv(nn.Module): class Conv(nn.Module):
"""
Basic 1-d convolution module.
initialize with xavier_uniform
"""
"""Basic 1-d convolution module, initialized with xavier_uniform.


:param int in_channels:
:param int out_channels:
:param tuple kernel_size:
:param int stride:
:param int padding:
:param int dilation:
:param int groups:
:param bool bias:
:param str activation:
:param str initial_method:
"""
def __init__(self, in_channels, out_channels, kernel_size, def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu',initial_method = None ):
groups=1, bias=True, activation='relu', initial_method=None):
super(Conv, self).__init__() super(Conv, self).__init__()
self.conv = nn.Conv1d( self.conv = nn.Conv1d(
in_channels=in_channels, in_channels=in_channels,


+ 15
- 5
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -4,17 +4,27 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter



class ConvMaxpool(nn.Module): class ConvMaxpool(nn.Module):
"""
Convolution and max-pooling module with multiple kernel sizes.
"""
"""Convolution and max-pooling module with multiple kernel sizes.


:param int in_channels:
:param int out_channels:
:param tuple kernel_sizes:
:param int stride:
:param int padding:
:param int dilation:
:param int groups:
:param bool bias:
:param str activation:
:param str initial_method:
"""
def __init__(self, in_channels, out_channels, kernel_sizes, def __init__(self, in_channels, out_channels, kernel_sizes,
stride=1, padding=0, dilation=1, stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu',initial_method = None ):
groups=1, bias=True, activation="relu", initial_method=None):
super(ConvMaxpool, self).__init__() super(ConvMaxpool, self).__init__()


# convolution # convolution


+ 6
- 9
fastNLP/modules/encoder/embedding.py View File

@@ -2,16 +2,13 @@ import torch.nn as nn




class Embedding(nn.Module): class Embedding(nn.Module):
"""
A simple lookup table
Args:
nums : the size of the lookup table
dims : the size of each vector
padding_idx : pads the tensor with zeros whenever it encounters this index
sparse : If True, gradient matrix will be a sparse tensor. In this case,
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
"""A simple lookup table.


:param int nums: the size of the lookup table
:param int dims: the size of each vector
:param int padding_idx: pads the tensor with zeros whenever it encounters this index
:param bool sparse: If True, gradient matrix will be a sparse tensor. In this case, only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0): def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
super(Embedding, self).__init__() super(Embedding, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)


+ 5
- 8
fastNLP/modules/encoder/linear.py View File

@@ -5,15 +5,12 @@ from fastNLP.modules.utils import initial_parameter


class Linear(nn.Module): class Linear(nn.Module):
""" """
Linear module
Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers
dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN
"""


:param int input_size: input size
:param int output_size: output size
:param bool bias:
:param str initial_method:
"""
def __init__(self, input_size, output_size, bias=True, initial_method=None): def __init__(self, input_size, output_size, bias=True, initial_method=None):
super(Linear, self).__init__() super(Linear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias) self.linear = nn.Linear(input_size, output_size, bias)


+ 9
- 7
fastNLP/modules/encoder/lstm.py View File

@@ -6,14 +6,16 @@ from fastNLP.modules.utils import initial_parameter
class LSTM(nn.Module): class LSTM(nn.Module):
"""Long Short Term Memory """Long Short Term Memory


Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers. Default: 1
dropout : dropout rate. Default: 0.5
bidirectional : If True, becomes a bidirectional RNN. Default: False.
:param int input_size:
:param int hidden_size:
:param int num_layers:
:param float dropout:
:param bool batch_first:
:param bool bidirectional:
:param bool bias:
:param str initial_method:
:param bool get_hidden:
""" """

def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None, get_hidden=False): bidirectional=False, bias=True, initial_method=None, get_hidden=False):
super(LSTM, self).__init__() super(LSTM, self).__init__()


+ 59
- 57
fastNLP/modules/encoder/masked_rnn.py View File

@@ -5,6 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F


from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter


def MaskedRecurrent(reverse=False): def MaskedRecurrent(reverse=False):
def forward(input, hidden, cell, mask, train=True, dropout=0): def forward(input, hidden, cell, mask, train=True, dropout=0):
""" """
@@ -254,16 +256,16 @@ class MaskedRNNBase(nn.Module):
return output, hidden return output, hidden


def step(self, input, hx=None, mask=None): def step(self, input, hx=None, mask=None):
'''
execute one step forward (only for one-directional RNN).
Args:
input (batch, input_size): input tensor of this step.
hx (num_layers, batch, hidden_size): the hidden state of last step.
mask (batch): the mask tensor of this step.
Returns:
output (batch, hidden_size): tensor containing the output of this step from the last layer of RNN.
hn (num_layers, batch, hidden_size): tensor containing the hidden state of this step
'''
"""Execute one step forward (only for one-directional RNN).
:param Tensor input: input tensor of this step. (batch, input_size)
:param Tensor hx: the hidden state of last step. (num_layers, batch, hidden_size)
:param Tensor mask: the mask tensor of this step. (batch, )
:returns:
**output** (batch, hidden_size), tensor containing the output of this step from the last layer of RNN.
**hn** (num_layers, batch, hidden_size), tensor containing the hidden state of this step
"""
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." # aha, typo! assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." # aha, typo!
batch_size = input.size(0) batch_size = input.size(0)
lstm = self.Cell is nn.LSTMCell lstm = self.Cell is nn.LSTMCell
@@ -285,25 +287,23 @@ class MaskedRNN(MaskedRNNBase):
r"""Applies a multi-layer Elman RNN with costomized non-linearity to an r"""Applies a multi-layer Elman RNN with costomized non-linearity to an
input sequence. input sequence.
For each element in the input sequence, each layer computes the following For each element in the input sequence, each layer computes the following
function:
.. math::
h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})
function. :math:`h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})`

where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is
the hidden state of the previous layer at time `t` or :math:`input_t` the hidden state of the previous layer at time `t` or :math:`input_t`
for the first layer. If nonlinearity='relu', then `ReLU` is used instead for the first layer. If nonlinearity='relu', then `ReLU` is used instead
of `tanh`. of `tanh`.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False


:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param float dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, h_0 Inputs: input, mask, h_0
- **input** (seq_len, batch, input_size): tensor containing the features - **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence. of the input sequence.
@@ -327,32 +327,33 @@ class MaskedLSTM(MaskedRNNBase):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
sequence. sequence.
For each element in the input sequence, each layer computes the following For each element in the input sequence, each layer computes the following
function:
function.

.. math:: .. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}

\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}

where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
state at time `t`, :math:`x_t` is the hidden state of the previous layer at state at time `t`, :math:`x_t` is the hidden state of the previous layer at
time `t` or :math:`input_t` for the first layer, and :math:`i_t`, time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
:math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell,
and out gates, respectively. and out gates, respectively.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False

:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, (h_0, c_0) Inputs: input, mask, (h_0, c_0)
- **input** (seq_len, batch, input_size): tensor containing the features - **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence. of the input sequence.
@@ -380,29 +381,30 @@ class MaskedGRU(MaskedRNNBase):
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
For each element in the input sequence, each layer computes the following For each element in the input sequence, each layer computes the following
function: function:

.. math:: .. math::

\begin{array}{ll} \begin{array}{ll}
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
\end{array} \end{array}

where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
state of the previous layer at time `t` or :math:`input_t` for the first state of the previous layer at time `t` or :math:`input_t` for the first
layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input, layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input,
and new gates, respectively. and new gates, respectively.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False

:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, h_0 Inputs: input, mask, h_0
- **input** (seq_len, batch, input_size): tensor containing the features - **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence. of the input sequence.


+ 3
- 6
fastNLP/modules/encoder/transformer.py View File

@@ -1,10 +1,9 @@
import torch
from torch import nn from torch import nn
import torch.nn.functional as F


from ..aggregator.attention import MultiHeadAtte from ..aggregator.attention import MultiHeadAtte
from ..other_modules import LayerNormalization from ..other_modules import LayerNormalization



class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
class SubLayer(nn.Module): class SubLayer(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte): def __init__(self, input_size, output_size, key_size, value_size, num_atte):
@@ -12,8 +11,8 @@ class TransformerEncoder(nn.Module):
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte)
self.norm1 = LayerNormalization(output_size) self.norm1 = LayerNormalization(output_size)
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), self.ffn = nn.Sequential(nn.Linear(output_size, output_size),
nn.ReLU(),
nn.Linear(output_size, output_size))
nn.ReLU(),
nn.Linear(output_size, output_size))
self.norm2 = LayerNormalization(output_size) self.norm2 = LayerNormalization(output_size)


def forward(self, input, seq_mask): def forward(self, input, seq_mask):
@@ -28,5 +27,3 @@ class TransformerEncoder(nn.Module):


def forward(self, x, seq_mask=None): def forward(self, x, seq_mask=None):
return self.layers(x, seq_mask) return self.layers(x, seq_mask)



+ 12
- 6
fastNLP/modules/encoder/variational_rnn.py View File

@@ -1,5 +1,3 @@
import math

import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence from torch.nn.utils.rnn import PackedSequence
@@ -9,15 +7,17 @@ from fastNLP.modules.utils import initial_parameter
try: try:
from torch import flip from torch import flip
except ImportError: except ImportError:
def flip(x, dims):
def flip(x, dims):
indices = [slice(None)] * x.dim() indices = [slice(None)] * x.dim()
for dim in dims: for dim in dims:
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)] return x[tuple(indices)]



class VarRnnCellWrapper(nn.Module): class VarRnnCellWrapper(nn.Module):
"""Wrapper for normal RNN Cells, make it support variational dropout """Wrapper for normal RNN Cells, make it support variational dropout
""" """

def __init__(self, cell, hidden_size, input_p, hidden_p): def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__() super(VarRnnCellWrapper, self).__init__()
self.cell = cell self.cell = cell
@@ -32,9 +32,9 @@ class VarRnnCellWrapper(nn.Module):
for other RNN, h_0, [batch_size, hidden_size] for other RNN, h_0, [batch_size, hidden_size]
:param mask_x: [batch_size, input_size] dropout mask for input :param mask_x: [batch_size, input_size] dropout mask for input
:param mask_h: [batch_size, hidden_size] dropout mask for hidden :param mask_h: [batch_size, hidden_size] dropout mask for hidden
:return output: [seq_len, bacth_size, hidden_size]
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
:return: (output, hidden)
**output**: [seq_len, bacth_size, hidden_size].
**hidden**: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]; For other RNN, h_n, [batch_size, hidden_size].
""" """
is_lstm = isinstance(hidden, tuple) is_lstm = isinstance(hidden, tuple)
input = input * mask_x.unsqueeze(0) if mask_x is not None else input input = input * mask_x.unsqueeze(0) if mask_x is not None else input
@@ -56,6 +56,7 @@ class VarRNNBase(nn.Module):
refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016)
https://arxiv.org/abs/1512.05287`. https://arxiv.org/abs/1512.05287`.
""" """

def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False, bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False): input_dropout=0, hidden_dropout=0, bidirectional=False):
@@ -138,17 +139,22 @@ class VarRNNBase(nn.Module):
class VarLSTM(VarRNNBase): class VarLSTM(VarRNNBase):
"""Variational Dropout LSTM. """Variational Dropout LSTM.
""" """

def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)



class VarRNN(VarRNNBase): class VarRNN(VarRNNBase):
"""Variational Dropout RNN. """Variational Dropout RNN.
""" """

def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs) super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs)



class VarGRU(VarRNNBase): class VarGRU(VarRNNBase):
"""Variational Dropout GRU. """Variational Dropout GRU.
""" """

def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs) super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs)

+ 23
- 34
fastNLP/modules/other_modules.py View File

@@ -29,8 +29,11 @@ class GroupNorm(nn.Module):




class LayerNormalization(nn.Module): class LayerNormalization(nn.Module):
""" Layer normalization module """
"""


:param int layer_size:
:param float eps: default=1e-3
"""
def __init__(self, layer_size, eps=1e-3): def __init__(self, layer_size, eps=1e-3):
super(LayerNormalization, self).__init__() super(LayerNormalization, self).__init__()


@@ -52,12 +55,11 @@ class LayerNormalization(nn.Module):
class BiLinear(nn.Module): class BiLinear(nn.Module):
def __init__(self, n_left, n_right, n_out, bias=True): def __init__(self, n_left, n_right, n_out, bias=True):
""" """
Args:
n_left: size of left input
n_right: size of right input
n_out: size of output
bias: If set to False, the layer will not learn an additive bias.
Default: True

:param int n_left: size of left input
:param int n_right: size of right input
:param int n_out: size of output
:param bool bias: If set to False, the layer will not learn an additive bias. Default: True
""" """
super(BiLinear, self).__init__() super(BiLinear, self).__init__()
self.n_left = n_left self.n_left = n_left
@@ -83,12 +85,9 @@ class BiLinear(nn.Module):


def forward(self, input_left, input_right): def forward(self, input_left, input_right):
""" """
Args:
input_left: Tensor
the left input tensor with shape = [batch1, batch2, ..., left_features]
input_right: Tensor
the right input tensor with shape = [batch1, batch2, ..., right_features]
Returns:
:param Tensor input_left: the left input tensor with shape = [batch1, batch2, ..., left_features]
:param Tensor input_right: the right input tensor with shape = [batch1, batch2, ..., right_features]

""" """
left_size = input_left.size() left_size = input_left.size()
right_size = input_right.size() right_size = input_right.size()
@@ -118,16 +117,11 @@ class BiLinear(nn.Module):
class BiAffine(nn.Module): class BiAffine(nn.Module):
def __init__(self, n_enc, n_dec, n_labels, biaffine=True, **kwargs): def __init__(self, n_enc, n_dec, n_labels, biaffine=True, **kwargs):
""" """
Args:
n_enc: int
the dimension of the encoder input.
n_dec: int
the dimension of the decoder input.
n_labels: int
the number of labels of the crf layer
biaffine: bool
if apply bi-affine parameter.
**kwargs:

:param int n_enc: the dimension of the encoder input.
:param int n_dec: the dimension of the decoder input.
:param int n_labels: the number of labels of the crf layer
:param bool biaffine: if apply bi-affine parameter.
""" """
super(BiAffine, self).__init__() super(BiAffine, self).__init__()
self.n_enc = n_enc self.n_enc = n_enc
@@ -154,17 +148,12 @@ class BiAffine(nn.Module):


def forward(self, input_d, input_e, mask_d=None, mask_e=None): def forward(self, input_d, input_e, mask_d=None, mask_e=None):
""" """
Args:
input_d: Tensor
the decoder input tensor with shape = [batch, length_decoder, input_size]
input_e: Tensor
the child input tensor with shape = [batch, length_encoder, input_size]
mask_d: Tensor or None
the mask tensor for decoder with shape = [batch, length_decoder]
mask_e: Tensor or None
the mask tensor for encoder with shape = [batch, length_encoder]
Returns: Tensor
the energy tensor with shape = [batch, num_label, length, length]

:param Tensor input_d: the decoder input tensor with shape = [batch, length_decoder, input_size]
:param Tensor input_e: the child input tensor with shape = [batch, length_encoder, input_size]
:param mask_d: Tensor or None, the mask tensor for decoder with shape = [batch, length_decoder]
:param mask_e: Tensor or None, the mask tensor for encoder with shape = [batch, length_encoder]
:returns: Tensor, the energy tensor with shape = [batch, num_label, length, length]
""" """
assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.' assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.'
batch, length_decoder, _ = input_d.size() batch, length_decoder, _ = input_d.size()


+ 2
- 2
fastNLP/modules/utils.py View File

@@ -15,7 +15,7 @@ def initial_parameter(net, initial_method=None):
"""A method used to initialize the weights of PyTorch models. """A method used to initialize the weights of PyTorch models.


:param net: a PyTorch model :param net: a PyTorch model
:param initial_method: str, one of the following initializations
:param str initial_method: one of the following initializations.


- xavier_uniform - xavier_uniform
- xavier_normal (default) - xavier_normal (default)
@@ -79,7 +79,7 @@ def seq_mask(seq_len, max_len):


:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. :param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch. :param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]
:return: mask, torch.LongTensor, [batch_size, max_len]


""" """
if not isinstance(seq_len, torch.Tensor): if not isinstance(seq_len, torch.Tensor):


Loading…
Cancel
Save