Browse Source

添加__all__属性

tags/v0.6.0
yh_cc 4 years ago
parent
commit
89ad195411
6 changed files with 19 additions and 1 deletions
  1. +2
    -0
      fastNLP/core/collate_fn.py
  2. +5
    -1
      fastNLP/embeddings/transformers_embedding.py
  3. +3
    -0
      fastNLP/models/seq2seq_generator.py
  4. +3
    -0
      fastNLP/models/seq2seq_model.py
  5. +3
    -0
      fastNLP/modules/decoder/seq2seq_decoder.py
  6. +3
    -0
      fastNLP/modules/encoder/seq2seq_encoder.py

+ 2
- 0
fastNLP/core/collate_fn.py View File

@@ -7,6 +7,8 @@ from .field import _get_ele_type_and_dim
from .utils import logger from .utils import logger
from copy import deepcopy from copy import deepcopy


__all__ = ['ConcatCollateFn']



def _check_type(batch_dict, fields): def _check_type(batch_dict, fields):
if len(fields) == 0: if len(fields) == 0:


+ 5
- 1
fastNLP/embeddings/transformers_embedding.py View File

@@ -15,6 +15,9 @@ from ..core import logger
from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary




__all__ = ['TransformersEmbedding', 'TransformersWordPieceEncoder']


class TransformersEmbedding(ContextualEmbedding): class TransformersEmbedding(ContextualEmbedding):
r""" r"""
使用transformers中的模型对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 使用transformers中的模型对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于
@@ -280,7 +283,7 @@ class _TransformersWordModel(nn.Module):
word = tokenizer.unk_token word = tokenizer.unk_token
elif vocab.word_count[word]<min_freq: elif vocab.word_count[word]<min_freq:
word = tokenizer.unk_token word = tokenizer.unk_token
word_pieces = self.tokenizer.tokenize(word)
word_pieces = self.tokenizer.tokenize(word, add_prefix_space=True)
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces)
word_to_wordpieces.append(word_pieces) word_to_wordpieces.append(word_pieces)
word_pieces_lengths.append(len(word_pieces)) word_pieces_lengths.append(len(word_pieces))
@@ -453,6 +456,7 @@ class _WordPieceTransformersModel(nn.Module):
:return: :return:
""" """
kwargs['add_special_tokens'] = kwargs.get('add_special_tokens', True) kwargs['add_special_tokens'] = kwargs.get('add_special_tokens', True)
kwargs['add_prefix_space'] = kwargs.get('add_special_tokens', True)


encode_func = partial(self.tokenizer.encode, **kwargs) encode_func = partial(self.tokenizer.encode, **kwargs)




+ 3
- 0
fastNLP/models/seq2seq_generator.py View File

@@ -6,6 +6,9 @@ from .seq2seq_model import Seq2SeqModel
from ..modules.generator.seq2seq_generator import SequenceGenerator from ..modules.generator.seq2seq_generator import SequenceGenerator




__all__ = ['SequenceGeneratorModel']


class SequenceGeneratorModel(nn.Module): class SequenceGeneratorModel(nn.Module):
""" """
用于封装Seq2SeqModel使其可以做生成任务 用于封装Seq2SeqModel使其可以做生成任务


+ 3
- 0
fastNLP/models/seq2seq_model.py View File

@@ -12,6 +12,9 @@ from ..modules.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2Seq
from ..modules.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder from ..modules.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder




__all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel']


class Seq2SeqModel(nn.Module): class Seq2SeqModel(nn.Module):
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder):
""" """


+ 3
- 0
fastNLP/modules/decoder/seq2seq_decoder.py View File

@@ -11,6 +11,9 @@ from ...embeddings.utils import get_embeddings
from .seq2seq_state import State, LSTMState, TransformerState from .seq2seq_state import State, LSTMState, TransformerState




__all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder']


class Seq2SeqDecoder(nn.Module): class Seq2SeqDecoder(nn.Module):
""" """
Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象


+ 3
- 0
fastNLP/modules/encoder/seq2seq_encoder.py View File

@@ -12,6 +12,9 @@ from ...embeddings import StaticEmbedding
from ...embeddings.utils import get_embeddings from ...embeddings.utils import get_embeddings




__all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder']


class Seq2SeqEncoder(nn.Module): class Seq2SeqEncoder(nn.Module):
""" """
所有Sequence2Sequence Encoder的基类。需要实现forward函数 所有Sequence2Sequence Encoder的基类。需要实现forward函数


Loading…
Cancel
Save