|
|
@@ -129,14 +129,14 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', |
|
|
|
pooled_cls: bool = False, requires_grad: bool=False): |
|
|
|
super().__init__() |
|
|
|
PRETRAIN_URL = _get_base_url('bert') |
|
|
|
|
|
|
|
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: |
|
|
|
PRETRAIN_URL = _get_base_url('bert') |
|
|
|
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] |
|
|
|
model_url = PRETRAIN_URL + model_name |
|
|
|
model_dir = cached_path(model_url) |
|
|
|
# 检查是否存在 |
|
|
|
elif os.path.isdir(model_dir_or_name): |
|
|
|
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): |
|
|
|
model_dir = model_dir_or_name |
|
|
|
else: |
|
|
|
raise ValueError(f"Cannot recognize {model_dir_or_name}.") |
|
|
@@ -166,16 +166,25 @@ class BertWordPieceEncoder(nn.Module): |
|
|
|
def embed_size(self): |
|
|
|
return self._embed_size |
|
|
|
|
|
|
|
def index_datasets(self, *datasets, field_name): |
|
|
|
@property |
|
|
|
def embedding_dim(self): |
|
|
|
return self._embed_size |
|
|
|
|
|
|
|
@property |
|
|
|
def num_embedding(self): |
|
|
|
return self.model.encoder.config.vocab_size |
|
|
|
|
|
|
|
def index_datasets(self, *datasets, field_name, add_cls_sep=True): |
|
|
|
""" |
|
|
|
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 |
|
|
|
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 |
|
|
|
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 |
|
|
|
bert的pad value。 |
|
|
|
|
|
|
|
:param datasets: DataSet对象 |
|
|
|
:param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 |
|
|
|
:param DataSet datasets: DataSet对象 |
|
|
|
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 |
|
|
|
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 |
|
|
|
:return: |
|
|
|
""" |
|
|
|
self.model.index_dataset(*datasets, field_name=field_name) |
|
|
|
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) |
|
|
|
|
|
|
|
def forward(self, word_pieces, token_type_ids=None): |
|
|
|
""" |
|
|
|