Browse Source

bert_embedding增加一个auto_truncate的参数,在word pieces长度超过512的情况自动使用0去padding

tags/v0.4.10
yh_cc 5 years ago
parent
commit
e166c119f5
1 changed files with 26 additions and 3 deletions
  1. +26
    -3
      fastNLP/embeddings/bert_embedding.py

+ 26
- 3
fastNLP/embeddings/bert_embedding.py View File

@@ -49,10 +49,13 @@ class BertEmbedding(ContextualEmbedding):
:param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测,
一般该值为True。
:param bool requires_grad: 是否需要gradient以更新Bert的权重。
:param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
来进行分类的任务将auto_truncate置为True。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1',
pool_method: str='first', word_dropout=0, dropout=0, include_cls_sep: bool=False,
pooled_cls=True, requires_grad: bool=False):
pooled_cls=True, requires_grad: bool=False, auto_truncate:bool=False):
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

# 根据model_dir_or_name检查是否存在并下载
@@ -69,7 +72,7 @@ class BertEmbedding(ContextualEmbedding):

self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls)
pooled_cls=pooled_cls, auto_truncate=auto_truncate)

self.requires_grad = requires_grad
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size
@@ -202,11 +205,12 @@ class BertWordPieceEncoder(nn.Module):

class _WordBertModel(nn.Module):
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first',
include_cls_sep:bool=False, pooled_cls:bool=False):
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False):
super().__init__()

self.tokenzier = BertTokenizer.from_pretrained(model_dir)
self.encoder = BertModel.from_pretrained(model_dir)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
self.layers = list(map(int, layers.split(',')))
@@ -222,6 +226,7 @@ class _WordBertModel(nn.Module):
self.pool_method = pool_method
self.include_cls_sep = include_cls_sep
self.pooled_cls = pooled_cls
self.auto_truncate = auto_truncate

# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
print("Start to generating word pieces for word.")
@@ -290,6 +295,17 @@ class _WordBertModel(nn.Module):
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1)
max_word_piece_length = word_pieces_lengths.max().item()
real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度
if max_word_piece_length+2>self._max_position_embeddings:
if self.auto_truncate:
word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings,
self._max_position_embeddings-2)
max_word_piece_length = self._max_position_embeddings-2
else:
raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the "
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.")


# +2是由于需要加入[CLS]与[SEP]
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index)
word_pieces[:, 0].fill_(self._cls_index)
@@ -300,6 +316,8 @@ class _WordBertModel(nn.Module):
word_indexes = words.tolist()
for i in range(batch_size):
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]]))
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2:
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2]
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i)
attn_masks[i, :len(word_pieces_i)+2].fill_(1)
# TODO 截掉长度超过的部分。
@@ -321,6 +339,11 @@ class _WordBertModel(nn.Module):
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
for l_index, l in enumerate(self.layers):
output_layer = bert_outputs[l]
if real_max_word_piece_length > max_word_piece_length: # 如果实际上是截取出来的
paddings = output_layer.new_zeros(batch_size,
real_max_word_piece_length-max_word_piece_length,
output_layer.size(2))
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
# 从word_piece collapse到word的表示
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size
outputs_seq_len = seq_len + s_shift


Loading…
Cancel
Save