Browse Source

* fix doc for modules/

tags/v0.3.0
FengZiYjun 6 years ago
parent
commit
1c8bca5db9
18 changed files with 226 additions and 212 deletions
  1. +4
    -4
      fastNLP/io/dataset_loader.py
  2. +1
    -1
      fastNLP/io/embed_loader.py
  3. +5
    -5
      fastNLP/modules/aggregator/attention.py
  4. +10
    -9
      fastNLP/modules/aggregator/self_attention.py
  5. +26
    -23
      fastNLP/modules/decoder/CRF.py
  6. +11
    -9
      fastNLP/modules/decoder/MLP.py
  7. +2
    -2
      fastNLP/modules/dropout.py
  8. +16
    -18
      fastNLP/modules/encoder/char_embedding.py
  9. +17
    -7
      fastNLP/modules/encoder/conv.py
  10. +15
    -5
      fastNLP/modules/encoder/conv_maxpool.py
  11. +6
    -9
      fastNLP/modules/encoder/embedding.py
  12. +5
    -8
      fastNLP/modules/encoder/linear.py
  13. +9
    -7
      fastNLP/modules/encoder/lstm.py
  14. +59
    -57
      fastNLP/modules/encoder/masked_rnn.py
  15. +3
    -6
      fastNLP/modules/encoder/transformer.py
  16. +12
    -6
      fastNLP/modules/encoder/variational_rnn.py
  17. +23
    -34
      fastNLP/modules/other_modules.py
  18. +2
    -2
      fastNLP/modules/utils.py

+ 4
- 4
fastNLP/io/dataset_loader.py View File

@@ -254,7 +254,7 @@ class TokenizeDataSetLoader(DataSetLoader):


class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""
"""Loader for a dummy classification data set"""

def __init__(self):
super(ClassDataSetLoader, self).__init__()
@@ -270,8 +270,8 @@ class ClassDataSetLoader(DataSetLoader):
def parse(lines):
"""

:param lines: lines from dataset
:return: list(list(list())): the three level of lists are words, sentence, and dataset
:param list lines: lines from dataset
:return: a 3-D list, indicating words, sentence, and dataset respectively.
"""
dataset = list()
for line in lines:
@@ -304,7 +304,7 @@ class ConllLoader(DataSetLoader):
@staticmethod
def parse(lines):
"""
:param list lines:a list containing all lines in a conll file.
:param list lines: a list containing all lines in a conll file.
:return: a 3D list
"""
sentences = list()


+ 1
- 1
fastNLP/io/embed_loader.py View File

@@ -88,7 +88,7 @@ class EmbedLoader(BaseLoader):
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding.
:param str emb_file: the pre-trained embedding file path.
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding
:return embedding_matrix: numpy.ndarray
:return: the embedding matrix, numpy.ndarray

"""
if vocab is None:


+ 5
- 5
fastNLP/modules/aggregator/attention.py View File

@@ -1,11 +1,12 @@
import math

import torch
from torch import nn
import math
from fastNLP.modules.utils import mask_softmax


class Attention(torch.nn.Module):

def __init__(self, normalize=False):
super(Attention, self).__init__()
self.normalize = normalize
@@ -19,9 +20,9 @@ class Attention(torch.nn.Module):
def _atten_forward(self, query, memory):
raise NotImplementedError


class DotAtte(nn.Module):
def __init__(self, key_size, value_size):
# TODO never test
super(DotAtte, self).__init__()
self.key_size = key_size
self.value_size = value_size
@@ -41,10 +42,9 @@ class DotAtte(nn.Module):
output = nn.functional.softmax(output, dim=2)
return torch.matmul(output, V)


class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
raise NotImplementedError
# TODO never test
super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):


+ 10
- 9
fastNLP/modules/aggregator/self_attention.py View File

@@ -7,13 +7,14 @@ from fastNLP.modules.utils import initial_parameter


class SelfAttention(nn.Module):
"""
Self Attention Module.
"""Self Attention Module.

Args:
input_size: int, the size for the input vector
dim: int, the width of weight matrix.
num_vec: int, the number of encoded vectors
:param int input_size:
:param int attention_unit:
:param int attention_hops:
:param float drop:
:param str initial_method:
:param bool use_cuda:
"""

