|
|
@@ -18,13 +18,13 @@ import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
from ..utils import _get_file_name_base_on_postfix |
|
|
|
from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR |
|
|
|
from ...core import logger |
|
|
|
|
|
|
|
CONFIG_FILE = 'bert_config.json' |
|
|
|
VOCAB_NAME = 'vocab.txt' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertConfig(object): |
|
|
|
"""Configuration class to store the configuration of a `BertModel`. |
|
|
|
""" |
|
|
@@ -133,6 +133,19 @@ def swish(x): |
|
|
|
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} |
|
|
|
|
|
|
|
|
|
|
|
def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): |
|
|
|
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: |
|
|
|
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) |
|
|
|
model_dir = cached_path(model_url, name='embedding') |
|
|
|
# 检查是否存在 |
|
|
|
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): |
|
|
|
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) |
|
|
|
else: |
|
|
|
logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") |
|
|
|
raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") |
|
|
|
return model_dir |
|
|
|
|
|
|
|
|
|
|
|
class BertLayerNorm(nn.Module): |
|
|
|
def __init__(self, hidden_size, eps=1e-12): |
|
|
|
"""Construct a layernorm module in the TF style (epsilon inside the square root). |
|
|
@@ -339,27 +352,9 @@ class BertModel(nn.Module): |
|
|
|
|
|
|
|
BERT(Bidirectional Embedding Representations from Transformers). |
|
|
|
|
|
|
|
如果你想使用预训练好的权重矩阵,请在以下网址下载. |
|
|
|
sources:: |
|
|
|
|
|
|
|
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", |
|
|
|
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", |
|
|
|
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", |
|
|
|
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", |
|
|
|
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", |
|
|
|
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", |
|
|
|
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", |
|
|
|
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin", |
|
|
|
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", |
|
|
|
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", |
|
|
|
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", |
|
|
|
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", |
|
|
|
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin" |
|
|
|
|
|
|
|
|
|
|
|
用预训练权重矩阵来建立BERT模型:: |
|
|
|
|
|
|
|
model = BertModel.from_pretrained("path/to/weights/directory") |
|
|
|
model = BertModel.from_pretrained(model_dir_or_name) |
|
|
|
|
|
|
|
用随机初始化权重矩阵来建立BERT模型:: |
|
|
|
|
|
|
@@ -440,11 +435,15 @@ class BertModel(nn.Module): |
|
|
|
return encoded_layers, pooled_output |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs): |
|
|
|
def from_pretrained(cls, pretrained_model_dir_or_name, *inputs, **kwargs): |
|
|
|
state_dict = kwargs.get('state_dict', None) |
|
|
|
kwargs.pop('state_dict', None) |
|
|
|
kwargs.pop('cache_dir', None) |
|
|
|
kwargs.pop('from_tf', None) |
|
|
|
|
|
|
|
# get model dir from name or dir |
|
|
|
pretrained_model_dir = _get_bert_dir(pretrained_model_dir_or_name) |
|
|
|
|
|
|
|
# Load config |
|
|
|
config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') |
|
|
|
config = BertConfig.from_json_file(config_file) |
|
|
@@ -493,6 +492,8 @@ class BertModel(nn.Module): |
|
|
|
if len(unexpected_keys) > 0: |
|
|
|
logger.warn("Weights from pretrained model not used in {}: {}".format( |
|
|
|
model.__class__.__name__, unexpected_keys)) |
|
|
|
|
|
|
|
logger.info(f"Load pre-trained BERT parameters from dir {pretrained_model_dir}.") |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
@@ -562,7 +563,7 @@ class WordpieceTokenizer(object): |
|
|
|
output_tokens.append(self.unk_token) |
|
|
|
else: |
|
|
|
output_tokens.extend(sub_tokens) |
|
|
|
if len(output_tokens)==0: #防止里面全是空格或者回车符号 |
|
|
|
if len(output_tokens) == 0: # 防止里面全是空格或者回车符号 |
|
|
|
return [self.unk_token] |
|
|
|
return output_tokens |
|
|
|
|
|
|
@@ -673,14 +674,14 @@ class BasicTokenizer(object): |
|
|
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write |
|
|
|
# space-separated words, so they are not treated specially and handled |
|
|
|
# like the all of the other languages. |
|
|
|
if ((cp >= 0x4E00 and cp <= 0x9FFF) or # |
|
|
|
(cp >= 0x3400 and cp <= 0x4DBF) or # |
|
|
|
(cp >= 0x20000 and cp <= 0x2A6DF) or # |
|
|
|
(cp >= 0x2A700 and cp <= 0x2B73F) or # |
|
|
|
(cp >= 0x2B740 and cp <= 0x2B81F) or # |
|
|
|
(cp >= 0x2B820 and cp <= 0x2CEAF) or |
|
|
|
(cp >= 0xF900 and cp <= 0xFAFF) or # |
|
|
|
(cp >= 0x2F800 and cp <= 0x2FA1F)): # |
|
|
|
if (((cp >= 0x4E00) and (cp <= 0x9FFF)) or # |
|
|
|
((cp >= 0x3400) and (cp <= 0x4DBF)) or # |
|
|
|
((cp >= 0x20000) and (cp <= 0x2A6DF)) or # |
|
|
|
((cp >= 0x2A700) and (cp <= 0x2B73F)) or # |
|
|
|
((cp >= 0x2B740) and (cp <= 0x2B81F)) or # |
|
|
|
((cp >= 0x2B820) and (cp <= 0x2CEAF)) or |
|
|
|
((cp >= 0xF900) and (cp <= 0xFAFF)) or # |
|
|
|
((cp >= 0x2F800) and (cp <= 0x2FA1F))): # |
|
|
|
return True |
|
|
|
|
|
|
|
return False |
|
|
@@ -730,8 +731,8 @@ def _is_punctuation(char): |
|
|
|
# Characters such as "^", "$", and "`" are not in the Unicode |
|
|
|
# Punctuation class but we treat them as punctuation anyways, for |
|
|
|
# consistency. |
|
|
|
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or |
|
|
|
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): |
|
|
|
if (((cp >= 33) and (cp <= 47)) or ((cp >= 58) and (cp <= 64)) or |
|
|
|
((cp >= 91) and (cp <= 96)) or ((cp >= 123) and (cp <= 126))): |
|
|
|
return True |
|
|
|
cat = unicodedata.category(char) |
|
|
|
if cat.startswith("P"): |
|
|
@@ -830,11 +831,11 @@ class BertTokenizer(object): |
|
|
|
return vocab_file |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def from_pretrained(cls, model_dir, *inputs, **kwargs): |
|
|
|
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): |
|
|
|
""" |
|
|
|
给定path,直接读取vocab. |
|
|
|
|
|
|
|
给定模型的名字或者路径,直接读取vocab. |
|
|
|
""" |
|
|
|
model_dir = _get_bert_dir(model_dir_or_name) |
|
|
|
pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt') |
|
|
|
logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path)) |
|
|
|
max_len = 512 |
|
|
@@ -843,17 +844,19 @@ class BertTokenizer(object): |
|
|
|
tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs) |
|
|
|
return tokenizer |
|
|
|
|
|
|
|
|
|
|
|
class _WordPieceBertModel(nn.Module): |
|
|
|
""" |
|
|
|
这个模块用于直接计算word_piece的结果. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model_dir: str, layers: str = '-1', pooled_cls:bool=False): |
|
|
|
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self.tokenzier = BertTokenizer.from_pretrained(model_dir) |
|
|
|
self.encoder = BertModel.from_pretrained(model_dir) |
|
|
|
self.model_dir = _get_bert_dir(model_dir_or_name) |
|
|
|
self.tokenzier = BertTokenizer.from_pretrained(self.model_dir) |
|
|
|
self.encoder = BertModel.from_pretrained(self.model_dir) |
|
|
|
# 检查encoder_layer_number是否合理 |
|
|
|
encoder_layer_number = len(self.encoder.encoder.layer) |
|
|
|
self.layers = list(map(int, layers.split(','))) |
|
|
@@ -914,7 +917,7 @@ class _WordPieceBertModel(nn.Module): |
|
|
|
|
|
|
|
attn_masks = word_pieces.ne(self._wordpiece_pad_index) |
|
|
|
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, |
|
|
|
output_all_encoded_layers=True) |
|
|
|
output_all_encoded_layers=True) |
|
|
|
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size |
|
|
|
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1))) |
|
|
|
for l_index, l in enumerate(self.layers): |
|
|
|