Browse Source

Merge pull request #121 from FengZiYjun/doc-fixing

[Doc] Improve Documentation (2)
tags/v0.3.0
Coet GitHub 5 years ago
parent
commit
ef82c1f100
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 360 additions and 334 deletions
  1. +2
    -1
      .travis.yml
  2. +4
    -4
      fastNLP/core/batch.py
  3. +5
    -1
      fastNLP/io/base_loader.py
  4. +22
    -27
      fastNLP/io/config_io.py
  5. +87
    -77
      fastNLP/io/dataset_loader.py
  6. +5
    -4
      fastNLP/io/embed_loader.py
  7. +5
    -4
      fastNLP/io/logger.py
  8. +9
    -9
      fastNLP/io/model_io.py
  9. +5
    -5
      fastNLP/modules/aggregator/attention.py
  10. +10
    -9
      fastNLP/modules/aggregator/self_attention.py
  11. +26
    -23
      fastNLP/modules/decoder/CRF.py
  12. +11
    -9
      fastNLP/modules/decoder/MLP.py
  13. +2
    -2
      fastNLP/modules/dropout.py
  14. +16
    -18
      fastNLP/modules/encoder/char_embedding.py
  15. +17
    -7
      fastNLP/modules/encoder/conv.py
  16. +15
    -5
      fastNLP/modules/encoder/conv_maxpool.py
  17. +6
    -9
      fastNLP/modules/encoder/embedding.py
  18. +5
    -8
      fastNLP/modules/encoder/linear.py
  19. +9
    -7
      fastNLP/modules/encoder/lstm.py
  20. +59
    -57
      fastNLP/modules/encoder/masked_rnn.py
  21. +3
    -6
      fastNLP/modules/encoder/transformer.py
  22. +12
    -6
      fastNLP/modules/encoder/variational_rnn.py
  23. +23
    -34
      fastNLP/modules/other_modules.py
  24. +2
    -2
      fastNLP/modules/utils.py

+ 2
- 1
.travis.yml View File

@@ -4,7 +4,8 @@ python:
# command to install dependencies
install:
- pip install --quiet -r requirements.txt
- pip install pytest pytest-cov
- pip install pytest>=3.6
- pip install pytest-cov
# command to run tests
script:
- pytest --cov=./


+ 4
- 4
fastNLP/core/batch.py View File

@@ -10,10 +10,10 @@ class Batch(object):
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):
# ...

:param dataset: a DataSet object
:param batch_size: int, the size of the batch
:param sampler: a Sampler object
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors.
:param DataSet dataset: a DataSet object
:param int batch_size: the size of the batch
:param Sampler sampler: a Sampler object
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors.

"""



+ 5
- 1
fastNLP/io/base_loader.py View File

@@ -3,7 +3,9 @@ import os


class BaseLoader(object):
"""Base loader for all loaders.

"""
def __init__(self):
super(BaseLoader, self).__init__()

@@ -32,7 +34,9 @@ class BaseLoader(object):


class DataLoaderRegister:
""""register for data sets"""
"""Register for all data sets.

"""
_readers = {}

@classmethod


+ 22
- 27
fastNLP/io/config_io.py View File

@@ -6,7 +6,11 @@ from fastNLP.io.base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""loader for configuration files"""
"""Loader for configuration.

:param str data_path: path to the config

"""

def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
@@ -19,13 +23,15 @@ class ConfigLoader(BaseLoader):

@staticmethod
def load_config(file_path, sections):
"""
:param file_path: the path of config file
:param sections: the dict of {section_name(string): Section instance}
Example:
"""Load section(s) of configuration into the ``sections`` provided. No returns.

:param str file_path: the path of config file
:param dict sections: the dict of ``{section_name(string): ConfigSection object}``
Example::

test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
:return: return nothing, but the value of attributes are saved in sessions
"""
assert isinstance(sections, dict)
cfg = configparser.ConfigParser()
@@ -60,9 +66,12 @@ class ConfigLoader(BaseLoader):


class ConfigSection(object):
"""ConfigSection is the data structure storing all key-value pairs in one section in a config file.

"""

def __init__(self):
pass
super(ConfigSection, self).__init__()

def __getitem__(self, key):
"""
@@ -132,25 +141,12 @@ class ConfigSection(object):
return self.__dict__


if __name__ == "__main__":
config = ConfigLoader('there is no data')

section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
"""
General and My can be found in config file, so the attr and
value will be updated
A cannot be found in config file, so nothing will be done
"""

config.load_config("../../test/data_for_tests/config", section)
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))


class ConfigSaver(object):
"""ConfigSaver is used to save config file and solve related conflicts.

