diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 28c47eb6..f304073d 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -374,20 +374,18 @@ class BertEncoder(nn.Module): self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0) if self.num_output_layer + 1 < len(self.layer): logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} ' - f'(start from 0)!') + f'(layer 0 means embedding layer)!') def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): all_encoder_layers = [] for idx, layer_module in enumerate(self.layer): - if idx > self.num_output_layer: + if idx >= self.num_output_layer: break hidden_states = layer_module(hidden_states, attention_mask) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) - if len(all_encoder_layers) == 0: - all_encoder_layers.append(hidden_states) return all_encoder_layers @@ -445,8 +443,8 @@ class BertModel(nn.Module): self.hidden_size = self.config.hidden_size self.model_type = 'bert' neg_num_output_layer = kwargs.get('neg_num_output_layer', -1) - pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1) - self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer) + pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers) + self.num_output_layer = max(neg_num_output_layer + 1 + self.config.num_hidden_layers, pos_num_output_layer) if hasattr(config, 'sinusoidal_pos_embds'): self.model_type = 'distilbert' elif 'model_type' in kwargs: @@ -535,6 +533,7 @@ class BertModel(nn.Module): encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers) + encoded_layers.insert(0, embedding_output) sequence_output = encoded_layers[-1] if self.model_type != 'distilbert': pooled_output = self.pooler(sequence_output) @@ -542,8 +541,6 @@ class BertModel(nn.Module): pooled_output = sequence_output[:, 0] if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] - else: - encoded_layers.insert(0, embedding_output) return encoded_layers, pooled_output @classmethod