diff --git a/fastNLP/embeddings/embedding.py b/fastNLP/embeddings/embedding.py index a9f228fb..9447c6ad 100644 --- a/fastNLP/embeddings/embedding.py +++ b/fastNLP/embeddings/embedding.py @@ -41,9 +41,9 @@ class Embedding(nn.Module): self.dropout = nn.Dropout(dropout) if not isinstance(self.embed, TokenEmbedding): - if hasattr(self, 'embed_size'): + if hasattr(self.embed, 'embed_size'): self._embed_size = self.embed.embed_size - elif hasattr(self, 'embedding_dim'): + elif hasattr(self.embed, 'embedding_dim'): self._embed_size = self.embed.embedding_dim else: self._embed_size = self.embed.weight.size(1)