Browse Source

补充注释,并修改部分代码

tags/v0.4.10
yh_cc 6 years ago
parent
commit
06891cf90a
16 changed files with 340 additions and 755 deletions
  1. +2
    -0
      fastNLP/__init__.py
  2. +13
    -2
      fastNLP/core/callback.py
  3. +3
    -0
      fastNLP/core/optimizer.py
  4. +48
    -12
      fastNLP/core/tester.py
  5. +5
    -4
      fastNLP/core/trainer.py
  6. +24
    -26
      fastNLP/io/embed_loader.py
  7. +34
    -13
      fastNLP/models/cnn_text_classification.py
  8. +132
    -126
      fastNLP/models/sequence_modeling.py
  9. +32
    -34
      fastNLP/modules/decoder/CRF.py
  10. +6
    -7
      fastNLP/modules/decoder/utils.py
  11. +0
    -2
      fastNLP/modules/encoder/__init__.py
  12. +0
    -58
      fastNLP/modules/encoder/conv.py
  13. +38
    -17
      fastNLP/modules/encoder/conv_maxpool.py
  14. +0
    -424
      fastNLP/modules/encoder/masked_rnn.py
  15. +0
    -27
      test/modules/test_masked_rnn.py
  16. +3
    -3
      test/test_tutorials.py

+ 2
- 0
fastNLP/__init__.py View File

@@ -1,3 +1,5 @@
from .core import *
from . import models
from . import modules

__version__ = '0.4.0'

+ 13
- 2
fastNLP/core/callback.py View File

@@ -1,4 +1,5 @@
"""
Callback的说明文档

.. _Callback:

@@ -28,7 +29,6 @@ class Callback(object):
def trainer(self):
"""
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。
:return:
"""
return self._trainer

@@ -323,11 +323,16 @@ class GradientClipCallback(Callback):

class CallbackException(BaseException):
def __init__(self, msg):
"""
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。
:param str msg: Exception的信息。
"""
super(CallbackException, self).__init__(msg)


class EarlyStopError(CallbackException):
def __init__(self, msg):
"""用于EarlyStop时从Trainer训练循环中跳出。"""
super(EarlyStopError, self).__init__(msg)


@@ -360,7 +365,13 @@ class EarlyStopCallback(Callback):

class LRScheduler(Callback):
def __init__(self, lr_scheduler):
"""对PyTorch LR Scheduler的包装
"""对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用

Example::

from fastNLP import LRScheduler



:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler
"""


+ 3
- 0
fastNLP/core/optimizer.py View File

@@ -13,6 +13,9 @@ class Optimizer(object):
self.model_params = model_params
self.settings = kwargs

def construct_from_pytorch(self, model_params):
raise NotImplementedError

def _get_require_grads_param(self, params):
"""
将params中不需要gradient的删除


+ 48
- 12
fastNLP/core/tester.py View File

@@ -14,20 +14,56 @@ from fastNLP.core.utils import _get_device


class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set.
"""
Tester是在提供数据,模型以及metric的情况下进行性能测试的类

Example::

import numpy as np
import torch
from torch import nn
from fastNLP import Tester
from fastNLP import DataSet
from fastNLP import AccuracyMetric


class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, a):
return {'pred': self.fc(a.unsqueeze(1)).squeeze(1)}

model = Model()

dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2})

dataset.set_input('a')
dataset.set_target('b')

tester = Tester(dataset, model, metrics=AccuracyMetric())
eval_results = tester.test()

这里Metric的映射规律是和 Trainer_ 中一致的,请参考 Trainer_ 使用metrics。


:param DataSet data: a validation/development set
:param torch.nn.modules.module model: a PyTorch model
:param MetricBase metrics: a metric object or a list of metrics (List[MetricBase])
:param int batch_size: batch size for validation
:param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持
以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中,
可见的第二个GPU中; torch.device,将模型装载到torch.device上。
:param int verbose: the number of steps after which an information is printed.

"""

def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1):
"""传入模型,数据以及metric进行验证。

:param DataSet data: 需要测试的数据集
:param torch.nn.module model: 使用的模型
:param MetricBase metrics: 一个Metric或者一个列表的metric对象
:param int batch_size: evaluation时使用的batch_size有多大。
:param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持
以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中,
可见的第二个GPU中; torch.device,将模型装载到torch.device上。
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。

"""

super(Tester, self).__init__()

if not isinstance(data, DataSet):
@@ -59,10 +95,10 @@ class Tester(object):
self._predict_func = self._model.forward

def test(self):
"""Start test or validation.

:return eval_results: a dictionary whose keys are the class name of metrics to use, values are the evaluation results of these metrics.
"""开始进行验证,并返回验证结果。

:return dict(dict) eval_results: dict为二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。
一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。
"""
# turn on the testing mode; clean up the history
network = self._model


+ 5
- 4
fastNLP/core/trainer.py View File

@@ -213,7 +213,7 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
from torch.optim import SGD
from fastNLP import Trainer
from fastNLP import DataSet
from fastNLP.core.metrics import AccuracyMetric
from fastNLP import AccuracyMetric
import torch

class Model(nn.Module):
@@ -322,7 +322,7 @@ from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import _get_func_signature
from fastNLP.core.utils import _get_device
from fastNLP.core.optimizer import Optimizer

