@@ -81,7 +81,7 @@ __all__ = [ | |||||
'logger', | 'logger', | ||||
"init_logger_dist", | "init_logger_dist", | ||||
] | ] | ||||
__version__ = '0.5.5' | |||||
__version__ = '0.5.6' | |||||
import sys | import sys | ||||
@@ -70,6 +70,7 @@ class BertForSequenceClassification(BaseModel): | |||||
def forward(self, words): | def forward(self, words): | ||||
r""" | r""" | ||||
输入为 [[w1, w2, w3, ...], [...]], BERTEmbedding会在开头和结尾额外加入[CLS]与[SEP] | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | ||||
""" | """ | ||||
@@ -115,6 +116,8 @@ class BertForSentenceMatching(BaseModel): | |||||
def forward(self, words): | def forward(self, words): | ||||
r""" | r""" | ||||
输入words的格式为 [sent1] + [SEP] + [sent2](BertEmbedding会在开头加入[CLS]和在结尾加入[SEP]),输出为batch_size x num_labels | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.Tensor [batch_size, num_labels] | ||||
""" | """ | ||||
@@ -247,6 +250,10 @@ class BertForQuestionAnswering(BaseModel): | |||||
def forward(self, words): | def forward(self, words): | ||||
r""" | r""" | ||||
输入words为question + [SEP] + [paragraph],BERTEmbedding在之后会额外加入开头的[CLS]和结尾的[SEP]. note: | |||||
如果BERTEmbedding中include_cls_sep=True,则输出的start和end index相对输入words会增加一位;如果为BERTEmbedding中 | |||||
include_cls_sep=False, 则输出start和end index的位置与输入words的顺序完全一致 | |||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] | :return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] | ||||
""" | """ | ||||
@@ -473,6 +473,17 @@ class BertModel(nn.Module): | |||||
module.bias.data.zero_() | module.bias.data.zero_() | ||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | ||||
""" | |||||
:param torch.LongTensor input_ids: bsz x max_len的输入id | |||||
:param torch.LongTensor token_type_ids: bsz x max_len,如果不输入认为全为0,一般第一个sep(含)及以前为0, 一个sep之后为1 | |||||
:param attention_mask: 需要attend的为1,不需要为0 | |||||
:param bool output_all_encoded_layers: 是否输出所有层,默认输出token embedding(包含bpe, position以及type embedding) | |||||
及每一层的hidden states。如果为False,只输出最后一层的结果 | |||||
:return: encode_layers: 如果output_all_encoded_layers为True,返回list(共num_layers+1个元素),每个元素为 | |||||
bsz x max_len x hidden_size否则返回bsz x max_len x hidden_size的tensor; | |||||
pooled_output: bsz x hidden_size为cls的表示,可以用于句子的分类 | |||||
""" | |||||
if attention_mask is None: | if attention_mask is None: | ||||
attention_mask = torch.ones_like(input_ids) | attention_mask = torch.ones_like(input_ids) | ||||
if token_type_ids is None: | if token_type_ids is None: | ||||
@@ -504,7 +515,8 @@ class BertModel(nn.Module): | |||||
pooled_output = sequence_output[:, 0] | pooled_output = sequence_output[:, 0] | ||||
if not output_all_encoded_layers: | if not output_all_encoded_layers: | ||||
encoded_layers = encoded_layers[-1] | encoded_layers = encoded_layers[-1] | ||||
encoded_layers.insert(0, embedding_output) | |||||
else: | |||||
encoded_layers.insert(0, embedding_output) | |||||
return encoded_layers, pooled_output | return encoded_layers, pooled_output | ||||
@classmethod | @classmethod | ||||
@@ -16,7 +16,7 @@ print(pkgs) | |||||
setup( | setup( | ||||
name='FastNLP', | name='FastNLP', | ||||
version='0.5.0', | |||||
version='0.5.6', | |||||
url='https://github.com/fastnlp/fastNLP', | url='https://github.com/fastnlp/fastNLP', | ||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
long_description=readme, | long_description=readme, | ||||