def __init__(self, input_size, attention_unit=350, attention_hops=10, drop=0.5, initial_method=None,
@@ -48,7 +49,7 @@ class SelfAttention(nn.Module):
def forward(self, input, input_origin):
"""
:param input: the matrix to do attention. [baz, senLen, h_dim]
:param inp: then token index include pad token( 0 ) [baz , senLen]
:param inp: then token index include pad token( 0 ) [baz , senLen]
:return output1: the input matrix after attention operation [baz, multi-head , h_dim]
:return output2: the attention penalty term, a scalar [1]
"""
@@ -59,8 +60,8 @@ class SelfAttention(nn.Module):
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]

y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1,
2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
attention = self.ws2(y1).transpose(1, 2).contiguous()
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]

attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
attention = F.softmax(attention, 2) # [baz ,hop, len]


+ 26
- 23
fastNLP/modules/decoder/CRF.py View File

@@ -21,11 +21,13 @@ def seq_len_to_byte_mask(seq_lens):


class ConditionalRandomField(nn.Module):
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None):
"""
:param tag_size: int, num of tags
:param include_start_end_trans: bool, whether to include start/end tag
"""
"""
:param int tag_size: num of tags
:param bool include_start_end_trans: whether to include start/end tag
:param str initial_method: method for initialization
"""

def __init__(self, tag_size, include_start_end_trans=False, initial_method=None):
super(ConditionalRandomField, self).__init__()

self.include_start_end_trans = include_start_end_trans
@@ -39,6 +41,7 @@ class ConditionalRandomField(nn.Module):

# self.reset_parameter()
initial_parameter(self, initial_method)

def reset_parameter(self):
nn.init.xavier_normal_(self.trans_m)
if self.include_start_end_trans:
@@ -46,12 +49,12 @@ class ConditionalRandomField(nn.Module):
nn.init.normal_(self.end_scores)

def _normalizer_likelihood(self, logits, mask):
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
:param logits:FloatTensor, max_len x batch_size x tag_size
:param mask:ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size

:param FloatTensor logits: [max_len, batch_size, tag_size]
:param ByteTensor mask: [max_len, batch_size]
:return: FloatTensor, [batch_size,]
"""
seq_len, batch_size, n_tags = logits.size()
alpha = logits[0]
@@ -70,8 +73,8 @@ class ConditionalRandomField(nn.Module):
return log_sum_exp(alpha, 1)

def _glod_score(self, logits, tags, mask):
"""
Compute the score for the gold path.
"""Compute the score for the gold path.
:param logits: FloatTensor, max_len x batch_size x tag_size
:param tags: LongTensor, max_len x batch_size
:param mask: ByteTensor, max_len x batch_size
@@ -97,12 +100,12 @@ class ConditionalRandomField(nn.Module):
return score

def forward(self, feats, tags, mask):
"""
Calculate the neg log likelihood
:param feats:FloatTensor, batch_size x max_len x tag_size
:param tags:LongTensor, batch_size x max_len
:param mask:ByteTensor batch_size x max_len
:return:FloatTensor, batch_size
"""Calculate the neg log likelihood
:param FloatTensor feats: [batch_size, max_len, tag_size]
:param LongTensor tags: [batch_size, max_len]
:param ByteTensor mask: [batch_size, max_len]
:return: FloatTensor, [batch_size,]
"""
feats = feats.transpose(0, 1)
tags = tags.transpose(0, 1).long()
@@ -113,11 +116,11 @@ class ConditionalRandomField(nn.Module):
return all_path_score - gold_path_score

