@@ -1,3 +1,5 @@ | |||||
from .core import * | from .core import * | ||||
from . import models | from . import models | ||||
from . import modules | from . import modules | ||||
__version__ = '0.4.0' |
@@ -1,4 +1,5 @@ | |||||
""" | """ | ||||
Callback的说明文档 | |||||
.. _Callback: | .. _Callback: | ||||
@@ -28,7 +29,6 @@ class Callback(object): | |||||
def trainer(self): | def trainer(self): | ||||
""" | """ | ||||
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | 该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 | ||||
:return: | |||||
""" | """ | ||||
return self._trainer | return self._trainer | ||||
@@ -323,11 +323,16 @@ class GradientClipCallback(Callback): | |||||
class CallbackException(BaseException): | class CallbackException(BaseException): | ||||
def __init__(self, msg): | def __init__(self, msg): | ||||
""" | |||||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | |||||
:param str msg: Exception的信息。 | |||||
""" | |||||
super(CallbackException, self).__init__(msg) | super(CallbackException, self).__init__(msg) | ||||
class EarlyStopError(CallbackException): | class EarlyStopError(CallbackException): | ||||
def __init__(self, msg): | def __init__(self, msg): | ||||
"""用于EarlyStop时从Trainer训练循环中跳出。""" | |||||
super(EarlyStopError, self).__init__(msg) | super(EarlyStopError, self).__init__(msg) | ||||
@@ -360,7 +365,13 @@ class EarlyStopCallback(Callback): | |||||
class LRScheduler(Callback): | class LRScheduler(Callback): | ||||
def __init__(self, lr_scheduler): | def __init__(self, lr_scheduler): | ||||
"""对PyTorch LR Scheduler的包装 | |||||
"""对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | |||||
Example:: | |||||
from fastNLP import LRScheduler | |||||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | :param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | ||||
""" | """ | ||||
@@ -13,6 +13,9 @@ class Optimizer(object): | |||||
self.model_params = model_params | self.model_params = model_params | ||||
self.settings = kwargs | self.settings = kwargs | ||||
def construct_from_pytorch(self, model_params): | |||||
raise NotImplementedError | |||||
def _get_require_grads_param(self, params): | def _get_require_grads_param(self, params): | ||||
""" | """ | ||||
将params中不需要gradient的删除 | 将params中不需要gradient的删除 | ||||
@@ -14,20 +14,56 @@ from fastNLP.core.utils import _get_device | |||||
class Tester(object): | class Tester(object): | ||||
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. | |||||
""" | |||||
Tester是在提供数据,模型以及metric的情况下进行性能测试的类 | |||||
Example:: | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
from fastNLP import Tester | |||||
from fastNLP import DataSet | |||||
from fastNLP import AccuracyMetric | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(1, 1) | |||||
def forward(self, a): | |||||
return {'pred': self.fc(a.unsqueeze(1)).squeeze(1)} | |||||
model = Model() | |||||
dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2}) | |||||
dataset.set_input('a') | |||||
dataset.set_target('b') | |||||
tester = Tester(dataset, model, metrics=AccuracyMetric()) | |||||
eval_results = tester.test() | |||||
这里Metric的映射规律是和 Trainer_ 中一致的,请参考 Trainer_ 使用metrics。 | |||||
:param DataSet data: a validation/development set | |||||
:param torch.nn.modules.module model: a PyTorch model | |||||
:param MetricBase metrics: a metric object or a list of metrics (List[MetricBase]) | |||||
:param int batch_size: batch size for validation | |||||
:param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持 | |||||
以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||||
可见的第二个GPU中; torch.device,将模型装载到torch.device上。 | |||||
:param int verbose: the number of steps after which an information is printed. | |||||
""" | """ | ||||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | ||||
"""传入模型,数据以及metric进行验证。 | |||||
:param DataSet data: 需要测试的数据集 | |||||
:param torch.nn.module model: 使用的模型 | |||||
:param MetricBase metrics: 一个Metric或者一个列表的metric对象 | |||||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||||
:param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持 | |||||
以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||||
可见的第二个GPU中; torch.device,将模型装载到torch.device上。 | |||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||||
""" | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
@@ -59,10 +95,10 @@ class Tester(object): | |||||
self._predict_func = self._model.forward | self._predict_func = self._model.forward | ||||
def test(self): | def test(self): | ||||
"""Start test or validation. | |||||
:return eval_results: a dictionary whose keys are the class name of metrics to use, values are the evaluation results of these metrics. | |||||
"""开始进行验证,并返回验证结果。 | |||||
:return dict(dict) eval_results: dict为二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。 | |||||
一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 | |||||
""" | """ | ||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
network = self._model | network = self._model | ||||
@@ -213,7 +213,7 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 | |||||
from torch.optim import SGD | from torch.optim import SGD | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP import AccuracyMetric | |||||
import torch | import torch | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
@@ -322,7 +322,7 @@ from fastNLP.core.utils import _check_loss_evaluate | |||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.utils import _get_func_signature | from fastNLP.core.utils import _get_func_signature | ||||
from fastNLP.core.utils import _get_device | from fastNLP.core.utils import _get_device | ||||
from fastNLP.core.optimizer import Optimizer | |||||
class Trainer(object): | class Trainer(object): | ||||
def __init__(self, train_data, model, optimizer, loss=None, | def __init__(self, train_data, model, optimizer, loss=None, | ||||
@@ -336,8 +336,7 @@ class Trainer(object): | |||||
""" | """ | ||||
:param DataSet train_data: 训练集 | :param DataSet train_data: 训练集 | ||||
:param nn.modules model: 待训练的模型 | :param nn.modules model: 待训练的模型 | ||||
:param Optimizer,None optimizer: 优化器,pytorch的torch.optim.Optimizer类型。如果为None,则Trainer不会更新模型, | |||||
请确保已在callback中进行了更新。 | |||||
:param torch.optim.Optimizer,None optimizer: 优化器。如果为None,则Trainer不会更新模型,请确保已在callback中进行了更新。 | |||||
:param int batch_size: 训练和验证的时候的batch大小。 | :param int batch_size: 训练和验证的时候的batch大小。 | ||||
:param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。 | :param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。 | ||||
:param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。 | :param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。 | ||||
@@ -438,6 +437,8 @@ class Trainer(object): | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
elif isinstance(optimizer, Optimizer): | |||||
self.optimizer = optimizer.construct_from_pytorch(model.parameters()) | |||||
elif optimizer is None: | elif optimizer is None: | ||||
warnings.warn("The optimizer is set to None, Trainer will update your model. Make sure you update the model" | warnings.warn("The optimizer is set to None, Trainer will update your model. Make sure you update the model" | ||||
" in the callback.") | " in the callback.") | ||||
@@ -8,7 +8,7 @@ from fastNLP.io.base_loader import BaseLoader | |||||
import warnings | import warnings | ||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
"""docstring for EmbedLoader""" | |||||
"""这个类用于从预训练的Embedding中load数据。""" | |||||
def __init__(self): | def __init__(self): | ||||
super(EmbedLoader, self).__init__() | super(EmbedLoader, self).__init__() | ||||
@@ -16,18 +16,17 @@ class EmbedLoader(BaseLoader): | |||||
@staticmethod | @staticmethod | ||||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | ||||
""" | """ | ||||
load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining | |||||
embedding are initialized from a normal distribution which has the mean and std of the found words vectors. | |||||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||||
:param embed_filepath: str, where to read pretrain embedding | |||||
:param vocab: Vocabulary. | |||||
:param dtype: the dtype of the embedding matrix | |||||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||||
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will | |||||
raise | |||||
:return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain | |||||
embedding | |||||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||||
:param str embed_filepath: 预训练的embedding的路径。 | |||||
:param Vocabulary vocab: 词表,读取出现在vocab中的词的embedding。没有出现在vocab中的词的embedding将通过找到的词的 | |||||
embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||||
:param dtype: 读出的embedding的类型 | |||||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||||
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地 | |||||
方在于词表有空行或者词表出现了维度不一致。 | |||||
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||||
""" | """ | ||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | ||||
if not os.path.exists(embed_filepath): | if not os.path.exists(embed_filepath): | ||||
@@ -76,19 +75,18 @@ class EmbedLoader(BaseLoader): | |||||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | ||||
error='ignore'): | error='ignore'): | ||||
""" | """ | ||||
load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. | |||||
The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). | |||||
:param embed_filepath: str, where to read pretrain embedding | |||||
:param dtype: the dtype of the embedding matrix | |||||
:param padding: the padding tag for vocabulary. | |||||
:param unknown: the unknown tag for vocabulary. | |||||
:param normalize: bool, whether to normalize each word vector so that every vector has norm 1. | |||||
:param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will | |||||
:raise | |||||
:return: np.ndarray() is determined by the pretraining embeddings | |||||
Vocabulary: contain all pretraining words and two special tag[<pad>, <unk>] | |||||
从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 | |||||
:param str embed_filepath: 预训练的embedding的路径。 | |||||
:param dtype: 读出的embedding的类型 | |||||
:param str padding: the padding tag for vocabulary. | |||||
:param str unknown: the unknown tag for vocabulary. | |||||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||||
:param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地 | |||||
方在于词表有空行或者词表出现了维度不一致。 | |||||
:return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||||
:return: numpy.ndarray,Vocabulary embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||||
是否使用padding, 以及unknown有没有在词表中找到对应的词。Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 | |||||
""" | """ | ||||
vocab = Vocabulary(padding=padding, unknown=unknown) | vocab = Vocabulary(padding=padding, unknown=unknown) | ||||
vec_dict = {} | vec_dict = {} | ||||
@@ -3,29 +3,38 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import numpy as np | |||||
# import torch.nn.functional as F | |||||
import fastNLP.modules.encoder as encoder | import fastNLP.modules.encoder as encoder | ||||
class CNNText(torch.nn.Module): | class CNNText(torch.nn.Module): | ||||
""" | """ | ||||
Text classification model by character CNN, the implementation of paper | |||||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence | |||||
Classification.' | |||||
使用CNN进行文本分类的模型 | |||||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' | |||||
""" | """ | ||||
def __init__(self, embed_num, | |||||
def __init__(self, vocab_size, | |||||
embed_dim, | embed_dim, | ||||
num_classes, | num_classes, | ||||
kernel_nums=(3, 4, 5), | kernel_nums=(3, 4, 5), | ||||
kernel_sizes=(3, 4, 5), | kernel_sizes=(3, 4, 5), | ||||
padding=0, | padding=0, | ||||
dropout=0.5): | dropout=0.5): | ||||
""" | |||||
:param int vocab_size: 词表的大小 | |||||
:param int embed_dim: 词embedding的维度大小 | |||||
:param int num_classes: 一共有多少类 | |||||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||||
:param int padding: | |||||
:param float dropout: Dropout的大小 | |||||
""" | |||||
super(CNNText, self).__init__() | super(CNNText, self).__init__() | ||||
# no support for pre-trained embedding currently | # no support for pre-trained embedding currently | ||||
self.embed = encoder.Embedding(embed_num, embed_dim) | |||||
self.embed = encoder.Embedding(vocab_size, embed_dim) | |||||
self.conv_pool = encoder.ConvMaxpool( | self.conv_pool = encoder.ConvMaxpool( | ||||
in_channels=embed_dim, | in_channels=embed_dim, | ||||
out_channels=kernel_nums, | out_channels=kernel_nums, | ||||
@@ -34,24 +43,36 @@ class CNNText(torch.nn.Module): | |||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
self.fc = encoder.Linear(sum(kernel_nums), num_classes) | self.fc = encoder.Linear(sum(kernel_nums), num_classes) | ||||
def forward(self, word_seq): | |||||
def init_embed(self, embed): | |||||
""" | |||||
加载预训练的模型 | |||||
:param numpy.ndarray embed: vocab_size x embed_dim的embedding | |||||
:return: | |||||
""" | |||||
assert isinstance(embed, np.ndarray) | |||||
assert embed.shape == self.embed.embed.weight.shape | |||||
self.embed.embed.weight.data = torch.from_numpy(embed) | |||||
def forward(self, words, seq_len=None): | |||||
""" | """ | ||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||||
:return output: dict of torch.LongTensor, [batch_size, num_classes] | :return output: dict of torch.LongTensor, [batch_size, num_classes] | ||||
""" | """ | ||||
x = self.embed(word_seq) # [N,L] -> [N,L,C] | |||||
x = self.embed(words) # [N,L] -> [N,L,C] | |||||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | x = self.conv_pool(x) # [N,L,C] -> [N,C] | ||||
x = self.dropout(x) | x = self.dropout(x) | ||||
x = self.fc(x) # [N,C] -> [N, N_class] | x = self.fc(x) # [N,C] -> [N, N_class] | ||||
return {'pred': x} | return {'pred': x} | ||||
def predict(self, word_seq): | |||||
def predict(self, words, seq_len=None): | |||||
""" | """ | ||||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||||
:return predict: dict of torch.LongTensor, [batch_size, ] | |||||
""" | """ | ||||
output = self(word_seq) | |||||
output = self(words, seq_len) | |||||
_, predict = output['pred'].max(dim=1) | _, predict = output['pred'].max(dim=1) | ||||
return {'pred': predict} | return {'pred': predict} |
@@ -8,47 +8,64 @@ from fastNLP.modules.utils import seq_mask | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
""" | """ | ||||
PyTorch Network for sequence labeling | |||||
一个基础的Sequence labeling的模型 | |||||
""" | """ | ||||
def __init__(self, args): | |||||
def __init__(self, vocab_size, embed_dim, hidden_size, num_classes): | |||||
""" | |||||
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 | |||||
:param int vocab_size: 词表大小。 | |||||
:param int embed_dim: embedding的维度 | |||||
:param int hidden_size: LSTM隐藏层的大小 | |||||
:param int num_classes: 一共有多少类 | |||||
""" | |||||
super(SeqLabeling, self).__init__() | super(SeqLabeling, self).__init__() | ||||
vocab_size = args["vocab_size"] | |||||
word_emb_dim = args["word_emb_dim"] | |||||
hidden_dim = args["rnn_hidden_units"] | |||||
num_classes = args["num_classes"] | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | |||||
self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim) | |||||
self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, embed_dim) | |||||
self.Rnn = encoder.lstm.LSTM(embed_dim, hidden_size) | |||||
self.Linear = encoder.linear.Linear(hidden_size, num_classes) | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
self.mask = None | self.mask = None | ||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||||
def forward(self, words, seq_len, target): | |||||
""" | """ | ||||
:param word_seq: LongTensor, [batch_size, mex_len] | |||||
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. | |||||
:param truth: LongTensor, [batch_size, max_len] | |||||
:param torch.LongTensor words: [batch_size, max_len],序列的index | |||||
:param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 | |||||
:param torch.LongTensor target: [batch_size, max_len], 序列的目标值 | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | ||||
If truth is not None, return loss, a scalar. Used in training. | If truth is not None, return loss, a scalar. Used in training. | ||||
""" | """ | ||||
assert word_seq.shape[0] == word_seq_origin_len.shape[0] | |||||
if truth is not None: | |||||
assert truth.shape == word_seq.shape | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||||
assert words.shape[0] == seq_len.shape[0] | |||||
assert target.shape == words.shape | |||||
self.mask = self._make_mask(words, seq_len) | |||||
x = self.Embedding(word_seq) | |||||
x = self.Embedding(words) | |||||
# [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
x = self.Rnn(x) | x = self.Rnn(x) | ||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
x = self.Linear(x) | x = self.Linear(x) | ||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||||
"predict": self.decode(x)} | |||||
return {"loss": self._internal_loss(x, target)} | |||||
def loss(self, x, y): | |||||
""" Since the loss has been computed in forward(), this function simply returns x.""" | |||||
return x | |||||
def predict(self, words, seq_len): | |||||
""" | |||||
用于在预测时使用 | |||||
:param torch.LongTensor words: [batch_size, max_len] | |||||
:param torch.LongTensor seq_len: [batch_size,] | |||||
:return: | |||||
""" | |||||
self.mask = self._make_mask(words, seq_len) | |||||
x = self.Embedding(words) | |||||
# [batch_size, max_len, word_emb_dim] | |||||
x = self.Rnn(x) | |||||
# [batch_size, max_len, hidden_size * direction] | |||||
x = self.Linear(x) | |||||
# [batch_size, max_len, num_classes] | |||||
pred = self._decode(x) | |||||
return {'pred': pred} | |||||
def _internal_loss(self, x, y): | def _internal_loss(self, x, y): | ||||
""" | """ | ||||
@@ -65,89 +82,114 @@ class SeqLabeling(BaseModel): | |||||
total_loss = self.Crf(x, y, self.mask) | total_loss = self.Crf(x, y, self.mask) | ||||
return torch.mean(total_loss) | return torch.mean(total_loss) | ||||
def make_mask(self, x, seq_len): | |||||
def _make_mask(self, x, seq_len): | |||||
batch_size, max_len = x.size(0), x.size(1) | batch_size, max_len = x.size(0), x.size(1) | ||||
mask = seq_mask(seq_len, max_len) | mask = seq_mask(seq_len, max_len) | ||||
mask = mask.view(batch_size, max_len) | mask = mask.view(batch_size, max_len) | ||||
mask = mask.to(x).float() | mask = mask.to(x).float() | ||||
return mask | return mask | ||||
def decode(self, x, pad=True): | |||||
def _decode(self, x): | |||||
""" | """ | ||||
:param x: FloatTensor, [batch_size, max_len, tag_size] | |||||
:param pad: pad the output sequence to equal lengths | |||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||||
:return prediction: list of [decode path(list)] | :return prediction: list of [decode path(list)] | ||||
""" | """ | ||||
max_len = x.shape[1] | |||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) | |||||
# pad prediction to equal length | |||||
if pad is True: | |||||
for pred in tag_seq: | |||||
if len(pred) < max_len: | |||||
pred += [0] * (max_len - len(pred)) | |||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True) | |||||
return tag_seq | return tag_seq | ||||
class AdvSeqLabel(SeqLabeling): | |||||
class AdvSeqLabel: | |||||
""" | """ | ||||
Advanced Sequence Labeling Model | |||||
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 | |||||
""" | """ | ||||
def __init__(self, args, emb=None, id2words=None): | |||||
super(AdvSeqLabel, self).__init__(args) | |||||
vocab_size = args["vocab_size"] | |||||
word_emb_dim = args["word_emb_dim"] | |||||
hidden_dim = args["rnn_hidden_units"] | |||||
num_classes = args["num_classes"] | |||||
dropout = args['dropout'] | |||||
def __init__(self, vocab_size, embed_dim, hidden_size, num_classes, dropout=0.3, embedding=None, | |||||
id2words=None, encoding_type='bmes'): | |||||
""" | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | |||||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, | |||||
:param int vocab_size: 词表的大小 | |||||
:param int embed_dim: embedding的维度 | |||||
:param int hidden_size: LSTM的隐层大小 | |||||
:param int num_classes: 有多少个类 | |||||
:param float dropout: LSTM中以及DropOut层的drop概率 | |||||
:param numpy.ndarray embedding: 预训练的embedding,需要与指定的词表大小等一致 | |||||
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S' | |||||
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证 | |||||
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。) | |||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 | |||||
""" | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, embed_dim, init_emb=embedding) | |||||
self.norm1 = torch.nn.LayerNorm(embed_dim) | |||||
self.Rnn = torch.nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout, | |||||
bidirectional=True, batch_first=True) | bidirectional=True, batch_first=True) | ||||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | |||||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||||
self.Linear1 = encoder.Linear(hidden_size * 2, hidden_size * 2 // 3) | |||||
self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3) | |||||
self.relu = torch.nn.LeakyReLU() | self.relu = torch.nn.LeakyReLU() | ||||
self.drop = torch.nn.Dropout(dropout) | self.drop = torch.nn.Dropout(dropout) | ||||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | |||||
self.Linear2 = encoder.Linear(hidden_size * 2 // 3, num_classes) | |||||
if id2words is None: | if id2words is None: | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | ||||
else: | else: | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | ||||
allowed_transitions=allowed_transitions(id2words, | allowed_transitions=allowed_transitions(id2words, | ||||
encoding_type="bmes")) | |||||
encoding_type=encoding_type)) | |||||
def _decode(self, x): | |||||
""" | |||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||||
:return prediction: list of [decode path(list)] | |||||
""" | |||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True) | |||||
return tag_seq | |||||
def _internal_loss(self, x, y): | |||||
""" | |||||
Negative log likelihood loss. | |||||
:param x: Tensor, [batch_size, max_len, tag_size] | |||||
:param y: Tensor, [batch_size, max_len] | |||||
:return loss: a scalar Tensor | |||||
""" | |||||
x = x.float() | |||||
y = y.long() | |||||
assert x.shape[:2] == y.shape | |||||
assert y.shape == self.mask.shape | |||||
total_loss = self.Crf(x, y, self.mask) | |||||
return torch.mean(total_loss) | |||||
def _make_mask(self, x, seq_len): | |||||
batch_size, max_len = x.size(0), x.size(1) | |||||
mask = seq_mask(seq_len, max_len) | |||||
mask = mask.view(batch_size, max_len) | |||||
mask = mask.to(x).float() | |||||
return mask | |||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||||
def _forward(self, words, seq_len, target=None): | |||||
""" | """ | ||||
:param word_seq: LongTensor, [batch_size, mex_len] | |||||
:param word_seq_origin_len: LongTensor, [batch_size, ] | |||||
:param truth: LongTensor, [batch_size, max_len] | |||||
:param torch.LongTensor words: [batch_size, mex_len] | |||||
:param torch.LongTensor seq_len:[batch_size, ] | |||||
:param torch.LongTensor target: [batch_size, max_len] | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | ||||
If truth is not None, return loss, a scalar. Used in training. | If truth is not None, return loss, a scalar. Used in training. | ||||
""" | """ | ||||
word_seq = word_seq.long() | |||||
word_seq_origin_len = word_seq_origin_len.long() | |||||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||||
sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True) | |||||
words = words.long() | |||||
seq_len = seq_len.long() | |||||
self.mask = self._make_mask(words, seq_len) | |||||
sent_len, idx_sort = torch.sort(seq_len, descending=True) | |||||
_, idx_unsort = torch.sort(idx_sort, descending=False) | _, idx_unsort = torch.sort(idx_sort, descending=False) | ||||
# word_seq_origin_len = word_seq_origin_len.long() | |||||
truth = truth.long() if truth is not None else None | |||||
# seq_len = seq_len.long() | |||||
target = target.long() if target is not None else None | |||||
batch_size = word_seq.size(0) | |||||
max_len = word_seq.size(1) | |||||
if next(self.parameters()).is_cuda: | if next(self.parameters()).is_cuda: | ||||
word_seq = word_seq.cuda() | |||||
words = words.cuda() | |||||
idx_sort = idx_sort.cuda() | idx_sort = idx_sort.cuda() | ||||
idx_unsort = idx_unsort.cuda() | idx_unsort = idx_unsort.cuda() | ||||
self.mask = self.mask.cuda() | self.mask = self.mask.cuda() | ||||
x = self.Embedding(word_seq) | |||||
x = self.Embedding(words) | |||||
x = self.norm1(x) | x = self.norm1(x) | ||||
# [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
@@ -155,71 +197,35 @@ class AdvSeqLabel(SeqLabeling): | |||||
sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) | ||||
x, _ = self.Rnn(sent_packed) | x, _ = self.Rnn(sent_packed) | ||||
# print(x) | |||||
# [batch_size, max_len, hidden_size * direction] | |||||
sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] | ||||
x = sent_output[idx_unsort] | x = sent_output[idx_unsort] | ||||
x = x.contiguous() | x = x.contiguous() | ||||
# x = x.view(batch_size * max_len, -1) | |||||
x = self.Linear1(x) | x = self.Linear1(x) | ||||
# x = self.batch_norm(x) | |||||
x = self.norm2(x) | x = self.norm2(x) | ||||
x = self.relu(x) | x = self.relu(x) | ||||
x = self.drop(x) | x = self.drop(x) | ||||
x = self.Linear2(x) | x = self.Linear2(x) | ||||
# x = x.view(batch_size, max_len, -1) | |||||
# [batch_size, max_len, num_classes] | |||||
# TODO seq_lens的key这样做不合理 | |||||
return {"loss": self._internal_loss(x, truth) if truth is not None else None, | |||||
"predict": self.decode(x), | |||||
'word_seq_origin_len': word_seq_origin_len} | |||||
def predict(self, **x): | |||||
out = self.forward(**x) | |||||
return {"predict": out["predict"]} | |||||
def loss(self, **kwargs): | |||||
assert 'loss' in kwargs | |||||
return kwargs['loss'] | |||||
if __name__ == '__main__': | |||||
args = { | |||||
'vocab_size': 20, | |||||
'word_emb_dim': 100, | |||||
'rnn_hidden_units': 100, | |||||
'num_classes': 10, | |||||
} | |||||
model = AdvSeqLabel(args) | |||||
data = [] | |||||
for i in range(20): | |||||
word_seq = torch.randint(20, (15,)).long() | |||||
word_seq_len = torch.LongTensor([15]) | |||||
truth = torch.randint(10, (15,)).long() | |||||
data.append((word_seq, word_seq_len, truth)) | |||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |||||
print(model) | |||||
curidx = 0 | |||||
for i in range(1000): | |||||
endidx = min(len(data), curidx + 5) | |||||
b_word, b_len, b_truth = [], [], [] | |||||
for word_seq, word_seq_len, truth in data[curidx: endidx]: | |||||
b_word.append(word_seq) | |||||
b_len.append(word_seq_len) | |||||
b_truth.append(truth) | |||||
word_seq = torch.stack(b_word, dim=0) | |||||
word_seq_len = torch.cat(b_len, dim=0) | |||||
truth = torch.stack(b_truth, dim=0) | |||||
res = model(word_seq, word_seq_len, truth) | |||||
loss = res['loss'] | |||||
pred = res['predict'] | |||||
print('loss: {} acc {}'.format(loss.item(), | |||||
((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
curidx = endidx | |||||
if curidx == len(data): | |||||
curidx = 0 | |||||
if target is not None: | |||||
return {"loss": self._internal_loss(x, target)} | |||||
else: | |||||
return {"pred": self._decode(x)} | |||||
def forward(self, words, seq_len, target): | |||||
""" | |||||
:param torch.LongTensor words: [batch_size, mex_len] | |||||
:param torch.LongTensor seq_len:[batch_size, ] | |||||
:param torch.LongTensor target: [batch_size, max_len], 目标 | |||||
:return torch.Tensor, a scalar loss | |||||
""" | |||||
return self._forward(words, seq_len, target) | |||||
def predict(self, words, seq_len): | |||||
""" | |||||
:param torch.LongTensor words: [batch_size, mex_len] | |||||
:param torch.LongTensor seq_len:[batch_size, ] | |||||
:return: [list1, list2, ...], 内部每个list为一个路径,已经unpad了。 | |||||
""" | |||||
return self._forward(words, seq_len, ) |
@@ -19,10 +19,10 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||||
""" | """ | ||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | ||||
:param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 | "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 | ||||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||||
:param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||||
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 | |||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | ||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | start_idx=len(id2label), end_idx=len(id2label)+1。 | ||||
为False, 返回的结果中不含与开始结尾相关的内容 | 为False, 返回的结果中不含与开始结尾相关的内容 | ||||
@@ -62,11 +62,11 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | ||||
""" | """ | ||||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||||
:param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param from_label: str, 比如"PER", "LOC"等label | |||||
:param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param to_label: str, 比如"PER", "LOC"等label | |||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 | |||||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param str from_label: 比如"PER", "LOC"等label | |||||
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param str to_label: 比如"PER", "LOC"等label | |||||
:return: bool,能否跃迁 | :return: bool,能否跃迁 | ||||
""" | """ | ||||
if to_tag=='start' or from_tag=='end': | if to_tag=='start' or from_tag=='end': | ||||
@@ -149,12 +149,12 @@ class ConditionalRandomField(nn.Module): | |||||
"""条件随机场。 | """条件随机场。 | ||||
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | 提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | ||||
:param num_tags: int, 标签的数量 | |||||
:param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。 | |||||
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int), | |||||
:param int num_tags: 标签的数量 | |||||
:param bool include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | |||||
:param List[Tuple[from_tag_id(int), to_tag_id(int)]] allowed_transitions: 内部的Tuple[from_tag_id(int), | |||||
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | ||||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | ||||
:param initial_method: str, 初始化方法。见initial_parameter | |||||
:param str initial_method: 初始化方法。见initial_parameter | |||||
""" | """ | ||||
super(ConditionalRandomField, self).__init__() | super(ConditionalRandomField, self).__init__() | ||||
@@ -237,10 +237,10 @@ class ConditionalRandomField(nn.Module): | |||||
""" | """ | ||||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | ||||
:param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||||
:param tags:LongTensor, batch_size x max_len,标签矩阵。 | |||||
:param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。 | |||||
:return:FloatTensor, batch_size | |||||
:param torch.FloatTensor feats:batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | |||||
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | |||||
:return:torch.FloatTensor, (batch_size,) | |||||
""" | """ | ||||
feats = feats.transpose(0, 1) | feats = feats.transpose(0, 1) | ||||
tags = tags.transpose(0, 1).long() | tags = tags.transpose(0, 1).long() | ||||
@@ -250,27 +250,26 @@ class ConditionalRandomField(nn.Module): | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, feats, mask, unpad=False): | |||||
def viterbi_decode(self, logits, mask, unpad=False): | |||||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | ||||
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||||
:param unpad: bool, 是否将结果删去padding, | |||||
False, 返回的是batch_size x max_len的tensor, | |||||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int] | |||||
的长度是这个sample的有效长度。 | |||||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||||
个sample的有效长度。 | |||||
:return: 返回 (paths, scores)。 | :return: 返回 (paths, scores)。 | ||||
paths: 是解码后的路径, 其值参照unpad参数. | paths: 是解码后的路径, 其值参照unpad参数. | ||||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | ||||
""" | """ | ||||
batch_size, seq_len, n_tags = feats.size() | |||||
feats = feats.transpose(0, 1).data # L, B, H | |||||
batch_size, seq_len, n_tags = logits.size() | |||||
logits = logits.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.byte() # L, B | mask = mask.transpose(0, 1).data.byte() # L, B | ||||
# dp | # dp | ||||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||||
vscore = feats[0] | |||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||||
vscore = logits[0] | |||||
transitions = self._constrain.data.clone() | transitions = self._constrain.data.clone() | ||||
transitions[:n_tags, :n_tags] += self.trans_m.data | transitions[:n_tags, :n_tags] += self.trans_m.data | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
@@ -281,7 +280,7 @@ class ConditionalRandomField(nn.Module): | |||||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | ||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
prev_score = vscore.view(batch_size, n_tags, 1) | prev_score = vscore.view(batch_size, n_tags, 1) | ||||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||||
cur_score = logits[i].view(batch_size, 1, n_tags) | |||||
score = prev_score + trans_score + cur_score | score = prev_score + trans_score + cur_score | ||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
@@ -292,13 +291,13 @@ class ConditionalRandomField(nn.Module): | |||||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | vscore += transitions[:n_tags, n_tags+1].view(1, -1) | ||||
# backtrace | # backtrace | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||||
lens = (mask.long().sum(0) - 1) | lens = (mask.long().sum(0) - 1) | ||||
# idxes [L, B], batched idx from seq_len-1 to 0 | # idxes [L, B], batched idx from seq_len-1 to 0 | ||||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | ||||
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||||
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) | |||||
ans_score, last_tags = vscore.max(1) | ans_score, last_tags = vscore.max(1) | ||||
ans[idxes[0], batch_idx] = last_tags | ans[idxes[0], batch_idx] = last_tags | ||||
for i in range(seq_len - 1): | for i in range(seq_len - 1): | ||||
@@ -311,6 +310,5 @@ class ConditionalRandomField(nn.Module): | |||||
paths.append(ans[idx, :seq_len+1].tolist()) | paths.append(ans[idx, :seq_len+1].tolist()) | ||||
else: | else: | ||||
paths = ans | paths = ans | ||||
if get_score: | |||||
return paths, ans_score.tolist() | |||||
return paths | |||||
return paths, ans_score | |||||
@@ -11,13 +11,12 @@ def log_sum_exp(x, dim=-1): | |||||
def viterbi_decode(logits, transitions, mask=None, unpad=False): | def viterbi_decode(logits, transitions, mask=None, unpad=False): | ||||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | ||||
:param logits: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||||
:param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 | |||||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||||
:param unpad: bool, 是否将结果删去padding, | |||||
False, 返回的是batch_size x max_len的tensor, | |||||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是 | |||||
这个sample的有效长度。 | |||||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.FloatTensor transitions: n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 | |||||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||||
个sample的有效长度。 | |||||
:return: 返回 (paths, scores)。 | :return: 返回 (paths, scores)。 | ||||
paths: 是解码后的路径, 其值参照unpad参数. | paths: 是解码后的路径, 其值参照unpad参数. | ||||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | ||||
@@ -1,4 +1,3 @@ | |||||
from .conv import Conv | |||||
from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
from .embedding import Embedding | from .embedding import Embedding | ||||
from .linear import Linear | from .linear import Linear | ||||
@@ -8,6 +7,5 @@ from .bert import BertModel | |||||
__all__ = ["LSTM", | __all__ = ["LSTM", | ||||
"Embedding", | "Embedding", | ||||
"Linear", | "Linear", | ||||
"Conv", | |||||
"ConvMaxpool", | "ConvMaxpool", | ||||
"BertModel"] | "BertModel"] |
@@ -1,58 +0,0 @@ | |||||
# python: 3.6 | |||||
# encoding: utf-8 | |||||
import torch | |||||
import torch.nn as nn | |||||
from fastNLP.modules.utils import initial_parameter | |||||
# import torch.nn.functional as F | |||||
class Conv(nn.Module): | |||||
"""Basic 1-d convolution module, initialized with xavier_uniform. | |||||
:param int in_channels: | |||||
:param int out_channels: | |||||
:param tuple kernel_size: | |||||
:param int stride: | |||||
:param int padding: | |||||
:param int dilation: | |||||
:param int groups: | |||||
:param bool bias: | |||||
:param str activation: | |||||
:param str initial_method: | |||||
""" | |||||
def __init__(self, in_channels, out_channels, kernel_size, | |||||
stride=1, padding=0, dilation=1, | |||||
groups=1, bias=True, activation='relu', initial_method=None): | |||||
super(Conv, self).__init__() | |||||
self.conv = nn.Conv1d( | |||||
in_channels=in_channels, | |||||
out_channels=out_channels, | |||||
kernel_size=kernel_size, | |||||
stride=stride, | |||||
padding=padding, | |||||
dilation=dilation, | |||||
groups=groups, | |||||
bias=bias) | |||||
# xavier_uniform_(self.conv.weight) | |||||
activations = { | |||||
'relu': nn.ReLU(), | |||||
'tanh': nn.Tanh()} | |||||
if activation in activations: | |||||
self.activation = activations[activation] | |||||
else: | |||||
raise Exception( | |||||
'Should choose activation function from: ' + | |||||
', '.join([x for x in activations])) | |||||
initial_parameter(self, initial_method) | |||||
def forward(self, x): | |||||
x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] | |||||
x = self.conv(x) # [N,C_in,L] -> [N,C_out,L] | |||||
x = self.activation(x) | |||||
x = torch.transpose(x, 1, 2) # [N,C,L] -> [N,L,C] | |||||
return x |
@@ -9,18 +9,21 @@ from fastNLP.modules.utils import initial_parameter | |||||
class ConvMaxpool(nn.Module): | class ConvMaxpool(nn.Module): | ||||
"""Convolution and max-pooling module with multiple kernel sizes. | |||||
"""集合了Convolution和Max-Pooling于一体的层。 | |||||
给定一个batch_size x max_len x input_size的输入,返回batch_size x sum(output_channels) 大小的matrix。在内部,是先使用 | |||||
CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len)这一维进行max_pooling。最后得到每个sample的一个vector | |||||
表示。 | |||||
:param int in_channels: | |||||
:param int out_channels: | |||||
:param tuple kernel_sizes: | |||||
:param int stride: | |||||
:param int padding: | |||||
:param int dilation: | |||||
:param int groups: | |||||
:param bool bias: | |||||
:param str activation: | |||||
:param str initial_method: | |||||
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | |||||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||||
:param int stride: 见pytorch Conv1D文档。所有kernel共享一个stride。 | |||||
:param int padding: 见pytorch Conv1D文档。所有kernel共享一个padding。 | |||||
:param int dilation: 见pytorch Conv1D文档。所有kernel共享一个dilation。 | |||||
:param int groups: 见pytorch Conv1D文档。所有kernel共享一个groups。 | |||||
:param bool bias: 见pytorch Conv1D文档。所有kernel共享一个bias。 | |||||
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | |||||
:param str initial_method: str。 | |||||
""" | """ | ||||
def __init__(self, in_channels, out_channels, kernel_sizes, | def __init__(self, in_channels, out_channels, kernel_sizes, | ||||
stride=1, padding=0, dilation=1, | stride=1, padding=0, dilation=1, | ||||
@@ -29,9 +32,14 @@ class ConvMaxpool(nn.Module): | |||||
# convolution | # convolution | ||||
if isinstance(kernel_sizes, (list, tuple, int)): | if isinstance(kernel_sizes, (list, tuple, int)): | ||||
if isinstance(kernel_sizes, int): | |||||
if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | |||||
out_channels = [out_channels] | out_channels = [out_channels] | ||||
kernel_sizes = [kernel_sizes] | kernel_sizes = [kernel_sizes] | ||||
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | |||||
assert len(out_channels)==len(kernel_sizes), "The number of out_channels should be equal to the number" \ | |||||
" of kernel_sizes." | |||||
else: | |||||
raise ValueError("The type of out_channels and kernel_sizes should be the same.") | |||||
self.convs = nn.ModuleList([nn.Conv1d( | self.convs = nn.ModuleList([nn.Conv1d( | ||||
in_channels=in_channels, | in_channels=in_channels, | ||||
@@ -51,18 +59,31 @@ class ConvMaxpool(nn.Module): | |||||
# activation function | # activation function | ||||
if activation == 'relu': | if activation == 'relu': | ||||
self.activation = F.relu | self.activation = F.relu | ||||
elif activation == 'sigmoid': | |||||
self.activation = F.sigmoid | |||||
elif activation == 'tanh': | |||||
self.activation = F.tanh | |||||
else: | else: | ||||
raise Exception( | raise Exception( | ||||
"Undefined activation function: choose from: relu") | |||||
"Undefined activation function: choose from: relu, tanh, sigmoid") | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x): | |||||
def forward(self, x, mask=None): | |||||
""" | |||||
:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | |||||
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | |||||
:return: | |||||
""" | |||||
# [N,L,C] -> [N,C,L] | # [N,L,C] -> [N,C,L] | ||||
x = torch.transpose(x, 1, 2) | x = torch.transpose(x, 1, 2) | ||||
# convolution | # convolution | ||||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L]] | |||||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | |||||
if mask is not None: | |||||
mask = mask.unsqueeze(1) # B x 1 x L | |||||
xs = [x.masked_fill_(mask, float('-inf')) for x in xs] | |||||
# max-pooling | # max-pooling | ||||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | ||||
for i in xs] # [[N, C]] | |||||
return torch.cat(xs, dim=-1) # [N,C] | |||||
for i in xs] # [[N, C], ...] | |||||
return torch.cat(xs, dim=-1) # [N, C] |
@@ -1,424 +0,0 @@ | |||||
__author__ = 'max' | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from fastNLP.modules.utils import initial_parameter | |||||
def MaskedRecurrent(reverse=False): | |||||
def forward(input, hidden, cell, mask, train=True, dropout=0): | |||||
""" | |||||
:param input: | |||||
:param hidden: | |||||
:param cell: | |||||
:param mask: | |||||
:param dropout: step之间的dropout,对mask了的也会drop,应该是没问题的,反正没有gradient | |||||
:param train: 控制dropout的行为,在StackedRNN的forward中调用 | |||||
:return: | |||||
""" | |||||
output = [] | |||||
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) | |||||
for i in steps: | |||||
if mask is None or mask[i].data.min() > 0.5: # 没有mask,都是1 | |||||
hidden = cell(input[i], hidden) | |||||
elif mask[i].data.max() > 0.5: # 有mask,但不全为0 | |||||
hidden_next = cell(input[i], hidden) # 一次喂入一个batch! | |||||
# hack to handle LSTM | |||||
if isinstance(hidden, tuple): # LSTM outputs a tuple of (hidden, cell), this is a common hack 😁 | |||||
mask = mask.float() | |||||
hx, cx = hidden | |||||
hp1, cp1 = hidden_next | |||||
hidden = ( | |||||
hx + (hp1 - hx) * mask[i].squeeze(), | |||||
cx + (cp1 - cx) * mask[i].squeeze()) # Why? 我知道了!!如果是mask就不用改变 | |||||
else: | |||||
hidden = hidden + (hidden_next - hidden) * mask[i] | |||||
# if dropout != 0 and train: # warning, should i treat masked tensor differently? | |||||
# if isinstance(hidden, tuple): | |||||
# hidden = (F.dropout(hidden[0], p=dropout, training=train), | |||||
# F.dropout(hidden[1], p=dropout, training=train)) | |||||
# else: | |||||
# hidden = F.dropout(hidden, p=dropout, training=train) | |||||
# hack to handle LSTM | |||||
output.append(hidden[0] if isinstance(hidden, tuple) else hidden) | |||||
if reverse: | |||||
output.reverse() | |||||
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) | |||||
return hidden, output | |||||
return forward | |||||
def StackedRNN(inners, num_layers, lstm=False, train=True, step_dropout=0, layer_dropout=0): | |||||
num_directions = len(inners) # rec_factory! | |||||
total_layers = num_layers * num_directions | |||||
def forward(input, hidden, cells, mask): | |||||
assert (len(cells) == total_layers) | |||||
next_hidden = [] | |||||
if lstm: | |||||
hidden = list(zip(*hidden)) | |||||
for i in range(num_layers): | |||||
all_output = [] | |||||
for j, inner in enumerate(inners): | |||||
l = i * num_directions + j | |||||
hy, output = inner(input, hidden[l], cells[l], mask, step_dropout, train) | |||||
next_hidden.append(hy) | |||||
all_output.append(output) | |||||
input = torch.cat(all_output, input.dim() - 1) # 下一层的输入 | |||||
if layer_dropout != 0 and i < num_layers - 1: | |||||
input = F.dropout(input, p=layer_dropout, training=train, inplace=False) | |||||
if lstm: | |||||
next_h, next_c = zip(*next_hidden) | |||||
next_hidden = ( | |||||
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), | |||||
torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) | |||||
) | |||||
else: | |||||
next_hidden = torch.cat(next_hidden, 0).view(total_layers, *next_hidden[0].size()) | |||||
return next_hidden, input | |||||
return forward | |||||
def AutogradMaskedRNN(num_layers=1, batch_first=False, train=True, layer_dropout=0, step_dropout=0, | |||||
bidirectional=False, lstm=False): | |||||
rec_factory = MaskedRecurrent | |||||
if bidirectional: | |||||
layer = (rec_factory(), rec_factory(reverse=True)) | |||||
else: | |||||
layer = (rec_factory(),) # rec_factory 就是每层的结构啦!!在MaskedRecurrent中进行每层的计算!然后用StackedRNN接起来 | |||||
func = StackedRNN(layer, | |||||
num_layers, | |||||
lstm=lstm, | |||||
layer_dropout=layer_dropout, step_dropout=step_dropout, | |||||
train=train) | |||||
def forward(input, cells, hidden, mask): | |||||
if batch_first: | |||||
input = input.transpose(0, 1) | |||||
if mask is not None: | |||||
mask = mask.transpose(0, 1) | |||||
nexth, output = func(input, hidden, cells, mask) | |||||
if batch_first: | |||||
output = output.transpose(0, 1) | |||||
return output, nexth | |||||
return forward | |||||
def MaskedStep(): | |||||
def forward(input, hidden, cell, mask): | |||||
if mask is None or mask.data.min() > 0.5: | |||||
hidden = cell(input, hidden) | |||||
elif mask.data.max() > 0.5: | |||||
hidden_next = cell(input, hidden) | |||||
# hack to handle LSTM | |||||
if isinstance(hidden, tuple): | |||||
hx, cx = hidden | |||||
hp1, cp1 = hidden_next | |||||
hidden = (hx + (hp1 - hx) * mask, cx + (cp1 - cx) * mask) | |||||
else: | |||||
hidden = hidden + (hidden_next - hidden) * mask | |||||
# hack to handle LSTM | |||||
output = hidden[0] if isinstance(hidden, tuple) else hidden | |||||
return hidden, output | |||||
return forward | |||||
def StackedStep(layer, num_layers, lstm=False, dropout=0, train=True): | |||||
def forward(input, hidden, cells, mask): | |||||
assert (len(cells) == num_layers) | |||||
next_hidden = [] | |||||
if lstm: | |||||
hidden = list(zip(*hidden)) | |||||
for l in range(num_layers): | |||||
hy, output = layer(input, hidden[l], cells[l], mask) | |||||
next_hidden.append(hy) | |||||
input = output | |||||
if dropout != 0 and l < num_layers - 1: | |||||
input = F.dropout(input, p=dropout, training=train, inplace=False) | |||||
if lstm: | |||||
next_h, next_c = zip(*next_hidden) | |||||
next_hidden = ( | |||||
torch.cat(next_h, 0).view(num_layers, *next_h[0].size()), | |||||
torch.cat(next_c, 0).view(num_layers, *next_c[0].size()) | |||||
) | |||||
else: | |||||
next_hidden = torch.cat(next_hidden, 0).view(num_layers, *next_hidden[0].size()) | |||||
return next_hidden, input | |||||
return forward | |||||
def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False): | |||||
layer = MaskedStep() | |||||
func = StackedStep(layer, | |||||
num_layers, | |||||
lstm=lstm, | |||||
dropout=dropout, | |||||
train=train) | |||||
def forward(input, cells, hidden, mask): | |||||
nexth, output = func(input, hidden, cells, mask) | |||||
return output, nexth | |||||
return forward | |||||
class MaskedRNNBase(nn.Module): | |||||
def __init__(self, Cell, input_size, hidden_size, | |||||
num_layers=1, bias=True, batch_first=False, | |||||
layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs): | |||||
""" | |||||
:param Cell: | |||||
:param input_size: | |||||
:param hidden_size: | |||||
:param num_layers: | |||||
:param bias: | |||||
:param batch_first: | |||||
:param layer_dropout: | |||||
:param step_dropout: | |||||
:param bidirectional: | |||||
:param kwargs: | |||||
""" | |||||
super(MaskedRNNBase, self).__init__() | |||||
self.Cell = Cell | |||||
self.input_size = input_size | |||||
self.hidden_size = hidden_size | |||||
self.num_layers = num_layers | |||||
self.bias = bias | |||||
self.batch_first = batch_first | |||||
self.layer_dropout = layer_dropout | |||||
self.step_dropout = step_dropout | |||||
self.bidirectional = bidirectional | |||||
num_directions = 2 if bidirectional else 1 | |||||
self.all_cells = [] | |||||
for layer in range(num_layers): # 初始化所有cell | |||||
for direction in range(num_directions): | |||||
layer_input_size = input_size if layer == 0 else hidden_size * num_directions | |||||
cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) | |||||
self.all_cells.append(cell) | |||||
self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 | |||||
initial_parameter(self, initial_method) | |||||
def reset_parameters(self): | |||||
for cell in self.all_cells: | |||||
cell.reset_parameters() | |||||
def forward(self, input, mask=None, hx=None): | |||||
batch_size = input.size(0) if self.batch_first else input.size(1) | |||||
lstm = self.Cell is nn.LSTMCell | |||||
if hx is None: | |||||
num_directions = 2 if self.bidirectional else 1 | |||||
hx = torch.autograd.Variable( | |||||
input.data.new(self.num_layers * num_directions, batch_size, self.hidden_size).zero_()) | |||||
if lstm: | |||||
hx = (hx, hx) | |||||
func = AutogradMaskedRNN(num_layers=self.num_layers, | |||||
batch_first=self.batch_first, | |||||
step_dropout=self.step_dropout, | |||||
layer_dropout=self.layer_dropout, | |||||
train=self.training, | |||||
bidirectional=self.bidirectional, | |||||
lstm=lstm) # 传入all_cells,继续往底层封装走 | |||||
output, hidden = func(input, self.all_cells, hx, | |||||
None if mask is None else mask.view(mask.size() + (1,))) # 这个+ (1, )是个什么操作? | |||||
return output, hidden | |||||
def step(self, input, hx=None, mask=None): | |||||
"""Execute one step forward (only for one-directional RNN). | |||||
:param Tensor input: input tensor of this step. (batch, input_size) | |||||
:param Tensor hx: the hidden state of last step. (num_layers, batch, hidden_size) | |||||
:param Tensor mask: the mask tensor of this step. (batch, ) | |||||
:returns: | |||||
**output** (batch, hidden_size), tensor containing the output of this step from the last layer of RNN. | |||||
**hn** (num_layers, batch, hidden_size), tensor containing the hidden state of this step | |||||
""" | |||||
assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." # aha, typo! | |||||
batch_size = input.size(0) | |||||
lstm = self.Cell is nn.LSTMCell | |||||
if hx is None: | |||||
hx = torch.autograd.Variable(input.data.new(self.num_layers, batch_size, self.hidden_size).zero_()) | |||||
if lstm: | |||||
hx = (hx, hx) | |||||
func = AutogradMaskedStep(num_layers=self.num_layers, | |||||
dropout=self.step_dropout, | |||||
train=self.training, | |||||
lstm=lstm) | |||||
output, hidden = func(input, self.all_cells, hx, mask) | |||||
return output, hidden | |||||
class MaskedRNN(MaskedRNNBase): | |||||
r"""Applies a multi-layer Elman RNN with costomized non-linearity to an | |||||
input sequence. | |||||
For each element in the input sequence, each layer computes the following | |||||
function. :math:`h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})` | |||||
where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is | |||||
the hidden state of the previous layer at time `t` or :math:`input_t` | |||||
for the first layer. If nonlinearity='relu', then `ReLU` is used instead | |||||
of `tanh`. | |||||
:param int input_size: The number of expected features in the input x | |||||
:param int hidden_size: The number of features in the hidden state h | |||||
:param int num_layers: Number of recurrent layers. | |||||
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh' | |||||
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True | |||||
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) | |||||
:param float dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer | |||||
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False | |||||
Inputs: input, mask, h_0 | |||||
- **input** (seq_len, batch, input_size): tensor containing the features | |||||
of the input sequence. | |||||
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. | |||||
- **h_0** (num_layers * num_directions, batch, hidden_size): tensor | |||||
containing the initial hidden state for each element in the batch. | |||||
Outputs: output, h_n | |||||
- **output** (seq_len, batch, hidden_size * num_directions): tensor | |||||
containing the output features (h_k) from the last layer of the RNN, | |||||
for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has | |||||
been given as the input, the output will also be a packed sequence. | |||||
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor | |||||
containing the hidden state for k=seq_len. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super(MaskedRNN, self).__init__(nn.RNNCell, *args, **kwargs) | |||||
class MaskedLSTM(MaskedRNNBase): | |||||
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input | |||||
sequence. | |||||
For each element in the input sequence, each layer computes the following | |||||
function. | |||||
.. math:: | |||||
\begin{array}{ll} | |||||
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ | |||||
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ | |||||
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ | |||||
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ | |||||
c_t = f_t * c_{(t-1)} + i_t * g_t \\ | |||||
h_t = o_t * \tanh(c_t) | |||||
\end{array} | |||||
where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell | |||||
state at time `t`, :math:`x_t` is the hidden state of the previous layer at | |||||
time `t` or :math:`input_t` for the first layer, and :math:`i_t`, | |||||
:math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, | |||||
and out gates, respectively. | |||||
:param int input_size: The number of expected features in the input x | |||||
:param int hidden_size: The number of features in the hidden state h | |||||
:param int num_layers: Number of recurrent layers. | |||||
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True | |||||
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) | |||||
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer | |||||
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False | |||||
Inputs: input, mask, (h_0, c_0) | |||||
- **input** (seq_len, batch, input_size): tensor containing the features | |||||
of the input sequence. | |||||
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. | |||||
- **h_0** (num_layers \* num_directions, batch, hidden_size): tensor | |||||
containing the initial hidden state for each element in the batch. | |||||
- **c_0** (num_layers \* num_directions, batch, hidden_size): tensor | |||||
containing the initial cell state for each element in the batch. | |||||
Outputs: output, (h_n, c_n) | |||||
- **output** (seq_len, batch, hidden_size * num_directions): tensor | |||||
containing the output features `(h_t)` from the last layer of the RNN, | |||||
for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been | |||||
given as the input, the output will also be a packed sequence. | |||||
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor | |||||
containing the hidden state for t=seq_len | |||||
- **c_n** (num_layers * num_directions, batch, hidden_size): tensor | |||||
containing the cell state for t=seq_len | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super(MaskedLSTM, self).__init__(nn.LSTMCell, *args, **kwargs) | |||||
class MaskedGRU(MaskedRNNBase): | |||||
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. | |||||
For each element in the input sequence, each layer computes the following | |||||
function: | |||||
.. math:: | |||||
\begin{array}{ll} | |||||
r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ | |||||
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ | |||||
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ | |||||
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ | |||||
\end{array} | |||||
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden | |||||
state of the previous layer at time `t` or :math:`input_t` for the first | |||||
layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input, | |||||
and new gates, respectively. | |||||
:param int input_size: The number of expected features in the input x | |||||
:param int hidden_size: The number of features in the hidden state h | |||||
:param int num_layers: Number of recurrent layers. | |||||
:param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh' | |||||
:param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True | |||||
:param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) | |||||
:param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer | |||||
:param bool bidirectional: If True, becomes a bidirectional RNN. Default: False | |||||
Inputs: input, mask, h_0 | |||||
- **input** (seq_len, batch, input_size): tensor containing the features | |||||
of the input sequence. | |||||
**mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. | |||||
- **h_0** (num_layers * num_directions, batch, hidden_size): tensor | |||||
containing the initial hidden state for each element in the batch. | |||||
Outputs: output, h_n | |||||
- **output** (seq_len, batch, hidden_size * num_directions): tensor | |||||
containing the output features (h_k) from the last layer of the RNN, | |||||
for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has | |||||
been given as the input, the output will also be a packed sequence. | |||||
- **h_n** (num_layers * num_directions, batch, hidden_size): tensor | |||||
containing the hidden state for k=seq_len. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super(MaskedGRU, self).__init__(nn.GRUCell, *args, **kwargs) |
@@ -1,27 +0,0 @@ | |||||
import torch | |||||
import unittest | |||||
from fastNLP.modules.encoder.masked_rnn import MaskedRNN | |||||
class TestMaskedRnn(unittest.TestCase): | |||||
def test_case_1(self): | |||||
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | |||||
x = torch.tensor([[[1.0], [2.0]]]) | |||||
print(x.size()) | |||||
y = masked_rnn(x) | |||||
mask = torch.tensor([[[1], [1]]]) | |||||
y = masked_rnn(x, mask=mask) | |||||
mask = torch.tensor([[[1], [0]]]) | |||||
y = masked_rnn(x, mask=mask) | |||||
def test_case_2(self): | |||||
masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=False, batch_first=True) | |||||
x = torch.tensor([[[1.0], [2.0]]]) | |||||
print(x.size()) | |||||
y = masked_rnn(x) | |||||
mask = torch.tensor([[[1], [1]]]) | |||||
y = masked_rnn(x, mask=mask) | |||||
xx = torch.tensor([[[1.0]]]) | |||||
y = masked_rnn.step(xx) | |||||
y = masked_rnn.step(xx, mask=mask) |
@@ -70,7 +70,7 @@ class TestTutorial(unittest.TestCase): | |||||
break | break | ||||
from fastNLP.models import CNNText | from fastNLP.models import CNNText | ||||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
from copy import deepcopy | from copy import deepcopy | ||||
@@ -145,7 +145,7 @@ class TestTutorial(unittest.TestCase): | |||||
is_input=True) | is_input=True) | ||||
from fastNLP.models import CNNText | from fastNLP.models import CNNText | ||||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | ||||
trainer = Trainer(model=model, | trainer = Trainer(model=model, | ||||
@@ -405,7 +405,7 @@ class TestTutorial(unittest.TestCase): | |||||
# 另一个例子:加载CNN文本分类模型 | # 另一个例子:加载CNN文本分类模型 | ||||
from fastNLP.models import CNNText | from fastNLP.models import CNNText | ||||
cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
cnn_text_model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||||
cnn_text_model | cnn_text_model | ||||
from fastNLP import CrossEntropyLoss | from fastNLP import CrossEntropyLoss | ||||