class Trainer(object):
def __init__(self, train_data, model, optimizer, loss=None,
@@ -336,8 +336,7 @@ class Trainer(object):
"""
:param DataSet train_data: 训练集
:param nn.modules model: 待训练的模型
:param Optimizer,None optimizer: 优化器,pytorch的torch.optim.Optimizer类型。如果为None,则Trainer不会更新模型,
请确保已在callback中进行了更新。
:param torch.optim.Optimizer,None optimizer: 优化器。如果为None,则Trainer不会更新模型,请确保已在callback中进行了更新。
:param int batch_size: 训练和验证的时候的batch大小。
:param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。
:param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。
@@ -438,6 +437,8 @@ class Trainer(object):

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
elif isinstance(optimizer, Optimizer):
self.optimizer = optimizer.construct_from_pytorch(model.parameters())
elif optimizer is None:
warnings.warn("The optimizer is set to None, Trainer will update your model. Make sure you update the model"
" in the callback.")


+ 24
- 26
fastNLP/io/embed_loader.py View File

@@ -8,7 +8,7 @@ from fastNLP.io.base_loader import BaseLoader
import warnings

class EmbedLoader(BaseLoader):
"""docstring for EmbedLoader"""
"""这个类用于从预训练的Embedding中load数据。"""

def __init__(self):
super(EmbedLoader, self).__init__()
@@ -16,18 +16,17 @@ class EmbedLoader(BaseLoader):
@staticmethod
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'):
"""
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining
embedding are initialized from a normal distribution which has the mean and std of the found words vectors.
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements).

:param embed_filepath: str, where to read pretrain embedding
:param vocab: Vocabulary.
:param dtype: the dtype of the embedding matrix
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1.
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will
raise
:return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain
embedding
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是
word2vec(第一行只有两个元素)还是glove格式的数据。

:param str embed_filepath: 预训练的embedding的路径。
:param Vocabulary vocab: 词表,读取出现在vocab中的词的embedding。没有出现在vocab中的词的embedding将通过找到的词的
embedding的正态分布采样出来,以使得整个Embedding是同分布的。
:param dtype: 读出的embedding的类型
:param bool normalize: 是否将每个vector归一化到norm为1
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地
方在于词表有空行或者词表出现了维度不一致。
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
"""
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported."
if not os.path.exists(embed_filepath):
@@ -76,19 +75,18 @@ class EmbedLoader(BaseLoader):
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True,
error='ignore'):
"""
load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding.
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements).

:param embed_filepath: str, where to read pretrain embedding
:param dtype: the dtype of the embedding matrix
:param padding: the padding tag for vocabulary.
:param unknown: the unknown tag for vocabulary.
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1.
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will
:raise
:return: np.ndarray() is determined by the pretraining embeddings
Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>]

从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。

:param str embed_filepath: 预训练的embedding的路径。
:param dtype: 读出的embedding的类型
:param str padding: the padding tag for vocabulary.
:param str unknown: the unknown tag for vocabulary.
:param bool normalize: 是否将每个vector归一化到norm为1
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地
方在于词表有空行或者词表出现了维度不一致。
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
:return: numpy.ndarray,Vocabulary embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与
是否使用padding, 以及unknown有没有在词表中找到对应的词。Vocabulary中的词的顺序与Embedding的顺序是一一对应的。
"""
vocab = Vocabulary(padding=padding, unknown=unknown)
vec_dict = {}


+ 34
- 13
fastNLP/models/cnn_text_classification.py View File

@@ -3,29 +3,38 @@

import torch
import torch.nn as nn
import numpy as np

# import torch.nn.functional as F
import fastNLP.modules.encoder as encoder


class CNNText(torch.nn.Module):
"""
Text classification model by character CNN, the implementation of paper
'Yoon Kim. 2014. Convolution Neural Networks for Sentence
Classification.'
使用CNN进行文本分类的模型
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.'
"""

def __init__(self, embed_num,
def __init__(self, vocab_size,
embed_dim,
num_classes,
kernel_nums=(3, 4, 5),
kernel_sizes=(3, 4, 5),
padding=0,
dropout=0.5):
"""

:param int vocab_size: 词表的大小
:param int embed_dim: 词embedding的维度大小
:param int num_classes: 一共有多少类
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param int padding:
:param float dropout: Dropout的大小
"""
super(CNNText, self).__init__()

# no support for pre-trained embedding currently
self.embed = encoder.Embedding(embed_num, embed_dim)
self.embed = encoder.Embedding(vocab_size, embed_dim)
self.conv_pool = encoder.ConvMaxpool(
in_channels=embed_dim,
out_channels=kernel_nums,
@@ -34,24 +43,36 @@ class CNNText(torch.nn.Module):
self.dropout = nn.Dropout(dropout)
self.fc = encoder.Linear(sum(kernel_nums), num_classes)

def forward(self, word_seq):
def init_embed(self, embed):
"""
加载预训练的模型
:param numpy.ndarray embed: vocab_size x embed_dim的embedding
:return:
"""
assert isinstance(embed, np.ndarray)
assert embed.shape == self.embed.embed.weight.shape
self.embed.embed.weight.data = torch.from_numpy(embed)

def forward(self, words, seq_len=None):
"""

