Browse Source

conflict merge

tags/v0.4.10
yh_cc 5 years ago
parent
commit
5d0877583e
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      fastNLP/modules/encoder/embedding.py

+ 4
- 3
fastNLP/modules/encoder/embedding.py View File

@@ -521,7 +521,7 @@ class BertEmbedding(ContextualEmbedding):


:param fastNLP.Vocabulary vocab: 词表
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased``
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``.
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。
@@ -572,7 +572,7 @@ class BertEmbedding(ContextualEmbedding):
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
删除这两个token的表示。

:param words: batch_size x max_len
:param torch.LongTensor words: [batch_size, max_len]
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
outputs = self._get_sent_reprs(words)
@@ -891,7 +891,8 @@ class StackEmbedding(TokenEmbedding):

Example::

>>>
>>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
>>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)


:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致


Loading…
Cancel
Save