From bfaf09df8cba78e02ad2aa73dab11c5ff6d7a7b9 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 29 Jan 2019 20:35:12 +0800 Subject: [PATCH] add BERT model * load pre-trained BERT weights from local binary * add tests --- fastNLP/models/bert.py | 342 ++++++++++++++++++++++++ fastNLP/modules/aggregator/attention.py | 17 +- fastNLP/modules/encoder/transformer.py | 1 - test/models/test_bert.py | 21 ++ 4 files changed, 372 insertions(+), 9 deletions(-) create mode 100644 fastNLP/models/bert.py create mode 100644 test/models/test_bert.py diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py new file mode 100644 index 00000000..754d1bbb --- /dev/null +++ b/fastNLP/models/bert.py @@ -0,0 +1,342 @@ +import copy +import json +import math +import os + +import torch +from torch import nn + +CONFIG_FILE = 'bert_config.json' +MODEL_WEIGHTS = 'pytorch_model.bin' + + +def gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads)) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, hidden_dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) + self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_act): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = ACT2FN[hidden_act] \ + if isinstance(hidden_act, str) else hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, + intermediate_size, hidden_act): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) + self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob, + intermediate_size, hidden_act): + super(BertEncoder, self).__init__() + layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, + intermediate_size, hidden_act) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + """ + + def __init__(self, vocab_size, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, **kwargs): + super(BertModel, self).__init__() + self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, + type_vocab_size, hidden_dropout_prob) + self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, + attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, + hidden_act) + self.pooler = BertPooler(hidden_size) + self.initializer_range = initializer_range + + self.apply(self.init_bert_weights) + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + @classmethod + def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): + # Load config + config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) + config = json.load(open(config_file, "r")) + # config = BertConfig.from_json_file(config_file) + # logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(*inputs, **config, **kwargs) + if state_dict is None: + weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + print("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + print("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + return model diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index ef3f3fe5..ef9d159d 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -4,8 +4,8 @@ import torch import torch.nn.functional as F from torch import nn -from fastNLP.modules.utils import mask_softmax from fastNLP.modules.dropout import TimestepDropout +from fastNLP.modules.utils import mask_softmax class Attention(torch.nn.Module): @@ -49,27 +49,27 @@ class DotAtte(nn.Module): class MultiHeadAtte(nn.Module): - def __init__(self, model_size, key_size, value_size, num_head, dropout=0.1): + def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): """ - :param model_size: int, 输入维度的大小。同时也是输出维度的大小。 + :param input_size: int, 输入维度的大小。同时也是输出维度的大小。 :param key_size: int, 每个head的维度大小。 :param value_size: int,每个head中value的维度。 :param num_head: int,head的数量。 :param dropout: float。 """ super(MultiHeadAtte, self).__init__() - self.input_size = model_size + self.input_size = input_size self.key_size = key_size self.value_size = value_size self.num_head = num_head in_size = key_size * num_head - self.q_in = nn.Linear(model_size, in_size) - self.k_in = nn.Linear(model_size, in_size) - self.v_in = nn.Linear(model_size, in_size) + self.q_in = nn.Linear(input_size, in_size) + self.k_in = nn.Linear(input_size, in_size) + self.v_in = nn.Linear(input_size, in_size) self.attention = DotAtte(key_size=key_size, value_size=value_size) - self.out = nn.Linear(value_size * num_head, model_size) + self.out = nn.Linear(value_size * num_head, input_size) self.drop = TimestepDropout(dropout) self.reset_parameters() @@ -108,6 +108,7 @@ class MultiHeadAtte(nn.Module): output = self.drop(self.out(atte)) return output + class Bi_Attention(nn.Module): def __init__(self): super(Bi_Attention, self).__init__() diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index 92ccc3fe..fe716bf7 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -1,4 +1,3 @@ -import torch from torch import nn from ..aggregator.attention import MultiHeadAtte diff --git a/test/models/test_bert.py b/test/models/test_bert.py new file mode 100644 index 00000000..b2899a89 --- /dev/null +++ b/test/models/test_bert.py @@ -0,0 +1,21 @@ +import unittest + +import torch + +from fastNLP.models.bert import BertModel + + +class TestBert(unittest.TestCase): + def test_bert_1(self): + # model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese") + model = BertModel(vocab_size=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + for layer in all_encoder_layers: + self.assertEqual(tuple(layer.shape), (2, 3, 768)) + self.assertEqual(tuple(pooled_output.shape), (2, 768))