Browse Source

fix a bug in early exit of bert

tags/v1.0.0alpha
Yige Xu 4 years ago
parent
commit
bffde7857a
1 changed files with 5 additions and 8 deletions
  1. +5
    -8
      fastNLP/modules/encoder/bert.py

+ 5
- 8
fastNLP/modules/encoder/bert.py View File

@@ -374,20 +374,18 @@ class BertEncoder(nn.Module):
self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0)
if self.num_output_layer + 1 < len(self.layer):
logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} '
f'(start from 0)!')
f'(layer 0 means embedding layer)!')

def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for idx, layer_module in enumerate(self.layer):
if idx > self.num_output_layer:
if idx >= self.num_output_layer:
break
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if len(all_encoder_layers) == 0:
all_encoder_layers.append(hidden_states)
return all_encoder_layers


@@ -445,8 +443,8 @@ class BertModel(nn.Module):
self.hidden_size = self.config.hidden_size
self.model_type = 'bert'
neg_num_output_layer = kwargs.get('neg_num_output_layer', -1)
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1)
self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer)
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers)
self.num_output_layer = max(neg_num_output_layer + 1 + self.config.num_hidden_layers, pos_num_output_layer)
if hasattr(config, 'sinusoidal_pos_embds'):
self.model_type = 'distilbert'
elif 'model_type' in kwargs:
@@ -535,6 +533,7 @@ class BertModel(nn.Module):
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
encoded_layers.insert(0, embedding_output)
sequence_output = encoded_layers[-1]
if self.model_type != 'distilbert':
pooled_output = self.pooler(sequence_output)
@@ -542,8 +541,6 @@ class BertModel(nn.Module):
pooled_output = sequence_output[:, 0]
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
else:
encoded_layers.insert(0, embedding_output)
return encoded_layers, pooled_output

@classmethod


Loading…
Cancel
Save