From e166c119f58d52ca08973f59772702c62bb39d7a Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 6 Aug 2019 02:01:02 +0800 Subject: [PATCH] =?UTF-8?q?bert=5Fembedding=E5=A2=9E=E5=8A=A0=E4=B8=80?= =?UTF-8?q?=E4=B8=AAauto=5Ftruncate=E7=9A=84=E5=8F=82=E6=95=B0=EF=BC=8C?= =?UTF-8?q?=E5=9C=A8word=20pieces=E9=95=BF=E5=BA=A6=E8=B6=85=E8=BF=87512?= =?UTF-8?q?=E7=9A=84=E6=83=85=E5=86=B5=E8=87=AA=E5=8A=A8=E4=BD=BF=E7=94=A8?= =?UTF-8?q?0=E5=8E=BBpadding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 29 +++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index adc205c2..38b8daf2 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -49,10 +49,13 @@ class BertEmbedding(ContextualEmbedding): :param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测, 一般该值为True。 :param bool requires_grad: 是否需要gradient以更新Bert的权重。 + :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 + word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] + 来进行分类的任务将auto_truncate置为True。 """ def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', pool_method: str='first', word_dropout=0, dropout=0, include_cls_sep: bool=False, - pooled_cls=True, requires_grad: bool=False): + pooled_cls=True, requires_grad: bool=False, auto_truncate:bool=False): super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) # 根据model_dir_or_name检查是否存在并下载 @@ -69,7 +72,7 @@ class BertEmbedding(ContextualEmbedding): self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, - pooled_cls=pooled_cls) + pooled_cls=pooled_cls, auto_truncate=auto_truncate) self.requires_grad = requires_grad self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size @@ -202,11 +205,12 @@ class BertWordPieceEncoder(nn.Module): class _WordBertModel(nn.Module): def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', - include_cls_sep:bool=False, pooled_cls:bool=False): + include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False): super().__init__() self.tokenzier = BertTokenizer.from_pretrained(model_dir) self.encoder = BertModel.from_pretrained(model_dir) + self._max_position_embeddings = self.encoder.config.max_position_embeddings # 检查encoder_layer_number是否合理 encoder_layer_number = len(self.encoder.encoder.layer) self.layers = list(map(int, layers.split(','))) @@ -222,6 +226,7 @@ class _WordBertModel(nn.Module): self.pool_method = pool_method self.include_cls_sep = include_cls_sep self.pooled_cls = pooled_cls + self.auto_truncate = auto_truncate # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] print("Start to generating word pieces for word.") @@ -290,6 +295,17 @@ class _WordBertModel(nn.Module): batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) max_word_piece_length = word_pieces_lengths.max().item() + real_max_word_piece_length = max_word_piece_length # 表示没有截断的word piece的长度 + if max_word_piece_length+2>self._max_position_embeddings: + if self.auto_truncate: + word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings, + self._max_position_embeddings-2) + max_word_piece_length = self._max_position_embeddings-2 + else: + raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the " + f"maximum allowed sequence length:{self._max_position_embeddings} of bert.") + + # +2是由于需要加入[CLS]与[SEP] word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) word_pieces[:, 0].fill_(self._cls_index) @@ -300,6 +316,8 @@ class _WordBertModel(nn.Module): word_indexes = words.tolist() for i in range(batch_size): word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) + if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2: + word_pieces_i = word_pieces_i[:self._max_position_embeddings-2] word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) attn_masks[i, :len(word_pieces_i)+2].fill_(1) # TODO 截掉长度超过的部分。 @@ -321,6 +339,11 @@ class _WordBertModel(nn.Module): batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len for l_index, l in enumerate(self.layers): output_layer = bert_outputs[l] + if real_max_word_piece_length > max_word_piece_length: # 如果实际上是截取出来的 + paddings = output_layer.new_zeros(batch_size, + real_max_word_piece_length-max_word_piece_length, + output_layer.size(2)) + output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() # 从word_piece collapse到word的表示 truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size outputs_seq_len = seq_len + s_shift