diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index b1b1a200..e15c15f5 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -18,7 +18,7 @@ from itertools import chain from ..core.vocabulary import Vocabulary from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR -from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer +from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer, _get_bert_dir from .contextual_embedding import ContextualEmbedding import warnings from ..core import logger @@ -70,19 +70,16 @@ class BertEmbedding(ContextualEmbedding): pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False): super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) - - # 根据model_dir_or_name检查是否存在并下载 + if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): + logger.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" + " faster speed.") warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" " faster speed.") - model_url = _get_embedding_url('bert', model_dir_or_name.lower()) - model_dir = cached_path(model_url, name='embedding') - # 检查是否存在 - elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): - model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) - else: - raise ValueError(f"Cannot recognize {model_dir_or_name}.") + + # 根据model_dir_or_name检查是否存在并下载 + model_dir = _get_bert_dir(model_dir_or_name) self._word_sep_index = None if '[SEP]' in vocab: @@ -173,15 +170,9 @@ class BertWordPieceEncoder(nn.Module): def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, word_dropout=0, dropout=0, requires_grad: bool = False): super().__init__() - - if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: - model_url = _get_embedding_url('bert', model_dir_or_name.lower()) - model_dir = cached_path(model_url, name='embedding') - # 检查是否存在 - elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): - model_dir = model_dir_or_name - else: - raise ValueError(f"Cannot recognize {model_dir_or_name}.") + + # 根据model_dir_or_name检查是否存在并下载 + model_dir = _get_bert_dir(model_dir_or_name) self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) self._sep_index = self.model._sep_index diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 5026f48a..89a1b09d 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -18,13 +18,13 @@ import torch from torch import nn from ..utils import _get_file_name_base_on_postfix +from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ...core import logger CONFIG_FILE = 'bert_config.json' VOCAB_NAME = 'vocab.txt' - class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. """ @@ -133,6 +133,19 @@ def swish(x): ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} +def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): + if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: + model_url = _get_embedding_url('bert', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') + # 检查是否存在 + elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): + model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) + else: + logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") + raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") + return model_dir + + class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). @@ -339,27 +352,9 @@ class BertModel(nn.Module): BERT(Bidirectional Embedding Representations from Transformers). - 如果你想使用预训练好的权重矩阵,请在以下网址下载. - sources:: - - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", - 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin", - 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", - 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", - 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", - 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", - 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin" - - 用预训练权重矩阵来建立BERT模型:: - model = BertModel.from_pretrained("path/to/weights/directory") + model = BertModel.from_pretrained(model_dir_or_name) 用随机初始化权重矩阵来建立BERT模型:: @@ -440,11 +435,15 @@ class BertModel(nn.Module): return encoded_layers, pooled_output @classmethod - def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_dir_or_name, *inputs, **kwargs): state_dict = kwargs.get('state_dict', None) kwargs.pop('state_dict', None) kwargs.pop('cache_dir', None) kwargs.pop('from_tf', None) + + # get model dir from name or dir + pretrained_model_dir = _get_bert_dir(pretrained_model_dir_or_name) + # Load config config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') config = BertConfig.from_json_file(config_file) @@ -493,6 +492,8 @@ class BertModel(nn.Module): if len(unexpected_keys) > 0: logger.warn("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) + + logger.info(f"Load pre-trained BERT parameters from dir {pretrained_model_dir}.") return model @@ -562,7 +563,7 @@ class WordpieceTokenizer(object): output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) - if len(output_tokens)==0: #防止里面全是空格或者回车符号 + if len(output_tokens) == 0: # 防止里面全是空格或者回车符号 return [self.unk_token] return output_tokens @@ -673,14 +674,14 @@ class BasicTokenizer(object): # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. - if ((cp >= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # + if (((cp >= 0x4E00) and (cp <= 0x9FFF)) or # + ((cp >= 0x3400) and (cp <= 0x4DBF)) or # + ((cp >= 0x20000) and (cp <= 0x2A6DF)) or # + ((cp >= 0x2A700) and (cp <= 0x2B73F)) or # + ((cp >= 0x2B740) and (cp <= 0x2B81F)) or # + ((cp >= 0x2B820) and (cp <= 0x2CEAF)) or + ((cp >= 0xF900) and (cp <= 0xFAFF)) or # + ((cp >= 0x2F800) and (cp <= 0x2FA1F))): # return True return False @@ -730,8 +731,8 @@ def _is_punctuation(char): # Characters such as "^", "$", and "`" are not in the Unicode # Punctuation class but we treat them as punctuation anyways, for # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + if (((cp >= 33) and (cp <= 47)) or ((cp >= 58) and (cp <= 64)) or + ((cp >= 91) and (cp <= 96)) or ((cp >= 123) and (cp <= 126))): return True cat = unicodedata.category(char) if cat.startswith("P"): @@ -830,11 +831,11 @@ class BertTokenizer(object): return vocab_file @classmethod - def from_pretrained(cls, model_dir, *inputs, **kwargs): + def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): """ - 给定path,直接读取vocab. - + 给定模型的名字或者路径,直接读取vocab. """ + model_dir = _get_bert_dir(model_dir_or_name) pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt') logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path)) max_len = 512 @@ -843,17 +844,19 @@ class BertTokenizer(object): tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs) return tokenizer + class _WordPieceBertModel(nn.Module): """ 这个模块用于直接计算word_piece的结果. """ - def __init__(self, model_dir: str, layers: str = '-1', pooled_cls:bool=False): + def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): super().__init__() - self.tokenzier = BertTokenizer.from_pretrained(model_dir) - self.encoder = BertModel.from_pretrained(model_dir) + self.model_dir = _get_bert_dir(model_dir_or_name) + self.tokenzier = BertTokenizer.from_pretrained(self.model_dir) + self.encoder = BertModel.from_pretrained(self.model_dir) # 检查encoder_layer_number是否合理 encoder_layer_number = len(self.encoder.encoder.layer) self.layers = list(map(int, layers.split(','))) @@ -914,7 +917,7 @@ class _WordPieceBertModel(nn.Module): attn_masks = word_pieces.ne(self._wordpiece_pad_index) bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, - output_all_encoded_layers=True) + output_all_encoded_layers=True) # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1))) for l_index, l in enumerate(self.layers):