:param str file_path: path to the config file

"""
def __init__(self, file_path):
self.file_path = file_path
if not os.path.exists(self.file_path):
@@ -244,9 +240,8 @@ class ConfigSaver(object):
def save_config_file(self, section_name, section):
"""This is the function to be called to change the config file with a single section and its name.

:param section_name: The name of section what needs to be changed and saved.
:param section: The section with key and value what needs to be changed and saved.
:return:
:param str section_name: The name of section what needs to be changed and saved.
:param ConfigSection section: The section with key and value what needs to be changed and saved.
"""
section_file = self._get_section(section_name)
if len(section_file.__dict__.keys()) == 0: # the section not in the file before


+ 87
- 77
fastNLP/io/dataset_loader.py View File

@@ -9,11 +9,12 @@ def convert_seq_dataset(data):
"""Create an DataSet instance that contains no labels.

:param data: list of list of strings, [num_examples, *].
::
[
[word_11, word_12, ...],
...
]
Example::

[
[word_11, word_12, ...],
...
]

:return: a DataSet.
"""
@@ -24,15 +25,16 @@ def convert_seq_dataset(data):


def convert_seq2tag_dataset(data):
"""Convert list of data into DataSet
"""Convert list of data into DataSet.

:param data: list of list of strings, [num_examples, *].
::
[
[ [word_11, word_12, ...], label_1 ],
[ [word_21, word_22, ...], label_2 ],
...
]
Example::

[
[ [word_11, word_12, ...], label_1 ],
[ [word_21, word_22, ...], label_2 ],
...
]

:return: a DataSet.
"""
@@ -43,15 +45,16 @@ def convert_seq2tag_dataset(data):


def convert_seq2seq_dataset(data):
"""Convert list of data into DataSet
"""Convert list of data into DataSet.

:param data: list of list of strings, [num_examples, *].
::
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
Example::

[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]

:return: a DataSet.
"""
@@ -62,20 +65,31 @@ def convert_seq2seq_dataset(data):


class DataSetLoader:
""""loader for data sets"""
"""Interface for all DataSetLoaders.

"""

def load(self, path):
""" load data in `path` into a dataset
"""Load data from a given file.

:param str path: file path
:return: a DataSet object
"""
raise NotImplementedError

def convert(self, data):
"""convert list of data into dataset
"""Optional operation to build a DataSet.

:param data: inner data structure (user-defined) to represent the data.
:return: a DataSet object
"""
raise NotImplementedError


class NativeDataSetLoader(DataSetLoader):
"""A simple example of DataSetLoader

"""
def __init__(self):
super(NativeDataSetLoader, self).__init__()

@@ -90,6 +104,9 @@ DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive')


class RawDataSetLoader(DataSetLoader):
"""A simple example of raw data reader

"""
def __init__(self):
super(RawDataSetLoader, self).__init__()

@@ -108,37 +125,35 @@ DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')


class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.

In these datasets, each line are divided by '\t'
while the first Col is the vocabulary and the second
Col is the label.
Different sentence are divided by an empty line.
e.g:
Tom label1
and label2
Jerry label1
. label3
(separated by an empty line)
Hello label4
world label5
! label3
In this file, there are two sentence "Tom and Jerry ."
and "Hello world !". Each word has its own label from label1
to label5.
"""Dataset Loader for a POS Tag dataset.

In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second
Col is the label. Different sentence are divided by an empty line.
E.g::

Tom label1
and label2
Jerry label1
. label3
(separated by an empty line)
Hello label4
world label5
! label3

In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label.
"""

def __init__(self):
super(POSDataSetLoader, self).__init__()

def load(self, data_path):
"""
:return data: three-level list
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
Example::
[
[ [word_11, word_12, ...], [label_1, label_1, ...] ],
[ [word_21, word_22, ...], [label_2, label_1, ...] ],
...
]
"""
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
@@ -188,17 +203,17 @@ class TokenizeDataSetLoader(DataSetLoader):
super(TokenizeDataSetLoader, self).__init__()

def load(self, data_path, max_seq_len=32):
"""
load pku dataset for Chinese word segmentation
"""Load pku dataset for Chinese word segmentation.
CWS (Chinese Word Segmentation) pku training dataset format:
1. Each line is a sentence.
2. Each word in a sentence is separated by space.
1. Each line is a sentence.
2. Each word in a sentence is separated by space.
This function convert the pku dataset into three-level lists with labels <BMES>.
B: beginning of a word
M: middle of a word
E: ending of a word
S: single character
B: beginning of a word
M: middle of a word
E: ending of a word
S: single character

