From 3b3f550cc5a691d7cd81beea26b3ceeb90475b20 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 23 May 2019 14:56:05 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=E9=83=A8=E5=88=86?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/embedding.py | 7 +++++++ fastNLP/modules/utils.py | 6 +++--- setup.py | 3 ++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index f3c1f475..c2dfab65 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -41,3 +41,10 @@ class Embedding(nn.Embedding): """ x = super().forward(x) return self.dropout(x) + + def size(self): + """ + Embedding的大小 + :return: torch.Size() + """ + return self.weight.size() diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index c9a1f682..741429bb 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -74,9 +74,9 @@ def get_embeddings(init_embed): """ 根据输入的init_embed生成nn.Embedding对象。 - :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 - embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, - 此时就以传入的对象作为embedding + :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 + nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始 + 化; 传入orch.Tensor, 将使用传入的值作为Embedding初始化。 :return nn.Embedding embeddings: """ if isinstance(init_embed, tuple): diff --git a/setup.py b/setup.py index b7834d8d..49646761 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,8 @@ setup( version='0.4.0', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', long_description=readme, - license=license, + long_description_content_type='text/markdown', + license='Apache License', author='FudanNLP', python_requires='>=3.6', packages=find_packages(),