Browse Source

Update BertModel.from_pretrained function. Now can pass a model_dir_or_name instead of model_dir.

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
39de27f472
2 changed files with 52 additions and 58 deletions
  1. +10
    -19
      fastNLP/embeddings/bert_embedding.py
  2. +42
    -39
      fastNLP/modules/encoder/bert.py

+ 10
- 19
fastNLP/embeddings/bert_embedding.py View File

@@ -18,7 +18,7 @@ from itertools import chain

from ..core.vocabulary import Vocabulary
from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer, _get_bert_dir
from .contextual_embedding import ContextualEmbedding
import warnings
from ..core import logger
@@ -70,19 +70,16 @@ class BertEmbedding(ContextualEmbedding):
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False):
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
# 根据model_dir_or_name检查是否存在并下载

if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'):
logger.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
" faster speed.")
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
" faster speed.")
model_url = _get_embedding_url('bert', model_dir_or_name.lower())
model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

# 根据model_dir_or_name检查是否存在并下载
model_dir = _get_bert_dir(model_dir_or_name)
self._word_sep_index = None
if '[SEP]' in vocab:
@@ -173,15 +170,9 @@ class BertWordPieceEncoder(nn.Module):
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool = False):
super().__init__()
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
model_url = _get_embedding_url('bert', model_dir_or_name.lower())
model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
model_dir = model_dir_or_name
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

# 根据model_dir_or_name检查是否存在并下载
model_dir = _get_bert_dir(model_dir_or_name)
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls)
self._sep_index = self.model._sep_index


+ 42
- 39
fastNLP/modules/encoder/bert.py View File

@@ -18,13 +18,13 @@ import torch
from torch import nn

from ..utils import _get_file_name_base_on_postfix
from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ...core import logger

CONFIG_FILE = 'bert_config.json'
VOCAB_NAME = 'vocab.txt'



class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
@@ -133,6 +133,19 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}


def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'):
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
model_url = _get_embedding_url('bert', model_dir_or_name.lower())
model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
else:
logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
return model_dir


class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
@@ -339,27 +352,9 @@ class BertModel(nn.Module):

BERT(Bidirectional Embedding Representations from Transformers).

如果你想使用预训练好的权重矩阵,请在以下网址下载.
sources::

'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin"


用预训练权重矩阵来建立BERT模型::

model = BertModel.from_pretrained("path/to/weights/directory")
model = BertModel.from_pretrained(model_dir_or_name)

用随机初始化权重矩阵来建立BERT模型::

@@ -440,11 +435,15 @@ class BertModel(nn.Module):
return encoded_layers, pooled_output

@classmethod
def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs):
def from_pretrained(cls, pretrained_model_dir_or_name, *inputs, **kwargs):
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
kwargs.pop('cache_dir', None)
kwargs.pop('from_tf', None)

# get model dir from name or dir
pretrained_model_dir = _get_bert_dir(pretrained_model_dir_or_name)

# Load config
config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json')
config = BertConfig.from_json_file(config_file)
@@ -493,6 +492,8 @@ class BertModel(nn.Module):
if len(unexpected_keys) > 0:
logger.warn("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))

logger.info(f"Load pre-trained BERT parameters from dir {pretrained_model_dir}.")
return model


@@ -562,7 +563,7 @@ class WordpieceTokenizer(object):
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
if len(output_tokens)==0: #防止里面全是空格或者回车符号
if len(output_tokens) == 0: # 防止里面全是空格或者回车符号
return [self.unk_token]
return output_tokens

@@ -673,14 +674,14 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
if (((cp >= 0x4E00) and (cp <= 0x9FFF)) or #
((cp >= 0x3400) and (cp <= 0x4DBF)) or #
((cp >= 0x20000) and (cp <= 0x2A6DF)) or #
((cp >= 0x2A700) and (cp <= 0x2B73F)) or #
((cp >= 0x2B740) and (cp <= 0x2B81F)) or #
((cp >= 0x2B820) and (cp <= 0x2CEAF)) or
((cp >= 0xF900) and (cp <= 0xFAFF)) or #
((cp >= 0x2F800) and (cp <= 0x2FA1F))): #
return True

return False
@@ -730,8 +731,8 @@ def _is_punctuation(char):
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
if (((cp >= 33) and (cp <= 47)) or ((cp >= 58) and (cp <= 64)) or
((cp >= 91) and (cp <= 96)) or ((cp >= 123) and (cp <= 126))):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
@@ -830,11 +831,11 @@ class BertTokenizer(object):
return vocab_file

@classmethod
def from_pretrained(cls, model_dir, *inputs, **kwargs):
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
"""
给定path,直接读取vocab.

给定模型的名字或者路径,直接读取vocab.
"""
model_dir = _get_bert_dir(model_dir_or_name)
pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt')
logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path))
max_len = 512
@@ -843,17 +844,19 @@ class BertTokenizer(object):
tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs)
return tokenizer


class _WordPieceBertModel(nn.Module):
"""
这个模块用于直接计算word_piece的结果.

"""

def __init__(self, model_dir: str, layers: str = '-1', pooled_cls:bool=False):
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False):
super().__init__()

self.tokenzier = BertTokenizer.from_pretrained(model_dir)
self.encoder = BertModel.from_pretrained(model_dir)
self.model_dir = _get_bert_dir(model_dir_or_name)
self.tokenzier = BertTokenizer.from_pretrained(self.model_dir)
self.encoder = BertModel.from_pretrained(self.model_dir)
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
self.layers = list(map(int, layers.split(',')))
@@ -914,7 +917,7 @@ class _WordPieceBertModel(nn.Module):

attn_masks = word_pieces.ne(self._wordpiece_pad_index)
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
output_all_encoded_layers=True)
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1)))
for l_index, l in enumerate(self.layers):


Loading…
Cancel
Save