def viterbi_decode(self, data, mask, get_score=False):
"""
Given a feats matrix, return best decode path and best score.
:param data:FloatTensor, batch_size x max_len x tag_size
:param mask:ByteTensor batch_size x max_len
:param get_score: bool, whether to output the decode score.
"""Given a feats matrix, return best decode path and best score.
:param FloatTensor data: [batch_size, max_len, tag_size]
:param ByteTensor mask: [batch_size, max_len]
:param bool get_score: whether to output the decode score.
:return: scores, paths
"""
batch_size, seq_len, n_tags = data.size()


+ 11
- 9
fastNLP/modules/decoder/MLP.py View File

@@ -1,21 +1,23 @@
import torch
import torch.nn as nn

from fastNLP.modules.utils import initial_parameter


class MLP(nn.Module):
def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0):
"""Multilayer Perceptrons as a decoder
"""Multilayer Perceptrons as a decoder

:param size_layer: list of int, define the size of MLP layers.
:param activation: str or function, the activation function for hidden layers.
:param initial_method: str, the name of init method.
:param dropout: float, the probability of dropout.
:param list size_layer: list of int, define the size of MLP layers.
:param str activation: str or function, the activation function for hidden layers.
:param str initial_method: the name of initialization method.
:param float dropout: the probability of dropout.

.. note::
There is no activation function applying on output layer.
.. note::
There is no activation function applying on output layer.

"""
"""

def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0):
super(MLP, self).__init__()
self.hiddens = nn.ModuleList()
self.output = None


+ 2
- 2
fastNLP/modules/dropout.py View File

@@ -2,8 +2,8 @@ import torch


class TimestepDropout(torch.nn.Dropout):
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step.
"""This module accepts a ``[batch_size, num_timesteps, embedding_dim)]`` and use a single
dropout mask of shape ``(batch_size, embedding_dim)`` to apply on every time step.
"""

def forward(self, x):


+ 16
- 18
fastNLP/modules/encoder/char_embedding.py View File

@@ -1,5 +1,4 @@
import torch
import torch.nn.functional as F
from torch import nn

from fastNLP.modules.utils import initial_parameter
@@ -7,17 +6,17 @@ from fastNLP.modules.utils import initial_parameter

# from torch.nn.init import xavier_uniform
class ConvCharEmbedding(nn.Module):
"""Character-level Embedding with CNN.

:param int char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param tuple feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param tuple kernels: tuple of int. The width of each kernel.
"""

def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None):
"""
Character Level Word Embedding
:param char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param kernels: tuple of int. The width of each kernel.
"""
super(ConvCharEmbedding, self).__init__()
self.convs = nn.ModuleList([
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4))
@@ -27,8 +26,8 @@ class ConvCharEmbedding(nn.Module):

def forward(self, x):
"""
:param x: [batch_size * sent_length, word_length, char_emb_size]
:return: [batch_size * sent_length, sum(feature_maps), 1]
:param x: ``[batch_size * sent_length, word_length, char_emb_size]``
:return: feature map of shape [batch_size * sent_length, sum(feature_maps), 1]
"""
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2))
# [batch_size*sent_length, channel, width, height]
@@ -51,13 +50,12 @@ class ConvCharEmbedding(nn.Module):


class LSTMCharEmbedding(nn.Module):
"""
Character Level Word Embedding with LSTM with a single layer.
:param char_emb_size: int, the size of character level embedding. Default: 50
"""Character-level Embedding with LSTM.
:param int char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size.
:param int hidden_size: the number of hidden units. Default: equal to char_emb_size.
"""

