will early exit if layer != -1tags/v1.0.0alpha
@@ -93,7 +93,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
""" | """ | ||||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
if word_dropout>0: | |||||
if word_dropout > 0: | |||||
assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token." | assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token." | ||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | ||||
@@ -370,17 +370,29 @@ class _BertWordModel(nn.Module): | |||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | ||||
super().__init__() | super().__init__() | ||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = BertModel.from_pretrained(model_dir_or_name) | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = len(self.encoder.encoder.layer) | |||||
if isinstance(layers, list): | if isinstance(layers, list): | ||||
self.layers = [int(l) for l in layers] | self.layers = [int(l) for l in layers] | ||||
elif isinstance(layers, str): | elif isinstance(layers, str): | ||||
self.layers = list(map(int, layers.split(','))) | self.layers = list(map(int, layers.split(','))) | ||||
else: | else: | ||||
raise TypeError("`layers` only supports str or list[int]") | raise TypeError("`layers` only supports str or list[int]") | ||||
assert len(self.layers) > 0, "There is no layer selected!" | |||||
neg_num_output_layer = -16384 | |||||
pos_num_output_layer = 0 | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||||
else: | |||||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = BertModel.from_pretrained(model_dir_or_name, | |||||
neg_num_output_layer=neg_num_output_layer, | |||||
pos_num_output_layer=pos_num_output_layer) | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = len(self.encoder.encoder.layer) | |||||
for layer in self.layers: | for layer in self.layers: | ||||
if layer < 0: | if layer < 0: | ||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | ||||
@@ -196,20 +196,30 @@ class _RobertaWordModel(nn.Module): | |||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | ||||
super().__init__() | super().__init__() | ||||
self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name) | |||||
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = len(self.encoder.encoder.layer) | |||||
if isinstance(layers, list): | if isinstance(layers, list): | ||||
self.layers = [int(l) for l in layers] | self.layers = [int(l) for l in layers] | ||||
elif isinstance(layers, str): | elif isinstance(layers, str): | ||||
self.layers = list(map(int, layers.split(','))) | self.layers = list(map(int, layers.split(','))) | ||||
else: | else: | ||||
raise TypeError("`layers` only supports str or list[int]") | raise TypeError("`layers` only supports str or list[int]") | ||||
assert len(self.layers) > 0, "There is no layer selected!" | |||||
neg_num_output_layer = -16384 | |||||
pos_num_output_layer = 0 | |||||
for layer in self.layers: | |||||
if layer < 0: | |||||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||||
else: | |||||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||||
self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name, | |||||
neg_num_output_layer=neg_num_output_layer, | |||||
pos_num_output_layer=pos_num_output_layer) | |||||
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 | |||||
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 | |||||
# 检查encoder_layer_number是否合理 | |||||
encoder_layer_number = len(self.encoder.encoder.layer) | |||||
for layer in self.layers: | for layer in self.layers: | ||||
if layer < 0: | if layer < 0: | ||||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | ||||
@@ -366,19 +366,28 @@ class BertLayer(nn.Module): | |||||
class BertEncoder(nn.Module): | class BertEncoder(nn.Module): | ||||
def __init__(self, config): | |||||
def __init__(self, config, num_output_layer=-1): | |||||
super(BertEncoder, self).__init__() | super(BertEncoder, self).__init__() | ||||
layer = BertLayer(config) | layer = BertLayer(config) | ||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) | ||||
num_output_layer = num_output_layer if num_output_layer >= 0 else (len(self.layer) + num_output_layer) | |||||
self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0) | |||||
if self.num_output_layer + 1 < len(self.layer): | |||||
logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} ' | |||||
f'(start from 0)!') | |||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | ||||
all_encoder_layers = [] | all_encoder_layers = [] | ||||
for layer_module in self.layer: | |||||
for idx, layer_module in enumerate(self.layer): | |||||
if idx > self.num_output_layer: | |||||
break | |||||
hidden_states = layer_module(hidden_states, attention_mask) | hidden_states = layer_module(hidden_states, attention_mask) | ||||
if output_all_encoded_layers: | if output_all_encoded_layers: | ||||
all_encoder_layers.append(hidden_states) | all_encoder_layers.append(hidden_states) | ||||
if not output_all_encoded_layers: | if not output_all_encoded_layers: | ||||
all_encoder_layers.append(hidden_states) | all_encoder_layers.append(hidden_states) | ||||
if len(all_encoder_layers) == 0: | |||||
all_encoder_layers.append(hidden_states) | |||||
return all_encoder_layers | return all_encoder_layers | ||||
@@ -435,6 +444,9 @@ class BertModel(nn.Module): | |||||
self.config = config | self.config = config | ||||
self.hidden_size = self.config.hidden_size | self.hidden_size = self.config.hidden_size | ||||
self.model_type = 'bert' | self.model_type = 'bert' | ||||
neg_num_output_layer = kwargs.get('neg_num_output_layer', -1) | |||||
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1) | |||||
self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer) | |||||
if hasattr(config, 'sinusoidal_pos_embds'): | if hasattr(config, 'sinusoidal_pos_embds'): | ||||
self.model_type = 'distilbert' | self.model_type = 'distilbert' | ||||
elif 'model_type' in kwargs: | elif 'model_type' in kwargs: | ||||
@@ -445,7 +457,7 @@ class BertModel(nn.Module): | |||||
else: | else: | ||||
self.embeddings = BertEmbeddings(config) | self.embeddings = BertEmbeddings(config) | ||||
self.encoder = BertEncoder(config) | |||||
self.encoder = BertEncoder(config, num_output_layer=self.num_output_layer) | |||||
if self.model_type != 'distilbert': | if self.model_type != 'distilbert': | ||||
self.pooler = BertPooler(config) | self.pooler = BertPooler(config) | ||||
else: | else: | ||||
@@ -64,8 +64,8 @@ class RobertaModel(BertModel): | |||||
undocumented | undocumented | ||||
""" | """ | ||||
def __init__(self, config): | |||||
super().__init__(config) | |||||
def __init__(self, config, *inputs, **kwargs): | |||||
super().__init__(config, *inputs, **kwargs) | |||||
self.embeddings = RobertaEmbeddings(config) | self.embeddings = RobertaEmbeddings(config) | ||||
self.apply(self.init_bert_weights) | self.apply(self.init_bert_weights) | ||||