@@ -7,6 +7,8 @@ from .field import _get_ele_type_and_dim | |||||
from .utils import logger | from .utils import logger | ||||
from copy import deepcopy | from copy import deepcopy | ||||
__all__ = ['ConcatCollateFn'] | |||||
def _check_type(batch_dict, fields): | def _check_type(batch_dict, fields): | ||||
if len(fields) == 0: | if len(fields) == 0: | ||||
@@ -15,6 +15,9 @@ from ..core import logger | |||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
__all__ = ['TransformersEmbedding', 'TransformersWordPieceEncoder'] | |||||
class TransformersEmbedding(ContextualEmbedding): | class TransformersEmbedding(ContextualEmbedding): | ||||
r""" | r""" | ||||
使用transformers中的模型对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | 使用transformers中的模型对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | ||||
@@ -280,7 +283,7 @@ class _TransformersWordModel(nn.Module): | |||||
word = tokenizer.unk_token | word = tokenizer.unk_token | ||||
elif vocab.word_count[word]<min_freq: | elif vocab.word_count[word]<min_freq: | ||||
word = tokenizer.unk_token | word = tokenizer.unk_token | ||||
word_pieces = self.tokenizer.tokenize(word) | |||||
word_pieces = self.tokenizer.tokenize(word, add_prefix_space=True) | |||||
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | ||||
word_to_wordpieces.append(word_pieces) | word_to_wordpieces.append(word_pieces) | ||||
word_pieces_lengths.append(len(word_pieces)) | word_pieces_lengths.append(len(word_pieces)) | ||||
@@ -453,6 +456,7 @@ class _WordPieceTransformersModel(nn.Module): | |||||
:return: | :return: | ||||
""" | """ | ||||
kwargs['add_special_tokens'] = kwargs.get('add_special_tokens', True) | kwargs['add_special_tokens'] = kwargs.get('add_special_tokens', True) | ||||
kwargs['add_prefix_space'] = kwargs.get('add_special_tokens', True) | |||||
encode_func = partial(self.tokenizer.encode, **kwargs) | encode_func = partial(self.tokenizer.encode, **kwargs) | ||||
@@ -6,6 +6,9 @@ from .seq2seq_model import Seq2SeqModel | |||||
from ..modules.generator.seq2seq_generator import SequenceGenerator | from ..modules.generator.seq2seq_generator import SequenceGenerator | ||||
__all__ = ['SequenceGeneratorModel'] | |||||
class SequenceGeneratorModel(nn.Module): | class SequenceGeneratorModel(nn.Module): | ||||
""" | """ | ||||
用于封装Seq2SeqModel使其可以做生成任务 | 用于封装Seq2SeqModel使其可以做生成任务 | ||||
@@ -12,6 +12,9 @@ from ..modules.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2Seq | |||||
from ..modules.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | from ..modules.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | ||||
__all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel'] | |||||
class Seq2SeqModel(nn.Module): | class Seq2SeqModel(nn.Module): | ||||
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | ||||
""" | """ | ||||
@@ -11,6 +11,9 @@ from ...embeddings.utils import get_embeddings | |||||
from .seq2seq_state import State, LSTMState, TransformerState | from .seq2seq_state import State, LSTMState, TransformerState | ||||
__all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder'] | |||||
class Seq2SeqDecoder(nn.Module): | class Seq2SeqDecoder(nn.Module): | ||||
""" | """ | ||||
Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | ||||
@@ -12,6 +12,9 @@ from ...embeddings import StaticEmbedding | |||||
from ...embeddings.utils import get_embeddings | from ...embeddings.utils import get_embeddings | ||||
__all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder'] | |||||
class Seq2SeqEncoder(nn.Module): | class Seq2SeqEncoder(nn.Module): | ||||
""" | """ | ||||
所有Sequence2Sequence Encoder的基类。需要实现forward函数 | 所有Sequence2Sequence Encoder的基类。需要实现forward函数 | ||||