From 0908c736ebc1a2afb9c36c908391943b08a45e95 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Thu, 29 Aug 2019 16:38:17 +0800 Subject: [PATCH] fix code in BertModel.from_pretrained and BertEmbedding --- fastNLP/embeddings/bert_embedding.py | 20 +++++++------------- fastNLP/modules/encoder/bert.py | 12 +++++++----- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index e15c15f5..d1a5514a 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -17,8 +17,8 @@ import numpy as np from itertools import chain from ..core.vocabulary import Vocabulary -from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR -from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer, _get_bert_dir +from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR +from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer from .contextual_embedding import ContextualEmbedding import warnings from ..core import logger @@ -77,15 +77,12 @@ class BertEmbedding(ContextualEmbedding): " faster speed.") warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" " faster speed.") - - # 根据model_dir_or_name检查是否存在并下载 - model_dir = _get_bert_dir(model_dir_or_name) self._word_sep_index = None if '[SEP]' in vocab: self._word_sep_index = vocab['[SEP]'] - self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, + self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) @@ -170,11 +167,8 @@ class BertWordPieceEncoder(nn.Module): def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, word_dropout=0, dropout=0, requires_grad: bool = False): super().__init__() - - # 根据model_dir_or_name检查是否存在并下载 - model_dir = _get_bert_dir(model_dir_or_name) - self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) + self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) self._sep_index = self.model._sep_index self._wordpiece_pad_index = self.model._wordpiece_pad_index self._wordpiece_unk_index = self.model._wordpiece_unknown_index @@ -269,12 +263,12 @@ class BertWordPieceEncoder(nn.Module): class _WordBertModel(nn.Module): - def __init__(self, model_dir: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', + def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): super().__init__() - self.tokenzier = BertTokenizer.from_pretrained(model_dir) - self.encoder = BertModel.from_pretrained(model_dir) + self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) + self.encoder = BertModel.from_pretrained(model_dir_or_name) self._max_position_embeddings = self.encoder.config.max_position_embeddings # 检查encoder_layer_number是否合理 encoder_layer_number = len(self.encoder.encoder.layer) diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 89a1b09d..e73a8172 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -143,7 +143,7 @@ def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): else: logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") - return model_dir + return str(model_dir) class BertLayerNorm(nn.Module): @@ -453,6 +453,9 @@ class BertModel(nn.Module): if state_dict is None: weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') state_dict = torch.load(weights_path, map_location='cpu') + else: + logger.error(f'Cannot load parameters through `state_dict` variable.') + raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') old_keys = [] new_keys = [] @@ -493,7 +496,7 @@ class BertModel(nn.Module): logger.warn("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) - logger.info(f"Load pre-trained BERT parameters from dir {pretrained_model_dir}.") + logger.info(f"Load pre-trained BERT parameters from file {weights_path}.") return model @@ -854,9 +857,8 @@ class _WordPieceBertModel(nn.Module): def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): super().__init__() - self.model_dir = _get_bert_dir(model_dir_or_name) - self.tokenzier = BertTokenizer.from_pretrained(self.model_dir) - self.encoder = BertModel.from_pretrained(self.model_dir) + self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) + self.encoder = BertModel.from_pretrained(model_dir_or_name) # 检查encoder_layer_number是否合理 encoder_layer_number = len(self.encoder.encoder.layer) self.layers = list(map(int, layers.split(',')))