def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None):
super(LSTMCharEmbedding, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size
@@ -71,7 +69,7 @@ class LSTMCharEmbedding(nn.Module):

def forward(self, x):
"""
:param x:[ n_batch*n_word, word_length, char_emb_size]
:param x: ``[ n_batch*n_word, word_length, char_emb_size]``
:return: [ n_batch*n_word, char_emb_size]
"""
batch_size = x.shape[0]


+ 17
- 7
fastNLP/modules/encoder/conv.py View File

@@ -3,20 +3,30 @@

import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_
# import torch.nn.functional as F

from fastNLP.modules.utils import initial_parameter


# import torch.nn.functional as F


class Conv(nn.Module):
"""
Basic 1-d convolution module.
initialize with xavier_uniform
"""
"""Basic 1-d convolution module, initialized with xavier_uniform.

:param int in_channels:
:param int out_channels:
:param tuple kernel_size:
:param int stride:
:param int padding:
:param int dilation:
:param int groups:
:param bool bias:
:param str activation:
:param str initial_method:
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu',initial_method = None ):
groups=1, bias=True, activation='relu', initial_method=None):
super(Conv, self).__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,


+ 15
- 5
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -4,17 +4,27 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
from fastNLP.modules.utils import initial_parameter


class ConvMaxpool(nn.Module):
"""
Convolution and max-pooling module with multiple kernel sizes.
"""
"""Convolution and max-pooling module with multiple kernel sizes.

:param int in_channels:
:param int out_channels:
:param tuple kernel_sizes:
:param int stride:
:param int padding:
:param int dilation:
:param int groups:
:param bool bias:
:param str activation:
:param str initial_method:
"""
def __init__(self, in_channels, out_channels, kernel_sizes,
stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu',initial_method = None ):
groups=1, bias=True, activation="relu", initial_method=None):
super(ConvMaxpool, self).__init__()

# convolution


+ 6
- 9
fastNLP/modules/encoder/embedding.py View File

@@ -2,16 +2,13 @@ import torch.nn as nn


class Embedding(nn.Module):
"""
A simple lookup table
Args:
nums : the size of the lookup table
dims : the size of each vector
padding_idx : pads the tensor with zeros whenever it encounters this index
sparse : If True, gradient matrix will be a sparse tensor. In this case,
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
"""A simple lookup table.

:param int nums: the size of the lookup table
:param int dims: the size of each vector
:param int padding_idx: pads the tensor with zeros whenever it encounters this index
:param bool sparse: If True, gradient matrix will be a sparse tensor. In this case, only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
super(Embedding, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)


+ 5
- 8
fastNLP/modules/encoder/linear.py View File

@@ -5,15 +5,12 @@ from fastNLP.modules.utils import initial_parameter

class Linear(nn.Module):
"""
Linear module
Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers
dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN
"""

:param int input_size: input size
:param int output_size: output size
:param bool bias:
:param str initial_method:
"""
def __init__(self, input_size, output_size, bias=True, initial_method=None):
super(Linear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias)


+ 9
- 7
fastNLP/modules/encoder/lstm.py View File

@@ -6,14 +6,16 @@ from fastNLP.modules.utils import initial_parameter
class LSTM(nn.Module):
"""Long Short Term Memory

Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers. Default: 1
dropout : dropout rate. Default: 0.5
bidirectional : If True, becomes a bidirectional RNN. Default: False.
:param int input_size:
:param int hidden_size:
:param int num_layers:
:param float dropout:
:param bool batch_first:
:param bool bidirectional:
:param bool bias:
:param str initial_method:
:param bool get_hidden:
"""

def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None, get_hidden=False):
super(LSTM, self).__init__()


+ 59
- 57
fastNLP/modules/encoder/masked_rnn.py View File

@@ -5,6 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F

from fastNLP.modules.utils import initial_parameter


def MaskedRecurrent(reverse=False):
def forward(input, hidden, cell, mask, train=True, dropout=0):
"""
@@ -254,16 +256,16 @@ class MaskedRNNBase(nn.Module):
return output, hidden

def step(self, input, hx=None, mask=None):
'''
execute one step forward (only for one-directional RNN).
Args:
input (batch, input_size): input tensor of this step.
hx (num_layers, batch, hidden_size): the hidden state of last step.
mask (batch): the mask tensor of this step.
Returns:
output (batch, hidden_size): tensor containing the output of this step from the last layer of RNN.
hn (num_layers, batch, hidden_size): tensor containing the hidden state of this step
'''
"""Execute one step forward (only for one-directional RNN).
:param Tensor input: input tensor of this step. (batch, input_size)
:param Tensor hx: the hidden state of last step. (num_layers, batch, hidden_size)
:param Tensor mask: the mask tensor of this step. (batch, )
:returns:
**output** (batch, hidden_size), tensor containing the output of this step from the last layer of RNN.
**hn** (num_layers, batch, hidden_size), tensor containing the hidden state of this step
"""
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." # aha, typo!
batch_size = input.size(0)
lstm = self.Cell is nn.LSTMCell
@@ -285,25 +287,23 @@ class MaskedRNN(MaskedRNNBase):
r"""Applies a multi-layer Elman RNN with costomized non-linearity to an
input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})
function. :math:`h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})`