:param str data_path: path to the data set.
:param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into
several sequences.
:return: three-level lists
@@ -239,7 +254,7 @@ class TokenizeDataSetLoader(DataSetLoader):


class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""
"""Loader for a dummy classification data set"""

def __init__(self):
super(ClassDataSetLoader, self).__init__()
@@ -254,11 +269,9 @@ class ClassDataSetLoader(DataSetLoader):
@staticmethod
def parse(lines):
"""
Params
lines: lines from dataset
Return
list(list(list())): the three level of lists are
words, sentence, and dataset

:param list lines: lines from dataset
:return: a 3-D list, indicating words, sentence, and dataset respectively.
"""
dataset = list()
for line in lines:
@@ -280,15 +293,9 @@ class ConllLoader(DataSetLoader):
"""loader for conll format files"""

def __init__(self):
"""
:param str data_path: the path to the conll data set
"""
super(ConllLoader, self).__init__()

def load(self, data_path):
"""
:return: list lines: all lines in a conll file
"""
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
@@ -297,7 +304,7 @@ class ConllLoader(DataSetLoader):
@staticmethod
def parse(lines):
"""
:param list lines:a list containing all lines in a conll file.
:param list lines: a list containing all lines in a conll file.
:return: a 3D list
"""
sentences = list()
@@ -320,8 +327,8 @@ class ConllLoader(DataSetLoader):
class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader

This loader produces data for language model training in a supervised way.
That means it has X and Y.
This loader produces data for language model training in a supervised way.
That means it has X and Y.

"""

@@ -467,6 +474,7 @@ class Conll2003Loader(DataSetLoader):

return dataset


class SNLIDataSetLoader(DataSetLoader):
"""A data set loader for SNLI data set.

@@ -478,8 +486,8 @@ class SNLIDataSetLoader(DataSetLoader):
def load(self, path_list):
"""

:param path_list: A list of file name, in the order of premise file, hypothesis file, and label file.
:return: data_set: A DataSet object.
:param list path_list: A list of file name, in the order of premise file, hypothesis file, and label file.
:return: A DataSet object.
"""
assert len(path_list) == 3
line_set = []
@@ -507,12 +515,14 @@ class SNLIDataSetLoader(DataSetLoader):
"""Convert a 3D list to a DataSet object.

:param data: A 3D tensor.
[
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
...
]
:return: data_set: A DataSet object.
Example::
[
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ],
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ],
...
]

:return: A DataSet object.
"""

data_set = DataSet()


+ 5
- 4
fastNLP/io/embed_loader.py View File

@@ -38,7 +38,7 @@ class EmbedLoader(BaseLoader):

:param str emb_file: the pre-trained embedding file path
:param str emb_type: the pre-trained embedding data format
:return dict embedding: `{str: np.array}`
:return: a dict of ``{str: np.array}``
"""
if emb_type == 'glove':
return EmbedLoader._load_glove(emb_file)
@@ -53,8 +53,9 @@ class EmbedLoader(BaseLoader):
:param str emb_file: the pre-trained embedding file path.
:param str emb_type: the pre-trained embedding format, support glove now
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim)
vocab: input vocab or vocab built by pre-train
:return (embedding_tensor, vocab):
embedding_tensor - Tensor of shape (len(word_dict), emb_dim);
vocab - input vocab or vocab built by pre-train

"""
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
@@ -95,7 +96,7 @@ class EmbedLoader(BaseLoader):
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding.
:param str emb_file: the pre-trained embedding file path.
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding
:return numpy.ndarray embedding_matrix:
:return: the embedding matrix, numpy.ndarray

"""
if vocab is None:


+ 5
- 4
fastNLP/io/logger.py View File

@@ -3,15 +3,16 @@ import os


def create_logger(logger_name, log_path, log_format=None, log_level=logging.INFO):
"""Return a logger.
"""Create a logger.

:param logger_name: str
:param log_path: str
:param str logger_name:
:param str log_path:
:param log_format:
:param log_level:
:return: logger

to use a logger:
To use a logger::

logger.debug("this is a debug message")
logger.info("this is a info message")
logger.warning("this is a warning message")


+ 9
- 9
fastNLP/io/model_io.py View File

@@ -13,10 +13,10 @@ class ModelLoader(BaseLoader):

@staticmethod
def load_pytorch(empty_model, model_path):
"""
Load model parameters from .pkl files into the empty PyTorch model.
"""Load model parameters from ".pkl" files into the empty PyTorch model.
:param empty_model: a PyTorch model with initialized parameters.
:param model_path: str, the path to the saved model.
:param str model_path: the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))

@@ -24,30 +24,30 @@ class ModelLoader(BaseLoader):
def load_pytorch_model(model_path):
"""Load the entire model.

