From 7de69b60b896cee013564b254f21b245395da9d4 Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 23 Apr 2019 17:56:16 +0800 Subject: [PATCH] update bert --- fastNLP/component/__init__.py | 1 + fastNLP/component/bert_tokenizer.py | 378 +++++++++++++++++ fastNLP/models/bert.py | 637 ++++++++++++---------------- fastNLP/modules/encoder/__init__.py | 4 +- fastNLP/modules/encoder/bert.py | 362 ++++++++++++++++ 5 files changed, 1027 insertions(+), 355 deletions(-) create mode 100644 fastNLP/component/__init__.py create mode 100644 fastNLP/component/bert_tokenizer.py create mode 100644 fastNLP/modules/encoder/bert.py diff --git a/fastNLP/component/__init__.py b/fastNLP/component/__init__.py new file mode 100644 index 00000000..c6784aef --- /dev/null +++ b/fastNLP/component/__init__.py @@ -0,0 +1 @@ +from .bert_tokenizer import BertTokenizer diff --git a/fastNLP/component/bert_tokenizer.py b/fastNLP/component/bert_tokenizer.py new file mode 100644 index 00000000..6354076d --- /dev/null +++ b/fastNLP/component/bert_tokenizer.py @@ -0,0 +1,378 @@ +""" +bert_tokenizer.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. +""" +import collections +import os +import unicodedata +from io import open + + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, +} +VOCAB_NAME = 'vocab.txt' + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + print( + "WARNING!\n\"" + "Token indices sequence length is longer than the specified maximum " + "sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + print("Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + return vocab_file + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): + print("The pre-trained model you are loading is a cased model but you have not set " + "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " + "you may want to check this behavior.") + kwargs['do_lower_case'] = False + elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): + print("The pre-trained model you are loading is an uncased model but you have set " + "`do_lower_case` to False. We are setting `do_lower_case=True` for you " + "but you may want to check this behavior.") + kwargs['do_lower_case'] = True + else: + vocab_file = pretrained_model_name_or_path + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + # redirect to the cache, if necessary + resolved_vocab_file = vocab_file + print("loading vocabulary file {}".format(vocab_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + return tokenizer + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index e87f6f5d..98d81025 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -2,361 +2,290 @@ bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. """ -import copy -import json -import math -import os - import torch from torch import nn -CONFIG_FILE = 'bert_config.json' -MODEL_WEIGHTS = 'pytorch_model.bin' - - -def gelu(x): - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - -def swish(x): - return x * torch.sigmoid(x) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} - - -class BertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): - super(BertLayerNorm, self).__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward(self, x): - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x + self.bias - - -class BertEmbeddings(nn.Module): - def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): - super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) - - def forward(self, input_ids, token_type_ids=None): - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): - super(BertSelfAttention, self).__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (hidden_size, num_attention_heads)) - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(hidden_size, self.all_head_size) - self.key = nn.Linear(hidden_size, self.all_head_size) - self.value = nn.Linear(hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(attention_probs_dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, attention_mask): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - return context_layer - - -class BertSelfOutput(nn.Module): - def __init__(self, hidden_size, hidden_dropout_prob): - super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): - super(BertAttention, self).__init__() - self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) - self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) - - def forward(self, input_tensor, attention_mask): - self_output = self.self(input_tensor, attention_mask) - attention_output = self.output(self_output, input_tensor) - return attention_output - - -class BertIntermediate(nn.Module): - def __init__(self, hidden_size, intermediate_size, hidden_act): - super(BertIntermediate, self).__init__() - self.dense = nn.Linear(hidden_size, intermediate_size) - self.intermediate_act_fn = ACT2FN[hidden_act] \ - if isinstance(hidden_act, str) else hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): - super(BertOutput, self).__init__() - self.dense = nn.Linear(intermediate_size, hidden_size) - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertLayer(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, - intermediate_size, hidden_act): - super(BertLayer, self).__init__() - self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, - hidden_dropout_prob) - self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) - self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) - - def forward(self, hidden_states, attention_mask): - attention_output = self.attention(hidden_states, attention_mask) - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - -class BertEncoder(nn.Module): - def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, - hidden_dropout_prob, - intermediate_size, hidden_act): - super(BertEncoder, self).__init__() - layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, - intermediate_size, hidden_act) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) - - def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): - all_encoder_layers = [] - for layer_module in self.layer: - hidden_states = layer_module(hidden_states, attention_mask) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - if not output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - return all_encoder_layers - - -class BertPooler(nn.Module): - def __init__(self, hidden_size): - super(BertPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertModel(nn.Module): - """Bidirectional Embedding Representations from Transformers. - - If you want to use pre-trained weights, please download from the following sources provided by pytorch-pretrained-BERT. - sources:: - - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", - - - Construct a BERT model with pre-trained weights:: - - model = BertModel.from_pretrained("path/to/weights/directory") - +from .base_model import BaseModel +from fastNLP.modules.encoder import BertModel + + +class BertForSequenceClassification(BaseModel): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + num_labels = 2 + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` """ - - def __init__(self, vocab_size, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, **kwargs): - super(BertModel, self).__init__() - self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, - type_vocab_size, hidden_dropout_prob) - self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, - attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, - hidden_act) - self.pooler = BertPooler(hidden_size) - self.initializer_range = initializer_range - - self.apply(self.init_bert_weights) - - def init_bert_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings(input_ids, token_type_ids) - encoded_layers = self.encoder(embedding_output, - extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers) - sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - return encoded_layers, pooled_output - - @classmethod - def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): - # Load config - config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) - config = json.load(open(config_file, "r")) - # config = BertConfig.from_json_file(config_file) - # logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(*inputs, **config, **kwargs) - if state_dict is None: - weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) - state_dict = torch.load(weights_path) - - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + '.') - - load(model, prefix='' if hasattr(model, 'bert') else 'bert.') - if len(missing_keys) > 0: - print("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys)) - if len(unexpected_keys) > 0: - print("Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, unexpected_keys)) - return model + def __init__(self, config, num_labels, bert_dir): + super(BertForSequenceClassification, self).__init__() + self.num_labels = num_labels + self.bert = BertModel.from_pretrained(bert_dir) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return {"pred": logits, "loss": loss} + else: + return {"pred": logits} + + def predict(self, input_ids, token_type_ids=None, attention_mask=None): + logits = self.forward(input_ids, token_type_ids, attention_mask) + return {"pred": torch.argmax(logits, dim=-1)} + + +class BertForMultipleChoice(BaseModel): + """BERT model for multiple choice tasks. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_choices`: the number of classes for the classifier. Default = 2. + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` + and type 1 corresponds to a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_choices]. + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) + input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) + token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + num_choices = 2 + model = BertForMultipleChoice(config, num_choices, bert_dir) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_choices, bert_dir): + super(BertForMultipleChoice, self).__init__() + self.num_choices = num_choices + self.bert = BertModel.from_pretrained(bert_dir) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, self.num_choices) + + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + return {"pred": reshaped_logits, "loss": loss} + else: + return {"pred": reshaped_logits} + + def predict(self, input_ids, token_type_ids=None, attention_mask=None): + logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"] + return {"pred": torch.argmax(logits, dim=-1)} + + +class BertForTokenClassification(BaseModel): + """BERT model for token-level classification. + This module is composed of the BERT model with a linear layer on top of + the full hidden state of the last layer. + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + `bert_dir`: a dir which contains the bert parameters within file `pytorch_model.bin` + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [0, ..., num_labels]. + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + num_labels = 2 + bert_dir = 'your-bert-file-dir' + model = BertForTokenClassification(config, num_labels, bert_dir) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_labels, bert_dir): + super(BertForTokenClassification, self).__init__() + self.num_labels = num_labels + self.bert = BertModel.from_pretrained(bert_dir) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return {"pred": logits, "loss": loss} + else: + return {"pred": logits} + + def predict(self, input_ids, token_type_ids=None, attention_mask=None): + logits = self.forward(input_ids, token_type_ids, attention_mask)["pred"] + return {"pred": torch.argmax(logits, dim=-1)} + + +class BertForQuestionAnswering(BaseModel): + """BERT model for Question Answering (span extraction). + This module is composed of the BERT model with a linear layer on top of + the sequence output that computes start_logits and end_logits + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `bert_dir`: a dir which contains the bert parameters within file `pytorch_model.bin` + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + Outputs: + if `start_positions` and `end_positions` are not `None`: + Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. + if `start_positions` or `end_positions` is `None`: + Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end + position tokens of shape [batch_size, sequence_length]. + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + bert_dir = 'your-bert-file-dir' + model = BertForQuestionAnswering(config, bert_dir) + start_logits, end_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, bert_dir): + super(BertForQuestionAnswering, self).__init__() + self.bert = BertModel.from_pretrained(bert_dir) + # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version + # self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + return {"loss": total_loss} + else: + return {"pred1": start_logits, "pred2": end_logits} + + def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): + logits = self.forward(input_ids, token_type_ids, attention_mask) + start_logits = logits["pred1"] + end_logits = logits["pred2"] + return {"pred1": torch.argmax(start_logits, dim=-1), "pred2": torch.argmax(end_logits, dim=-1)} diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index b00a0ae9..56b9ca59 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -3,9 +3,11 @@ from .conv_maxpool import ConvMaxpool from .embedding import Embedding from .linear import Linear from .lstm import LSTM +from .bert import BertModel __all__ = ["LSTM", "Embedding", "Linear", "Conv", - "ConvMaxpool"] + "ConvMaxpool", + "BertModel"] diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py new file mode 100644 index 00000000..e87f6f5d --- /dev/null +++ b/fastNLP/modules/encoder/bert.py @@ -0,0 +1,362 @@ +""" +bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. + +""" +import copy +import json +import math +import os + +import torch +from torch import nn + +CONFIG_FILE = 'bert_config.json' +MODEL_WEIGHTS = 'pytorch_model.bin' + + +def gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads)) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, hidden_dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) + self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_act): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = ACT2FN[hidden_act] \ + if isinstance(hidden_act, str) else hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, + intermediate_size, hidden_act): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) + self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob, + intermediate_size, hidden_act): + super(BertEncoder, self).__init__() + layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, + intermediate_size, hidden_act) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + """Bidirectional Embedding Representations from Transformers. + + If you want to use pre-trained weights, please download from the following sources provided by pytorch-pretrained-BERT. + sources:: + + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", + + + Construct a BERT model with pre-trained weights:: + + model = BertModel.from_pretrained("path/to/weights/directory") + + """ + + def __init__(self, vocab_size, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, **kwargs): + super(BertModel, self).__init__() + self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, + type_vocab_size, hidden_dropout_prob) + self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, + attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, + hidden_act) + self.pooler = BertPooler(hidden_size) + self.initializer_range = initializer_range + + self.apply(self.init_bert_weights) + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + @classmethod + def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): + # Load config + config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) + config = json.load(open(config_file, "r")) + # config = BertConfig.from_json_file(config_file) + # logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(*inputs, **config, **kwargs) + if state_dict is None: + weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + print("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + print("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + return model