where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is
the hidden state of the previous layer at time `t` or :math:`input_t`
for the first layer. If nonlinearity='relu', then `ReLU` is used instead
of `tanh`.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False


:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param float dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, h_0
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.
@@ -327,32 +327,33 @@ class MaskedLSTM(MaskedRNNBase):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
sequence.
For each element in the input sequence, each layer computes the following
function:
function.

.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}

\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}

where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
state at time `t`, :math:`x_t` is the hidden state of the previous layer at
time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
:math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell,
and out gates, respectively.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False

:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, (h_0, c_0)
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.
@@ -380,29 +381,30 @@ class MaskedGRU(MaskedRNNBase):
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
For each element in the input sequence, each layer computes the following
function:

.. math::

\begin{array}{ll}
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
\end{array}

where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
state of the previous layer at time `t` or :math:`input_t` for the first
layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input,
and new gates, respectively.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False

:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, h_0
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.


+ 3
- 6
fastNLP/modules/encoder/transformer.py View File

@@ -1,10 +1,9 @@
import torch
from torch import nn
import torch.nn.functional as F

from ..aggregator.attention import MultiHeadAtte
from ..other_modules import LayerNormalization


class TransformerEncoder(nn.Module):
class SubLayer(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
@@ -12,8 +11,8 @@ class TransformerEncoder(nn.Module):
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte)
self.norm1 = LayerNormalization(output_size)
self.ffn = nn.Sequential(nn.Linear(output_size, output_size),
nn.ReLU(),
nn.Linear(output_size, output_size))
nn.ReLU(),
nn.Linear(output_size, output_size))
self.norm2 = LayerNormalization(output_size)

def forward(self, input, seq_mask):
@@ -28,5 +27,3 @@ class TransformerEncoder(nn.Module):

def forward(self, x, seq_mask=None):
return self.layers(x, seq_mask)



+ 12
- 6
fastNLP/modules/encoder/variational_rnn.py View File

@@ -1,5 +1,3 @@
import math

import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
@@ -9,15 +7,17 @@ from fastNLP.modules.utils import initial_parameter
try:
from torch import flip
except ImportError:
def flip(x, dims):
def flip(x, dims):
indices = [slice(None)] * x.dim()
for dim in dims:
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)]


class VarRnnCellWrapper(nn.Module):
"""Wrapper for normal RNN Cells, make it support variational dropout
"""

def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
@@ -32,9 +32,9 @@ class VarRnnCellWrapper(nn.Module):
for other RNN, h_0, [batch_size, hidden_size]
:param mask_x: [batch_size, input_size] dropout mask for input
:param mask_h: [batch_size, hidden_size] dropout mask for hidden
:return output: [seq_len, bacth_size, hidden_size]
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
:return: (output, hidden)
**output**: [seq_len, bacth_size, hidden_size].
**hidden**: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]; For other RNN, h_n, [batch_size, hidden_size].
"""
is_lstm = isinstance(hidden, tuple)
input = input * mask_x.unsqueeze(0) if mask_x is not None else input
@@ -56,6 +56,7 @@ class VarRNNBase(nn.Module):
refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016)
https://arxiv.org/abs/1512.05287`.
"""

def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
@@ -138,17 +139,22 @@ class VarRNNBase(nn.Module):
class VarLSTM(VarRNNBase):
"""Variational Dropout LSTM.
"""

def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)


class VarRNN(VarRNNBase):
"""Variational Dropout RNN.
"""