:param word_seq: torch.LongTensor, [batch_size, seq_len]
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index
:param torch.LongTensor seq_len: [batch,] 每个句子的长度
:return output: dict of torch.LongTensor, [batch_size, num_classes]
"""
x = self.embed(word_seq) # [N,L] -> [N,L,C]
x = self.embed(words) # [N,L] -> [N,L,C]
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]
return {'pred': x}

def predict(self, word_seq):
def predict(self, words, seq_len=None):
"""
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index
:param torch.LongTensor seq_len: [batch,] 每个句子的长度

:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return predict: dict of torch.LongTensor, [batch_size, seq_len]
:return predict: dict of torch.LongTensor, [batch_size, ]
"""
output = self(word_seq)
output = self(words, seq_len)
_, predict = output['pred'].max(dim=1)
return {'pred': predict}

+ 132
- 126
fastNLP/models/sequence_modeling.py View File

@@ -8,47 +8,64 @@ from fastNLP.modules.utils import seq_mask

class SeqLabeling(BaseModel):
"""
PyTorch Network for sequence labeling
一个基础的Sequence labeling的模型
"""

def __init__(self, args):
def __init__(self, vocab_size, embed_dim, hidden_size, num_classes):
"""
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。

:param int vocab_size: 词表大小。
:param int embed_dim: embedding的维度
:param int hidden_size: LSTM隐藏层的大小
:param int num_classes: 一共有多少类
"""
super(SeqLabeling, self).__init__()
vocab_size = args["vocab_size"]
word_emb_dim = args["word_emb_dim"]
hidden_dim = args["rnn_hidden_units"]
num_classes = args["num_classes"]

self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim)
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim)
self.Linear = encoder.linear.Linear(hidden_dim, num_classes)

self.Embedding = encoder.embedding.Embedding(vocab_size, embed_dim)
self.Rnn = encoder.lstm.LSTM(embed_dim, hidden_size)
self.Linear = encoder.linear.Linear(hidden_size, num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
self.mask = None

def forward(self, word_seq, word_seq_origin_len, truth=None):
def forward(self, words, seq_len, target):
"""
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences.
:param truth: LongTensor, [batch_size, max_len]
:param torch.LongTensor words: [batch_size, max_len],序列的index
:param torch.LongTensor seq_len: [batch_size,], 这个序列的长度
:param torch.LongTensor target: [batch_size, max_len], 序列的目标值
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training.
"""
assert word_seq.shape[0] == word_seq_origin_len.shape[0]
if truth is not None:
assert truth.shape == word_seq.shape
self.mask = self.make_mask(word_seq, word_seq_origin_len)
assert words.shape[0] == seq_len.shape[0]
assert target.shape == words.shape
self.mask = self._make_mask(words, seq_len)

x = self.Embedding(word_seq)
x = self.Embedding(words)
# [batch_size, max_len, word_emb_dim]
x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction]
x = self.Linear(x)
# [batch_size, max_len, num_classes]
return {"loss": self._internal_loss(x, truth) if truth is not None else None,
"predict": self.decode(x)}
return {"loss": self._internal_loss(x, target)}

def loss(self, x, y):
""" Since the loss has been computed in forward(), this function simply returns x."""
return x
def predict(self, words, seq_len):
"""
用于在预测时使用

:param torch.LongTensor words: [batch_size, max_len]
:param torch.LongTensor seq_len: [batch_size,]
:return:
"""
self.mask = self._make_mask(words, seq_len)

x = self.Embedding(words)
# [batch_size, max_len, word_emb_dim]
x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction]
x = self.Linear(x)
# [batch_size, max_len, num_classes]
pred = self._decode(x)
return {'pred': pred}

def _internal_loss(self, x, y):
"""
@@ -65,89 +82,114 @@ class SeqLabeling(BaseModel):
total_loss = self.Crf(x, y, self.mask)
return torch.mean(total_loss)

def make_mask(self, x, seq_len):
def _make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = mask.view(batch_size, max_len)
mask = mask.to(x).float()
return mask

def decode(self, x, pad=True):
def _decode(self, x):
"""
:param x: FloatTensor, [batch_size, max_len, tag_size]
:param pad: pad the output sequence to equal lengths
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
:return prediction: list of [decode path(list)]
"""
max_len = x.shape[1]
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask)
# pad prediction to equal length
if pad is True:
for pred in tag_seq:
if len(pred) < max_len:
pred += [0] * (max_len - len(pred))
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True)
return tag_seq


class AdvSeqLabel(SeqLabeling):
class AdvSeqLabel:
"""
Advanced Sequence Labeling Model
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。
"""

def __init__(self, args, emb=None, id2words=None):
super(AdvSeqLabel, self).__init__(args)

vocab_size = args["vocab_size"]
word_emb_dim = args["word_emb_dim"]
hidden_dim = args["rnn_hidden_units"]
num_classes = args["num_classes"]
dropout = args['dropout']
def __init__(self, vocab_size, embed_dim, hidden_size, num_classes, dropout=0.3, embedding=None,
id2words=None, encoding_type='bmes'):
"""

self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb)
self.norm1 = torch.nn.LayerNorm(word_emb_dim)
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True)
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout,
:param int vocab_size: 词表的大小
:param int embed_dim: embedding的维度
:param int hidden_size: LSTM的隐层大小
:param int num_classes: 有多少个类
:param float dropout: LSTM中以及DropOut层的drop概率
:param numpy.ndarray embedding: 预训练的embedding,需要与指定的词表大小等一致
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S'
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。)
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。
"""
self.Embedding = encoder.embedding.Embedding(vocab_size, embed_dim, init_emb=embedding)
self.norm1 = torch.nn.LayerNorm(embed_dim)
self.Rnn = torch.nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout,
bidirectional=True, batch_first=True)
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3)
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3)
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3)
self.Linear1 = encoder.Linear(hidden_size * 2, hidden_size * 2 // 3)
self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3)
self.relu = torch.nn.LeakyReLU()
self.drop = torch.nn.Dropout(dropout)
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes)
self.Linear2 = encoder.Linear(hidden_size * 2 // 3, num_classes)

if id2words is None:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)
else:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False,
allowed_transitions=allowed_transitions(id2words,
encoding_type="bmes"))
encoding_type=encoding_type))

