Browse Source

fix code in BertModel.from_pretrained and BertEmbedding

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
0908c736eb
2 changed files with 14 additions and 18 deletions
  1. +7
    -13
      fastNLP/embeddings/bert_embedding.py
  2. +7
    -5
      fastNLP/modules/encoder/bert.py

+ 7
- 13
fastNLP/embeddings/bert_embedding.py View File

@@ -17,8 +17,8 @@ import numpy as np
from itertools import chain

from ..core.vocabulary import Vocabulary
from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer, _get_bert_dir
from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer
from .contextual_embedding import ContextualEmbedding
import warnings
from ..core import logger
@@ -77,15 +77,12 @@ class BertEmbedding(ContextualEmbedding):
" faster speed.")
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
" faster speed.")

# 根据model_dir_or_name检查是否存在并下载
model_dir = _get_bert_dir(model_dir_or_name)
self._word_sep_index = None
if '[SEP]' in vocab:
self._word_sep_index = vocab['[SEP]']
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers,
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2)
@@ -170,11 +167,8 @@ class BertWordPieceEncoder(nn.Module):
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool = False):
super().__init__()

# 根据model_dir_or_name检查是否存在并下载
model_dir = _get_bert_dir(model_dir_or_name)
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls)
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls)
self._sep_index = self.model._sep_index
self._wordpiece_pad_index = self.model._wordpiece_pad_index
self._wordpiece_unk_index = self.model._wordpiece_unknown_index
@@ -269,12 +263,12 @@ class BertWordPieceEncoder(nn.Module):


class _WordBertModel(nn.Module):
def __init__(self, model_dir: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
super().__init__()
self.tokenzier = BertTokenizer.from_pretrained(model_dir)
self.encoder = BertModel.from_pretrained(model_dir)
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name)
self._max_position_embeddings = self.encoder.config.max_position_embeddings
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)


+ 7
- 5
fastNLP/modules/encoder/bert.py View File

@@ -143,7 +143,7 @@ def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'):
else:
logger.error(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.")
return model_dir
return str(model_dir)


class BertLayerNorm(nn.Module):
@@ -453,6 +453,9 @@ class BertModel(nn.Module):
if state_dict is None:
weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin')
state_dict = torch.load(weights_path, map_location='cpu')
else:
logger.error(f'Cannot load parameters through `state_dict` variable.')
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.')

old_keys = []
new_keys = []
@@ -493,7 +496,7 @@ class BertModel(nn.Module):
logger.warn("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))

logger.info(f"Load pre-trained BERT parameters from dir {pretrained_model_dir}.")
logger.info(f"Load pre-trained BERT parameters from file {weights_path}.")
return model


@@ -854,9 +857,8 @@ class _WordPieceBertModel(nn.Module):
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False):
super().__init__()

self.model_dir = _get_bert_dir(model_dir_or_name)
self.tokenzier = BertTokenizer.from_pretrained(self.model_dir)
self.encoder = BertModel.from_pretrained(self.model_dir)
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name)
self.encoder = BertModel.from_pretrained(model_dir_or_name)
# 检查encoder_layer_number是否合理
encoder_layer_number = len(self.encoder.encoder.layer)
self.layers = list(map(int, layers.split(',')))


Loading…
Cancel
Save