diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 7bcffb8e..c93fa1a3 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -2,20 +2,25 @@ import torch.nn as nn class Embedding(nn.Module): - """A simple lookup table. + """Embedding组件.""" - :param int nums: the size of the lookup table - :param int dims: the size of each vector - :param int padding_idx: pads the tensor with zeros whenever it encounters this index - :param bool sparse: If True, gradient matrix will be a sparse tensor. In this case, only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used - """ - def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0): + def __init__(self, vocab_size, embed_dim, padding_idx=0, sparse=False, init_emb=None, dropout=0.0): + """ + :param int vocab_size: 词表大小. + :param int embed_dim: embedding维度. + :param int padding_idx: 如果碰到padding_idx则自动补0. + :param bool sparse: 如果为`True`则权重矩阵是一个sparse的矩阵. + :param torch.Tensor init_emb: 初始的embedding矩阵. + :param float dropout: dropout概率. + """ super(Embedding, self).__init__() - self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) - if init_emb is not None: - self.embed.weight = nn.Parameter(init_emb) + self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx, sparse=sparse, _weight=init_emb) self.dropout = nn.Dropout(dropout) def forward(self, x): + """ + :param torch.LongTensor x: [batch, seq_len] + :return: torch.Tensor : [batch, seq_len, embed_dim] + """ x = self.embed(x) return self.dropout(x)