def _decode(self, x):
"""
:param torch.FloatTensor x: [batch_size, max_len, tag_size]
:return prediction: list of [decode path(list)]
"""
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True)
return tag_seq

def _internal_loss(self, x, y):
"""
Negative log likelihood loss.
:param x: Tensor, [batch_size, max_len, tag_size]
:param y: Tensor, [batch_size, max_len]
:return loss: a scalar Tensor

"""
x = x.float()
y = y.long()
assert x.shape[:2] == y.shape
assert y.shape == self.mask.shape
total_loss = self.Crf(x, y, self.mask)
return torch.mean(total_loss)

def _make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = mask.view(batch_size, max_len)
mask = mask.to(x).float()
return mask

def forward(self, word_seq, word_seq_origin_len, truth=None):
def _forward(self, words, seq_len, target=None):
"""
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: LongTensor, [batch_size, ]
:param truth: LongTensor, [batch_size, max_len]
:param torch.LongTensor words: [batch_size, mex_len]
:param torch.LongTensor seq_len:[batch_size, ]
:param torch.LongTensor target: [batch_size, max_len]
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training.
"""

word_seq = word_seq.long()
word_seq_origin_len = word_seq_origin_len.long()
self.mask = self.make_mask(word_seq, word_seq_origin_len)
sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True)
words = words.long()
seq_len = seq_len.long()
self.mask = self._make_mask(words, seq_len)
sent_len, idx_sort = torch.sort(seq_len, descending=True)
_, idx_unsort = torch.sort(idx_sort, descending=False)

# word_seq_origin_len = word_seq_origin_len.long()
truth = truth.long() if truth is not None else None
# seq_len = seq_len.long()
target = target.long() if target is not None else None

batch_size = word_seq.size(0)
max_len = word_seq.size(1)
if next(self.parameters()).is_cuda:
word_seq = word_seq.cuda()
words = words.cuda()
idx_sort = idx_sort.cuda()
idx_unsort = idx_unsort.cuda()
self.mask = self.mask.cuda()

x = self.Embedding(word_seq)
x = self.Embedding(words)
x = self.norm1(x)
# [batch_size, max_len, word_emb_dim]

@@ -155,71 +197,35 @@ class AdvSeqLabel(SeqLabeling):
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True)

x, _ = self.Rnn(sent_packed)
# print(x)
# [batch_size, max_len, hidden_size * direction]

sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0]
x = sent_output[idx_unsort]

x = x.contiguous()
# x = x.view(batch_size * max_len, -1)
x = self.Linear1(x)
# x = self.batch_norm(x)
x = self.norm2(x)
x = self.relu(x)
x = self.drop(x)
x = self.Linear2(x)
# x = x.view(batch_size, max_len, -1)
# [batch_size, max_len, num_classes]
# TODO seq_lens的key这样做不合理
return {"loss": self._internal_loss(x, truth) if truth is not None else None,
"predict": self.decode(x),
'word_seq_origin_len': word_seq_origin_len}

def predict(self, **x):
out = self.forward(**x)
return {"predict": out["predict"]}

def loss(self, **kwargs):
assert 'loss' in kwargs
return kwargs['loss']


if __name__ == '__main__':
args = {
'vocab_size': 20,
'word_emb_dim': 100,
'rnn_hidden_units': 100,
'num_classes': 10,
}
model = AdvSeqLabel(args)
data = []
for i in range(20):
word_seq = torch.randint(20, (15,)).long()
word_seq_len = torch.LongTensor([15])
truth = torch.randint(10, (15,)).long()
data.append((word_seq, word_seq_len, truth))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print(model)
curidx = 0
for i in range(1000):
endidx = min(len(data), curidx + 5)
b_word, b_len, b_truth = [], [], []
for word_seq, word_seq_len, truth in data[curidx: endidx]:
b_word.append(word_seq)
b_len.append(word_seq_len)
b_truth.append(truth)
word_seq = torch.stack(b_word, dim=0)
word_seq_len = torch.cat(b_len, dim=0)
truth = torch.stack(b_truth, dim=0)
res = model(word_seq, word_seq_len, truth)
loss = res['loss']
pred = res['predict']
print('loss: {} acc {}'.format(loss.item(),
((pred.data == truth).long().sum().float() / word_seq_len.sum().float())))
optimizer.zero_grad()
loss.backward()
optimizer.step()
curidx = endidx
if curidx == len(data):
curidx = 0
if target is not None:
return {"loss": self._internal_loss(x, target)}
else:
return {"pred": self._decode(x)}

