|
|
@@ -17,8 +17,8 @@ import numpy as np |
|
|
|
from itertools import chain |
|
|
|
|
|
|
|
from ..core.vocabulary import Vocabulary |
|
|
|
from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR |
|
|
|
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer, _get_bert_dir |
|
|
|
from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR |
|
|
|
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer |
|
|
|
from .contextual_embedding import ContextualEmbedding |
|
|
|
import warnings |
|
|
|
from ..core import logger |
|
|
@@ -77,15 +77,12 @@ class BertEmbedding(ContextualEmbedding): |
|
|
|
" faster speed.") |
|
|
|
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" |
|
|
|
" faster speed.") |
|
|
|
|
|
|
|
# 根据model_dir_or_name检查是否存在并下载 |
|
|
|
model_dir = _get_bert_dir(model_dir_or_name) |
|
|
|
|
|
|
|
self._word_sep_index = None |
|
|
|
if '[SEP]' in vocab: |
|
|
|
self._word_sep_index = vocab['[SEP]'] |
|
|
|
|
|
|
|
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, |
|
|
|
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, |
|
|
|
pool_method=pool_method, include_cls_sep=include_cls_sep, |
|
|
|
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) |
|
|
|
|
|
|
@@ -170,11 +167,8 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, |
|
|
|
word_dropout=0, dropout=0, requires_grad: bool = False): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
# 根据model_dir_or_name检查是否存在并下载 |
|
|
|
model_dir = _get_bert_dir(model_dir_or_name) |
|
|
|
|
|
|
|
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) |
|
|
|
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) |
|
|
|
self._sep_index = self.model._sep_index |
|
|
|
self._wordpiece_pad_index = self.model._wordpiece_pad_index |
|
|
|
self._wordpiece_unk_index = self.model._wordpiece_unknown_index |
|
|
@@ -269,12 +263,12 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
class _WordBertModel(nn.Module): |
|
|
|
def __init__(self, model_dir: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', |
|
|
|
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', |
|
|
|
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): |
|
|
|
super().__init__() |
|
|
|
|
|
|
|
self.tokenzier = BertTokenizer.from_pretrained(model_dir) |
|
|
|
self.encoder = BertModel.from_pretrained(model_dir) |
|
|
|
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) |
|
|
|
self.encoder = BertModel.from_pretrained(model_dir_or_name) |
|
|
|
self._max_position_embeddings = self.encoder.config.max_position_embeddings |
|
|
|
# 检查encoder_layer_number是否合理 |
|
|
|
encoder_layer_number = len(self.encoder.encoder.layer) |
|
|
|