Browse Source

optimize BertEmbedding and RoBERTaEmbedding

will early exit if layer != -1
tags/v1.0.0alpha
Yige Xu 4 years ago
parent
commit
9c4f802b6b
4 changed files with 52 additions and 18 deletions
  1. +18
    -6
      fastNLP/embeddings/bert_embedding.py
  2. +17
    -7
      fastNLP/embeddings/roberta_embedding.py
  3. +15
    -3
      fastNLP/modules/encoder/bert.py
  4. +2
    -2
      fastNLP/modules/encoder/roberta.py

+ 18
- 6
fastNLP/embeddings/bert_embedding.py View File

@@ -93,7 +93,7 @@ class BertEmbedding(ContextualEmbedding):
"""
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

if word_dropout>0:
if word_dropout > 0:
assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token."

if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
@@ -370,17 +370,29 @@ class _BertWordModel(nn.Module):
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
super().__init__()

self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
if isinstance(layers, list):
self.layers = [int(l) for l in layers]
elif isinstance(layers, str):
self.layers = list(map(int, layers.split(',')))
else:
raise TypeError("`layers` only supports str or list[int]")
assert len(self.layers) > 0, "There is no layer selected!"

neg_num_output_layer = -16384
pos_num_output_layer = 0
for layer in self.layers:
if layer < 0:
neg_num_output_layer = max(layer, neg_num_output_layer)
else:
pos_num_output_layer = max(layer, pos_num_output_layer)

self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name,
neg_num_output_layer=neg_num_output_layer,
pos_num_output_layer=pos_num_output_layer)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
for layer in self.layers:
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \


+ 17
- 7
fastNLP/embeddings/roberta_embedding.py View File

@@ -196,20 +196,30 @@ class _RobertaWordModel(nn.Module):
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
super().__init__()

self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name)
self.encoder = RobertaModel.from_pretrained(model_dir_or_name)
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)

if isinstance(layers, list):
self.layers = [int(l) for l in layers]
elif isinstance(layers, str):
self.layers = list(map(int, layers.split(',')))
else:
raise TypeError("`layers` only supports str or list[int]")
assert len(self.layers) > 0, "There is no layer selected!"

neg_num_output_layer = -16384
pos_num_output_layer = 0
for layer in self.layers:
if layer < 0:
neg_num_output_layer = max(layer, neg_num_output_layer)
else:
pos_num_output_layer = max(layer, pos_num_output_layer)

self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name)
self.encoder = RobertaModel.from_pretrained(model_dir_or_name,
neg_num_output_layer=neg_num_output_layer,
pos_num_output_layer=pos_num_output_layer)
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
for layer in self.layers:
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \


+ 15
- 3
fastNLP/modules/encoder/bert.py View File

@@ -366,19 +366,28 @@ class BertLayer(nn.Module):


class BertEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config, num_output_layer=-1):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
num_output_layer = num_output_layer if num_output_layer >= 0 else (len(self.layer) + num_output_layer)
self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0)
if self.num_output_layer + 1 < len(self.layer):
logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} '
f'(start from 0)!')

def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
for idx, layer_module in enumerate(self.layer):
if idx > self.num_output_layer:
break
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if len(all_encoder_layers) == 0:
all_encoder_layers.append(hidden_states)
return all_encoder_layers


@@ -435,6 +444,9 @@ class BertModel(nn.Module):
self.config = config
self.hidden_size = self.config.hidden_size
self.model_type = 'bert'
neg_num_output_layer = kwargs.get('neg_num_output_layer', -1)
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1)
self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer)
if hasattr(config, 'sinusoidal_pos_embds'):
self.model_type = 'distilbert'
elif 'model_type' in kwargs:
@@ -445,7 +457,7 @@ class BertModel(nn.Module):
else:
self.embeddings = BertEmbeddings(config)

self.encoder = BertEncoder(config)
self.encoder = BertEncoder(config, num_output_layer=self.num_output_layer)
if self.model_type != 'distilbert':
self.pooler = BertPooler(config)
else:


+ 2
- 2
fastNLP/modules/encoder/roberta.py View File

@@ -64,8 +64,8 @@ class RobertaModel(BertModel):
undocumented
"""

def __init__(self, config):
super().__init__(config)
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

self.embeddings = RobertaEmbeddings(config)
self.apply(self.init_bert_weights)


Loading…
Cancel
Save