def forward(self, words, seq_len, target):
"""
:param torch.LongTensor words: [batch_size, mex_len]
:param torch.LongTensor seq_len:[batch_size, ]
:param torch.LongTensor target: [batch_size, max_len], 目标
:return torch.Tensor, a scalar loss
"""
return self._forward(words, seq_len, target)

def predict(self, words, seq_len):
"""

:param torch.LongTensor words: [batch_size, mex_len]
:param torch.LongTensor seq_len:[batch_size, ]
:return: [list1, list2, ...], 内部每个list为一个路径,已经unpad了。
"""
return self._forward(words, seq_len, )

+ 32
- 34
fastNLP/modules/decoder/CRF.py View File

@@ -19,10 +19,10 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True):
"""
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。

:param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。
:param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头;
:param str encoding_type: 支持"bio", "bmes", "bmeso"。
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头;
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx);
start_idx=len(id2label), end_idx=len(id2label)+1。
为False, 返回的结果中不含与开始结尾相关的内容
@@ -62,11 +62,11 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True):
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
"""

:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param from_label: str, 比如"PER", "LOC"等label
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param to_label: str, 比如"PER", "LOC"等label
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param str from_label: 比如"PER", "LOC"等label
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param str to_label: 比如"PER", "LOC"等label
:return: bool,能否跃迁
"""
if to_tag=='start' or from_tag=='end':
@@ -149,12 +149,12 @@ class ConditionalRandomField(nn.Module):
"""条件随机场。
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。

:param num_tags: int, 标签的数量
:param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int),
:param int num_tags: 标签的数量
:param bool include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。
:param List[Tuple[from_tag_id(int), to_tag_id(int)]] allowed_transitions: 内部的Tuple[from_tag_id(int),
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法
:param initial_method: str, 初始化方法。见initial_parameter
:param str initial_method: 初始化方法。见initial_parameter
"""
super(ConditionalRandomField, self).__init__()

@@ -237,10 +237,10 @@ class ConditionalRandomField(nn.Module):
"""
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。

:param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。
:param tags:LongTensor, batch_size x max_len,标签矩阵。
:param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。
:return:FloatTensor, batch_size
:param torch.FloatTensor feats:batch_size x max_len x num_tags,特征矩阵。
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。
:return:torch.FloatTensor, (batch_size,)
"""
feats = feats.transpose(0, 1)
tags = tags.transpose(0, 1).long()
@@ -250,27 +250,26 @@ class ConditionalRandomField(nn.Module):

return all_path_score - gold_path_score

def viterbi_decode(self, feats, mask, unpad=False):
def viterbi_decode(self, logits, mask, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数

:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param unpad: bool, 是否将结果删去padding,
False, 返回的是batch_size x max_len的tensor,
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]
的长度是这个sample的有效长度。
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这
个sample的有效长度。
:return: 返回 (paths, scores)。
paths: 是解码后的路径, 其值参照unpad参数.
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。

"""
batch_size, seq_len, n_tags = feats.size()
feats = feats.transpose(0, 1).data # L, B, H
batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.byte() # L, B

# dp
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = feats[0]
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0]
transitions = self._constrain.data.clone()
transitions[:n_tags, :n_tags] += self.trans_m.data
if self.include_start_end_trans:
@@ -281,7 +280,7 @@ class ConditionalRandomField(nn.Module):
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = feats[i].view(batch_size, 1, n_tags)
cur_score = logits[i].view(batch_size, 1, n_tags)
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
@@ -292,13 +291,13 @@ class ConditionalRandomField(nn.Module):
vscore += transitions[:n_tags, n_tags+1].view(1, -1)

# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device)
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len

ans = feats.new_empty((seq_len, batch_size), dtype=torch.long)
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1):
@@ -311,6 +310,5 @@ class ConditionalRandomField(nn.Module):
paths.append(ans[idx, :seq_len+1].tolist())
else:
paths = ans
if get_score:
return paths, ans_score.tolist()
return paths
return paths, ans_score


+ 6
- 7
fastNLP/modules/decoder/utils.py View File

