From 4718804e2208e6642d8e2440ddcfa9998296c3e9 Mon Sep 17 00:00:00 2001 From: yh Date: Fri, 19 Jul 2019 17:33:41 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E4=BF=AE=E6=94=B9=E5=86=85=E9=83=A8?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E4=BD=8D=E7=BD=AE;=202.=E4=BF=AE=E6=94=B9Ber?= =?UTF-8?q?tWordPieceEncoder=E7=9A=84=E9=83=A8=E5=88=86=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/_parallel_utils.py | 14 ++++++++++++++ fastNLP/core/tester.py | 2 +- fastNLP/core/trainer.py | 3 ++- fastNLP/core/utils.py | 12 ------------ fastNLP/embeddings/bert_embedding.py | 25 +++++++++++++++++-------- fastNLP/embeddings/char_embedding.py | 4 ++-- fastNLP/embeddings/embedding.py | 7 ++++++- fastNLP/modules/encoder/bert.py | 11 ++++++----- 8 files changed, 48 insertions(+), 30 deletions(-) diff --git a/fastNLP/core/_parallel_utils.py b/fastNLP/core/_parallel_utils.py index 4a7757d3..6b24d9f9 100644 --- a/fastNLP/core/_parallel_utils.py +++ b/fastNLP/core/_parallel_utils.py @@ -1,6 +1,7 @@ import threading import torch +from torch import nn from torch.nn.parallel.parallel_apply import get_a_var from torch.nn.parallel.scatter_gather import scatter_kwargs, gather @@ -86,3 +87,16 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) return gather(outputs, output_device) return wrapper + + +def _model_contains_inner_module(model): + """ + + :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, + nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 + :return: bool + """ + if isinstance(model, nn.Module): + if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): + return True + return False \ No newline at end of file diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 3d672ccc..067ff30c 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -47,7 +47,7 @@ from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device from ._parallel_utils import _data_parallel_wrapper -from .utils import _model_contains_inner_module +from fastNLP.core._parallel_utils import _model_contains_inner_module from functools import partial __all__ = [ diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 09e8a437..4ec3d0f4 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -352,7 +352,8 @@ from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device -from .utils import _model_contains_inner_module +from fastNLP.core._parallel_utils import _model_contains_inner_module + class Trainer(object): """ diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index b849687b..8483f9f2 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -187,18 +187,6 @@ def _save_model(model, model_name, save_dir, only_param=False): torch.save(model, model_path) model.to(_model_device) -def _model_contains_inner_module(model): - """ - - :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, - nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 - :return: bool - """ - if isinstance(model, nn.Module): - if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): - return True - return False - def _move_model_to_device(model, device): """ 将model移动到device diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 010b464d..21944570 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -129,14 +129,14 @@ class BertWordPieceEncoder(nn.Module): def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', pooled_cls: bool = False, requires_grad: bool=False): super().__init__() - PRETRAIN_URL = _get_base_url('bert') if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: + PRETRAIN_URL = _get_base_url('bert') model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] model_url = PRETRAIN_URL + model_name model_dir = cached_path(model_url) # 检查是否存在 - elif os.path.isdir(model_dir_or_name): + elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): model_dir = model_dir_or_name else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") @@ -166,16 +166,25 @@ class BertWordPieceEncoder(nn.Module): def embed_size(self): return self._embed_size - def index_datasets(self, *datasets, field_name): + @property + def embedding_dim(self): + return self._embed_size + + @property + def num_embedding(self): + return self.model.encoder.config.vocab_size + + def index_datasets(self, *datasets, field_name, add_cls_sep=True): """ - 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 - [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 + 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 + bert的pad value。 - :param datasets: DataSet对象 - :param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 + :param DataSet datasets: DataSet对象 + :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 + :param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 :return: """ - self.model.index_dataset(*datasets, field_name=field_name) + self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) def forward(self, word_pieces, token_type_ids=None): """ diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py index b9e6659e..b670313e 100644 --- a/fastNLP/embeddings/char_embedding.py +++ b/fastNLP/embeddings/char_embedding.py @@ -92,7 +92,7 @@ class CNNCharEmbedding(TokenEmbedding): for i in range(len(kernel_sizes))]) self._embed_size = embed_size self.fc = nn.Linear(sum(filter_nums), embed_size) - self.init_param() + self.reset_parameters() def forward(self, words): """ @@ -149,7 +149,7 @@ class CNNCharEmbedding(TokenEmbedding): continue param.requires_grad = value - def init_param(self): + def reset_parameters(self): for name, param in self.named_parameters(): if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset continue diff --git a/fastNLP/embeddings/embedding.py b/fastNLP/embeddings/embedding.py index 111bacd0..a9f228fb 100644 --- a/fastNLP/embeddings/embedding.py +++ b/fastNLP/embeddings/embedding.py @@ -41,7 +41,12 @@ class Embedding(nn.Module): self.dropout = nn.Dropout(dropout) if not isinstance(self.embed, TokenEmbedding): - self._embed_size = self.embed.weight.size(1) + if hasattr(self, 'embed_size'): + self._embed_size = self.embed.embed_size + elif hasattr(self, 'embedding_dim'): + self._embed_size = self.embed.embedding_dim + else: + self._embed_size = self.embed.weight.size(1) if word_dropout>0 and not isinstance(unk_index, int): raise ValueError("When drop word is set, you need to pass in the unk_index.") else: diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 1bd810a8..9a990d9d 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -871,7 +871,7 @@ class _WordPieceBertModel(nn.Module): self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece self.pooled_cls = pooled_cls - def index_dataset(self, *datasets, field_name): + def index_dataset(self, *datasets, field_name, add_cls_sep=True): """ 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 @@ -887,10 +887,11 @@ class _WordPieceBertModel(nn.Module): tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word) word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) word_pieces.extend(word_piece_ids) - if word_pieces[0] != self._cls_index: - word_pieces.insert(0, self._cls_index) - if word_pieces[-1] != self._sep_index: - word_pieces.insert(-1, self._sep_index) + if add_cls_sep: + if word_pieces[0] != self._cls_index: + word_pieces.insert(0, self._cls_index) + if word_pieces[-1] != self._sep_index: + word_pieces.insert(-1, self._sep_index) return word_pieces for index, dataset in enumerate(datasets):