Browse Source

1. 修改内部函数位置; 2.修改BertWordPieceEncoder的部分功能

tags/v0.4.10
yh 6 years ago
parent
commit
4718804e22
8 changed files with 48 additions and 30 deletions
  1. +14
    -0
      fastNLP/core/_parallel_utils.py
  2. +1
    -1
      fastNLP/core/tester.py
  3. +2
    -1
      fastNLP/core/trainer.py
  4. +0
    -12
      fastNLP/core/utils.py
  5. +17
    -8
      fastNLP/embeddings/bert_embedding.py
  6. +2
    -2
      fastNLP/embeddings/char_embedding.py
  7. +6
    -1
      fastNLP/embeddings/embedding.py
  8. +6
    -5
      fastNLP/modules/encoder/bert.py

+ 14
- 0
fastNLP/core/_parallel_utils.py View File

@@ -1,6 +1,7 @@

import threading
import torch
from torch import nn
from torch.nn.parallel.parallel_apply import get_a_var

from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
@@ -86,3 +87,16 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
return gather(outputs, output_device)
return wrapper


def _model_contains_inner_module(model):
"""

:param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel,
nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。
:return: bool
"""
if isinstance(model, nn.Module):
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return True
return False

+ 1
- 1
fastNLP/core/tester.py View File

@@ -47,7 +47,7 @@ from .utils import _get_func_signature
from .utils import _get_model_device
from .utils import _move_model_to_device
from ._parallel_utils import _data_parallel_wrapper
from .utils import _model_contains_inner_module
from fastNLP.core._parallel_utils import _model_contains_inner_module
from functools import partial

__all__ = [


+ 2
- 1
fastNLP/core/trainer.py View File

@@ -352,7 +352,8 @@ from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from .utils import _get_model_device
from .utils import _move_model_to_device
from .utils import _model_contains_inner_module
from fastNLP.core._parallel_utils import _model_contains_inner_module


class Trainer(object):
"""


+ 0
- 12
fastNLP/core/utils.py View File

@@ -187,18 +187,6 @@ def _save_model(model, model_name, save_dir, only_param=False):
torch.save(model, model_path)
model.to(_model_device)

def _model_contains_inner_module(model):
"""

:param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel,
nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。
:return: bool
"""
if isinstance(model, nn.Module):
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return True
return False

def _move_model_to_device(model, device):
"""
将model移动到device


+ 17
- 8
fastNLP/embeddings/bert_embedding.py View File

@@ -129,14 +129,14 @@ class BertWordPieceEncoder(nn.Module):
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1',
pooled_cls: bool = False, requires_grad: bool=False):
super().__init__()
PRETRAIN_URL = _get_base_url('bert')

if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR:
PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
# 检查是否存在
elif os.path.isdir(model_dir_or_name):
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
model_dir = model_dir_or_name
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")
@@ -166,16 +166,25 @@ class BertWordPieceEncoder(nn.Module):
def embed_size(self):
return self._embed_size

def index_datasets(self, *datasets, field_name):
@property
def embedding_dim(self):
return self._embed_size

@property
def num_embedding(self):
return self.model.encoder.config.vocab_size

def index_datasets(self, *datasets, field_name, add_cls_sep=True):
"""
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
bert的pad value。

:param datasets: DataSet对象
:param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
:param DataSet datasets: DataSet对象
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。
:return:
"""
self.model.index_dataset(*datasets, field_name=field_name)
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep)

def forward(self, word_pieces, token_type_ids=None):
"""


+ 2
- 2
fastNLP/embeddings/char_embedding.py View File

@@ -92,7 +92,7 @@ class CNNCharEmbedding(TokenEmbedding):
for i in range(len(kernel_sizes))])
self._embed_size = embed_size
self.fc = nn.Linear(sum(filter_nums), embed_size)
self.init_param()
self.reset_parameters()

def forward(self, words):
"""
@@ -149,7 +149,7 @@ class CNNCharEmbedding(TokenEmbedding):
continue
param.requires_grad = value

def init_param(self):
def reset_parameters(self):
for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset
continue


+ 6
- 1
fastNLP/embeddings/embedding.py View File

@@ -41,7 +41,12 @@ class Embedding(nn.Module):
self.dropout = nn.Dropout(dropout)
if not isinstance(self.embed, TokenEmbedding):
self._embed_size = self.embed.weight.size(1)
if hasattr(self, 'embed_size'):
self._embed_size = self.embed.embed_size
elif hasattr(self, 'embedding_dim'):
self._embed_size = self.embed.embedding_dim
else:
self._embed_size = self.embed.weight.size(1)
if word_dropout>0 and not isinstance(unk_index, int):
raise ValueError("When drop word is set, you need to pass in the unk_index.")
else:


+ 6
- 5
fastNLP/modules/encoder/bert.py View File

@@ -871,7 +871,7 @@ class _WordPieceBertModel(nn.Module):
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
self.pooled_cls = pooled_cls

def index_dataset(self, *datasets, field_name):
def index_dataset(self, *datasets, field_name, add_cls_sep=True):
"""
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。
@@ -887,10 +887,11 @@ class _WordPieceBertModel(nn.Module):
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word)
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens)
word_pieces.extend(word_piece_ids)
if word_pieces[0] != self._cls_index:
word_pieces.insert(0, self._cls_index)
if word_pieces[-1] != self._sep_index:
word_pieces.insert(-1, self._sep_index)
if add_cls_sep:
if word_pieces[0] != self._cls_index:
word_pieces.insert(0, self._cls_index)
if word_pieces[-1] != self._sep_index:
word_pieces.insert(-1, self._sep_index)
return word_pieces

for index, dataset in enumerate(datasets):


Loading…
Cancel
Save