|
|
@@ -15,6 +15,9 @@ from ..core import logger |
|
|
|
from ..core.vocabulary import Vocabulary |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['TransformersEmbedding', 'TransformersWordPieceEncoder'] |
|
|
|
|
|
|
|
|
|
|
|
class TransformersEmbedding(ContextualEmbedding): |
|
|
|
r""" |
|
|
|
使用transformers中的模型对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 |
|
|
@@ -280,7 +283,7 @@ class _TransformersWordModel(nn.Module): |
|
|
|
word = tokenizer.unk_token |
|
|
|
elif vocab.word_count[word]<min_freq: |
|
|
|
word = tokenizer.unk_token |
|
|
|
word_pieces = self.tokenizer.tokenize(word) |
|
|
|
word_pieces = self.tokenizer.tokenize(word, add_prefix_space=True) |
|
|
|
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) |
|
|
|
word_to_wordpieces.append(word_pieces) |
|
|
|
word_pieces_lengths.append(len(word_pieces)) |
|
|
@@ -453,6 +456,7 @@ class _WordPieceTransformersModel(nn.Module): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
kwargs['add_special_tokens'] = kwargs.get('add_special_tokens', True) |
|
|
|
kwargs['add_prefix_space'] = kwargs.get('add_special_tokens', True) |
|
|
|
|
|
|
|
encode_func = partial(self.tokenizer.encode, **kwargs) |
|
|
|
|
|
|
|