Browse Source

update documents on embedding.py

tags/v0.4.10
xuyige 5 years ago
parent
commit
e76dca9ad7
1 changed files with 15 additions and 10 deletions
  1. +15
    -10
      fastNLP/modules/encoder/embedding.py

+ 15
- 10
fastNLP/modules/encoder/embedding.py View File

@@ -2,20 +2,25 @@ import torch.nn as nn


class Embedding(nn.Module):
"""A simple lookup table.
"""Embedding组件."""

:param int nums: the size of the lookup table
:param int dims: the size of each vector
:param int padding_idx: pads the tensor with zeros whenever it encounters this index
:param bool sparse: If True, gradient matrix will be a sparse tensor. In this case, only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
def __init__(self, vocab_size, embed_dim, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
"""
:param int vocab_size: 词表大小.
:param int embed_dim: embedding维度.
:param int padding_idx: 如果碰到padding_idx则自动补0.
:param bool sparse: 如果为`True`则权重矩阵是一个sparse的矩阵.
:param torch.Tensor init_emb: 初始的embedding矩阵.
:param float dropout: dropout概率.
"""
super(Embedding, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)
if init_emb is not None:
self.embed.weight = nn.Parameter(init_emb)
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx, sparse=sparse, _weight=init_emb)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
"""
:param torch.LongTensor x: [batch, seq_len]
:return: torch.Tensor : [batch, seq_len, embed_dim]
"""
x = self.embed(x)
return self.dropout(x)

Loading…
Cancel
Save