From 89ad1954113b485638c8a784d8bd2bf0e0603e78 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 26 Oct 2020 14:34:33 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=5F=5Fall=5F=5F=E5=B1=9E?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/collate_fn.py | 2 ++ fastNLP/embeddings/transformers_embedding.py | 6 +++++- fastNLP/models/seq2seq_generator.py | 3 +++ fastNLP/models/seq2seq_model.py | 3 +++ fastNLP/modules/decoder/seq2seq_decoder.py | 3 +++ fastNLP/modules/encoder/seq2seq_encoder.py | 3 +++ 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/collate_fn.py b/fastNLP/core/collate_fn.py index 7d7f9726..403af270 100644 --- a/fastNLP/core/collate_fn.py +++ b/fastNLP/core/collate_fn.py @@ -7,6 +7,8 @@ from .field import _get_ele_type_and_dim from .utils import logger from copy import deepcopy +__all__ = ['ConcatCollateFn'] + def _check_type(batch_dict, fields): if len(fields) == 0: diff --git a/fastNLP/embeddings/transformers_embedding.py b/fastNLP/embeddings/transformers_embedding.py index 4d84fefd..46282af7 100644 --- a/fastNLP/embeddings/transformers_embedding.py +++ b/fastNLP/embeddings/transformers_embedding.py @@ -15,6 +15,9 @@ from ..core import logger from ..core.vocabulary import Vocabulary +__all__ = ['TransformersEmbedding', 'TransformersWordPieceEncoder'] + + class TransformersEmbedding(ContextualEmbedding): r""" 使用transformers中的模型对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 @@ -280,7 +283,7 @@ class _TransformersWordModel(nn.Module): word = tokenizer.unk_token elif vocab.word_count[word]