From 995f4ef18d8dd29bb06507d0f3b092eb669a6055 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Fri, 27 Sep 2019 03:32:03 +0800 Subject: [PATCH] add distilbert from pytorch-transformers package --- fastNLP/io/file_utils.py | 2 + fastNLP/modules/encoder/bert.py | 124 +++++++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 11 deletions(-) diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index a4abb575..9e7ac6f6 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -37,6 +37,8 @@ PRETRAINED_BERT_MODEL_DIR = { 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', + 'en-distilbert-base-uncased': 'distilbert-base-uncased.zip', + 'multi-base-cased': 'bert-base-multilingual-cased.zip', 'multi-base-uncased': 'bert-base-multilingual-uncased.zip', diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 16b456fb..821b9c5c 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -16,6 +16,7 @@ import unicodedata import torch from torch import nn +import numpy as np from ..utils import _get_file_name_base_on_postfix from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR @@ -24,6 +25,24 @@ from ...core import logger CONFIG_FILE = 'bert_config.json' VOCAB_NAME = 'vocab.txt' +BERT_KEY_RENAME_MAP_1 = { + 'gamma': 'weight', + 'beta': 'bias', + 'distilbert.embeddings': 'bert.embeddings', + 'distilbert.transformer': 'bert.encoder', +} + +BERT_KEY_RENAME_MAP_2 = { + 'q_lin': 'self.query', + 'k_lin': 'self.key', + 'v_lin': 'self.value', + 'out_lin': 'output.dense', + 'sa_layer_norm': 'attention.output.LayerNorm', + 'ffn.lin1': 'intermediate.dense', + 'ffn.lin2': 'output.dense', + 'output_layer_norm': 'output.LayerNorm', +} + class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. @@ -162,6 +181,55 @@ class BertLayerNorm(nn.Module): return self.weight * x + self.bias +class DistilBertEmbeddings(nn.Module): + def __init__(self, config): + super(DistilBertEmbeddings, self).__init__() + + def create_sinusoidal_embeddings(n_pos, dim, out): + position_enc = np.array([ + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + for pos in range(n_pos) + ]) + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + out.requires_grad = False + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + if config.sinusoidal_pos_embds: + create_sinusoidal_embeddings(n_pos=config.max_position_embeddings, + dim=config.hidden_size, + out=self.position_embeddings.weight) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids): + """ + Parameters + ---------- + input_ids: torch.tensor(bs, max_seq_length) + The token ids to embed. + token_type_ids: no used. + Outputs + ------- + embeddings: torch.tensor(bs, max_seq_length, dim) + The embedded tokens (plus position embeddings, no token_type embeddings) + """ + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) + + word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) + position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) + + embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim) + embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) + embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) + return embeddings + + class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ @@ -383,9 +451,22 @@ class BertModel(nn.Module): super(BertModel, self).__init__() self.config = config self.hidden_size = self.config.hidden_size - self.embeddings = BertEmbeddings(config) + self.model_type = 'bert' + if hasattr(config, 'sinusoidal_pos_embds'): + self.model_type = 'distilbert' + elif 'model_type' in kwargs: + self.model_type = kwargs['model_type'].lower() + + if self.model_type == 'distilbert': + self.embeddings = DistilBertEmbeddings(config) + else: + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) - self.pooler = BertPooler(config) + if self.model_type != 'distilbert': + self.pooler = BertPooler(config) + else: + logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') self.apply(self.init_bert_weights) def init_bert_weights(self, module): @@ -427,7 +508,10 @@ class BertModel(nn.Module): extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers) sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) + if self.model_type != 'distilbert': + pooled_output = self.pooler(sequence_output) + else: + pooled_output = sequence_output[:, 0] if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return encoded_layers, pooled_output @@ -445,9 +529,7 @@ class BertModel(nn.Module): # Load config config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') config = BertConfig.from_json_file(config_file) - # logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(config, *inputs, **kwargs) + if state_dict is None: weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') state_dict = torch.load(weights_path, map_location='cpu') @@ -455,20 +537,40 @@ class BertModel(nn.Module): logger.error(f'Cannot load parameters through `state_dict` variable.') raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') + model_type = 'BERT' + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + for key_name in BERT_KEY_RENAME_MAP_1: + if key_name in key: + new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_1[key_name]) + if 'distilbert' in key: + model_type = 'DistilBert' + break + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') + for key_name in BERT_KEY_RENAME_MAP_2: + if key_name in key: + new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_2[key_name]) + break if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) + # Instantiate model. + model = cls(config, model_type=model_type, *inputs, **kwargs) + missing_keys = [] unexpected_keys = [] error_msgs = [] @@ -494,7 +596,7 @@ class BertModel(nn.Module): logger.warning("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) - logger.info(f"Load pre-trained BERT parameters from file {weights_path}.") + logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") return model