:param str model_path: the path to the saved model.
"""
return torch.load(model_path)


class ModelSaver(object):
"""Save a model

:param str save_path: the path to the saving directory.
Example::

saver = ModelSaver("./save/model_ckpt_100.pkl")
saver.save_pytorch(model)

"""

def __init__(self, save_path):
"""

:param save_path: str, the path to the saving directory.
"""
self.save_path = save_path

def save_pytorch(self, model, param_only=True):
"""Save a pytorch model into .pkl file.
"""Save a pytorch model into ".pkl" file.

:param model: a PyTorch model
:param param_only: bool, whether only to save the model parameters or the entire model.
:param bool param_only: whether only to save the model parameters or the entire model.

"""
if param_only is True:


+ 5
- 5
fastNLP/modules/aggregator/attention.py View File

@@ -1,11 +1,12 @@
import math

import torch
from torch import nn
import math
from fastNLP.modules.utils import mask_softmax


class Attention(torch.nn.Module):

def __init__(self, normalize=False):
super(Attention, self).__init__()
self.normalize = normalize
@@ -19,9 +20,9 @@ class Attention(torch.nn.Module):
def _atten_forward(self, query, memory):
raise NotImplementedError


class DotAtte(nn.Module):
def __init__(self, key_size, value_size):
# TODO never test
super(DotAtte, self).__init__()
self.key_size = key_size
self.value_size = value_size
@@ -41,10 +42,9 @@ class DotAtte(nn.Module):
output = nn.functional.softmax(output, dim=2)
return torch.matmul(output, V)


class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
raise NotImplementedError
# TODO never test
super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):


+ 10
- 9
fastNLP/modules/aggregator/self_attention.py View File

@@ -7,13 +7,14 @@ from fastNLP.modules.utils import initial_parameter


class SelfAttention(nn.Module):
"""
Self Attention Module.
"""Self Attention Module.

Args:
input_size: int, the size for the input vector
dim: int, the width of weight matrix.
num_vec: int, the number of encoded vectors
:param int input_size:
:param int attention_unit:
:param int attention_hops:
:param float drop:
:param str initial_method:
:param bool use_cuda:
"""

def __init__(self, input_size, attention_unit=350, attention_hops=10, drop=0.5, initial_method=None,
@@ -48,7 +49,7 @@ class SelfAttention(nn.Module):
def forward(self, input, input_origin):
"""
:param input: the matrix to do attention. [baz, senLen, h_dim]
:param inp: then token index include pad token( 0 ) [baz , senLen]
:param inp: then token index include pad token( 0 ) [baz , senLen]
:return output1: the input matrix after attention operation [baz, multi-head , h_dim]
:return output2: the attention penalty term, a scalar [1]
"""
@@ -59,8 +60,8 @@ class SelfAttention(nn.Module):
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]

y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1,
2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
attention = self.ws2(y1).transpose(1, 2).contiguous()
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]

attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
attention = F.softmax(attention, 2) # [baz ,hop, len]


+ 26
- 23
fastNLP/modules/decoder/CRF.py View File

@@ -21,11 +21,13 @@ def seq_len_to_byte_mask(seq_lens):


class ConditionalRandomField(nn.Module):
def __init__(self, tag_size, include_start_end_trans=False ,initial_method = None):
"""
:param tag_size: int, num of tags
:param include_start_end_trans: bool, whether to include start/end tag
"""
"""
:param int tag_size: num of tags
:param bool include_start_end_trans: whether to include start/end tag
:param str initial_method: method for initialization
"""

def __init__(self, tag_size, include_start_end_trans=False, initial_method=None):
super(ConditionalRandomField, self).__init__()

self.include_start_end_trans = include_start_end_trans
@@ -39,6 +41,7 @@ class ConditionalRandomField(nn.Module):

# self.reset_parameter()
initial_parameter(self, initial_method)

def reset_parameter(self):
nn.init.xavier_normal_(self.trans_m)
if self.include_start_end_trans:
@@ -46,12 +49,12 @@ class ConditionalRandomField(nn.Module):
nn.init.normal_(self.end_scores)

def _normalizer_likelihood(self, logits, mask):
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
:param logits:FloatTensor, max_len x batch_size x tag_size
:param mask:ByteTensor, max_len x batch_size
:return:FloatTensor, batch_size

:param FloatTensor logits: [max_len, batch_size, tag_size]
:param ByteTensor mask: [max_len, batch_size]
:return: FloatTensor, [batch_size,]
"""
seq_len, batch_size, n_tags = logits.size()
alpha = logits[0]
@@ -70,8 +73,8 @@ class ConditionalRandomField(nn.Module):
return log_sum_exp(alpha, 1)

