|
@@ -16,6 +16,7 @@ import unicodedata |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch import nn |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from ..utils import _get_file_name_base_on_postfix |
|
|
from ..utils import _get_file_name_base_on_postfix |
|
|
from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR |
|
|
from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR |
|
@@ -24,6 +25,24 @@ from ...core import logger |
|
|
CONFIG_FILE = 'bert_config.json' |
|
|
CONFIG_FILE = 'bert_config.json' |
|
|
VOCAB_NAME = 'vocab.txt' |
|
|
VOCAB_NAME = 'vocab.txt' |
|
|
|
|
|
|
|
|
|
|
|
BERT_KEY_RENAME_MAP_1 = { |
|
|
|
|
|
'gamma': 'weight', |
|
|
|
|
|
'beta': 'bias', |
|
|
|
|
|
'distilbert.embeddings': 'bert.embeddings', |
|
|
|
|
|
'distilbert.transformer': 'bert.encoder', |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
BERT_KEY_RENAME_MAP_2 = { |
|
|
|
|
|
'q_lin': 'self.query', |
|
|
|
|
|
'k_lin': 'self.key', |
|
|
|
|
|
'v_lin': 'self.value', |
|
|
|
|
|
'out_lin': 'output.dense', |
|
|
|
|
|
'sa_layer_norm': 'attention.output.LayerNorm', |
|
|
|
|
|
'ffn.lin1': 'intermediate.dense', |
|
|
|
|
|
'ffn.lin2': 'output.dense', |
|
|
|
|
|
'output_layer_norm': 'output.LayerNorm', |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertConfig(object): |
|
|
class BertConfig(object): |
|
|
"""Configuration class to store the configuration of a `BertModel`. |
|
|
"""Configuration class to store the configuration of a `BertModel`. |
|
@@ -162,6 +181,55 @@ class BertLayerNorm(nn.Module): |
|
|
return self.weight * x + self.bias |
|
|
return self.weight * x + self.bias |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistilBertEmbeddings(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
|
|
|
super(DistilBertEmbeddings, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
def create_sinusoidal_embeddings(n_pos, dim, out): |
|
|
|
|
|
position_enc = np.array([ |
|
|
|
|
|
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] |
|
|
|
|
|
for pos in range(n_pos) |
|
|
|
|
|
]) |
|
|
|
|
|
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) |
|
|
|
|
|
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) |
|
|
|
|
|
out.detach_() |
|
|
|
|
|
out.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) |
|
|
|
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
|
|
|
if config.sinusoidal_pos_embds: |
|
|
|
|
|
create_sinusoidal_embeddings(n_pos=config.max_position_embeddings, |
|
|
|
|
|
dim=config.hidden_size, |
|
|
|
|
|
out=self.position_embeddings.weight) |
|
|
|
|
|
|
|
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) |
|
|
|
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_ids, token_type_ids): |
|
|
|
|
|
""" |
|
|
|
|
|
Parameters |
|
|
|
|
|
---------- |
|
|
|
|
|
input_ids: torch.tensor(bs, max_seq_length) |
|
|
|
|
|
The token ids to embed. |
|
|
|
|
|
token_type_ids: no used. |
|
|
|
|
|
Outputs |
|
|
|
|
|
------- |
|
|
|
|
|
embeddings: torch.tensor(bs, max_seq_length, dim) |
|
|
|
|
|
The embedded tokens (plus position embeddings, no token_type embeddings) |
|
|
|
|
|
""" |
|
|
|
|
|
seq_length = input_ids.size(1) |
|
|
|
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) |
|
|
|
|
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) |
|
|
|
|
|
|
|
|
|
|
|
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) |
|
|
|
|
|
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) |
|
|
|
|
|
|
|
|
|
|
|
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim) |
|
|
|
|
|
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) |
|
|
|
|
|
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertEmbeddings(nn.Module): |
|
|
class BertEmbeddings(nn.Module): |
|
|
"""Construct the embeddings from word, position and token_type embeddings. |
|
|
"""Construct the embeddings from word, position and token_type embeddings. |
|
|
""" |
|
|
""" |
|
@@ -383,9 +451,22 @@ class BertModel(nn.Module): |
|
|
super(BertModel, self).__init__() |
|
|
super(BertModel, self).__init__() |
|
|
self.config = config |
|
|
self.config = config |
|
|
self.hidden_size = self.config.hidden_size |
|
|
self.hidden_size = self.config.hidden_size |
|
|
self.embeddings = BertEmbeddings(config) |
|
|
|
|
|
|
|
|
self.model_type = 'bert' |
|
|
|
|
|
if hasattr(config, 'sinusoidal_pos_embds'): |
|
|
|
|
|
self.model_type = 'distilbert' |
|
|
|
|
|
elif 'model_type' in kwargs: |
|
|
|
|
|
self.model_type = kwargs['model_type'].lower() |
|
|
|
|
|
|
|
|
|
|
|
if self.model_type == 'distilbert': |
|
|
|
|
|
self.embeddings = DistilBertEmbeddings(config) |
|
|
|
|
|
else: |
|
|
|
|
|
self.embeddings = BertEmbeddings(config) |
|
|
|
|
|
|
|
|
self.encoder = BertEncoder(config) |
|
|
self.encoder = BertEncoder(config) |
|
|
self.pooler = BertPooler(config) |
|
|
|
|
|
|
|
|
if self.model_type != 'distilbert': |
|
|
|
|
|
self.pooler = BertPooler(config) |
|
|
|
|
|
else: |
|
|
|
|
|
logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') |
|
|
self.apply(self.init_bert_weights) |
|
|
self.apply(self.init_bert_weights) |
|
|
|
|
|
|
|
|
def init_bert_weights(self, module): |
|
|
def init_bert_weights(self, module): |
|
@@ -427,7 +508,10 @@ class BertModel(nn.Module): |
|
|
extended_attention_mask, |
|
|
extended_attention_mask, |
|
|
output_all_encoded_layers=output_all_encoded_layers) |
|
|
output_all_encoded_layers=output_all_encoded_layers) |
|
|
sequence_output = encoded_layers[-1] |
|
|
sequence_output = encoded_layers[-1] |
|
|
pooled_output = self.pooler(sequence_output) |
|
|
|
|
|
|
|
|
if self.model_type != 'distilbert': |
|
|
|
|
|
pooled_output = self.pooler(sequence_output) |
|
|
|
|
|
else: |
|
|
|
|
|
pooled_output = sequence_output[:, 0] |
|
|
if not output_all_encoded_layers: |
|
|
if not output_all_encoded_layers: |
|
|
encoded_layers = encoded_layers[-1] |
|
|
encoded_layers = encoded_layers[-1] |
|
|
return encoded_layers, pooled_output |
|
|
return encoded_layers, pooled_output |
|
@@ -445,9 +529,7 @@ class BertModel(nn.Module): |
|
|
# Load config |
|
|
# Load config |
|
|
config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') |
|
|
config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') |
|
|
config = BertConfig.from_json_file(config_file) |
|
|
config = BertConfig.from_json_file(config_file) |
|
|
# logger.info("Model config {}".format(config)) |
|
|
|
|
|
# Instantiate model. |
|
|
|
|
|
model = cls(config, *inputs, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
if state_dict is None: |
|
|
if state_dict is None: |
|
|
weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') |
|
|
weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') |
|
|
state_dict = torch.load(weights_path, map_location='cpu') |
|
|
state_dict = torch.load(weights_path, map_location='cpu') |
|
@@ -455,20 +537,40 @@ class BertModel(nn.Module): |
|
|
logger.error(f'Cannot load parameters through `state_dict` variable.') |
|
|
logger.error(f'Cannot load parameters through `state_dict` variable.') |
|
|
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') |
|
|
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') |
|
|
|
|
|
|
|
|
|
|
|
model_type = 'BERT' |
|
|
|
|
|
old_keys = [] |
|
|
|
|
|
new_keys = [] |
|
|
|
|
|
for key in state_dict.keys(): |
|
|
|
|
|
new_key = None |
|
|
|
|
|
for key_name in BERT_KEY_RENAME_MAP_1: |
|
|
|
|
|
if key_name in key: |
|
|
|
|
|
new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_1[key_name]) |
|
|
|
|
|
if 'distilbert' in key: |
|
|
|
|
|
model_type = 'DistilBert' |
|
|
|
|
|
break |
|
|
|
|
|
if new_key: |
|
|
|
|
|
old_keys.append(key) |
|
|
|
|
|
new_keys.append(new_key) |
|
|
|
|
|
for old_key, new_key in zip(old_keys, new_keys): |
|
|
|
|
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
|
|
|
|
old_keys = [] |
|
|
old_keys = [] |
|
|
new_keys = [] |
|
|
new_keys = [] |
|
|
for key in state_dict.keys(): |
|
|
for key in state_dict.keys(): |
|
|
new_key = None |
|
|
new_key = None |
|
|
if 'gamma' in key: |
|
|
|
|
|
new_key = key.replace('gamma', 'weight') |
|
|
|
|
|
if 'beta' in key: |
|
|
|
|
|
new_key = key.replace('beta', 'bias') |
|
|
|
|
|
|
|
|
for key_name in BERT_KEY_RENAME_MAP_2: |
|
|
|
|
|
if key_name in key: |
|
|
|
|
|
new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_2[key_name]) |
|
|
|
|
|
break |
|
|
if new_key: |
|
|
if new_key: |
|
|
old_keys.append(key) |
|
|
old_keys.append(key) |
|
|
new_keys.append(new_key) |
|
|
new_keys.append(new_key) |
|
|
for old_key, new_key in zip(old_keys, new_keys): |
|
|
for old_key, new_key in zip(old_keys, new_keys): |
|
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
|
|
|
|
|
|
|
# Instantiate model. |
|
|
|
|
|
model = cls(config, model_type=model_type, *inputs, **kwargs) |
|
|
|
|
|
|
|
|
missing_keys = [] |
|
|
missing_keys = [] |
|
|
unexpected_keys = [] |
|
|
unexpected_keys = [] |
|
|
error_msgs = [] |
|
|
error_msgs = [] |
|
@@ -494,7 +596,7 @@ class BertModel(nn.Module): |
|
|
logger.warning("Weights from pretrained model not used in {}: {}".format( |
|
|
logger.warning("Weights from pretrained model not used in {}: {}".format( |
|
|
model.__class__.__name__, unexpected_keys)) |
|
|
model.__class__.__name__, unexpected_keys)) |
|
|
|
|
|
|
|
|
logger.info(f"Load pre-trained BERT parameters from file {weights_path}.") |
|
|
|
|
|
|
|
|
logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") |
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|