will early exit if layer != -1tags/v1.0.0alpha
@@ -93,7 +93,7 @@ class BertEmbedding(ContextualEmbedding): | |||
""" | |||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
if word_dropout>0: | |||
if word_dropout > 0: | |||
assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token." | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
@@ -370,17 +370,29 @@ class _BertWordModel(nn.Module): | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
self.encoder = BertModel.from_pretrained(model_dir_or_name) | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
if isinstance(layers, list): | |||
self.layers = [int(l) for l in layers] | |||
elif isinstance(layers, str): | |||
self.layers = list(map(int, layers.split(','))) | |||
else: | |||
raise TypeError("`layers` only supports str or list[int]") | |||
assert len(self.layers) > 0, "There is no layer selected!" | |||
neg_num_output_layer = -16384 | |||
pos_num_output_layer = 0 | |||
for layer in self.layers: | |||
if layer < 0: | |||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||
else: | |||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
self.encoder = BertModel.from_pretrained(model_dir_or_name, | |||
neg_num_output_layer=neg_num_output_layer, | |||
pos_num_output_layer=pos_num_output_layer) | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
for layer in self.layers: | |||
if layer < 0: | |||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
@@ -196,20 +196,30 @@ class _RobertaWordModel(nn.Module): | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
super().__init__() | |||
self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name) | |||
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
if isinstance(layers, list): | |||
self.layers = [int(l) for l in layers] | |||
elif isinstance(layers, str): | |||
self.layers = list(map(int, layers.split(','))) | |||
else: | |||
raise TypeError("`layers` only supports str or list[int]") | |||
assert len(self.layers) > 0, "There is no layer selected!" | |||
neg_num_output_layer = -16384 | |||
pos_num_output_layer = 0 | |||
for layer in self.layers: | |||
if layer < 0: | |||
neg_num_output_layer = max(layer, neg_num_output_layer) | |||
else: | |||
pos_num_output_layer = max(layer, pos_num_output_layer) | |||
self.tokenizer = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||
self.encoder = RobertaModel.from_pretrained(model_dir_or_name, | |||
neg_num_output_layer=neg_num_output_layer, | |||
pos_num_output_layer=pos_num_output_layer) | |||
# 由于RobertaEmbedding中设置了padding_idx为1, 且使用了非常神奇的position计算方式,所以-2 | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings - 2 | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
for layer in self.layers: | |||
if layer < 0: | |||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
@@ -366,19 +366,28 @@ class BertLayer(nn.Module): | |||
class BertEncoder(nn.Module): | |||
def __init__(self, config): | |||
def __init__(self, config, num_output_layer=-1): | |||
super(BertEncoder, self).__init__() | |||
layer = BertLayer(config) | |||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) | |||
num_output_layer = num_output_layer if num_output_layer >= 0 else (len(self.layer) + num_output_layer) | |||
self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0) | |||
if self.num_output_layer + 1 < len(self.layer): | |||
logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} ' | |||
f'(start from 0)!') | |||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||
all_encoder_layers = [] | |||
for layer_module in self.layer: | |||
for idx, layer_module in enumerate(self.layer): | |||
if idx > self.num_output_layer: | |||
break | |||
hidden_states = layer_module(hidden_states, attention_mask) | |||
if output_all_encoded_layers: | |||
all_encoder_layers.append(hidden_states) | |||
if not output_all_encoded_layers: | |||
all_encoder_layers.append(hidden_states) | |||
if len(all_encoder_layers) == 0: | |||
all_encoder_layers.append(hidden_states) | |||
return all_encoder_layers | |||
@@ -435,6 +444,9 @@ class BertModel(nn.Module): | |||
self.config = config | |||
self.hidden_size = self.config.hidden_size | |||
self.model_type = 'bert' | |||
neg_num_output_layer = kwargs.get('neg_num_output_layer', -1) | |||
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1) | |||
self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer) | |||
if hasattr(config, 'sinusoidal_pos_embds'): | |||
self.model_type = 'distilbert' | |||
elif 'model_type' in kwargs: | |||
@@ -445,7 +457,7 @@ class BertModel(nn.Module): | |||
else: | |||
self.embeddings = BertEmbeddings(config) | |||
self.encoder = BertEncoder(config) | |||
self.encoder = BertEncoder(config, num_output_layer=self.num_output_layer) | |||
if self.model_type != 'distilbert': | |||
self.pooler = BertPooler(config) | |||
else: | |||
@@ -64,8 +64,8 @@ class RobertaModel(BertModel): | |||
undocumented | |||
""" | |||
def __init__(self, config): | |||
super().__init__(config) | |||
def __init__(self, config, *inputs, **kwargs): | |||
super().__init__(config, *inputs, **kwargs) | |||
self.embeddings = RobertaEmbeddings(config) | |||
self.apply(self.init_bert_weights) | |||