def _glod_score(self, logits, tags, mask):
"""
Compute the score for the gold path.
"""Compute the score for the gold path.
:param logits: FloatTensor, max_len x batch_size x tag_size
:param tags: LongTensor, max_len x batch_size
:param mask: ByteTensor, max_len x batch_size
@@ -97,12 +100,12 @@ class ConditionalRandomField(nn.Module):
return score

def forward(self, feats, tags, mask):
"""
Calculate the neg log likelihood
:param feats:FloatTensor, batch_size x max_len x tag_size
:param tags:LongTensor, batch_size x max_len
:param mask:ByteTensor batch_size x max_len
:return:FloatTensor, batch_size
"""Calculate the neg log likelihood
:param FloatTensor feats: [batch_size, max_len, tag_size]
:param LongTensor tags: [batch_size, max_len]
:param ByteTensor mask: [batch_size, max_len]
:return: FloatTensor, [batch_size,]
"""
feats = feats.transpose(0, 1)
tags = tags.transpose(0, 1).long()
@@ -113,11 +116,11 @@ class ConditionalRandomField(nn.Module):
return all_path_score - gold_path_score

def viterbi_decode(self, data, mask, get_score=False):
"""
Given a feats matrix, return best decode path and best score.
:param data:FloatTensor, batch_size x max_len x tag_size
:param mask:ByteTensor batch_size x max_len
:param get_score: bool, whether to output the decode score.
"""Given a feats matrix, return best decode path and best score.
:param FloatTensor data: [batch_size, max_len, tag_size]
:param ByteTensor mask: [batch_size, max_len]
:param bool get_score: whether to output the decode score.
:return: scores, paths
"""
batch_size, seq_len, n_tags = data.size()


+ 11
- 9
fastNLP/modules/decoder/MLP.py View File

@@ -1,21 +1,23 @@
import torch
import torch.nn as nn

from fastNLP.modules.utils import initial_parameter


class MLP(nn.Module):
def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0):
"""Multilayer Perceptrons as a decoder
"""Multilayer Perceptrons as a decoder

:param size_layer: list of int, define the size of MLP layers.
:param activation: str or function, the activation function for hidden layers.
:param initial_method: str, the name of init method.
:param dropout: float, the probability of dropout.
:param list size_layer: list of int, define the size of MLP layers.
:param str activation: str or function, the activation function for hidden layers.
:param str initial_method: the name of initialization method.
:param float dropout: the probability of dropout.

.. note::
There is no activation function applying on output layer.
.. note::
There is no activation function applying on output layer.

"""
"""

def __init__(self, size_layer, activation='relu', initial_method=None, dropout=0.0):
super(MLP, self).__init__()
self.hiddens = nn.ModuleList()
self.output = None


+ 2
- 2
fastNLP/modules/dropout.py View File

@@ -2,8 +2,8 @@ import torch


class TimestepDropout(torch.nn.Dropout):
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step.
"""This module accepts a ``[batch_size, num_timesteps, embedding_dim)]`` and use a single
dropout mask of shape ``(batch_size, embedding_dim)`` to apply on every time step.
"""

def forward(self, x):


+ 16
- 18
fastNLP/modules/encoder/char_embedding.py View File

@@ -1,5 +1,4 @@
import torch
import torch.nn.functional as F
from torch import nn

from fastNLP.modules.utils import initial_parameter
@@ -7,17 +6,17 @@ from fastNLP.modules.utils import initial_parameter

# from torch.nn.init import xavier_uniform
class ConvCharEmbedding(nn.Module):
"""Character-level Embedding with CNN.

:param int char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param tuple feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param tuple kernels: tuple of int. The width of each kernel.
"""

def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None):
"""
Character Level Word Embedding
:param char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param kernels: tuple of int. The width of each kernel.
"""
super(ConvCharEmbedding, self).__init__()
self.convs = nn.ModuleList([
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4))
@@ -27,8 +26,8 @@ class ConvCharEmbedding(nn.Module):

def forward(self, x):
"""
:param x: [batch_size * sent_length, word_length, char_emb_size]
:return: [batch_size * sent_length, sum(feature_maps), 1]
:param x: ``[batch_size * sent_length, word_length, char_emb_size]``
:return: feature map of shape [batch_size * sent_length, sum(feature_maps), 1]
"""
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2))
# [batch_size*sent_length, channel, width, height]
@@ -51,13 +50,12 @@ class ConvCharEmbedding(nn.Module):


