diff --git a/fastNLP/core/collate_fn.py b/fastNLP/core/collate_fn.py index 7d7f9726..403af270 100644 --- a/fastNLP/core/collate_fn.py +++ b/fastNLP/core/collate_fn.py @@ -7,6 +7,8 @@ from .field import _get_ele_type_and_dim from .utils import logger from copy import deepcopy +__all__ = ['ConcatCollateFn'] + def _check_type(batch_dict, fields): if len(fields) == 0: diff --git a/fastNLP/embeddings/transformers_embedding.py b/fastNLP/embeddings/transformers_embedding.py index 4d84fefd..46282af7 100644 --- a/fastNLP/embeddings/transformers_embedding.py +++ b/fastNLP/embeddings/transformers_embedding.py @@ -15,6 +15,9 @@ from ..core import logger from ..core.vocabulary import Vocabulary +__all__ = ['TransformersEmbedding', 'TransformersWordPieceEncoder'] + + class TransformersEmbedding(ContextualEmbedding): r""" 使用transformers中的模型对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 @@ -280,7 +283,7 @@ class _TransformersWordModel(nn.Module): word = tokenizer.unk_token elif vocab.word_count[word]