|
|
@@ -374,20 +374,18 @@ class BertEncoder(nn.Module): |
|
|
|
self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0) |
|
|
|
if self.num_output_layer + 1 < len(self.layer): |
|
|
|
logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} ' |
|
|
|
f'(start from 0)!') |
|
|
|
f'(layer 0 means embedding layer)!') |
|
|
|
|
|
|
|
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): |
|
|
|
all_encoder_layers = [] |
|
|
|
for idx, layer_module in enumerate(self.layer): |
|
|
|
if idx > self.num_output_layer: |
|
|
|
if idx >= self.num_output_layer: |
|
|
|
break |
|
|
|
hidden_states = layer_module(hidden_states, attention_mask) |
|
|
|
if output_all_encoded_layers: |
|
|
|
all_encoder_layers.append(hidden_states) |
|
|
|
if not output_all_encoded_layers: |
|
|
|
all_encoder_layers.append(hidden_states) |
|
|
|
if len(all_encoder_layers) == 0: |
|
|
|
all_encoder_layers.append(hidden_states) |
|
|
|
return all_encoder_layers |
|
|
|
|
|
|
|
|
|
|
@@ -445,8 +443,8 @@ class BertModel(nn.Module): |
|
|
|
self.hidden_size = self.config.hidden_size |
|
|
|
self.model_type = 'bert' |
|
|
|
neg_num_output_layer = kwargs.get('neg_num_output_layer', -1) |
|
|
|
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1) |
|
|
|
self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer) |
|
|
|
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers) |
|
|
|
self.num_output_layer = max(neg_num_output_layer + 1 + self.config.num_hidden_layers, pos_num_output_layer) |
|
|
|
if hasattr(config, 'sinusoidal_pos_embds'): |
|
|
|
self.model_type = 'distilbert' |
|
|
|
elif 'model_type' in kwargs: |
|
|
@@ -535,6 +533,7 @@ class BertModel(nn.Module): |
|
|
|
encoded_layers = self.encoder(embedding_output, |
|
|
|
extended_attention_mask, |
|
|
|
output_all_encoded_layers=output_all_encoded_layers) |
|
|
|
encoded_layers.insert(0, embedding_output) |
|
|
|
sequence_output = encoded_layers[-1] |
|
|
|
if self.model_type != 'distilbert': |
|
|
|
pooled_output = self.pooler(sequence_output) |
|
|
@@ -542,8 +541,6 @@ class BertModel(nn.Module): |
|
|
|
pooled_output = sequence_output[:, 0] |
|
|
|
if not output_all_encoded_layers: |
|
|
|
encoded_layers = encoded_layers[-1] |
|
|
|
else: |
|
|
|
encoded_layers.insert(0, embedding_output) |
|
|
|
return encoded_layers, pooled_output |
|
|
|
|
|
|
|
@classmethod |
|
|
|