class LSTMCharEmbedding(nn.Module):
"""
Character Level Word Embedding with LSTM with a single layer.
:param char_emb_size: int, the size of character level embedding. Default: 50
"""Character-level Embedding with LSTM.
:param int char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param hidden_size: int, the number of hidden units. Default: equal to char_emb_size.
:param int hidden_size: the number of hidden units. Default: equal to char_emb_size.
"""

def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None):
super(LSTMCharEmbedding, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size
@@ -71,7 +69,7 @@ class LSTMCharEmbedding(nn.Module):

def forward(self, x):
"""
:param x:[ n_batch*n_word, word_length, char_emb_size]
:param x: ``[ n_batch*n_word, word_length, char_emb_size]``
:return: [ n_batch*n_word, char_emb_size]
"""
batch_size = x.shape[0]


+ 17
- 7
fastNLP/modules/encoder/conv.py View File

@@ -3,20 +3,30 @@

import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_
# import torch.nn.functional as F

from fastNLP.modules.utils import initial_parameter


# import torch.nn.functional as F


class Conv(nn.Module):
"""
Basic 1-d convolution module.
initialize with xavier_uniform
"""
"""Basic 1-d convolution module, initialized with xavier_uniform.

:param int in_channels:
:param int out_channels:
:param tuple kernel_size:
:param int stride:
:param int padding:
:param int dilation:
:param int groups:
:param bool bias:
:param str activation:
:param str initial_method:
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu',initial_method = None ):
groups=1, bias=True, activation='relu', initial_method=None):
super(Conv, self).__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,


+ 15
- 5
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -4,17 +4,27 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
from fastNLP.modules.utils import initial_parameter


class ConvMaxpool(nn.Module):
"""
Convolution and max-pooling module with multiple kernel sizes.
"""
"""Convolution and max-pooling module with multiple kernel sizes.

:param int in_channels:
:param int out_channels:
:param tuple kernel_sizes:
:param int stride:
:param int padding:
:param int dilation:
:param int groups:
:param bool bias:
:param str activation:
:param str initial_method:
"""
def __init__(self, in_channels, out_channels, kernel_sizes,
stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu',initial_method = None ):
groups=1, bias=True, activation="relu", initial_method=None):
super(ConvMaxpool, self).__init__()

# convolution


+ 6
- 9
fastNLP/modules/encoder/embedding.py View File

@@ -2,16 +2,13 @@ import torch.nn as nn


class Embedding(nn.Module):
"""
A simple lookup table
Args:
nums : the size of the lookup table
dims : the size of each vector
padding_idx : pads the tensor with zeros whenever it encounters this index
sparse : If True, gradient matrix will be a sparse tensor. In this case,
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
"""A simple lookup table.

:param int nums: the size of the lookup table
:param int dims: the size of each vector
:param int padding_idx: pads the tensor with zeros whenever it encounters this index
:param bool sparse: If True, gradient matrix will be a sparse tensor. In this case, only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
super(Embedding, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)


+ 5
- 8
fastNLP/modules/encoder/linear.py View File

@@ -5,15 +5,12 @@ from fastNLP.modules.utils import initial_parameter

class Linear(nn.Module):
"""
Linear module
Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers
dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN
"""

:param int input_size: input size
:param int output_size: output size
:param bool bias:
:param str initial_method:
"""
def __init__(self, input_size, output_size, bias=True, initial_method=None):
super(Linear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias)


+ 9
- 7
fastNLP/modules/encoder/lstm.py View File

@@ -6,14 +6,16 @@ from fastNLP.modules.utils import initial_parameter
class LSTM(nn.Module):
"""Long Short Term Memory

Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers. Default: 1
dropout : dropout rate. Default: 0.5
bidirectional : If True, becomes a bidirectional RNN. Default: False.
:param int input_size:
:param int hidden_size:
:param int num_layers:
:param float dropout:
:param bool batch_first:
:param bool bidirectional:
:param bool bias:
:param str initial_method:
:param bool get_hidden:
"""

def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True, initial_method=None, get_hidden=False):
super(LSTM, self).__init__()


+ 59
- 57
fastNLP/modules/encoder/masked_rnn.py View File

@@ -5,6 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F

from fastNLP.modules.utils import initial_parameter


def MaskedRecurrent(reverse=False):
def forward(input, hidden, cell, mask, train=True, dropout=0):
"""
@@ -254,16 +256,16 @@ class MaskedRNNBase(nn.Module):
return output, hidden

