diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 0f6da45f..35309bd3 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -1,3 +1,5 @@ from .core import * from . import models from . import modules + +__version__ = '0.4.0' \ No newline at end of file diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index a416f655..914e4d28 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -1,4 +1,5 @@ """ +Callback的说明文档 .. _Callback: @@ -28,7 +29,6 @@ class Callback(object): def trainer(self): """ 该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。 - :return: """ return self._trainer @@ -323,11 +323,16 @@ class GradientClipCallback(Callback): class CallbackException(BaseException): def __init__(self, msg): + """ + 当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 + :param str msg: Exception的信息。 + """ super(CallbackException, self).__init__(msg) class EarlyStopError(CallbackException): def __init__(self, msg): + """用于EarlyStop时从Trainer训练循环中跳出。""" super(EarlyStopError, self).__init__(msg) @@ -360,7 +365,13 @@ class EarlyStopCallback(Callback): class LRScheduler(Callback): def __init__(self, lr_scheduler): - """对PyTorch LR Scheduler的包装 + """对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 + + Example:: + + from fastNLP import LRScheduler + + :param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler """ diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index da2c45fe..584aa5ff 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -13,6 +13,9 @@ class Optimizer(object): self.model_params = model_params self.settings = kwargs + def construct_from_pytorch(self, model_params): + raise NotImplementedError + def _get_require_grads_param(self, params): """ 将params中不需要gradient的删除 diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 9737f53a..6e3f98b5 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -14,20 +14,56 @@ from fastNLP.core.utils import _get_device class Tester(object): - """An collection of model inference and evaluation of performance, used over validation/dev set and test set. + """ + Tester是在提供数据,模型以及metric的情况下进行性能测试的类 + + Example:: + + import numpy as np + import torch + from torch import nn + from fastNLP import Tester + from fastNLP import DataSet + from fastNLP import AccuracyMetric + + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(1, 1) + def forward(self, a): + return {'pred': self.fc(a.unsqueeze(1)).squeeze(1)} + + model = Model() + + dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2}) + + dataset.set_input('a') + dataset.set_target('b') + + tester = Tester(dataset, model, metrics=AccuracyMetric()) + eval_results = tester.test() + + 这里Metric的映射规律是和 Trainer_ 中一致的,请参考 Trainer_ 使用metrics。 + - :param DataSet data: a validation/development set - :param torch.nn.modules.module model: a PyTorch model - :param MetricBase metrics: a metric object or a list of metrics (List[MetricBase]) - :param int batch_size: batch size for validation - :param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持 - 以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, - 可见的第二个GPU中; torch.device,将模型装载到torch.device上。 - :param int verbose: the number of steps after which an information is printed. """ def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): + """传入模型,数据以及metric进行验证。 + + :param DataSet data: 需要测试的数据集 + :param torch.nn.module model: 使用的模型 + :param MetricBase metrics: 一个Metric或者一个列表的metric对象 + :param int batch_size: evaluation时使用的batch_size有多大。 + :param str,torch.device,None device: 将模型load到哪个设备。默认为None,即Trainer不对模型的计算位置进行管理。支持 + 以下的输入str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, + 可见的第二个GPU中; torch.device,将模型装载到torch.device上。 + :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 + + """ + super(Tester, self).__init__() if not isinstance(data, DataSet): @@ -59,10 +95,10 @@ class Tester(object): self._predict_func = self._model.forward def test(self): - """Start test or validation. - - :return eval_results: a dictionary whose keys are the class name of metrics to use, values are the evaluation results of these metrics. + """开始进行验证,并返回验证结果。 + :return dict(dict) eval_results: dict为二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。 + 一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 """ # turn on the testing mode; clean up the history network = self._model diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 44d88d3c..48733652 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -213,7 +213,7 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 from torch.optim import SGD from fastNLP import Trainer from fastNLP import DataSet - from fastNLP.core.metrics import AccuracyMetric + from fastNLP import AccuracyMetric import torch class Model(nn.Module): @@ -322,7 +322,7 @@ from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _get_func_signature from fastNLP.core.utils import _get_device - +from fastNLP.core.optimizer import Optimizer class Trainer(object): def __init__(self, train_data, model, optimizer, loss=None, @@ -336,8 +336,7 @@ class Trainer(object): """ :param DataSet train_data: 训练集 :param nn.modules model: 待训练的模型 - :param Optimizer,None optimizer: 优化器,pytorch的torch.optim.Optimizer类型。如果为None,则Trainer不会更新模型, - 请确保已在callback中进行了更新。 + :param torch.optim.Optimizer,None optimizer: 优化器。如果为None,则Trainer不会更新模型,请确保已在callback中进行了更新。 :param int batch_size: 训练和验证的时候的batch大小。 :param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。 :param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。 @@ -438,6 +437,8 @@ class Trainer(object): if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer + elif isinstance(optimizer, Optimizer): + self.optimizer = optimizer.construct_from_pytorch(model.parameters()) elif optimizer is None: warnings.warn("The optimizer is set to None, Trainer will update your model. Make sure you update the model" " in the callback.") diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index e1f20b94..31e590da 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -8,7 +8,7 @@ from fastNLP.io.base_loader import BaseLoader import warnings class EmbedLoader(BaseLoader): - """docstring for EmbedLoader""" + """这个类用于从预训练的Embedding中load数据。""" def __init__(self): super(EmbedLoader, self).__init__() @@ -16,18 +16,17 @@ class EmbedLoader(BaseLoader): @staticmethod def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): """ - load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining - embedding are initialized from a normal distribution which has the mean and std of the found words vectors. - The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). - - :param embed_filepath: str, where to read pretrain embedding - :param vocab: Vocabulary. - :param dtype: the dtype of the embedding matrix - :param normalize: bool, whether to normalize each word vector so that every vector has norm 1. - :param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will - raise - :return: np.ndarray() will have the same [len(vocab), dimension], dimension is determined by the pretrain - embedding + 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 + word2vec(第一行只有两个元素)还是glove格式的数据。 + + :param str embed_filepath: 预训练的embedding的路径。 + :param Vocabulary vocab: 词表,读取出现在vocab中的词的embedding。没有出现在vocab中的词的embedding将通过找到的词的 + embedding的正态分布采样出来,以使得整个Embedding是同分布的。 + :param dtype: 读出的embedding的类型 + :param bool normalize: 是否将每个vector归一化到norm为1 + :param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地 + 方在于词表有空行或者词表出现了维度不一致。 + :return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 """ assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." if not os.path.exists(embed_filepath): @@ -76,19 +75,18 @@ class EmbedLoader(BaseLoader): def load_without_vocab(embed_filepath, dtype=np.float32, padding='', unknown='', normalize=True, error='ignore'): """ - load pretraining embedding in {embed_file}. And construct a Vocabulary based on the pretraining embedding. - The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). - - :param embed_filepath: str, where to read pretrain embedding - :param dtype: the dtype of the embedding matrix - :param padding: the padding tag for vocabulary. - :param unknown: the unknown tag for vocabulary. - :param normalize: bool, whether to normalize each word vector so that every vector has norm 1. - :param error: str, 'ignore', 'strict'; if 'ignore' errors will not raise. if strict, any bad format error will - :raise - :return: np.ndarray() is determined by the pretraining embeddings - Vocabulary: contain all pretraining words and two special tag[, ] - + 从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 + + :param str embed_filepath: 预训练的embedding的路径。 + :param dtype: 读出的embedding的类型 + :param str padding: the padding tag for vocabulary. + :param str unknown: the unknown tag for vocabulary. + :param bool normalize: 是否将每个vector归一化到norm为1 + :param str error: 'ignore', 'strict'; 如果'ignore',错误将自动跳过; 如果strict, 错误将抛出。这里主要可能出错的地 + 方在于词表有空行或者词表出现了维度不一致。 + :return: numpy.ndarray, shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 + :return: numpy.ndarray,Vocabulary embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 + 是否使用padding, 以及unknown有没有在词表中找到对应的词。Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 """ vocab = Vocabulary(padding=padding, unknown=unknown) vec_dict = {} diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index f3898c00..37551e14 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -3,29 +3,38 @@ import torch import torch.nn as nn +import numpy as np -# import torch.nn.functional as F import fastNLP.modules.encoder as encoder class CNNText(torch.nn.Module): """ - Text classification model by character CNN, the implementation of paper - 'Yoon Kim. 2014. Convolution Neural Networks for Sentence - Classification.' + 使用CNN进行文本分类的模型 + 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' """ - def __init__(self, embed_num, + def __init__(self, vocab_size, embed_dim, num_classes, kernel_nums=(3, 4, 5), kernel_sizes=(3, 4, 5), padding=0, dropout=0.5): + """ + + :param int vocab_size: 词表的大小 + :param int embed_dim: 词embedding的维度大小 + :param int num_classes: 一共有多少类 + :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 + :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 + :param int padding: + :param float dropout: Dropout的大小 + """ super(CNNText, self).__init__() # no support for pre-trained embedding currently - self.embed = encoder.Embedding(embed_num, embed_dim) + self.embed = encoder.Embedding(vocab_size, embed_dim) self.conv_pool = encoder.ConvMaxpool( in_channels=embed_dim, out_channels=kernel_nums, @@ -34,24 +43,36 @@ class CNNText(torch.nn.Module): self.dropout = nn.Dropout(dropout) self.fc = encoder.Linear(sum(kernel_nums), num_classes) - def forward(self, word_seq): + def init_embed(self, embed): + """ + 加载预训练的模型 + :param numpy.ndarray embed: vocab_size x embed_dim的embedding + :return: + """ + assert isinstance(embed, np.ndarray) + assert embed.shape == self.embed.embed.weight.shape + self.embed.embed.weight.data = torch.from_numpy(embed) + + def forward(self, words, seq_len=None): """ - :param word_seq: torch.LongTensor, [batch_size, seq_len] + :param torch.LongTensor words: [batch_size, seq_len],句子中word的index + :param torch.LongTensor seq_len: [batch,] 每个句子的长度 :return output: dict of torch.LongTensor, [batch_size, num_classes] """ - x = self.embed(word_seq) # [N,L] -> [N,L,C] + x = self.embed(words) # [N,L] -> [N,L,C] x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] return {'pred': x} - def predict(self, word_seq): + def predict(self, words, seq_len=None): """ + :param torch.LongTensor words: [batch_size, seq_len],句子中word的index + :param torch.LongTensor seq_len: [batch,] 每个句子的长度 - :param word_seq: torch.LongTensor, [batch_size, seq_len] - :return predict: dict of torch.LongTensor, [batch_size, seq_len] + :return predict: dict of torch.LongTensor, [batch_size, ] """ - output = self(word_seq) + output = self(words, seq_len) _, predict = output['pred'].max(dim=1) return {'pred': predict} diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index cb615daf..bd04a803 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -8,47 +8,64 @@ from fastNLP.modules.utils import seq_mask class SeqLabeling(BaseModel): """ - PyTorch Network for sequence labeling + 一个基础的Sequence labeling的模型 """ - def __init__(self, args): + def __init__(self, vocab_size, embed_dim, hidden_size, num_classes): + """ + 用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 + + :param int vocab_size: 词表大小。 + :param int embed_dim: embedding的维度 + :param int hidden_size: LSTM隐藏层的大小 + :param int num_classes: 一共有多少类 + """ super(SeqLabeling, self).__init__() - vocab_size = args["vocab_size"] - word_emb_dim = args["word_emb_dim"] - hidden_dim = args["rnn_hidden_units"] - num_classes = args["num_classes"] - - self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) - self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim) - self.Linear = encoder.linear.Linear(hidden_dim, num_classes) + + self.Embedding = encoder.embedding.Embedding(vocab_size, embed_dim) + self.Rnn = encoder.lstm.LSTM(embed_dim, hidden_size) + self.Linear = encoder.linear.Linear(hidden_size, num_classes) self.Crf = decoder.CRF.ConditionalRandomField(num_classes) self.mask = None - def forward(self, word_seq, word_seq_origin_len, truth=None): + def forward(self, words, seq_len, target): """ - :param word_seq: LongTensor, [batch_size, mex_len] - :param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. - :param truth: LongTensor, [batch_size, max_len] + :param torch.LongTensor words: [batch_size, max_len],序列的index + :param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 + :param torch.LongTensor target: [batch_size, max_len], 序列的目标值 :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. If truth is not None, return loss, a scalar. Used in training. """ - assert word_seq.shape[0] == word_seq_origin_len.shape[0] - if truth is not None: - assert truth.shape == word_seq.shape - self.mask = self.make_mask(word_seq, word_seq_origin_len) + assert words.shape[0] == seq_len.shape[0] + assert target.shape == words.shape + self.mask = self._make_mask(words, seq_len) - x = self.Embedding(word_seq) + x = self.Embedding(words) # [batch_size, max_len, word_emb_dim] x = self.Rnn(x) # [batch_size, max_len, hidden_size * direction] x = self.Linear(x) # [batch_size, max_len, num_classes] - return {"loss": self._internal_loss(x, truth) if truth is not None else None, - "predict": self.decode(x)} + return {"loss": self._internal_loss(x, target)} - def loss(self, x, y): - """ Since the loss has been computed in forward(), this function simply returns x.""" - return x + def predict(self, words, seq_len): + """ + 用于在预测时使用 + + :param torch.LongTensor words: [batch_size, max_len] + :param torch.LongTensor seq_len: [batch_size,] + :return: + """ + self.mask = self._make_mask(words, seq_len) + + x = self.Embedding(words) + # [batch_size, max_len, word_emb_dim] + x = self.Rnn(x) + # [batch_size, max_len, hidden_size * direction] + x = self.Linear(x) + # [batch_size, max_len, num_classes] + pred = self._decode(x) + return {'pred': pred} def _internal_loss(self, x, y): """ @@ -65,89 +82,114 @@ class SeqLabeling(BaseModel): total_loss = self.Crf(x, y, self.mask) return torch.mean(total_loss) - def make_mask(self, x, seq_len): + def _make_mask(self, x, seq_len): batch_size, max_len = x.size(0), x.size(1) mask = seq_mask(seq_len, max_len) mask = mask.view(batch_size, max_len) mask = mask.to(x).float() return mask - def decode(self, x, pad=True): + def _decode(self, x): """ - :param x: FloatTensor, [batch_size, max_len, tag_size] - :param pad: pad the output sequence to equal lengths + :param torch.FloatTensor x: [batch_size, max_len, tag_size] :return prediction: list of [decode path(list)] """ - max_len = x.shape[1] - tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) - # pad prediction to equal length - if pad is True: - for pred in tag_seq: - if len(pred) < max_len: - pred += [0] * (max_len - len(pred)) + tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True) return tag_seq -class AdvSeqLabel(SeqLabeling): +class AdvSeqLabel: """ - Advanced Sequence Labeling Model + 更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 """ - def __init__(self, args, emb=None, id2words=None): - super(AdvSeqLabel, self).__init__(args) - - vocab_size = args["vocab_size"] - word_emb_dim = args["word_emb_dim"] - hidden_dim = args["rnn_hidden_units"] - num_classes = args["num_classes"] - dropout = args['dropout'] + def __init__(self, vocab_size, embed_dim, hidden_size, num_classes, dropout=0.3, embedding=None, + id2words=None, encoding_type='bmes'): + """ - self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) - self.norm1 = torch.nn.LayerNorm(word_emb_dim) - # self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) - self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, + :param int vocab_size: 词表的大小 + :param int embed_dim: embedding的维度 + :param int hidden_size: LSTM的隐层大小 + :param int num_classes: 有多少个类 + :param float dropout: LSTM中以及DropOut层的drop概率 + :param numpy.ndarray embedding: 预训练的embedding,需要与指定的词表大小等一致 + :param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S' + 不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证 + 'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。) + :param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 + """ + self.Embedding = encoder.embedding.Embedding(vocab_size, embed_dim, init_emb=embedding) + self.norm1 = torch.nn.LayerNorm(embed_dim) + self.Rnn = torch.nn.LSTM(input_size=embed_dim, hidden_size=hidden_size, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) - self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) - self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) - # self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) + self.Linear1 = encoder.Linear(hidden_size * 2, hidden_size * 2 // 3) + self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3) self.relu = torch.nn.LeakyReLU() self.drop = torch.nn.Dropout(dropout) - self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) + self.Linear2 = encoder.Linear(hidden_size * 2 // 3, num_classes) if id2words is None: self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) else: self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, allowed_transitions=allowed_transitions(id2words, - encoding_type="bmes")) + encoding_type=encoding_type)) + + def _decode(self, x): + """ + :param torch.FloatTensor x: [batch_size, max_len, tag_size] + :return prediction: list of [decode path(list)] + """ + tag_seq, _ = self.Crf.viterbi_decode(x, self.mask, unpad=True) + return tag_seq + + def _internal_loss(self, x, y): + """ + Negative log likelihood loss. + :param x: Tensor, [batch_size, max_len, tag_size] + :param y: Tensor, [batch_size, max_len] + :return loss: a scalar Tensor + + """ + x = x.float() + y = y.long() + assert x.shape[:2] == y.shape + assert y.shape == self.mask.shape + total_loss = self.Crf(x, y, self.mask) + return torch.mean(total_loss) + + def _make_mask(self, x, seq_len): + batch_size, max_len = x.size(0), x.size(1) + mask = seq_mask(seq_len, max_len) + mask = mask.view(batch_size, max_len) + mask = mask.to(x).float() + return mask - def forward(self, word_seq, word_seq_origin_len, truth=None): + def _forward(self, words, seq_len, target=None): """ - :param word_seq: LongTensor, [batch_size, mex_len] - :param word_seq_origin_len: LongTensor, [batch_size, ] - :param truth: LongTensor, [batch_size, max_len] + :param torch.LongTensor words: [batch_size, mex_len] + :param torch.LongTensor seq_len:[batch_size, ] + :param torch.LongTensor target: [batch_size, max_len] :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. If truth is not None, return loss, a scalar. Used in training. """ - word_seq = word_seq.long() - word_seq_origin_len = word_seq_origin_len.long() - self.mask = self.make_mask(word_seq, word_seq_origin_len) - sent_len, idx_sort = torch.sort(word_seq_origin_len, descending=True) + words = words.long() + seq_len = seq_len.long() + self.mask = self._make_mask(words, seq_len) + sent_len, idx_sort = torch.sort(seq_len, descending=True) _, idx_unsort = torch.sort(idx_sort, descending=False) - # word_seq_origin_len = word_seq_origin_len.long() - truth = truth.long() if truth is not None else None + # seq_len = seq_len.long() + target = target.long() if target is not None else None - batch_size = word_seq.size(0) - max_len = word_seq.size(1) if next(self.parameters()).is_cuda: - word_seq = word_seq.cuda() + words = words.cuda() idx_sort = idx_sort.cuda() idx_unsort = idx_unsort.cuda() self.mask = self.mask.cuda() - x = self.Embedding(word_seq) + x = self.Embedding(words) x = self.norm1(x) # [batch_size, max_len, word_emb_dim] @@ -155,71 +197,35 @@ class AdvSeqLabel(SeqLabeling): sent_packed = torch.nn.utils.rnn.pack_padded_sequence(sent_variable, sent_len, batch_first=True) x, _ = self.Rnn(sent_packed) - # print(x) - # [batch_size, max_len, hidden_size * direction] sent_output = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] x = sent_output[idx_unsort] x = x.contiguous() - # x = x.view(batch_size * max_len, -1) x = self.Linear1(x) - # x = self.batch_norm(x) x = self.norm2(x) x = self.relu(x) x = self.drop(x) x = self.Linear2(x) - # x = x.view(batch_size, max_len, -1) - # [batch_size, max_len, num_classes] - # TODO seq_lens的key这样做不合理 - return {"loss": self._internal_loss(x, truth) if truth is not None else None, - "predict": self.decode(x), - 'word_seq_origin_len': word_seq_origin_len} - - def predict(self, **x): - out = self.forward(**x) - return {"predict": out["predict"]} - - def loss(self, **kwargs): - assert 'loss' in kwargs - return kwargs['loss'] - - -if __name__ == '__main__': - args = { - 'vocab_size': 20, - 'word_emb_dim': 100, - 'rnn_hidden_units': 100, - 'num_classes': 10, - } - model = AdvSeqLabel(args) - data = [] - for i in range(20): - word_seq = torch.randint(20, (15,)).long() - word_seq_len = torch.LongTensor([15]) - truth = torch.randint(10, (15,)).long() - data.append((word_seq, word_seq_len, truth)) - optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - print(model) - curidx = 0 - for i in range(1000): - endidx = min(len(data), curidx + 5) - b_word, b_len, b_truth = [], [], [] - for word_seq, word_seq_len, truth in data[curidx: endidx]: - b_word.append(word_seq) - b_len.append(word_seq_len) - b_truth.append(truth) - word_seq = torch.stack(b_word, dim=0) - word_seq_len = torch.cat(b_len, dim=0) - truth = torch.stack(b_truth, dim=0) - res = model(word_seq, word_seq_len, truth) - loss = res['loss'] - pred = res['predict'] - print('loss: {} acc {}'.format(loss.item(), - ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) - optimizer.zero_grad() - loss.backward() - optimizer.step() - curidx = endidx - if curidx == len(data): - curidx = 0 + if target is not None: + return {"loss": self._internal_loss(x, target)} + else: + return {"pred": self._decode(x)} + + def forward(self, words, seq_len, target): + """ + :param torch.LongTensor words: [batch_size, mex_len] + :param torch.LongTensor seq_len:[batch_size, ] + :param torch.LongTensor target: [batch_size, max_len], 目标 + :return torch.Tensor, a scalar loss + """ + return self._forward(words, seq_len, target) + + def predict(self, words, seq_len): + """ + + :param torch.LongTensor words: [batch_size, mex_len] + :param torch.LongTensor seq_len:[batch_size, ] + :return: [list1, list2, ...], 内部每个list为一个路径,已经unpad了。 + """ + return self._forward(words, seq_len, ) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 99e7a9c2..cc713bc6 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -19,10 +19,10 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): """ 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 - :param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 + :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 - :param encoding_type: str, 支持"bio", "bmes", "bmeso"。 - :param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; + :param str encoding_type: 支持"bio", "bmes", "bmeso"。 + :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); start_idx=len(id2label), end_idx=len(id2label)+1。 为False, 返回的结果中不含与开始结尾相关的内容 @@ -62,11 +62,11 @@ def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): """ - :param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 - :param from_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag - :param from_label: str, 比如"PER", "LOC"等label - :param to_tag: str, 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag - :param to_label: str, 比如"PER", "LOC"等label + :param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 + :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag + :param str from_label: 比如"PER", "LOC"等label + :param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag + :param str to_label: 比如"PER", "LOC"等label :return: bool,能否跃迁 """ if to_tag=='start' or from_tag=='end': @@ -149,12 +149,12 @@ class ConditionalRandomField(nn.Module): """条件随机场。 提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 - :param num_tags: int, 标签的数量 - :param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。 - :param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int), + :param int num_tags: 标签的数量 + :param bool include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 + :param List[Tuple[from_tag_id(int), to_tag_id(int)]] allowed_transitions: 内部的Tuple[from_tag_id(int), to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 - :param initial_method: str, 初始化方法。见initial_parameter + :param str initial_method: 初始化方法。见initial_parameter """ super(ConditionalRandomField, self).__init__() @@ -237,10 +237,10 @@ class ConditionalRandomField(nn.Module): """ 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 - :param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。 - :param tags:LongTensor, batch_size x max_len,标签矩阵。 - :param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。 - :return:FloatTensor, batch_size + :param torch.FloatTensor feats:batch_size x max_len x num_tags,特征矩阵。 + :param torch.LongTensor tags: batch_size x max_len,标签矩阵。 + :param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 + :return:torch.FloatTensor, (batch_size,) """ feats = feats.transpose(0, 1) tags = tags.transpose(0, 1).long() @@ -250,27 +250,26 @@ class ConditionalRandomField(nn.Module): return all_path_score - gold_path_score - def viterbi_decode(self, feats, mask, unpad=False): + def viterbi_decode(self, logits, mask, unpad=False): """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 - :param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 - :param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 - :param unpad: bool, 是否将结果删去padding, - False, 返回的是batch_size x max_len的tensor, - True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int] - 的长度是这个sample的有效长度。 + :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 + :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 + :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 + List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 + 个sample的有效长度。 :return: 返回 (paths, scores)。 paths: 是解码后的路径, 其值参照unpad参数. scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 """ - batch_size, seq_len, n_tags = feats.size() - feats = feats.transpose(0, 1).data # L, B, H + batch_size, seq_len, n_tags = logits.size() + logits = logits.transpose(0, 1).data # L, B, H mask = mask.transpose(0, 1).data.byte() # L, B # dp - vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) - vscore = feats[0] + vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) + vscore = logits[0] transitions = self._constrain.data.clone() transitions[:n_tags, :n_tags] += self.trans_m.data if self.include_start_end_trans: @@ -281,7 +280,7 @@ class ConditionalRandomField(nn.Module): trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data for i in range(1, seq_len): prev_score = vscore.view(batch_size, n_tags, 1) - cur_score = feats[i].view(batch_size, 1, n_tags) + cur_score = logits[i].view(batch_size, 1, n_tags) score = prev_score + trans_score + cur_score best_score, best_dst = score.max(1) vpath[i] = best_dst @@ -292,13 +291,13 @@ class ConditionalRandomField(nn.Module): vscore += transitions[:n_tags, n_tags+1].view(1, -1) # backtrace - batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) - seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) + batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) + seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) lens = (mask.long().sum(0) - 1) # idxes [L, B], batched idx from seq_len-1 to 0 idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len - ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) + ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) ans_score, last_tags = vscore.max(1) ans[idxes[0], batch_idx] = last_tags for i in range(seq_len - 1): @@ -311,6 +310,5 @@ class ConditionalRandomField(nn.Module): paths.append(ans[idx, :seq_len+1].tolist()) else: paths = ans - if get_score: - return paths, ans_score.tolist() - return paths + return paths, ans_score + diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py index dfaac622..67db08f7 100644 --- a/fastNLP/modules/decoder/utils.py +++ b/fastNLP/modules/decoder/utils.py @@ -11,13 +11,12 @@ def log_sum_exp(x, dim=-1): def viterbi_decode(logits, transitions, mask=None, unpad=False): """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 - :param logits: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 - :param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 - :param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 - :param unpad: bool, 是否将结果删去padding, - False, 返回的是batch_size x max_len的tensor, - True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是 - 这个sample的有效长度。 + :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 + :param torch.FloatTensor transitions: n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 + :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 + :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 + List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 + 个sample的有效长度。 :return: 返回 (paths, scores)。 paths: 是解码后的路径, 其值参照unpad参数. scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index 56b9ca59..06d8b86a 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -1,4 +1,3 @@ -from .conv import Conv from .conv_maxpool import ConvMaxpool from .embedding import Embedding from .linear import Linear @@ -8,6 +7,5 @@ from .bert import BertModel __all__ = ["LSTM", "Embedding", "Linear", - "Conv", "ConvMaxpool", "BertModel"] diff --git a/fastNLP/modules/encoder/conv.py b/fastNLP/modules/encoder/conv.py deleted file mode 100644 index 42254a8b..00000000 --- a/fastNLP/modules/encoder/conv.py +++ /dev/null @@ -1,58 +0,0 @@ -# python: 3.6 -# encoding: utf-8 - -import torch -import torch.nn as nn - -from fastNLP.modules.utils import initial_parameter - - -# import torch.nn.functional as F - - -class Conv(nn.Module): - """Basic 1-d convolution module, initialized with xavier_uniform. - - :param int in_channels: - :param int out_channels: - :param tuple kernel_size: - :param int stride: - :param int padding: - :param int dilation: - :param int groups: - :param bool bias: - :param str activation: - :param str initial_method: - """ - def __init__(self, in_channels, out_channels, kernel_size, - stride=1, padding=0, dilation=1, - groups=1, bias=True, activation='relu', initial_method=None): - super(Conv, self).__init__() - self.conv = nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias) - # xavier_uniform_(self.conv.weight) - - activations = { - 'relu': nn.ReLU(), - 'tanh': nn.Tanh()} - if activation in activations: - self.activation = activations[activation] - else: - raise Exception( - 'Should choose activation function from: ' + - ', '.join([x for x in activations])) - initial_parameter(self, initial_method) - - def forward(self, x): - x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] - x = self.conv(x) # [N,C_in,L] -> [N,C_out,L] - x = self.activation(x) - x = torch.transpose(x, 1, 2) # [N,C,L] -> [N,L,C] - return x diff --git a/fastNLP/modules/encoder/conv_maxpool.py b/fastNLP/modules/encoder/conv_maxpool.py index 8b035871..d7a8b286 100644 --- a/fastNLP/modules/encoder/conv_maxpool.py +++ b/fastNLP/modules/encoder/conv_maxpool.py @@ -9,18 +9,21 @@ from fastNLP.modules.utils import initial_parameter class ConvMaxpool(nn.Module): - """Convolution and max-pooling module with multiple kernel sizes. + """集合了Convolution和Max-Pooling于一体的层。 + 给定一个batch_size x max_len x input_size的输入,返回batch_size x sum(output_channels) 大小的matrix。在内部,是先使用 + CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len)这一维进行max_pooling。最后得到每个sample的一个vector + 表示。 - :param int in_channels: - :param int out_channels: - :param tuple kernel_sizes: - :param int stride: - :param int padding: - :param int dilation: - :param int groups: - :param bool bias: - :param str activation: - :param str initial_method: + :param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 + :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 + :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 + :param int stride: 见pytorch Conv1D文档。所有kernel共享一个stride。 + :param int padding: 见pytorch Conv1D文档。所有kernel共享一个padding。 + :param int dilation: 见pytorch Conv1D文档。所有kernel共享一个dilation。 + :param int groups: 见pytorch Conv1D文档。所有kernel共享一个groups。 + :param bool bias: 见pytorch Conv1D文档。所有kernel共享一个bias。 + :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh + :param str initial_method: str。 """ def __init__(self, in_channels, out_channels, kernel_sizes, stride=1, padding=0, dilation=1, @@ -29,9 +32,14 @@ class ConvMaxpool(nn.Module): # convolution if isinstance(kernel_sizes, (list, tuple, int)): - if isinstance(kernel_sizes, int): + if isinstance(kernel_sizes, int) and isinstance(out_channels, int): out_channels = [out_channels] kernel_sizes = [kernel_sizes] + elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): + assert len(out_channels)==len(kernel_sizes), "The number of out_channels should be equal to the number" \ + " of kernel_sizes." + else: + raise ValueError("The type of out_channels and kernel_sizes should be the same.") self.convs = nn.ModuleList([nn.Conv1d( in_channels=in_channels, @@ -51,18 +59,31 @@ class ConvMaxpool(nn.Module): # activation function if activation == 'relu': self.activation = F.relu + elif activation == 'sigmoid': + self.activation = F.sigmoid + elif activation == 'tanh': + self.activation = F.tanh else: raise Exception( - "Undefined activation function: choose from: relu") + "Undefined activation function: choose from: relu, tanh, sigmoid") initial_parameter(self, initial_method) - def forward(self, x): + def forward(self, x, mask=None): + """ + + :param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 + :param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 + :return: + """ # [N,L,C] -> [N,C,L] x = torch.transpose(x, 1, 2) # convolution - xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L]] + xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] + if mask is not None: + mask = mask.unsqueeze(1) # B x 1 x L + xs = [x.masked_fill_(mask, float('-inf')) for x in xs] # max-pooling xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) - for i in xs] # [[N, C]] - return torch.cat(xs, dim=-1) # [N,C] + for i in xs] # [[N, C], ...] + return torch.cat(xs, dim=-1) # [N, C] \ No newline at end of file diff --git a/fastNLP/modules/encoder/masked_rnn.py b/fastNLP/modules/encoder/masked_rnn.py deleted file mode 100644 index 321546c4..00000000 --- a/fastNLP/modules/encoder/masked_rnn.py +++ /dev/null @@ -1,424 +0,0 @@ -__author__ = 'max' - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from fastNLP.modules.utils import initial_parameter - - -def MaskedRecurrent(reverse=False): - def forward(input, hidden, cell, mask, train=True, dropout=0): - """ - :param input: - :param hidden: - :param cell: - :param mask: - :param dropout: step之间的dropout,对mask了的也会drop,应该是没问题的,反正没有gradient - :param train: 控制dropout的行为,在StackedRNN的forward中调用 - :return: - """ - output = [] - steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) - for i in steps: - if mask is None or mask[i].data.min() > 0.5: # 没有mask,都是1 - hidden = cell(input[i], hidden) - elif mask[i].data.max() > 0.5: # 有mask,但不全为0 - hidden_next = cell(input[i], hidden) # 一次喂入一个batch! - # hack to handle LSTM - if isinstance(hidden, tuple): # LSTM outputs a tuple of (hidden, cell), this is a common hack 😁 - mask = mask.float() - hx, cx = hidden - hp1, cp1 = hidden_next - hidden = ( - hx + (hp1 - hx) * mask[i].squeeze(), - cx + (cp1 - cx) * mask[i].squeeze()) # Why? 我知道了!!如果是mask就不用改变 - else: - hidden = hidden + (hidden_next - hidden) * mask[i] - - # if dropout != 0 and train: # warning, should i treat masked tensor differently? - # if isinstance(hidden, tuple): - # hidden = (F.dropout(hidden[0], p=dropout, training=train), - # F.dropout(hidden[1], p=dropout, training=train)) - # else: - # hidden = F.dropout(hidden, p=dropout, training=train) - - # hack to handle LSTM - output.append(hidden[0] if isinstance(hidden, tuple) else hidden) - - if reverse: - output.reverse() - output = torch.cat(output, 0).view(input.size(0), *output[0].size()) - - return hidden, output - - return forward - - -def StackedRNN(inners, num_layers, lstm=False, train=True, step_dropout=0, layer_dropout=0): - num_directions = len(inners) # rec_factory! - total_layers = num_layers * num_directions - - def forward(input, hidden, cells, mask): - assert (len(cells) == total_layers) - next_hidden = [] - - if lstm: - hidden = list(zip(*hidden)) - - for i in range(num_layers): - all_output = [] - for j, inner in enumerate(inners): - l = i * num_directions + j - hy, output = inner(input, hidden[l], cells[l], mask, step_dropout, train) - next_hidden.append(hy) - all_output.append(output) - - input = torch.cat(all_output, input.dim() - 1) # 下一层的输入 - - if layer_dropout != 0 and i < num_layers - 1: - input = F.dropout(input, p=layer_dropout, training=train, inplace=False) - - if lstm: - next_h, next_c = zip(*next_hidden) - next_hidden = ( - torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), - torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) - ) - else: - next_hidden = torch.cat(next_hidden, 0).view(total_layers, *next_hidden[0].size()) - - return next_hidden, input - - return forward - - -def AutogradMaskedRNN(num_layers=1, batch_first=False, train=True, layer_dropout=0, step_dropout=0, - bidirectional=False, lstm=False): - rec_factory = MaskedRecurrent - - if bidirectional: - layer = (rec_factory(), rec_factory(reverse=True)) - else: - layer = (rec_factory(),) # rec_factory 就是每层的结构啦!!在MaskedRecurrent中进行每层的计算!然后用StackedRNN接起来 - - func = StackedRNN(layer, - num_layers, - lstm=lstm, - layer_dropout=layer_dropout, step_dropout=step_dropout, - train=train) - - def forward(input, cells, hidden, mask): - if batch_first: - input = input.transpose(0, 1) - if mask is not None: - mask = mask.transpose(0, 1) - - nexth, output = func(input, hidden, cells, mask) - - if batch_first: - output = output.transpose(0, 1) - - return output, nexth - - return forward - - -def MaskedStep(): - def forward(input, hidden, cell, mask): - if mask is None or mask.data.min() > 0.5: - hidden = cell(input, hidden) - elif mask.data.max() > 0.5: - hidden_next = cell(input, hidden) - # hack to handle LSTM - if isinstance(hidden, tuple): - hx, cx = hidden - hp1, cp1 = hidden_next - hidden = (hx + (hp1 - hx) * mask, cx + (cp1 - cx) * mask) - else: - hidden = hidden + (hidden_next - hidden) * mask - # hack to handle LSTM - output = hidden[0] if isinstance(hidden, tuple) else hidden - - return hidden, output - - return forward - - -def StackedStep(layer, num_layers, lstm=False, dropout=0, train=True): - def forward(input, hidden, cells, mask): - assert (len(cells) == num_layers) - next_hidden = [] - - if lstm: - hidden = list(zip(*hidden)) - - for l in range(num_layers): - hy, output = layer(input, hidden[l], cells[l], mask) - next_hidden.append(hy) - input = output - - if dropout != 0 and l < num_layers - 1: - input = F.dropout(input, p=dropout, training=train, inplace=False) - - if lstm: - next_h, next_c = zip(*next_hidden) - next_hidden = ( - torch.cat(next_h, 0).view(num_layers, *next_h[0].size()), - torch.cat(next_c, 0).view(num_layers, *next_c[0].size()) - ) - else: - next_hidden = torch.cat(next_hidden, 0).view(num_layers, *next_hidden[0].size()) - - return next_hidden, input - - return forward - - -def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False): - layer = MaskedStep() - - func = StackedStep(layer, - num_layers, - lstm=lstm, - dropout=dropout, - train=train) - - def forward(input, cells, hidden, mask): - nexth, output = func(input, hidden, cells, mask) - return output, nexth - - return forward - - -class MaskedRNNBase(nn.Module): - def __init__(self, Cell, input_size, hidden_size, - num_layers=1, bias=True, batch_first=False, - layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs): - """ - :param Cell: - :param input_size: - :param hidden_size: - :param num_layers: - :param bias: - :param batch_first: - :param layer_dropout: - :param step_dropout: - :param bidirectional: - :param kwargs: - """ - - super(MaskedRNNBase, self).__init__() - self.Cell = Cell - self.input_size = input_size - self.hidden_size = hidden_size - self.num_layers = num_layers - self.bias = bias - self.batch_first = batch_first - self.layer_dropout = layer_dropout - self.step_dropout = step_dropout - self.bidirectional = bidirectional - num_directions = 2 if bidirectional else 1 - - self.all_cells = [] - for layer in range(num_layers): # 初始化所有cell - for direction in range(num_directions): - layer_input_size = input_size if layer == 0 else hidden_size * num_directions - - cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) - self.all_cells.append(cell) - self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 - initial_parameter(self, initial_method) - def reset_parameters(self): - for cell in self.all_cells: - cell.reset_parameters() - - def forward(self, input, mask=None, hx=None): - batch_size = input.size(0) if self.batch_first else input.size(1) - lstm = self.Cell is nn.LSTMCell - if hx is None: - num_directions = 2 if self.bidirectional else 1 - hx = torch.autograd.Variable( - input.data.new(self.num_layers * num_directions, batch_size, self.hidden_size).zero_()) - if lstm: - hx = (hx, hx) - - func = AutogradMaskedRNN(num_layers=self.num_layers, - batch_first=self.batch_first, - step_dropout=self.step_dropout, - layer_dropout=self.layer_dropout, - train=self.training, - bidirectional=self.bidirectional, - lstm=lstm) # 传入all_cells,继续往底层封装走 - - output, hidden = func(input, self.all_cells, hx, - None if mask is None else mask.view(mask.size() + (1,))) # 这个+ (1, )是个什么操作? - return output, hidden - - def step(self, input, hx=None, mask=None): - """Execute one step forward (only for one-directional RNN). - - :param Tensor input: input tensor of this step. (batch, input_size) - :param Tensor hx: the hidden state of last step. (num_layers, batch, hidden_size) - :param Tensor mask: the mask tensor of this step. (batch, ) - :returns: - **output** (batch, hidden_size), tensor containing the output of this step from the last layer of RNN. - **hn** (num_layers, batch, hidden_size), tensor containing the hidden state of this step - - """ - assert not self.bidirectional, "step only cannot be applied to bidirectional RNN." # aha, typo! - batch_size = input.size(0) - lstm = self.Cell is nn.LSTMCell - if hx is None: - hx = torch.autograd.Variable(input.data.new(self.num_layers, batch_size, self.hidden_size).zero_()) - if lstm: - hx = (hx, hx) - - func = AutogradMaskedStep(num_layers=self.num_layers, - dropout=self.step_dropout, - train=self.training, - lstm=lstm) - - output, hidden = func(input, self.all_cells, hx, mask) - return output, hidden - - -class MaskedRNN(MaskedRNNBase): - r"""Applies a multi-layer Elman RNN with costomized non-linearity to an - input sequence. - For each element in the input sequence, each layer computes the following - function. :math:`h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})` - - where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is - the hidden state of the previous layer at time `t` or :math:`input_t` - for the first layer. If nonlinearity='relu', then `ReLU` is used instead - of `tanh`. - - - :param int input_size: The number of expected features in the input x - :param int hidden_size: The number of features in the hidden state h - :param int num_layers: Number of recurrent layers. - :param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh' - :param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True - :param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) - :param float dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer - :param bool bidirectional: If True, becomes a bidirectional RNN. Default: False - - Inputs: input, mask, h_0 - - **input** (seq_len, batch, input_size): tensor containing the features - of the input sequence. - **mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. - - **h_0** (num_layers * num_directions, batch, hidden_size): tensor - containing the initial hidden state for each element in the batch. - Outputs: output, h_n - - **output** (seq_len, batch, hidden_size * num_directions): tensor - containing the output features (h_k) from the last layer of the RNN, - for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has - been given as the input, the output will also be a packed sequence. - - **h_n** (num_layers * num_directions, batch, hidden_size): tensor - containing the hidden state for k=seq_len. - """ - - def __init__(self, *args, **kwargs): - super(MaskedRNN, self).__init__(nn.RNNCell, *args, **kwargs) - - -class MaskedLSTM(MaskedRNNBase): - r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input - sequence. - For each element in the input sequence, each layer computes the following - function. - - .. math:: - - \begin{array}{ll} - i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ - f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ - g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ - o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ - c_t = f_t * c_{(t-1)} + i_t * g_t \\ - h_t = o_t * \tanh(c_t) - \end{array} - - where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell - state at time `t`, :math:`x_t` is the hidden state of the previous layer at - time `t` or :math:`input_t` for the first layer, and :math:`i_t`, - :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, - and out gates, respectively. - - :param int input_size: The number of expected features in the input x - :param int hidden_size: The number of features in the hidden state h - :param int num_layers: Number of recurrent layers. - :param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True - :param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) - :param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer - :param bool bidirectional: If True, becomes a bidirectional RNN. Default: False - - Inputs: input, mask, (h_0, c_0) - - **input** (seq_len, batch, input_size): tensor containing the features - of the input sequence. - **mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. - - **h_0** (num_layers \* num_directions, batch, hidden_size): tensor - containing the initial hidden state for each element in the batch. - - **c_0** (num_layers \* num_directions, batch, hidden_size): tensor - containing the initial cell state for each element in the batch. - Outputs: output, (h_n, c_n) - - **output** (seq_len, batch, hidden_size * num_directions): tensor - containing the output features `(h_t)` from the last layer of the RNN, - for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been - given as the input, the output will also be a packed sequence. - - **h_n** (num_layers * num_directions, batch, hidden_size): tensor - containing the hidden state for t=seq_len - - **c_n** (num_layers * num_directions, batch, hidden_size): tensor - containing the cell state for t=seq_len - """ - - def __init__(self, *args, **kwargs): - super(MaskedLSTM, self).__init__(nn.LSTMCell, *args, **kwargs) - - -class MaskedGRU(MaskedRNNBase): - r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. - For each element in the input sequence, each layer computes the following - function: - - .. math:: - - \begin{array}{ll} - r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ - z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ - n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ - h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ - \end{array} - - where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden - state of the previous layer at time `t` or :math:`input_t` for the first - layer, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, input, - and new gates, respectively. - - :param int input_size: The number of expected features in the input x - :param int hidden_size: The number of features in the hidden state h - :param int num_layers: Number of recurrent layers. - :param str nonlinearity: The non-linearity to use ['tanh'|'relu']. Default: 'tanh' - :param bool bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True - :param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature) - :param bool dropout: If non-zero, introduces a dropout layer on the outputs of each RNN layer except the last layer - :param bool bidirectional: If True, becomes a bidirectional RNN. Default: False - - Inputs: input, mask, h_0 - - **input** (seq_len, batch, input_size): tensor containing the features - of the input sequence. - **mask** (seq_len, batch): 0-1 tensor containing the mask of the input sequence. - - **h_0** (num_layers * num_directions, batch, hidden_size): tensor - containing the initial hidden state for each element in the batch. - Outputs: output, h_n - - **output** (seq_len, batch, hidden_size * num_directions): tensor - containing the output features (h_k) from the last layer of the RNN, - for each k. If a :class:`torch.nn.utils.rnn.PackedSequence` has - been given as the input, the output will also be a packed sequence. - - **h_n** (num_layers * num_directions, batch, hidden_size): tensor - containing the hidden state for k=seq_len. - """ - - def __init__(self, *args, **kwargs): - super(MaskedGRU, self).__init__(nn.GRUCell, *args, **kwargs) diff --git a/test/modules/test_masked_rnn.py b/test/modules/test_masked_rnn.py deleted file mode 100644 index 80f49f33..00000000 --- a/test/modules/test_masked_rnn.py +++ /dev/null @@ -1,27 +0,0 @@ - -import torch -import unittest - -from fastNLP.modules.encoder.masked_rnn import MaskedRNN - -class TestMaskedRnn(unittest.TestCase): - def test_case_1(self): - masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) - x = torch.tensor([[[1.0], [2.0]]]) - print(x.size()) - y = masked_rnn(x) - mask = torch.tensor([[[1], [1]]]) - y = masked_rnn(x, mask=mask) - mask = torch.tensor([[[1], [0]]]) - y = masked_rnn(x, mask=mask) - - def test_case_2(self): - masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=False, batch_first=True) - x = torch.tensor([[[1.0], [2.0]]]) - print(x.size()) - y = masked_rnn(x) - mask = torch.tensor([[[1], [1]]]) - y = masked_rnn(x, mask=mask) - xx = torch.tensor([[[1.0]]]) - y = masked_rnn.step(xx) - y = masked_rnn.step(xx, mask=mask) \ No newline at end of file diff --git a/test/test_tutorials.py b/test/test_tutorials.py index eb77321c..a1d47dde 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -70,7 +70,7 @@ class TestTutorial(unittest.TestCase): break from fastNLP.models import CNNText - model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) + model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) from fastNLP import Trainer from copy import deepcopy @@ -145,7 +145,7 @@ class TestTutorial(unittest.TestCase): is_input=True) from fastNLP.models import CNNText - model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) + model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric trainer = Trainer(model=model, @@ -405,7 +405,7 @@ class TestTutorial(unittest.TestCase): # 另一个例子:加载CNN文本分类模型 from fastNLP.models import CNNText - cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) + cnn_text_model = CNNText(vocab_size=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) cnn_text_model from fastNLP import CrossEntropyLoss