Browse Source

update some function about bert and roberta

tags/v1.0.0alpha
Yige Xu 3 years ago
parent
commit
030e0aa3ee
3 changed files with 69 additions and 31 deletions
  1. +28
    -14
      fastNLP/embeddings/bert_embedding.py
  2. +36
    -14
      fastNLP/embeddings/roberta_embedding.py
  3. +5
    -3
      fastNLP/modules/encoder/bert.py

+ 28
- 14
fastNLP/embeddings/bert_embedding.py View File

@@ -110,11 +110,12 @@ class BertEmbedding(ContextualEmbedding):
if '[CLS]' in vocab:
self._word_cls_index = vocab['[CLS]']

min_freq = kwargs.get('min_freq', 1)
min_freq = kwargs.pop('min_freq', 1)
self._min_freq = min_freq
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate)
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate,
**kwargs)

self.requires_grad = requires_grad
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
@@ -367,32 +368,44 @@ class BertWordPieceEncoder(nn.Module):

class _BertWordModel(nn.Module):
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
**kwargs):
super().__init__()

if isinstance(layers, list):
self.layers = [int(l) for l in layers]
elif isinstance(layers, str):
self.layers = list(map(int, layers.split(',')))
if layers.lower() == 'all':
self.layers = None
else:
self.layers = list(map(int, layers.split(',')))
else:
raise TypeError("`layers` only supports str or list[int]")
assert len(self.layers) > 0, "There is no layer selected!"

neg_num_output_layer = -16384
pos_num_output_layer = 0
for layer in self.layers:
if layer < 0:
neg_num_output_layer = max(layer, neg_num_output_layer)
else:
pos_num_output_layer = max(layer, pos_num_output_layer)
if self.layers is None:
neg_num_output_layer = -1
else:
for layer in self.layers:
if layer < 0:
neg_num_output_layer = max(layer, neg_num_output_layer)
else:
pos_num_output_layer = max(layer, pos_num_output_layer)

self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name,
neg_num_output_layer=neg_num_output_layer,
pos_num_output_layer=pos_num_output_layer)
pos_num_output_layer=pos_num_output_layer,
**kwargs)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
if self.layers is None:
self.layers = [idx for idx in range(encoder_layer_number + 1)]
logger.info(f'Bert Model will return {len(self.layers)} layers (layer-0 '
f'is embedding result): {self.layers}')
assert len(self.layers) > 0, "There is no layer selected!"
for layer in self.layers:
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
@@ -417,7 +430,7 @@ class _BertWordModel(nn.Module):
word = '[PAD]'
elif index == vocab.unknown_idx:
word = '[UNK]'
elif vocab.word_count[word]<min_freq:
elif vocab.word_count[word] < min_freq:
word = '[UNK]'
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word)
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces)
@@ -481,14 +494,15 @@ class _BertWordModel(nn.Module):
token_type_ids = torch.zeros_like(word_pieces)
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids,
attention_mask=attn_masks,
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size

if self.include_cls_sep:
s_shift = 1
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
bert_outputs[-1].size(-1))
bert_outputs[-1].size(-1))

else:
s_shift = 0


+ 36
- 14
fastNLP/embeddings/roberta_embedding.py View File

@@ -93,12 +93,13 @@ class RobertaEmbedding(ContextualEmbedding):
if '<s>' in vocab:
self._word_cls_index = vocab['<s>']

min_freq = kwargs.get('min_freq', 1)
min_freq = kwargs.pop('min_freq', 1)
self._min_freq = min_freq

self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq)
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq,
**kwargs)
self.requires_grad = requires_grad
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size

@@ -193,33 +194,45 @@ class RobertaEmbedding(ContextualEmbedding):

class _RobertaWordModel(nn.Module):
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2,
**kwargs):
super().__init__()

if isinstance(layers, list):
self.layers = [int(l) for l in layers]
elif isinstance(layers, str):
self.layers = list(map(int, layers.split(',')))
if layers.lower() == 'all':
self.layers = None
else:
self.layers = list(map(int, layers.split(',')))
else:
raise TypeError("`layers` only supports str or list[int]")
assert len(self.layers) > 0, "There is no layer selected!"

neg_num_output_layer = -16384
pos_num_output_layer = 0
for layer in self.layers:
if layer < 0:
neg_num_output_layer = max(layer, neg_num_output_layer)
else:
pos_num_output_layer = max(layer, pos_num_output_layer)
if self.layers is None:
neg_num_output_layer = -1
else:
for layer in self.layers:
if layer < 0:
neg_num_output_layer = max(layer, neg_num_output_layer)
else:
pos_num_output_layer = max(layer, pos_num_output_layer)

self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name)
self.encoder = RobertaModel.from_pretrained(model_dir_or_name,
neg_num_output_layer=neg_num_output_layer,
pos_num_output_layer=pos_num_output_layer)
pos_num_output_layer=pos_num_output_layer,
**kwargs)
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
if self.layers is None:
self.layers = [idx for idx in range(encoder_layer_number + 1)]
logger.info(f'RoBERTa Model will return {len(self.layers)} layers (layer-0 '
f'is embedding result): {self.layers}')
assert len(self.layers) > 0, "There is no layer selected!"
for layer in self.layers:
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
@@ -241,7 +254,7 @@ class _RobertaWordModel(nn.Module):
word = '<pad>'
elif index == vocab.unknown_idx:
word = '<unk>'
elif vocab.word_count[word]<min_freq:
elif vocab.word_count[word] < min_freq:
word = '<unk>'
word_pieces = self.tokenizer.tokenize(word)
word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces)
@@ -265,13 +278,15 @@ class _RobertaWordModel(nn.Module):
batch_size, max_word_len = words.size()
word_mask = words.ne(self._word_pad_index) # 为1的地方有word
seq_len = word_mask.sum(dim=-1)
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), 0) # batch_size x max_len
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
0) # batch_size x max_len
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
max_word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
if max_word_piece_length + 2 > self._max_position_embeddings:
if self.auto_truncate:
word_pieces_lengths = word_pieces_lengths.masked_fill(
word_pieces_lengths + 2 > self._max_position_embeddings, self._max_position_embeddings - 2)
word_pieces_lengths + 2 > self._max_position_embeddings,
self._max_position_embeddings - 2)
else:
raise RuntimeError(
"After split words into word pieces, the lengths of word pieces are longer than the "
@@ -290,6 +305,7 @@ class _RobertaWordModel(nn.Module):
word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
# 添加<s>和</s>
word_pieces[:, 0].fill_(self._cls_index)
batch_indexes = torch.arange(batch_size).to(words)
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
@@ -362,6 +378,12 @@ class _RobertaWordModel(nn.Module):
return outputs

def save(self, folder):
"""
给定一个folder保存pytorch_model.bin, config.json, vocab.txt

:param str folder:
:return:
"""
self.tokenizer.save_pretrained(folder)
self.encoder.save_pretrained(folder)



+ 5
- 3
fastNLP/modules/encoder/bert.py View File

@@ -184,21 +184,23 @@ class DistilBertEmbeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids, token_type_ids):
def forward(self, input_ids, token_type_ids, position_ids=None):
r"""
Parameters
----------
input_ids: torch.tensor(bs, max_seq_length)
The token ids to embed.
token_type_ids: no used.
position_ids: no used.
Outputs
-------
embeddings: torch.tensor(bs, max_seq_length, dim)
The embedded tokens (plus position embeddings, no token_type embeddings)
"""
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)

word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)


Loading…
Cancel
Save