def step(self, input, hx=None, mask=None):
'''
execute one step forward (only for one-directional RNN).
Args:
input (batch, input_size): input tensor of this step.
hx (num_layers, batch, hidden_size): the hidden state of last step.
mask (batch): the mask tensor of this step.
Returns:
output (batch, hidden_size): tensor containing the output of this step from the last layer of RNN.
hn (num_layers, batch, hidden_size): tensor containing the hidden state of this step
'''
"""Execute one step forward (only for one-directional RNN).
:param Tensor input: input tensor of this step. (batch, input_size)
:param Tensor hx: the hidden state of last step. (num_layers, batch, hidden_size)
:param Tensor mask: the mask tensor of this step. (batch, )
:returns:
**output** (batch, hidden_size), tensor containing the output of this step from the last layer of RNN.
**hn** (num_layers, batch, hidden_size), tensor containing the hidden state of this step
"""
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." # aha, typo!
batch_size = input.size(0)
lstm = self.Cell is nn.LSTMCell
@@ -285,25 +287,23 @@ class MaskedRNN(MaskedRNNBase):
r"""Applies a multi-layer Elman RNN with costomized non-linearity to an
input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})
function. :math:`h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})`

where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is
the hidden state of the previous layer at time `t` or :math:`input_t`
for the first layer. If nonlinearity='relu', then `ReLU` is used instead
of `tanh`.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False


:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param float dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, h_0
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.
@@ -327,32 +327,33 @@ class MaskedLSTM(MaskedRNNBase):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
sequence.
For each element in the input sequence, each layer computes the following
function:
function.

.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}

\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}

where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
state at time `t`, :math:`x_t` is the hidden state of the previous layer at
time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
:math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell,
and out gates, respectively.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False

:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, (h_0, c_0)
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.
@@ -380,29 +381,30 @@ class MaskedGRU(MaskedRNNBase):
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
For each element in the input sequence, each layer computes the following
function:

.. math::

\begin{array}{ll}
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
\end{array}

where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
state of the previous layer at time `t` or :math:`input_t` for the first
layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input,
and new gates, respectively.
Args:
input_size: The number of expected features in the input x
hidden_size: The number of features in the hidden state h
num_layers: Number of recurrent layers.
nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
bias: If False, then the layer does not use bias weights b_ih and b_hh.
Default: True
batch_first: If True, then the input and output tensors are provided
as (batch, seq, feature)
dropout: If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: If True, becomes a bidirectional RNN. Default: False

:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, h_0
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.


+ 3
- 6
fastNLP/modules/encoder/transformer.py View File

@@ -1,10 +1,9 @@
import torch
from torch import nn
import torch.nn.functional as F

from ..aggregator.attention import MultiHeadAtte
from ..other_modules import LayerNormalization


class TransformerEncoder(nn.Module):
class SubLayer(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
@@ -12,8 +11,8 @@ class TransformerEncoder(nn.Module):
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte)
self.norm1 = LayerNormalization(output_size)
self.ffn = nn.Sequential(nn.Linear(output_size, output_size),
nn.ReLU(),
nn.Linear(output_size, output_size))
nn.ReLU(),
nn.Linear(output_size, output_size))
self.norm2 = LayerNormalization(output_size)

def forward(self, input, seq_mask):
@@ -28,5 +27,3 @@ class TransformerEncoder(nn.Module):

def forward(self, x, seq_mask=None):
return self.layers(x, seq_mask)



+ 12
- 6
fastNLP/modules/encoder/variational_rnn.py View File

@@ -1,5 +1,3 @@
import math

import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
@@ -9,15 +7,17 @@ from fastNLP.modules.utils import initial_parameter
try:
from torch import flip
except ImportError:
def flip(x, dims):
def flip(x, dims):
indices = [slice(None)] * x.dim()
for dim in dims:
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)]


class VarRnnCellWrapper(nn.Module):
"""Wrapper for normal RNN Cells, make it support variational dropout
"""

def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
@@ -32,9 +32,9 @@ class VarRnnCellWrapper(nn.Module):
for other RNN, h_0, [batch_size, hidden_size]
:param mask_x: [batch_size, input_size] dropout mask for input
:param mask_h: [batch_size, hidden_size] dropout mask for hidden
:return output: [seq_len, bacth_size, hidden_size]
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
:return: (output, hidden)
**output**: [seq_len, bacth_size, hidden_size].
**hidden**: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]; For other RNN, h_n, [batch_size, hidden_size].
"""
is_lstm = isinstance(hidden, tuple)
input = input * mask_x.unsqueeze(0) if mask_x is not None else input
@@ -56,6 +56,7 @@ class VarRNNBase(nn.Module):
refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016)
https://arxiv.org/abs/1512.05287`.
"""

def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
@@ -138,17 +139,22 @@ class VarRNNBase(nn.Module):
class VarLSTM(VarRNNBase):
"""Variational Dropout LSTM.
"""

def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)


class VarRNN(VarRNNBase):
"""Variational Dropout RNN.
"""

def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(mode="RNN", Cell=nn.RNNCell, *args, **kwargs)


class VarGRU(VarRNNBase):
"""Variational Dropout GRU.
"""

def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(mode="GRU", Cell=nn.GRUCell, *args, **kwargs)

+ 23
- 34
fastNLP/modules/other_modules.py View File

@@ -29,8 +29,11 @@ class GroupNorm(nn.Module):


class LayerNormalization(nn.Module):
""" Layer normalization module """
"""