@@ -11,13 +11,12 @@ def log_sum_exp(x, dim=-1):
def viterbi_decode(logits, transitions, mask=None, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数

:param logits: FloatTensor, batch_size x max_len x num_tags,特征矩阵。
:param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param unpad: bool, 是否将结果删去padding,
False, 返回的是batch_size x max_len的tensor,
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是
这个sample的有效长度。
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。
:param torch.FloatTensor transitions: n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这
个sample的有效长度。
:return: 返回 (paths, scores)。
paths: 是解码后的路径, 其值参照unpad参数.
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。


+ 0
- 2
fastNLP/modules/encoder/__init__.py View File

@@ -1,4 +1,3 @@
from .conv import Conv
from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .linear import Linear
@@ -8,6 +7,5 @@ from .bert import BertModel
__all__ = ["LSTM",
"Embedding",
"Linear",
"Conv",
"ConvMaxpool",
"BertModel"]

+ 0
- 58
fastNLP/modules/encoder/conv.py View File

@@ -1,58 +0,0 @@
# python: 3.6
# encoding: utf-8

import torch
import torch.nn as nn

from fastNLP.modules.utils import initial_parameter


# import torch.nn.functional as F


class Conv(nn.Module):
"""Basic 1-d convolution module, initialized with xavier_uniform.

:param int in_channels:
:param int out_channels:
:param tuple kernel_size:
:param int stride:
:param int padding:
:param int dilation:
:param int groups:
:param bool bias:
:param str activation:
:param str initial_method:
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu', initial_method=None):
super(Conv, self).__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
# xavier_uniform_(self.conv.weight)

activations = {
'relu': nn.ReLU(),
'tanh': nn.Tanh()}
if activation in activations:
self.activation = activations[activation]
else:
raise Exception(
'Should choose activation function from: ' +
', '.join([x for x in activations]))
initial_parameter(self, initial_method)

def forward(self, x):
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L]
x = self.conv(x) # [N,C_in,L] -> [N,C_out,L]
x = self.activation(x)
x = torch.transpose(x, 1, 2) # [N,C,L] -> [N,L,C]
return x

+ 38
- 17
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -9,18 +9,21 @@ from fastNLP.modules.utils import initial_parameter


class ConvMaxpool(nn.Module):
"""Convolution and max-pooling module with multiple kernel sizes.
"""集合了Convolution和Max-Pooling于一体的层。
给定一个batch_size x max_len x input_size的输入,返回batch_size x sum(output_channels) 大小的matrix。在内部,是先使用
CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len)这一维进行max_pooling。最后得到每个sample的一个vector
表示。

:param int in_channels:
:param int out_channels:
:param tuple kernel_sizes:
:param int stride:
:param int padding:
:param int dilation:
:param int groups:
:param bool bias:
:param str activation:
:param str initial_method:
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param int stride: 见pytorch Conv1D文档。所有kernel共享一个stride。
:param int padding: 见pytorch Conv1D文档。所有kernel共享一个padding。
:param int dilation: 见pytorch Conv1D文档。所有kernel共享一个dilation。
:param int groups: 见pytorch Conv1D文档。所有kernel共享一个groups。
:param bool bias: 见pytorch Conv1D文档。所有kernel共享一个bias。
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh
:param str initial_method: str。
"""
def __init__(self, in_channels, out_channels, kernel_sizes,
stride=1, padding=0, dilation=1,
@@ -29,9 +32,14 @@ class ConvMaxpool(nn.Module):

# convolution
if isinstance(kernel_sizes, (list, tuple, int)):
if isinstance(kernel_sizes, int):
if isinstance(kernel_sizes, int) and isinstance(out_channels, int):
out_channels = [out_channels]
kernel_sizes = [kernel_sizes]
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)):
assert len(out_channels)==len(kernel_sizes), "The number of out_channels should be equal to the number" \
" of kernel_sizes."
else:
raise ValueError("The type of out_channels and kernel_sizes should be the same.")

self.convs = nn.ModuleList([nn.Conv1d(
in_channels=in_channels,
@@ -51,18 +59,31 @@ class ConvMaxpool(nn.Module):
# activation function
if activation == 'relu':
self.activation = F.relu
elif activation == 'sigmoid':
self.activation = F.sigmoid
elif activation == 'tanh':
self.activation = F.tanh
else:
raise Exception(
"Undefined activation function: choose from: relu")
"Undefined activation function: choose from: relu, tanh, sigmoid")

initial_parameter(self, initial_method)

def forward(self, x):
def forward(self, x, mask=None):
"""

:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置
:return:
"""
# [N,L,C] -> [N,C,L]
x = torch.transpose(x, 1, 2)
# convolution
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L]]
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...]
if mask is not None:
mask = mask.unsqueeze(1) # B x 1 x L
xs = [x.masked_fill_(mask, float('-inf')) for x in xs]
# max-pooling
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2)
for i in xs] # [[N, C]]
return torch.cat(xs, dim=-1) # [N,C]
for i in xs] # [[N, C], ...]
return torch.cat(xs, dim=-1) # [N, C]

+ 0
- 424
fastNLP/modules/encoder/masked_rnn.py View File

@@ -1,424 +0,0 @@
__author__ = 'max'

import torch
import torch.nn as nn
import torch.nn.functional as F

from fastNLP.modules.utils import initial_parameter


def MaskedRecurrent(reverse=False):
def forward(input, hidden, cell, mask, train=True, dropout=0):
"""
:param input:
:param hidden:
:param cell:
:param mask:
:param dropout: step之间的dropout,对mask了的也会drop,应该是没问题的,反正没有gradient
:param train: 控制dropout的行为,在StackedRNN的forward中调用
:return:
"""
output = []
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
for i in steps:
if mask is None or mask[i].data.min() > 0.5: # 没有mask,都是1
hidden = cell(input[i], hidden)
elif mask[i].data.max() > 0.5: # 有mask,但不全为0
hidden_next = cell(input[i], hidden) # 一次喂入一个batch!
# hack to handle LSTM
if isinstance(hidden, tuple): # LSTM outputs a tuple of (hidden, cell), this is a common hack 😁
mask = mask.float()
hx, cx = hidden
hp1, cp1 = hidden_next
hidden = (
hx + (hp1 - hx) * mask[i].squeeze(),
cx + (cp1 - cx) * mask[i].squeeze()) # Why? 我知道了!!如果是mask就不用改变
else:
hidden = hidden + (hidden_next - hidden) * mask[i]

# if dropout != 0 and train: # warning, should i treat masked tensor differently?
# if isinstance(hidden, tuple):
# hidden = (F.dropout(hidden[0], p=dropout, training=train),
# F.dropout(hidden[1], p=dropout, training=train))
# else:
# hidden = F.dropout(hidden, p=dropout, training=train)

