|
@@ -41,9 +41,9 @@ class Embedding(nn.Module): |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
if not isinstance(self.embed, TokenEmbedding): |
|
|
if not isinstance(self.embed, TokenEmbedding): |
|
|
if hasattr(self, 'embed_size'): |
|
|
|
|
|
|
|
|
if hasattr(self.embed, 'embed_size'): |
|
|
self._embed_size = self.embed.embed_size |
|
|
self._embed_size = self.embed.embed_size |
|
|
elif hasattr(self, 'embedding_dim'): |
|
|
|
|
|
|
|
|
elif hasattr(self.embed, 'embedding_dim'): |
|
|
self._embed_size = self.embed.embedding_dim |
|
|
self._embed_size = self.embed.embedding_dim |
|
|
else: |
|
|
else: |
|
|
self._embed_size = self.embed.weight.size(1) |
|
|
self._embed_size = self.embed.weight.size(1) |
|
|