:param int layer_size:
:param float eps: default=1e-3
"""
def __init__(self, layer_size, eps=1e-3):
super(LayerNormalization, self).__init__()

@@ -52,12 +55,11 @@ class LayerNormalization(nn.Module):
class BiLinear(nn.Module):
def __init__(self, n_left, n_right, n_out, bias=True):
"""
Args:
n_left: size of left input
n_right: size of right input
n_out: size of output
bias: If set to False, the layer will not learn an additive bias.
Default: True

:param int n_left: size of left input
:param int n_right: size of right input
:param int n_out: size of output
:param bool bias: If set to False, the layer will not learn an additive bias. Default: True
"""
super(BiLinear, self).__init__()
self.n_left = n_left
@@ -83,12 +85,9 @@ class BiLinear(nn.Module):

def forward(self, input_left, input_right):
"""
Args:
input_left: Tensor
the left input tensor with shape = [batch1, batch2, ..., left_features]
input_right: Tensor
the right input tensor with shape = [batch1, batch2, ..., right_features]
Returns:
:param Tensor input_left: the left input tensor with shape = [batch1, batch2, ..., left_features]
:param Tensor input_right: the right input tensor with shape = [batch1, batch2, ..., right_features]

"""
left_size = input_left.size()
right_size = input_right.size()
@@ -118,16 +117,11 @@ class BiLinear(nn.Module):
class BiAffine(nn.Module):
def __init__(self, n_enc, n_dec, n_labels, biaffine=True, **kwargs):
"""
Args:
n_enc: int
the dimension of the encoder input.
n_dec: int
the dimension of the decoder input.
n_labels: int
the number of labels of the crf layer
biaffine: bool
if apply bi-affine parameter.
**kwargs:

:param int n_enc: the dimension of the encoder input.
:param int n_dec: the dimension of the decoder input.
:param int n_labels: the number of labels of the crf layer
:param bool biaffine: if apply bi-affine parameter.
"""
super(BiAffine, self).__init__()
self.n_enc = n_enc
@@ -154,17 +148,12 @@ class BiAffine(nn.Module):

def forward(self, input_d, input_e, mask_d=None, mask_e=None):
"""
Args:
input_d: Tensor
the decoder input tensor with shape = [batch, length_decoder, input_size]
input_e: Tensor
the child input tensor with shape = [batch, length_encoder, input_size]
mask_d: Tensor or None
the mask tensor for decoder with shape = [batch, length_decoder]
mask_e: Tensor or None
the mask tensor for encoder with shape = [batch, length_encoder]
Returns: Tensor
the energy tensor with shape = [batch, num_label, length, length]

:param Tensor input_d: the decoder input tensor with shape = [batch, length_decoder, input_size]
:param Tensor input_e: the child input tensor with shape = [batch, length_encoder, input_size]
:param mask_d: Tensor or None, the mask tensor for decoder with shape = [batch, length_decoder]
:param mask_e: Tensor or None, the mask tensor for encoder with shape = [batch, length_encoder]
:returns: Tensor, the energy tensor with shape = [batch, num_label, length, length]
"""
assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.'
batch, length_decoder, _ = input_d.size()


+ 2
- 2
fastNLP/modules/utils.py View File

@@ -15,7 +15,7 @@ def initial_parameter(net, initial_method=None):
"""A method used to initialize the weights of PyTorch models.

:param net: a PyTorch model
:param initial_method: str, one of the following initializations
:param str initial_method: one of the following initializations.

- xavier_uniform
- xavier_normal (default)
@@ -79,7 +79,7 @@ def seq_mask(seq_len, max_len):

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]
:return: mask, torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):


Loading…
Cancel
Save