|
|
@@ -15,7 +15,7 @@ class Embedding(nn.Module): |
|
|
|
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0): |
|
|
|
super(Embedding, self).__init__() |
|
|
|
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) |
|
|
|
if init_emb: |
|
|
|
if init_emb is not None: |
|
|
|
self.embed.weight = nn.Parameter(init_emb) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|