# hack to handle LSTM
output.append(hidden[0] if isinstance(hidden, tuple) else hidden)

if reverse:
output.reverse()
output = torch.cat(output, 0).view(input.size(0), *output[0].size())

return hidden, output

return forward


def StackedRNN(inners, num_layers, lstm=False, train=True, step_dropout=0, layer_dropout=0):
num_directions = len(inners) # rec_factory!
total_layers = num_layers * num_directions

def forward(input, hidden, cells, mask):
assert (len(cells) == total_layers)
next_hidden = []

if lstm:
hidden = list(zip(*hidden))

for i in range(num_layers):
all_output = []
for j, inner in enumerate(inners):
l = i * num_directions + j
hy, output = inner(input, hidden[l], cells[l], mask, step_dropout, train)
next_hidden.append(hy)
all_output.append(output)

input = torch.cat(all_output, input.dim() - 1) # 下一层的输入

if layer_dropout != 0 and i < num_layers - 1:
input = F.dropout(input, p=layer_dropout, training=train, inplace=False)

if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(total_layers, *next_hidden[0].size())

return next_hidden, input

return forward


def AutogradMaskedRNN(num_layers=1, batch_first=False, train=True, layer_dropout=0, step_dropout=0,
bidirectional=False, lstm=False):
rec_factory = MaskedRecurrent

if bidirectional:
layer = (rec_factory(), rec_factory(reverse=True))
else:
layer = (rec_factory(),) # rec_factory 就是每层的结构啦!!在MaskedRecurrent中进行每层的计算!然后用StackedRNN接起来

func = StackedRNN(layer,
num_layers,
lstm=lstm,
layer_dropout=layer_dropout, step_dropout=step_dropout,
train=train)

def forward(input, cells, hidden, mask):
if batch_first:
input = input.transpose(0, 1)
if mask is not None:
mask = mask.transpose(0, 1)

nexth, output = func(input, hidden, cells, mask)

if batch_first:
output = output.transpose(0, 1)

return output, nexth

return forward


def MaskedStep():
def forward(input, hidden, cell, mask):
if mask is None or mask.data.min() > 0.5:
hidden = cell(input, hidden)
elif mask.data.max() > 0.5:
hidden_next = cell(input, hidden)
# hack to handle LSTM
if isinstance(hidden, tuple):
hx, cx = hidden
hp1, cp1 = hidden_next
hidden = (hx + (hp1 - hx) * mask, cx + (cp1 - cx) * mask)
else:
hidden = hidden + (hidden_next - hidden) * mask
# hack to handle LSTM
output = hidden[0] if isinstance(hidden, tuple) else hidden

return hidden, output

return forward


def StackedStep(layer, num_layers, lstm=False, dropout=0, train=True):
def forward(input, hidden, cells, mask):
assert (len(cells) == num_layers)
next_hidden = []

if lstm:
hidden = list(zip(*hidden))

for l in range(num_layers):
hy, output = layer(input, hidden[l], cells[l], mask)
next_hidden.append(hy)
input = output

if dropout != 0 and l < num_layers - 1:
input = F.dropout(input, p=dropout, training=train, inplace=False)

if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(num_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(num_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(num_layers, *next_hidden[0].size())

return next_hidden, input

return forward


def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False):
layer = MaskedStep()

func = StackedStep(layer,
num_layers,
lstm=lstm,
dropout=dropout,
train=train)

def forward(input, cells, hidden, mask):
nexth, output = func(input, hidden, cells, mask)
return output, nexth

return forward


class MaskedRNNBase(nn.Module):
def __init__(self, Cell, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs):
"""
:param Cell:
:param input_size:
:param hidden_size:
:param num_layers:
:param bias:
:param batch_first:
:param layer_dropout:
:param step_dropout:
:param bidirectional:
:param kwargs:
"""

super(MaskedRNNBase, self).__init__()
self.Cell = Cell
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.layer_dropout = layer_dropout
self.step_dropout = step_dropout
self.bidirectional = bidirectional
num_directions = 2 if bidirectional else 1

self.all_cells = []
for layer in range(num_layers): # 初始化所有cell
for direction in range(num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * num_directions

cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs)
self.all_cells.append(cell)
self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看
initial_parameter(self, initial_method)
def reset_parameters(self):
for cell in self.all_cells:
cell.reset_parameters()

def forward(self, input, mask=None, hx=None):
batch_size = input.size(0) if self.batch_first else input.size(1)
lstm = self.Cell is nn.LSTMCell
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.autograd.Variable(
input.data.new(self.num_layers * num_directions, batch_size, self.hidden_size).zero_())
if lstm:
hx = (hx, hx)

func = AutogradMaskedRNN(num_layers=self.num_layers,
batch_first=self.batch_first,
step_dropout=self.step_dropout,
layer_dropout=self.layer_dropout,
train=self.training,
bidirectional=self.bidirectional,
lstm=lstm) # 传入all_cells,继续往底层封装走

output, hidden = func(input, self.all_cells, hx,
None if mask is None else mask.view(mask.size() + (1,))) # 这个+ (1, )是个什么操作?
return output, hidden

def step(self, input, hx=None, mask=None):
"""Execute one step forward (only for one-directional RNN).

:param Tensor input: input tensor of this step. (batch, input_size)
:param Tensor hx: the hidden state of last step. (num_layers, batch, hidden_size)
:param Tensor mask: the mask tensor of this step. (batch, )
:returns:
**output** (batch, hidden_size), tensor containing the output of this step from the last layer of RNN.
**hn** (num_layers, batch, hidden_size), tensor containing the hidden state of this step

"""
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." # aha, typo!
batch_size = input.size(0)
lstm = self.Cell is nn.LSTMCell
if hx is None:
hx = torch.autograd.Variable(input.data.new(self.num_layers, batch_size, self.hidden_size).zero_())
if lstm:
hx = (hx, hx)

func = AutogradMaskedStep(num_layers=self.num_layers,
dropout=self.step_dropout,
train=self.training,
lstm=lstm)

output, hidden = func(input, self.all_cells, hx, mask)
return output, hidden


class MaskedRNN(MaskedRNNBase):
r"""Applies a multi-layer Elman RNN with costomized non-linearity to an
input sequence.
For each element in the input sequence, each layer computes the following
function. :math:`h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})`

where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is
the hidden state of the previous layer at time `t` or :math:`input_t`
for the first layer. If nonlinearity='relu', then `ReLU` is used instead
of `tanh`.


:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param float dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, h_0
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence.
- **h_0** (num_layers * num_directions, batch, hidden_size): tensor
containing the initial hidden state for each element in the batch.
Outputs: output, h_n
- **output** (seq_len, batch, hidden_size * num_directions): tensor
containing the output features (h_k) from the last layer of the RNN,
for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has
been given as the input, the output will also be a packed sequence.
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the hidden state for k=seq_len.
"""

def __init__(self, *args, **kwargs):
super(MaskedRNN, self).__init__(nn.RNNCell, *args, **kwargs)


class MaskedLSTM(MaskedRNNBase):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
sequence.
For each element in the input sequence, each layer computes the following
function.

.. math::

\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}

