@@ -1,3 +1,5 @@ | |||
from .core import * | |||
from . import models | |||
from . import modules | |||
__version__ = '0.4.0' |
@@ -1,4 +1,5 @@ | |||
""" | |||
Callback的说明文档 | |||
.. _Callback: | |||
@@ -28,7 +29,6 @@ class Callback(object): | |||
def trainer(self): | |||
""" | |||
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | |||
:return: | |||
""" | |||
return self._trainer | |||
@@ -323,11 +323,16 @@ class GradientClipCallback(Callback): | |||
class CallbackException(BaseException): | |||
def __init__(self, msg): | |||
""" | |||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | |||
:param str msg: Exception的信息。 | |||
""" | |||
super(CallbackException, self).__init__(msg) | |||
class EarlyStopError(CallbackException): | |||
def __init__(self, msg): | |||
"""用于EarlyStop时从Trainer训练循环中跳出。""" | |||
super(EarlyStopError, self).__init__(msg) | |||
@@ -360,7 +365,13 @@ class EarlyStopCallback(Callback): | |||
class LRScheduler(Callback): | |||
def __init__(self, lr_scheduler): | |||
"""对PyTorch LR Scheduler的包装 | |||
"""对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | |||
Example:: | |||
from fastNLP import LRScheduler | |||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||
""" | |||
@@ -13,6 +13,9 @@ class Optimizer(object): | |||
self.model_params = model_params | |||
self.settings = kwargs | |||
def construct_from_pytorch(self, model_params): | |||
raise NotImplementedError | |||
def _get_require_grads_param(self, params): | |||
""" | |||
将params中不需要gradient的删除 | |||
@@ -14,20 +14,56 @@ from fastNLP.core.utils import _get_device | |||
class Tester(object): | |||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. | |||
""" | |||
Tester是在提供数据,模型以及metric的情况下进行性能测试的类 | |||
Example:: | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
from fastNLP import Tester | |||
from fastNLP import DataSet | |||
from fastNLP import AccuracyMetric | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(1, 1) | |||
def forward(self, a): | |||
return {'pred': self.fc(a.unsqueeze(1)).squeeze(1)} | |||
model = Model() | |||
dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2}) | |||
dataset.set_input('a') | |||
dataset.set_target('b') | |||
tester = Tester(dataset, model, metrics=AccuracyMetric()) | |||
eval_results = tester.test() | |||
这里Metric的映射规律是和 Trainer_ 中一致的,请参考 Trainer_ 使用metrics。 | |||
:param DataSet data: a validation/development set | |||
:param torch.nn.modules.module model: a PyTorch model | |||
:param MetricBase metrics: a metric object or a list of metrics (List[MetricBase]) | |||
:param int batch_size: batch size for validation | |||
:param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持 | |||
以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||
可见的第二个GPU中; torch.device,将模型装载到torch.device上。 | |||
:param int verbose: the number of steps after which an information is printed. | |||
""" | |||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | |||
"""传入模型,数据以及metric进行验证。 | |||
:param DataSet data: 需要测试的数据集 | |||
:param torch.nn.module model: 使用的模型 | |||
:param MetricBase metrics: 一个Metric或者一个列表的metric对象 | |||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||
:param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持 | |||
以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||
可见的第二个GPU中; torch.device,将模型装载到torch.device上。 | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
""" | |||
super(Tester, self).__init__() | |||
if not isinstance(data, DataSet): | |||
@@ -59,10 +95,10 @@ class Tester(object): | |||
self._predict_func = self._model.forward | |||
def test(self): | |||
"""Start test or validation. | |||
:return eval_results: a dictionary whose keys are the class name of metrics to use, values are the evaluation results of these metrics. | |||
"""开始进行验证,并返回验证结果。 | |||
:return dict(dict) eval_results: dict为二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。 | |||
一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 | |||
""" | |||
# turn on the testing mode; clean up the history | |||
network = self._model | |||
@@ -213,7 +213,7 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||
from torch.optim import SGD | |||
from fastNLP import Trainer | |||
from fastNLP import DataSet | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP import AccuracyMetric | |||
import torch | |||
class Model(nn.Module): | |||
@@ -322,7 +322,7 @@ from fastNLP.core.utils import _check_loss_evaluate | |||
from fastNLP.core.utils import _move_dict_value_to_device | |||
from fastNLP.core.utils import _get_func_signature | |||
from fastNLP.core.utils import _get_device | |||
from fastNLP.core.optimizer import Optimizer | |||
class Trainer(object): | |||
def __init__(self, train_data, model, optimizer, loss=None, | |||
@@ -336,8 +336,7 @@ class Trainer(object): | |||
""" | |||
:param DataSet train_data: 训练集 | |||
:param nn.modules model: 待训练的模型 | |||
:param Optimizer,None optimizer: 优化器,pytorch的torch.optim.Optimizer类型。如果为None,则Trainer不会更新模型, | |||
请确保已在callback中进行了更新。 | |||
:param torch.optim.Optimizer,None optimizer: 优化器。如果为None,则Trainer不会更新模型,请确保已在callback中进行了更新。 | |||
:param int batch_size: 训练和验证的时候的batch大小。 | |||
:param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。 | |||
:param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。 | |||
@@ -438,6 +437,8 @@ class Trainer(object): | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
elif isinstance(optimizer, Optimizer): | |||
self.optimizer = optimizer.construct_from_pytorch(model.parameters()) | |||
elif optimizer is None: | |||
warnings.warn("The optimizer is set to None, Trainer will update your model. Make sure you update the model" | |||
" in the callback.") | |||
@@ -8,7 +8,7 @@ from fastNLP.io.base_loader import BaseLoader | |||
import warnings | |||
class EmbedLoader(BaseLoader): | |||
"""docstring for EmbedLoader""" | |||
"""这个类用于从预训练的Embedding中load数据。""" | |||
def __init__(self): | |||
super(EmbedLoader, self).__init__() | |||
@@ -16,18 +16,17 @@ class EmbedLoader(BaseLoader): | |||
@staticmethod | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | |||
""" | |||
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | |||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||
:param embed_filepath: str, where to read pretrain embedding | |||
:param vocab: Vocabulary. | |||
:param dtype: the dtype of the embedding matrix | |||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will | |||
raise | |||
:return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain | |||
embedding | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
:param Vocabulary vocab: 词表,读取出现在vocab中的词的embedding。没有出现在vocab中的词的embedding将通过找到的词的 | |||
embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||
:param dtype: 读出的embedding的类型 | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地 | |||
方在于词表有空行或者词表出现了维度不一致。 | |||
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
if not os.path.exists(embed_filepath): | |||
@@ -76,19 +75,18 @@ class EmbedLoader(BaseLoader): | |||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||
error='ignore'): | |||
""" | |||
load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. | |||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||
:param embed_filepath: str, where to read pretrain embedding | |||
:param dtype: the dtype of the embedding matrix | |||
:param padding: the padding tag for vocabulary. | |||
:param unknown: the unknown tag for vocabulary. | |||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will | |||
:raise | |||
:return: np.ndarray() is determined by the pretraining embeddings | |||
Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>] | |||
从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
:param dtype: 读出的embedding的类型 | |||
:param str padding: the padding tag for vocabulary. | |||
:param str unknown: the unknown tag for vocabulary. | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地 | |||
方在于词表有空行或者词表出现了维度不一致。 | |||
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
:return: numpy.ndarray,Vocabulary embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||
是否使用padding, 以及unknown有没有在词表中找到对应的词。Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 | |||
""" | |||
vocab = Vocabulary(padding=padding, unknown=unknown) | |||
vec_dict = {} | |||
@@ -3,29 +3,38 @@ | |||
import torch | |||
import torch.nn as nn | |||
import numpy as np | |||
# import torch.nn.functional as F | |||
import fastNLP.modules.encoder as encoder | |||
class CNNText(torch.nn.Module): | |||
""" | |||
Text classification model by character CNN, the implementation of paper | |||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence | |||
Classification.' | |||
使用CNN进行文本分类的模型 | |||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' | |||
""" | |||
def __init__(self, embed_num, | |||
def __init__(self, vocab_size, | |||
embed_dim, | |||
num_classes, | |||
kernel_nums=(3, 4, 5), | |||
kernel_sizes=(3, 4, 5), | |||
padding=0, | |||
dropout=0.5): | |||
""" | |||
:param int vocab_size: 词表的大小 | |||
:param int embed_dim: 词embedding的维度大小 | |||
:param int num_classes: 一共有多少类 | |||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
:param int padding: | |||
:param float dropout: Dropout的大小 | |||
""" | |||
super(CNNText, self).__init__() | |||
# no support for pre-trained embedding currently | |||
self.embed = encoder.Embedding(embed_num, embed_dim) | |||
self.embed = encoder.Embedding(vocab_size, embed_dim) | |||
self.conv_pool = encoder.ConvMaxpool( | |||
in_channels=embed_dim, | |||
out_channels=kernel_nums, | |||
@@ -34,24 +43,36 @@ class CNNText(torch.nn.Module): | |||
self.dropout = nn.Dropout(dropout) | |||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | |||
def forward(self, word_seq): | |||
def init_embed(self, embed): | |||
""" | |||
加载预训练的模型 | |||
:param numpy.ndarray embed: vocab_size x embed_dim的embedding | |||
:return: | |||
""" | |||
assert isinstance(embed, np.ndarray) | |||
assert embed.shape == self.embed.embed.weight.shape | |||
self.embed.embed.weight.data = torch.from_numpy(embed) | |||
def forward(self, words, seq_len=None): | |||
""" | |||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||
:return output: dict of torch.LongTensor, [batch_size, num_classes] | |||
""" | |||
x = self.embed(word_seq) # [N,L] -> [N,L,C] | |||
x = self.embed(words) # [N,L] -> [N,L,C] | |||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||
x = self.dropout(x) | |||
x = self.fc(x) # [N,C] -> [N, N_class] | |||
return {'pred': x} | |||
def predict(self, word_seq): | |||
def predict(self, words, seq_len=None): | |||
""" | |||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||
:return predict: dict of torch.LongTensor, [batch_size, ] | |||
""" | |||
output = self(word_seq) | |||
output = self(words, seq_len) | |||
_, predict = output['pred'].max(dim=1) | |||
return {'pred': predict} |
@@ -8,47 +8,64 @@ from fastNLP.modules.utils import seq_mask | |||
class SeqLabeling(BaseModel): | |||
""" | |||
PyTorch Network for sequence labeling | |||
一个基础的Sequence labeling的模型 | |||
""" | |||
def __init__(self, args): | |||
def __init__(self, vocab_size, embed_dim, hidden_size, num_classes): | |||
""" | |||
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 | |||
:param int vocab_size: 词表大小。 | |||
:param int embed_dim: embedding的维度 | |||
:param int hidden_size: LSTM隐藏层的大小 | |||
:param int num_classes: 一共有多少类 | |||
""" | |||
super(SeqLabeling, self).__init__() | |||
vocab_size = args["vocab_size"] | |||
word_emb_dim = args["word_emb_dim"] | |||
hidden_dim = args["rnn_hidden_units"] | |||
num_classes = args["num_classes"] | |||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | |||
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim) | |||
self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | |||
self.Embedding = encoder.embedding.Embedding(vocab_size, embed_dim) | |||
self.Rnn = encoder.lstm.LSTM(embed_dim, hidden_size) | |||
self.Linear = encoder.linear.Linear(hidden_size, num_classes) | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||
self.mask = None | |||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||
def forward(self, words, seq_len, target): | |||
""" | |||
:param word_seq: LongTensor, [batch_size, mex_len] | |||
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. | |||
:param truth: LongTensor, [batch_size, max_len] | |||
:param torch.LongTensor words: [batch_size, max_len],序列的index | |||
:param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 | |||
:param torch.LongTensor target: [batch_size, max_len], 序列的目标值 | |||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
If truth is not None, return loss, a scalar. Used in training. | |||
""" | |||
assert word_seq.shape[0] == word_seq_origin_len.shape[0] | |||
if truth is not None: | |||
assert truth.shape == word_seq.shape | |||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
assert words.shape[0] == seq_len.shape[0] | |||
assert target.shape == words.shape | |||
self.mask = self._make_mask(words, seq_len) | |||
x = self.Embedding(word_seq) | |||
x = self.Embedding(words) | |||
# [batch_size, max_len, word_emb_dim] | |||
x = self.Rnn(x) | |||
# [batch_size, max_len, hidden_size * direction] | |||
x = self.Linear(x) | |||
# [batch_size, max_len, num_classes] | |||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||
"predict": self.decode(x)} | |||
return {"loss": self._internal_loss(x, target)} | |||
def loss(self, x, y): | |||
""" Since the loss has been computed in forward(), this function simply returns x.""" | |||
return x | |||
def predict(self, words, seq_len): | |||
""" | |||
用于在预测时使用 | |||
:param torch.LongTensor words: [batch_size, max_len] | |||
:param torch.LongTensor seq_len: [batch_size,] | |||
:return: | |||
""" | |||
self.mask = self._make_mask(words, seq_len) | |||
x = self.Embedding(words) | |||
# [batch_size, max_len, word_emb_dim] | |||
x = self.Rnn(x) | |||
# [batch_size, max_len, hidden_size * direction] | |||
x = self.Linear(x) | |||
# [batch_size, max_len, num_classes] | |||
pred = self._decode(x) | |||
return {'pred': pred} | |||
def _internal_loss(self, x, y): | |||
""" | |||
@@ -65,89 +82,114 @@ class SeqLabeling(BaseModel): | |||
total_loss = self.Crf(x, y, self.mask) | |||
return torch.mean(total_loss) | |||
def make_mask(self, x, seq_len): | |||
def _make_mask(self, x, seq_len): | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = seq_mask(seq_len, max_len) | |||
mask = mask.view(batch_size, max_len) | |||
mask = mask.to(x).float() | |||
return mask | |||
def decode(self, x, pad=True): | |||
def _decode(self, x): | |||
""" | |||
:param x: FloatTensor, [batch_size, max_len, tag_size] | |||
:param pad: pad the output sequence to equal lengths | |||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||
:return prediction: list of [decode path(list)] | |||
""" | |||
max_len = x.shape[1] | |||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) | |||
# pad prediction to equal length | |||
if pad is True: | |||
for pred in tag_seq: | |||
if len(pred) < max_len: | |||
pred += [0] * (max_len - len(pred)) | |||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True) | |||
return tag_seq | |||
class AdvSeqLabel(SeqLabeling): | |||
class AdvSeqLabel: | |||
""" | |||
Advanced Sequence Labeling Model | |||
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 | |||
""" | |||
def __init__(self, args, emb=None, id2words=None): | |||
super(AdvSeqLabel, self).__init__(args) | |||
vocab_size = args["vocab_size"] | |||
word_emb_dim = args["word_emb_dim"] | |||
hidden_dim = args["rnn_hidden_units"] | |||
num_classes = args["num_classes"] | |||
dropout = args['dropout'] | |||
def __init__(self, vocab_size, embed_dim, hidden_size, num_classes, dropout=0.3, embedding=None, | |||
id2words=None, encoding_type='bmes'): | |||
""" | |||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | |||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, | |||
:param int vocab_size: 词表的大小 | |||
:param int embed_dim: embedding的维度 | |||
:param int hidden_size: LSTM的隐层大小 | |||
:param int num_classes: 有多少个类 | |||
:param float dropout: LSTM中以及DropOut层的drop概率 | |||
:param numpy.ndarray embedding: 预训练的embedding,需要与指定的词表大小等一致 | |||
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S' | |||
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证 | |||
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。) | |||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 | |||
""" | |||
self.Embedding = encoder.embedding.Embedding(vocab_size, embed_dim, init_emb=embedding) | |||
self.norm1 = torch.nn.LayerNorm(embed_dim) | |||
self.Rnn = torch.nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout, | |||
bidirectional=True, batch_first=True) | |||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | |||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||
self.Linear1 = encoder.Linear(hidden_size * 2, hidden_size * 2 // 3) | |||
self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3) | |||
self.relu = torch.nn.LeakyReLU() | |||
self.drop = torch.nn.Dropout(dropout) | |||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | |||
self.Linear2 = encoder.Linear(hidden_size * 2 // 3, num_classes) | |||
if id2words is None: | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
else: | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||
allowed_transitions=allowed_transitions(id2words, | |||
encoding_type="bmes")) | |||
encoding_type=encoding_type)) | |||
def _decode(self, x): | |||
""" | |||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||
:return prediction: list of [decode path(list)] | |||
""" | |||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True) | |||
return tag_seq | |||
def _internal_loss(self, x, y): | |||
""" | |||
Negative log likelihood loss. | |||
:param x: Tensor, [batch_size, max_len, tag_size] | |||
:param y: Tensor, [batch_size, max_len] | |||
:return loss: a scalar Tensor | |||
""" | |||
x = x.float() | |||
y = y.long() | |||
assert x.shape[:2] == y.shape | |||
assert y.shape == self.mask.shape | |||
total_loss = self.Crf(x, y, self.mask) | |||
return torch.mean(total_loss) | |||
def _make_mask(self, x, seq_len): | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = seq_mask(seq_len, max_len) | |||
mask = mask.view(batch_size, max_len) | |||
mask = mask.to(x).float() | |||
return mask | |||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||
def _forward(self, words, seq_len, target=None): | |||
""" | |||
:param word_seq: LongTensor, [batch_size, mex_len] | |||
:param word_seq_origin_len: LongTensor, [batch_size, ] | |||
:param truth: LongTensor, [batch_size, max_len] | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len:[batch_size, ] | |||
:param torch.LongTensor target: [batch_size, max_len] | |||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
If truth is not None, return loss, a scalar. Used in training. | |||
""" | |||
word_seq = word_seq.long() | |||
word_seq_origin_len = word_seq_origin_len.long() | |||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True) | |||
words = words.long() | |||
seq_len = seq_len.long() | |||
self.mask = self._make_mask(words, seq_len) | |||
sent_len, idx_sort = torch.sort(seq_len, descending=True) | |||
_, idx_unsort = torch.sort(idx_sort, descending=False) | |||
# word_seq_origin_len = word_seq_origin_len.long() | |||
truth = truth.long() if truth is not None else None | |||
# seq_len = seq_len.long() | |||
target = target.long() if target is not None else None | |||
batch_size = word_seq.size(0) | |||
max_len = word_seq.size(1) | |||
if next(self.parameters()).is_cuda: | |||
word_seq = word_seq.cuda() | |||
words = words.cuda() | |||
idx_sort = idx_sort.cuda() | |||
idx_unsort = idx_unsort.cuda() | |||
self.mask = self.mask.cuda() | |||
x = self.Embedding(word_seq) | |||
x = self.Embedding(words) | |||
x = self.norm1(x) | |||
# [batch_size, max_len, word_emb_dim] | |||
@@ -155,71 +197,35 @@ class AdvSeqLabel(SeqLabeling): | |||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | |||
x, _ = self.Rnn(sent_packed) | |||
# print(x) | |||
# [batch_size, max_len, hidden_size * direction] | |||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | |||
x = sent_output[idx_unsort] | |||
x = x.contiguous() | |||
# x = x.view(batch_size * max_len, -1) | |||
x = self.Linear1(x) | |||
# x = self.batch_norm(x) | |||
x = self.norm2(x) | |||
x = self.relu(x) | |||
x = self.drop(x) | |||
x = self.Linear2(x) | |||
# x = x.view(batch_size, max_len, -1) | |||
# [batch_size, max_len, num_classes] | |||
# TODO seq_lens的key这样做不合理 | |||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||
"predict": self.decode(x), | |||
'word_seq_origin_len': word_seq_origin_len} | |||
def predict(self, **x): | |||
out = self.forward(**x) | |||
return {"predict": out["predict"]} | |||
def loss(self, **kwargs): | |||
assert 'loss' in kwargs | |||
return kwargs['loss'] | |||
if __name__ == '__main__': | |||
args = { | |||
'vocab_size': 20, | |||
'word_emb_dim': 100, | |||
'rnn_hidden_units': 100, | |||
'num_classes': 10, | |||
} | |||
model = AdvSeqLabel(args) | |||
data = [] | |||
for i in range(20): | |||
word_seq = torch.randint(20, (15,)).long() | |||
word_seq_len = torch.LongTensor([15]) | |||
truth = torch.randint(10, (15,)).long() | |||
data.append((word_seq, word_seq_len, truth)) | |||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||
print(model) | |||
curidx = 0 | |||
for i in range(1000): | |||
endidx = min(len(data), curidx + 5) | |||
b_word, b_len, b_truth = [], [], [] | |||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||
b_word.append(word_seq) | |||
b_len.append(word_seq_len) | |||
b_truth.append(truth) | |||
word_seq = torch.stack(b_word, dim=0) | |||
word_seq_len = torch.cat(b_len, dim=0) | |||
truth = torch.stack(b_truth, dim=0) | |||
res = model(word_seq, word_seq_len, truth) | |||
loss = res['loss'] | |||
pred = res['predict'] | |||
print('loss: {} acc {}'.format(loss.item(), | |||
((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
curidx = endidx | |||
if curidx == len(data): | |||
curidx = 0 | |||
if target is not None: | |||
return {"loss": self._internal_loss(x, target)} | |||
else: | |||
return {"pred": self._decode(x)} | |||
def forward(self, words, seq_len, target): | |||
""" | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len:[batch_size, ] | |||
:param torch.LongTensor target: [batch_size, max_len], 目标 | |||
:return torch.Tensor, a scalar loss | |||
""" | |||
return self._forward(words, seq_len, target) | |||
def predict(self, words, seq_len): | |||
""" | |||
:param torch.LongTensor words: [batch_size, mex_len] | |||
:param torch.LongTensor seq_len:[batch_size, ] | |||
:return: [list1, list2, ...], 内部每个list为一个路径,已经unpad了。 | |||
""" | |||
return self._forward(words, seq_len, ) |
@@ -19,10 +19,10 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||
""" | |||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
:param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 | |||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||
:param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 | |||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
为False, 返回的结果中不含与开始结尾相关的内容 | |||
@@ -62,11 +62,11 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
""" | |||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param from_label: str, 比如"PER", "LOC"等label | |||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param to_label: str, 比如"PER", "LOC"等label | |||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 | |||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param str from_label: 比如"PER", "LOC"等label | |||
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param str to_label: 比如"PER", "LOC"等label | |||
:return: bool,能否跃迁 | |||
""" | |||
if to_tag=='start' or from_tag=='end': | |||
@@ -149,12 +149,12 @@ class ConditionalRandomField(nn.Module): | |||
"""条件随机场。 | |||
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | |||
:param num_tags: int, 标签的数量 | |||
:param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。 | |||
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int), | |||
:param int num_tags: 标签的数量 | |||
:param bool include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | |||
:param List[Tuple[from_tag_id(int), to_tag_id(int)]] allowed_transitions: 内部的Tuple[from_tag_id(int), | |||
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||
:param initial_method: str, 初始化方法。见initial_parameter | |||
:param str initial_method: 初始化方法。见initial_parameter | |||
""" | |||
super(ConditionalRandomField, self).__init__() | |||
@@ -237,10 +237,10 @@ class ConditionalRandomField(nn.Module): | |||
""" | |||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||
:param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param tags:LongTensor, batch_size x max_len,标签矩阵。 | |||
:param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。 | |||
:return:FloatTensor, batch_size | |||
:param torch.FloatTensor feats:batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | |||
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | |||
:return:torch.FloatTensor, (batch_size,) | |||
""" | |||
feats = feats.transpose(0, 1) | |||
tags = tags.transpose(0, 1).long() | |||
@@ -250,27 +250,26 @@ class ConditionalRandomField(nn.Module): | |||
return all_path_score - gold_path_score | |||
def viterbi_decode(self, feats, mask, unpad=False): | |||
def viterbi_decode(self, logits, mask, unpad=False): | |||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param unpad: bool, 是否将结果删去padding, | |||
False, 返回的是batch_size x max_len的tensor, | |||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int] | |||
的长度是这个sample的有效长度。 | |||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||
个sample的有效长度。 | |||
:return: 返回 (paths, scores)。 | |||
paths: 是解码后的路径, 其值参照unpad参数. | |||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
""" | |||
batch_size, seq_len, n_tags = feats.size() | |||
feats = feats.transpose(0, 1).data # L, B, H | |||
batch_size, seq_len, n_tags = logits.size() | |||
logits = logits.transpose(0, 1).data # L, B, H | |||
mask = mask.transpose(0, 1).data.byte() # L, B | |||
# dp | |||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = feats[0] | |||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = logits[0] | |||
transitions = self._constrain.data.clone() | |||
transitions[:n_tags, :n_tags] += self.trans_m.data | |||
if self.include_start_end_trans: | |||
@@ -281,7 +280,7 @@ class ConditionalRandomField(nn.Module): | |||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
for i in range(1, seq_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||
cur_score = logits[i].view(batch_size, 1, n_tags) | |||
score = prev_score + trans_score + cur_score | |||
best_score, best_dst = score.max(1) | |||
vpath[i] = best_dst | |||
@@ -292,13 +291,13 @@ class ConditionalRandomField(nn.Module): | |||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
# backtrace | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||
lens = (mask.long().sum(0) - 1) | |||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans_score, last_tags = vscore.max(1) | |||
ans[idxes[0], batch_idx] = last_tags | |||
for i in range(seq_len - 1): | |||
@@ -311,6 +310,5 @@ class ConditionalRandomField(nn.Module): | |||
paths.append(ans[idx, :seq_len+1].tolist()) | |||
else: | |||
paths = ans | |||
if get_score: | |||
return paths, ans_score.tolist() | |||
return paths | |||
return paths, ans_score | |||
@@ -11,13 +11,12 @@ def log_sum_exp(x, dim=-1): | |||
def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
:param logits: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 | |||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param unpad: bool, 是否将结果删去padding, | |||
False, 返回的是batch_size x max_len的tensor, | |||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是 | |||
这个sample的有效长度。 | |||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||
:param torch.FloatTensor transitions: n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 | |||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||
个sample的有效长度。 | |||
:return: 返回 (paths, scores)。 | |||
paths: 是解码后的路径, 其值参照unpad参数. | |||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
@@ -1,4 +1,3 @@ | |||
from .conv import Conv | |||
from .conv_maxpool import ConvMaxpool | |||
from .embedding import Embedding | |||
from .linear import Linear | |||
@@ -8,6 +7,5 @@ from .bert import BertModel | |||
__all__ = ["LSTM", | |||
"Embedding", | |||
"Linear", | |||
"Conv", | |||
"ConvMaxpool", | |||
"BertModel"] |
@@ -1,58 +0,0 @@ | |||
# python: 3.6 | |||
# encoding: utf-8 | |||
import torch | |||
import torch.nn as nn | |||
from fastNLP.modules.utils import initial_parameter | |||
# import torch.nn.functional as F | |||
class Conv(nn.Module): | |||
"""Basic 1-d convolution module, initialized with xavier_uniform. | |||
:param int in_channels: | |||
:param int out_channels: | |||
:param tuple kernel_size: | |||
:param int stride: | |||
:param int padding: | |||
:param int dilation: | |||
:param int groups: | |||
:param bool bias: | |||
:param str activation: | |||
:param str initial_method: | |||
""" | |||
def __init__(self, in_channels, out_channels, kernel_size, | |||
stride=1, padding=0, dilation=1, | |||
groups=1, bias=True, activation='relu', initial_method=None): | |||
super(Conv, self).__init__() | |||
self.conv = nn.Conv1d( | |||
in_channels=in_channels, | |||
out_channels=out_channels, | |||
kernel_size=kernel_size, | |||
stride=stride, | |||
padding=padding, | |||
dilation=dilation, | |||
groups=groups, | |||
bias=bias) | |||
# xavier_uniform_(self.conv.weight) | |||
activations = { | |||
'relu': nn.ReLU(), | |||
'tanh': nn.Tanh()} | |||
if activation in activations: | |||
self.activation = activations[activation] | |||
else: | |||
raise Exception( | |||
'Should choose activation function from: ' + | |||
', '.join([x for x in activations])) | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] | |||
x = self.conv(x) # [N,C_in,L] -> [N,C_out,L] | |||
x = self.activation(x) | |||
x = torch.transpose(x, 1, 2) # [N,C,L] -> [N,L,C] | |||
return x |
@@ -9,18 +9,21 @@ from fastNLP.modules.utils import initial_parameter | |||
class ConvMaxpool(nn.Module): | |||
"""Convolution and max-pooling module with multiple kernel sizes. | |||
"""集合了Convolution和Max-Pooling于一体的层。 | |||
给定一个batch_size x max_len x input_size的输入,返回batch_size x sum(output_channels) 大小的matrix。在内部,是先使用 | |||
CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len)这一维进行max_pooling。最后得到每个sample的一个vector | |||
表示。 | |||
:param int in_channels: | |||
:param int out_channels: | |||
:param tuple kernel_sizes: | |||
:param int stride: | |||
:param int padding: | |||
:param int dilation: | |||
:param int groups: | |||
:param bool bias: | |||
:param str activation: | |||
:param str initial_method: | |||
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | |||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
:param int stride: 见pytorch Conv1D文档。所有kernel共享一个stride。 | |||
:param int padding: 见pytorch Conv1D文档。所有kernel共享一个padding。 | |||
:param int dilation: 见pytorch Conv1D文档。所有kernel共享一个dilation。 | |||
:param int groups: 见pytorch Conv1D文档。所有kernel共享一个groups。 | |||
:param bool bias: 见pytorch Conv1D文档。所有kernel共享一个bias。 | |||
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | |||
:param str initial_method: str。 | |||
""" | |||
def __init__(self, in_channels, out_channels, kernel_sizes, | |||
stride=1, padding=0, dilation=1, | |||
@@ -29,9 +32,14 @@ class ConvMaxpool(nn.Module): | |||
# convolution | |||
if isinstance(kernel_sizes, (list, tuple, int)): | |||
if isinstance(kernel_sizes, int): | |||
if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | |||
out_channels = [out_channels] | |||
kernel_sizes = [kernel_sizes] | |||
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | |||
assert len(out_channels)==len(kernel_sizes), "The number of out_channels should be equal to the number" \ | |||
" of kernel_sizes." | |||
else: | |||
raise ValueError("The type of out_channels and kernel_sizes should be the same.") | |||
self.convs = nn.ModuleList([nn.Conv1d( | |||
in_channels=in_channels, | |||
@@ -51,18 +59,31 @@ class ConvMaxpool(nn.Module): | |||
# activation function | |||
if activation == 'relu': | |||
self.activation = F.relu | |||
elif activation == 'sigmoid': | |||
self.activation = F.sigmoid | |||
elif activation == 'tanh': | |||
self.activation = F.tanh | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: relu") | |||
"Undefined activation function: choose from: relu, tanh, sigmoid") | |||
initial_parameter(self, initial_method) | |||
def forward(self, x): | |||
def forward(self, x, mask=None): | |||
""" | |||
:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | |||
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | |||
:return: | |||
""" | |||
# [N,L,C] -> [N,C,L] | |||
x = torch.transpose(x, 1, 2) | |||
# convolution | |||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L]] | |||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | |||
if mask is not None: | |||
mask = mask.unsqueeze(1) # B x 1 x L | |||
xs = [x.masked_fill_(mask, float('-inf')) for x in xs] | |||
# max-pooling | |||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | |||
for i in xs] # [[N, C]] | |||
return torch.cat(xs, dim=-1) # [N,C] | |||
for i in xs] # [[N, C], ...] | |||
return torch.cat(xs, dim=-1) # [N, C] |
@@ -1,424 +0,0 @@ | |||
__author__ = 'max' | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from fastNLP.modules.utils import initial_parameter | |||
def MaskedRecurrent(reverse=False): | |||
def forward(input, hidden, cell, mask, train=True, dropout=0): | |||
""" | |||
:param input: | |||
:param hidden: | |||
:param cell: | |||
:param mask: | |||
:param dropout: step之间的dropout,对mask了的也会drop,应该是没问题的,反正没有gradient | |||
:param train: 控制dropout的行为,在StackedRNN的forward中调用 | |||
:return: | |||
""" | |||
output = [] | |||
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) | |||
for i in steps: | |||
if mask is None or mask[i].data.min() > 0.5: # 没有mask,都是1 | |||
hidden = cell(input[i], hidden) | |||
elif mask[i].data.max() > 0.5: # 有mask,但不全为0 | |||
hidden_next = cell(input[i], hidden) # 一次喂入一个batch! | |||
# hack to handle LSTM | |||
if isinstance(hidden, tuple): # LSTM outputs a tuple of (hidden, cell), this is a common hack 😁 | |||
mask = mask.float() | |||
hx, cx = hidden | |||
hp1, cp1 = hidden_next | |||
hidden = ( | |||
hx + (hp1 - hx) * mask[i].squeeze(), | |||
cx + (cp1 - cx) * mask[i].squeeze()) # Why? 我知道了!!如果是mask就不用改变 | |||
else: | |||
hidden = hidden + (hidden_next - hidden) * mask[i] | |||
# if dropout != 0 and train: # warning, should i treat masked tensor differently? | |||
# if isinstance(hidden, tuple): | |||
# hidden = (F.dropout(hidden[0], p=dropout, training=train), | |||
# F.dropout(hidden[1], p=dropout, training=train)) | |||
# else: | |||
# hidden = F.dropout(hidden, p=dropout, training=train) | |||
# hack to handle LSTM | |||
output.append(hidden[0] if isinstance(hidden, tuple) else hidden) | |||
if reverse: | |||
output.reverse() | |||
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) | |||
return hidden, output | |||
return forward | |||
def StackedRNN(inners, num_layers, lstm=False, train=True, step_dropout=0, layer_dropout=0): | |||
num_directions = len(inners) # rec_factory! | |||
total_layers = num_layers * num_directions | |||
def forward(input, hidden, cells, mask): | |||
assert (len(cells) == total_layers) | |||
next_hidden = [] | |||
if lstm: | |||
hidden = list(zip(*hidden)) | |||
for i in range(num_layers): | |||
all_output = [] | |||
for j, inner in enumerate(inners): | |||
l = i * num_directions + j | |||
hy, output = inner(input, hidden[l], cells[l], mask, step_dropout, train) | |||
next_hidden.append(hy) | |||
all_output.append(output) | |||
input = torch.cat(all_output, input.dim() - 1) # 下一层的输入 | |||
if layer_dropout != 0 and i < num_layers - 1: | |||
input = F.dropout(input, p=layer_dropout, training=train, inplace=False) | |||
if lstm: | |||
next_h, next_c = zip(*next_hidden) | |||
next_hidden = ( | |||
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), | |||
torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) | |||
) | |||
else: | |||
next_hidden = torch.cat(next_hidden, 0).view(total_layers, *next_hidden[0].size()) | |||
return next_hidden, input | |||
return forward | |||
def AutogradMaskedRNN(num_layers=1, batch_first=False, train=True, layer_dropout=0, step_dropout=0, | |||
bidirectional=False, lstm=False): | |||
rec_factory = MaskedRecurrent | |||
if bidirectional: | |||
layer = (rec_factory(), rec_factory(reverse=True)) | |||
else: | |||
layer = (rec_factory(),) # rec_factory 就是每层的结构啦!!在MaskedRecurrent中进行每层的计算!然后用StackedRNN接起来 | |||
func = StackedRNN(layer, | |||
num_layers, | |||
lstm=lstm, | |||
layer_dropout=layer_dropout, step_dropout=step_dropout, | |||
train=train) | |||
def forward(input, cells, hidden, mask): | |||
if batch_first: | |||
input = input.transpose(0, 1) | |||
if mask is not None: | |||
mask = mask.transpose(0, 1) | |||
nexth, output = func(input, hidden, cells, mask) | |||
if batch_first: | |||
output = output.transpose(0, 1) | |||
return output, nexth | |||
return forward | |||
def MaskedStep(): | |||
def forward(input, hidden, cell, mask): | |||
if mask is None or mask.data.min() > 0.5: | |||
hidden = cell(input, hidden) | |||
elif mask.data.max() > 0.5: | |||
hidden_next = cell(input, hidden) | |||
# hack to handle LSTM | |||
if isinstance(hidden, tuple): | |||
hx, cx = hidden | |||
hp1, cp1 = hidden_next | |||
hidden = (hx + (hp1 - hx) * mask, cx + (cp1 - cx) * mask) | |||
else: | |||
hidden = hidden + (hidden_next - hidden) * mask | |||
# hack to handle LSTM | |||
output = hidden[0] if isinstance(hidden, tuple) else hidden | |||
return hidden, output | |||
return forward | |||
def StackedStep(layer, num_layers, lstm=False, dropout=0, train=True): | |||
def forward(input, hidden, cells, mask): | |||
assert (len(cells) == num_layers) | |||
next_hidden = [] | |||
if lstm: | |||
hidden = list(zip(*hidden)) | |||
for l in range(num_layers): | |||
hy, output = layer(input, hidden[l], cells[l], mask) | |||
next_hidden.append(hy) | |||
input = output | |||
if dropout != 0 and l < num_layers - 1: | |||
input = F.dropout(input, p=dropout, training=train, inplace=False) | |||
if lstm: | |||
next_h, next_c = zip(*next_hidden) | |||
next_hidden = ( | |||
torch.cat(next_h, 0).view(num_layers, *next_h[0].size()), | |||
torch.cat(next_c, 0).view(num_layers, *next_c[0].size()) | |||
) | |||
else: | |||
next_hidden = torch.cat(next_hidden, 0).view(num_layers, *next_hidden[0].size()) | |||
return next_hidden, input | |||
return forward | |||
def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False): | |||
layer = MaskedStep() | |||
func = StackedStep(layer, | |||
num_layers, | |||
lstm=lstm, | |||
dropout=dropout, | |||
train=train) | |||
def forward(input, cells, hidden, mask): | |||
nexth, output = func(input, hidden, cells, mask) | |||
return output, nexth | |||
return forward | |||
class MaskedRNNBase(nn.Module): | |||
def __init__(self, Cell, input_size, hidden_size, | |||
num_layers=1, bias=True, batch_first=False, | |||
layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs): | |||
""" | |||
:param Cell: | |||
:param input_size: | |||
:param hidden_size: | |||
:param num_layers: | |||
:param bias: | |||
:param batch_first: | |||
:param layer_dropout: | |||
:param step_dropout: | |||
:param bidirectional: | |||
:param kwargs: | |||
""" | |||
super(MaskedRNNBase, self).__init__() | |||
self.Cell = Cell | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.bias = bias | |||
self.batch_first = batch_first | |||
self.layer_dropout = layer_dropout | |||
self.step_dropout = step_dropout | |||
self.bidirectional = bidirectional | |||
num_directions = 2 if bidirectional else 1 | |||
self.all_cells = [] | |||
for layer in range(num_layers): # 初始化所有cell | |||
for direction in range(num_directions): | |||
layer_input_size = input_size if layer == 0 else hidden_size * num_directions | |||
cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) | |||
self.all_cells.append(cell) | |||
self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 | |||
initial_parameter(self, initial_method) | |||
def reset_parameters(self): | |||
for cell in self.all_cells: | |||
cell.reset_parameters() | |||
def forward(self, input, mask=None, hx=None): | |||
batch_size = input.size(0) if self.batch_first else input.size(1) | |||
lstm = self.Cell is nn.LSTMCell | |||
if hx is None: | |||
num_directions = 2 if self.bidirectional else 1 | |||
hx = torch.autograd.Variable( | |||
input.data.new(self.num_layers * num_directions, batch_size, self.hidden_size).zero_()) | |||
if lstm: | |||
hx = (hx, hx) | |||
func = AutogradMaskedRNN(num_layers=self.num_layers, | |||
batch_first=self.batch_first, | |||
step_dropout=self.step_dropout, | |||
layer_dropout=self.layer_dropout, | |||
train=self.training, | |||
bidirectional=self.bidirectional, | |||
lstm=lstm) # 传入all_cells,继续往底层封装走 | |||
output, hidden = func(input, self.all_cells, hx, | |||
None if mask is None else mask.view(mask.size() + (1,))) # 这个+ (1, )是个什么操作? | |||
return output, hidden | |||
def step(self, input, hx=None, mask=None): | |||
"""Execute one step forward (only for one-directional RNN). | |||
:param Tensor input: input tensor of this step. (batch, input_size) | |||
:param Tensor hx: the hidden state of last step. (num_layers, batch, hidden_size) | |||
:param Tensor mask: the mask tensor of this step. (batch, ) | |||
:returns: | |||
**output** (batch, hidden_size), tensor containing the output of this step from the last layer of RNN. | |||
**hn** (num_layers, batch, hidden_size), tensor containing the hidden state of this step | |||
""" | |||
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." # aha, typo! | |||
batch_size = input.size(0) | |||
lstm = self.Cell is nn.LSTMCell | |||
if hx is None: | |||
hx = torch.autograd.Variable(input.data.new(self.num_layers, batch_size, self.hidden_size).zero_()) | |||
if lstm: | |||
hx = (hx, hx) | |||
func = AutogradMaskedStep(num_layers=self.num_layers, | |||
dropout=self.step_dropout, | |||
train=self.training, | |||
lstm=lstm) | |||
output, hidden = func(input, self.all_cells, hx, mask) | |||
return output, hidden | |||
class MaskedRNN(MaskedRNNBase): | |||
r"""Applies a multi-layer Elman RNN with costomized non-linearity to an | |||
input sequence. | |||
For each element in the input sequence, each layer computes the following | |||
function. :math:`h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})` | |||
where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is | |||
the hidden state of the previous layer at time `t` or :math:`input_t` | |||
for the first layer. If nonlinearity='relu', then `ReLU` is used instead | |||
of `tanh`. | |||
:param int input_size: The number of expected features in the input x | |||
:param int hidden_size: The number of features in the hidden state h | |||
:param int num_layers: Number of recurrent layers. | |||
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh' | |||
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True | |||
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) | |||
:param float dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer | |||
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False | |||
Inputs: input, mask, h_0 | |||
- **input** (seq_len, batch, input_size): tensor containing the features | |||
of the input sequence. | |||
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. | |||
- **h_0** (num_layers * num_directions, batch, hidden_size): tensor | |||
containing the initial hidden state for each element in the batch. | |||
Outputs: output, h_n | |||
- **output** (seq_len, batch, hidden_size * num_directions): tensor | |||
containing the output features (h_k) from the last layer of the RNN, | |||
for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has | |||
been given as the input, the output will also be a packed sequence. | |||
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor | |||
containing the hidden state for k=seq_len. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super(MaskedRNN, self).__init__(nn.RNNCell, *args, **kwargs) | |||
class MaskedLSTM(MaskedRNNBase): | |||
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input | |||
sequence. | |||
For each element in the input sequence, each layer computes the following | |||
function. | |||
.. math:: | |||
\begin{array}{ll} | |||
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ | |||
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ | |||
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ | |||
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ | |||
c_t = f_t * c_{(t-1)} + i_t * g_t \\ | |||
h_t = o_t * \tanh(c_t) | |||
\end{array} | |||
where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell | |||
state at time `t`, :math:`x_t` is the hidden state of the previous layer at | |||
time `t` or :math:`input_t` for the first layer, and :math:`i_t`, | |||
:math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, | |||
and out gates, respectively. | |||
:param int input_size: The number of expected features in the input x | |||
:param int hidden_size: The number of features in the hidden state h | |||
:param int num_layers: Number of recurrent layers. | |||
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True | |||
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) | |||
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer | |||
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False | |||
Inputs: input, mask, (h_0, c_0) | |||
- **input** (seq_len, batch, input_size): tensor containing the features | |||
of the input sequence. | |||
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. | |||
- **h_0** (num_layers \* num_directions, batch, hidden_size): tensor | |||
containing the initial hidden state for each element in the batch. | |||
- **c_0** (num_layers \* num_directions, batch, hidden_size): tensor | |||
containing the initial cell state for each element in the batch. | |||
Outputs: output, (h_n, c_n) | |||
- **output** (seq_len, batch, hidden_size * num_directions): tensor | |||
containing the output features `(h_t)` from the last layer of the RNN, | |||
for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been | |||
given as the input, the output will also be a packed sequence. | |||
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor | |||
containing the hidden state for t=seq_len | |||
- **c_n** (num_layers * num_directions, batch, hidden_size): tensor | |||
containing the cell state for t=seq_len | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super(MaskedLSTM, self).__init__(nn.LSTMCell, *args, **kwargs) | |||
class MaskedGRU(MaskedRNNBase): | |||
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. | |||
For each element in the input sequence, each layer computes the following | |||
function: | |||
.. math:: | |||
\begin{array}{ll} | |||
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ | |||
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ | |||
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ | |||
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ | |||
\end{array} | |||
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden | |||
state of the previous layer at time `t` or :math:`input_t` for the first | |||
layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input, | |||
and new gates, respectively. | |||
:param int input_size: The number of expected features in the input x | |||
:param int hidden_size: The number of features in the hidden state h | |||
:param int num_layers: Number of recurrent layers. | |||
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh' | |||
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True | |||
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) | |||
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer | |||
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False | |||
Inputs: input, mask, h_0 | |||
- **input** (seq_len, batch, input_size): tensor containing the features | |||
of the input sequence. | |||
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. | |||
- **h_0** (num_layers * num_directions, batch, hidden_size): tensor | |||
containing the initial hidden state for each element in the batch. | |||
Outputs: output, h_n | |||
- **output** (seq_len, batch, hidden_size * num_directions): tensor | |||
containing the output features (h_k) from the last layer of the RNN, | |||
for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has | |||
been given as the input, the output will also be a packed sequence. | |||
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor | |||
containing the hidden state for k=seq_len. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super(MaskedGRU, self).__init__(nn.GRUCell, *args, **kwargs) |
@@ -1,27 +0,0 @@ | |||
import torch | |||
import unittest | |||
from fastNLP.modules.encoder.masked_rnn import MaskedRNN | |||
class TestMaskedRnn(unittest.TestCase): | |||
def test_case_1(self): | |||
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | |||
x = torch.tensor([[[1.0], [2.0]]]) | |||
print(x.size()) | |||
y = masked_rnn(x) | |||
mask = torch.tensor([[[1], [1]]]) | |||
y = masked_rnn(x, mask=mask) | |||
mask = torch.tensor([[[1], [0]]]) | |||
y = masked_rnn(x, mask=mask) | |||
def test_case_2(self): | |||
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=False, batch_first=True) | |||
x = torch.tensor([[[1.0], [2.0]]]) | |||
print(x.size()) | |||
y = masked_rnn(x) | |||
mask = torch.tensor([[[1], [1]]]) | |||
y = masked_rnn(x, mask=mask) | |||
xx = torch.tensor([[[1.0]]]) | |||
y = masked_rnn.step(xx) | |||
y = masked_rnn.step(xx, mask=mask) |
@@ -70,7 +70,7 @@ class TestTutorial(unittest.TestCase): | |||
break | |||
from fastNLP.models import CNNText | |||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
from fastNLP import Trainer | |||
from copy import deepcopy | |||
@@ -145,7 +145,7 @@ class TestTutorial(unittest.TestCase): | |||
is_input=True) | |||
from fastNLP.models import CNNText | |||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | |||
trainer = Trainer(model=model, | |||
@@ -405,7 +405,7 @@ class TestTutorial(unittest.TestCase): | |||
# 另一个例子:加载CNN文本分类模型 | |||
from fastNLP.models import CNNText | |||
cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
cnn_text_model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
cnn_text_model | |||
from fastNLP import CrossEntropyLoss | |||