Browse Source

add distilbert from pytorch-transformers package

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
995f4ef18d
2 changed files with 115 additions and 11 deletions
  1. +2
    -0
      fastNLP/io/file_utils.py
  2. +113
    -11
      fastNLP/modules/encoder/bert.py

+ 2
- 0
fastNLP/io/file_utils.py View File

@@ -37,6 +37,8 @@ PRETRAINED_BERT_MODEL_DIR = {


'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip',


'en-distilbert-base-uncased': 'distilbert-base-uncased.zip',

'multi-base-cased': 'bert-base-multilingual-cased.zip', 'multi-base-cased': 'bert-base-multilingual-cased.zip',
'multi-base-uncased': 'bert-base-multilingual-uncased.zip', 'multi-base-uncased': 'bert-base-multilingual-uncased.zip',




+ 113
- 11
fastNLP/modules/encoder/bert.py View File

@@ -16,6 +16,7 @@ import unicodedata


import torch import torch
from torch import nn from torch import nn
import numpy as np


from ..utils import _get_file_name_base_on_postfix from ..utils import _get_file_name_base_on_postfix
from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR
@@ -24,6 +25,24 @@ from ...core import logger
CONFIG_FILE = 'bert_config.json' CONFIG_FILE = 'bert_config.json'
VOCAB_NAME = 'vocab.txt' VOCAB_NAME = 'vocab.txt'


BERT_KEY_RENAME_MAP_1 = {
'gamma': 'weight',
'beta': 'bias',
'distilbert.embeddings': 'bert.embeddings',
'distilbert.transformer': 'bert.encoder',
}

BERT_KEY_RENAME_MAP_2 = {
'q_lin': 'self.query',
'k_lin': 'self.key',
'v_lin': 'self.value',
'out_lin': 'output.dense',
'sa_layer_norm': 'attention.output.LayerNorm',
'ffn.lin1': 'intermediate.dense',
'ffn.lin2': 'output.dense',
'output_layer_norm': 'output.LayerNorm',
}



class BertConfig(object): class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`. """Configuration class to store the configuration of a `BertModel`.
@@ -162,6 +181,55 @@ class BertLayerNorm(nn.Module):
return self.weight * x + self.bias return self.weight * x + self.bias




class DistilBertEmbeddings(nn.Module):
def __init__(self, config):
super(DistilBertEmbeddings, self).__init__()

def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
for pos in range(n_pos)
])
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False

self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
if config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(n_pos=config.max_position_embeddings,
dim=config.hidden_size,
out=self.position_embeddings.weight)

self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids, token_type_ids):
"""
Parameters
----------
input_ids: torch.tensor(bs, max_seq_length)
The token ids to embed.
token_type_ids: no used.
Outputs
-------
embeddings: torch.tensor(bs, max_seq_length, dim)
The embedded tokens (plus position embeddings, no token_type embeddings)
"""
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)

word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)

embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
return embeddings


class BertEmbeddings(nn.Module): class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings. """Construct the embeddings from word, position and token_type embeddings.
""" """
@@ -383,9 +451,22 @@ class BertModel(nn.Module):
super(BertModel, self).__init__() super(BertModel, self).__init__()
self.config = config self.config = config
self.hidden_size = self.config.hidden_size self.hidden_size = self.config.hidden_size
self.embeddings = BertEmbeddings(config)
self.model_type = 'bert'
if hasattr(config, 'sinusoidal_pos_embds'):
self.model_type = 'distilbert'
elif 'model_type' in kwargs:
self.model_type = kwargs['model_type'].lower()

if self.model_type == 'distilbert':
self.embeddings = DistilBertEmbeddings(config)
else:
self.embeddings = BertEmbeddings(config)

self.encoder = BertEncoder(config) self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
if self.model_type != 'distilbert':
self.pooler = BertPooler(config)
else:
logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.')
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)


def init_bert_weights(self, module): def init_bert_weights(self, module):
@@ -427,7 +508,10 @@ class BertModel(nn.Module):
extended_attention_mask, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers) output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1] sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if self.model_type != 'distilbert':
pooled_output = self.pooler(sequence_output)
else:
pooled_output = sequence_output[:, 0]
if not output_all_encoded_layers: if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1] encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output return encoded_layers, pooled_output
@@ -445,9 +529,7 @@ class BertModel(nn.Module):
# Load config # Load config
config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json')
config = BertConfig.from_json_file(config_file) config = BertConfig.from_json_file(config_file)
# logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)

if state_dict is None: if state_dict is None:
weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin')
state_dict = torch.load(weights_path, map_location='cpu') state_dict = torch.load(weights_path, map_location='cpu')
@@ -455,20 +537,40 @@ class BertModel(nn.Module):
logger.error(f'Cannot load parameters through `state_dict` variable.') logger.error(f'Cannot load parameters through `state_dict` variable.')
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') raise RuntimeError(f'Cannot load parameters through `state_dict` variable.')


model_type = 'BERT'
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
for key_name in BERT_KEY_RENAME_MAP_1:
if key_name in key:
new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_1[key_name])
if 'distilbert' in key:
model_type = 'DistilBert'
break
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

old_keys = [] old_keys = []
new_keys = [] new_keys = []
for key in state_dict.keys(): for key in state_dict.keys():
new_key = None new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
for key_name in BERT_KEY_RENAME_MAP_2:
if key_name in key:
new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_2[key_name])
break
if new_key: if new_key:
old_keys.append(key) old_keys.append(key)
new_keys.append(new_key) new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys): for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)


# Instantiate model.
model = cls(config, model_type=model_type, *inputs, **kwargs)

missing_keys = [] missing_keys = []
unexpected_keys = [] unexpected_keys = []
error_msgs = [] error_msgs = []
@@ -494,7 +596,7 @@ class BertModel(nn.Module):
logger.warning("Weights from pretrained model not used in {}: {}".format( logger.warning("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys)) model.__class__.__name__, unexpected_keys))


logger.info(f"Load pre-trained BERT parameters from file {weights_path}.")
logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.")
return model return model






Loading…
Cancel
Save