where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
state at time `t`, :math:`x_t` is the hidden state of the previous layer at
time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
:math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell,
and out gates, respectively.

:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, (h_0, c_0)
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence.
- **h_0** (num_layers \* num_directions, batch, hidden_size): tensor
containing the initial hidden state for each element in the batch.
- **c_0** (num_layers \* num_directions, batch, hidden_size): tensor
containing the initial cell state for each element in the batch.
Outputs: output, (h_n, c_n)
- **output** (seq_len, batch, hidden_size * num_directions): tensor
containing the output features `(h_t)` from the last layer of the RNN,
for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
given as the input, the output will also be a packed sequence.
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the hidden state for t=seq_len
- **c_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the cell state for t=seq_len
"""

def __init__(self, *args, **kwargs):
super(MaskedLSTM, self).__init__(nn.LSTMCell, *args, **kwargs)


class MaskedGRU(MaskedRNNBase):
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
For each element in the input sequence, each layer computes the following
function:

.. math::

\begin{array}{ll}
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
\end{array}

where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
state of the previous layer at time `t` or :math:`input_t` for the first
layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input,
and new gates, respectively.

:param int input_size: The number of expected features in the input x
:param int hidden_size: The number of features in the hidden state h
:param int num_layers: Number of recurrent layers.
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh'
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature)
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False

Inputs: input, mask, h_0
- **input** (seq_len, batch, input_size): tensor containing the features
of the input sequence.
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence.
- **h_0** (num_layers * num_directions, batch, hidden_size): tensor
containing the initial hidden state for each element in the batch.
Outputs: output, h_n
- **output** (seq_len, batch, hidden_size * num_directions): tensor
containing the output features (h_k) from the last layer of the RNN,
for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has
been given as the input, the output will also be a packed sequence.
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor
containing the hidden state for k=seq_len.
"""

def __init__(self, *args, **kwargs):
super(MaskedGRU, self).__init__(nn.GRUCell, *args, **kwargs)

+ 0
- 27
test/modules/test_masked_rnn.py View File

@@ -1,27 +0,0 @@

import torch
import unittest

from fastNLP.modules.encoder.masked_rnn import MaskedRNN

class TestMaskedRnn(unittest.TestCase):
def test_case_1(self):
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
mask = torch.tensor([[[1], [0]]])
y = masked_rnn(x, mask=mask)

def test_case_2(self):
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=False, batch_first=True)
x = torch.tensor([[[1.0], [2.0]]])
print(x.size())
y = masked_rnn(x)
mask = torch.tensor([[[1], [1]]])
y = masked_rnn(x, mask=mask)
xx = torch.tensor([[[1.0]]])
y = masked_rnn.step(xx)
y = masked_rnn.step(xx, mask=mask)

+ 3
- 3
test/test_tutorials.py View File

@@ -70,7 +70,7 @@ class TestTutorial(unittest.TestCase):
break

from fastNLP.models import CNNText
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)
model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)

from fastNLP import Trainer
from copy import deepcopy
@@ -145,7 +145,7 @@ class TestTutorial(unittest.TestCase):
is_input=True)

from fastNLP.models import CNNText
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)
model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)

from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric
trainer = Trainer(model=model,
@@ -405,7 +405,7 @@ class TestTutorial(unittest.TestCase):

# 另一个例子:加载CNN文本分类模型
from fastNLP.models import CNNText
cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)
cnn_text_model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)
cnn_text_model

from fastNLP import CrossEntropyLoss


Loading…
Cancel
Save