Browse Source

增加pooled_cls选项,可以是Bert在做分类时可以使用预训练的权重

tags/v0.4.10
yh 6 years ago
parent
commit
3c2e419059
2 changed files with 25 additions and 11 deletions
  1. +18
    -8
      fastNLP/embeddings/bert_embedding.py
  2. +7
    -3
      fastNLP/modules/encoder/bert.py

+ 18
- 8
fastNLP/embeddings/bert_embedding.py View File

@@ -46,11 +46,13 @@ class BertEmbedding(ContextualEmbedding):
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的
embedding长度不匹配。
:param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测,
一般该值为True。
:param bool requires_grad: 是否需要gradient以更新Bert的权重。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1',
pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False,
include_cls_sep: bool=False):
pool_method: str='first', word_dropout=0, dropout=0, include_cls_sep: bool=False,
pooled_cls=True, requires_grad: bool=False):
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

# 根据model_dir_or_name检查是否存在并下载
@@ -66,7 +68,8 @@ class BertEmbedding(ContextualEmbedding):
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep)
pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls)

self.requires_grad = requires_grad
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size
@@ -119,10 +122,12 @@ class BertWordPieceEncoder(nn.Module):

:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取
[CLS]做预测,一般该值为True。
:param bool requires_grad: 是否需要gradient。
"""
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1',
requires_grad: bool=False):
pooled_cls: bool = False, requires_grad: bool=False):
super().__init__()
PRETRAIN_URL = _get_base_url('bert')

@@ -136,7 +141,7 @@ class BertWordPieceEncoder(nn.Module):
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers)
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls)
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
self.requires_grad = requires_grad

@@ -187,7 +192,8 @@ class BertWordPieceEncoder(nn.Module):


class _WordBertModel(nn.Module):
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', include_cls_sep:bool=False):
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first',
include_cls_sep:bool=False, pooled_cls:bool=False):
super().__init__()

self.tokenzier = BertTokenizer.from_pretrained(model_dir)
@@ -206,6 +212,7 @@ class _WordBertModel(nn.Module):
assert pool_method in ('avg', 'max', 'first', 'last')
self.pool_method = pool_method
self.include_cls_sep = include_cls_sep
self.pooled_cls = pooled_cls

# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
print("Start to generating word pieces for word.")
@@ -289,7 +296,7 @@ class _WordBertModel(nn.Module):
# TODO 截掉长度超过的部分。
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size

@@ -327,7 +334,10 @@ class _WordBertModel(nn.Module):
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1]
outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
if self.include_cls_sep:
outputs[l_index, :, 0] = output_layer[:, 0]
if l==len(bert_outputs) and self.pooled_cls:
outputs[l_index, :, 0] = pooled_cls
else:
outputs[l_index, :, 0] = output_layer[:, 0]
outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift]
# 3. 最终的embedding结果
return outputs


+ 7
- 3
fastNLP/modules/encoder/bert.py View File

@@ -848,7 +848,7 @@ class _WordPieceBertModel(nn.Module):

"""

def __init__(self, model_dir: str, layers: str = '-1'):
def __init__(self, model_dir: str, layers: str = '-1', pooled_cls:bool=False):
super().__init__()

self.tokenzier = BertTokenizer.from_pretrained(model_dir)
@@ -867,6 +867,7 @@ class _WordPieceBertModel(nn.Module):
self._cls_index = self.tokenzier.vocab['[CLS]']
self._sep_index = self.tokenzier.vocab['[SEP]']
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
self.pooled_cls = pooled_cls

def index_dataset(self, *datasets, field_name):
"""
@@ -909,10 +910,13 @@ class _WordPieceBertModel(nn.Module):
batch_size, max_len = word_pieces.size()

attn_masks = word_pieces.ne(self._wordpiece_pad_index)
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1)))
for l_index, l in enumerate(self.layers):
outputs[l_index] = bert_outputs[l]
bert_output = bert_outputs[l]
if l==len(bert_outputs) and self.pooled_cls:
bert_output[:, 0] = pooled_cls
outputs[l_index] = bert_output
return outputs

Loading…
Cancel
Save