def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs)


class VarGRU(VarRNNBase):
"""Variational Dropout GRU.
"""

def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs)

+ 23
- 34
fastNLP/modules/other_modules.py View File

@@ -29,8 +29,11 @@ class GroupNorm(nn.Module):


class LayerNormalization(nn.Module):
""" Layer normalization module """
"""

:param int layer_size:
:param float eps: default=1e-3
"""
def __init__(self, layer_size, eps=1e-3):
super(LayerNormalization, self).__init__()

@@ -52,12 +55,11 @@ class LayerNormalization(nn.Module):
class BiLinear(nn.Module):
def __init__(self, n_left, n_right, n_out, bias=True):
"""
Args:
n_left: size of left input
n_right: size of right input
n_out: size of output
bias: If set to False, the layer will not learn an additive bias.
Default: True

:param int n_left: size of left input
:param int n_right: size of right input
:param int n_out: size of output
:param bool bias: If set to False, the layer will not learn an additive bias. Default: True
"""
super(BiLinear, self).__init__()
self.n_left = n_left
@@ -83,12 +85,9 @@ class BiLinear(nn.Module):

def forward(self, input_left, input_right):
"""
Args:
input_left: Tensor
the left input tensor with shape = [batch1, batch2, ..., left_features]
input_right: Tensor
the right input tensor with shape = [batch1, batch2, ..., right_features]
Returns:
:param Tensor input_left: the left input tensor with shape = [batch1, batch2, ..., left_features]
:param Tensor input_right: the right input tensor with shape = [batch1, batch2, ..., right_features]

"""
left_size = input_left.size()
right_size = input_right.size()
@@ -118,16 +117,11 @@ class BiLinear(nn.Module):
class BiAffine(nn.Module):
def __init__(self, n_enc, n_dec, n_labels, biaffine=True, **kwargs):
"""
Args:
n_enc: int
the dimension of the encoder input.
n_dec: int
the dimension of the decoder input.
n_labels: int
the number of labels of the crf layer
biaffine: bool
if apply bi-affine parameter.
**kwargs:

:param int n_enc: the dimension of the encoder input.
:param int n_dec: the dimension of the decoder input.
:param int n_labels: the number of labels of the crf layer
:param bool biaffine: if apply bi-affine parameter.
"""
super(BiAffine, self).__init__()
self.n_enc = n_enc
@@ -154,17 +148,12 @@ class BiAffine(nn.Module):

def forward(self, input_d, input_e, mask_d=None, mask_e=None):
"""
Args:
input_d: Tensor
the decoder input tensor with shape = [batch, length_decoder, input_size]
input_e: Tensor
the child input tensor with shape = [batch, length_encoder, input_size]
mask_d: Tensor or None
the mask tensor for decoder with shape = [batch, length_decoder]
mask_e: Tensor or None
the mask tensor for encoder with shape = [batch, length_encoder]
Returns: Tensor
the energy tensor with shape = [batch, num_label, length, length]

:param Tensor input_d: the decoder input tensor with shape = [batch, length_decoder, input_size]
:param Tensor input_e: the child input tensor with shape = [batch, length_encoder, input_size]
:param mask_d: Tensor or None, the mask tensor for decoder with shape = [batch, length_decoder]
:param mask_e: Tensor or None, the mask tensor for encoder with shape = [batch, length_encoder]
:returns: Tensor, the energy tensor with shape = [batch, num_label, length, length]
"""
assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.'
batch, length_decoder, _ = input_d.size()


+ 2
- 2
fastNLP/modules/utils.py View File

@@ -15,7 +15,7 @@ def initial_parameter(net, initial_method=None):
"""A method used to initialize the weights of PyTorch models.

:param net: a PyTorch model
:param initial_method: str, one of the following initializations
:param str initial_method: one of the following initializations.

- xavier_uniform
- xavier_normal (default)
@@ -79,7 +79,7 @@ def seq_mask(seq_len, max_len):

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]
:return: mask, torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):


Loading…
Cancel
Save