@@ -1,6 +1,7 @@ | |||||
import threading | import threading | ||||
import torch | import torch | ||||
from torch import nn | |||||
from torch.nn.parallel.parallel_apply import get_a_var | from torch.nn.parallel.parallel_apply import get_a_var | ||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | ||||
@@ -86,3 +87,16 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | ||||
return gather(outputs, output_device) | return gather(outputs, output_device) | ||||
return wrapper | return wrapper | ||||
def _model_contains_inner_module(model): | |||||
""" | |||||
:param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, | |||||
nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 | |||||
:return: bool | |||||
""" | |||||
if isinstance(model, nn.Module): | |||||
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | |||||
return True | |||||
return False |
@@ -47,7 +47,7 @@ from .utils import _get_func_signature | |||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from ._parallel_utils import _data_parallel_wrapper | from ._parallel_utils import _data_parallel_wrapper | ||||
from .utils import _model_contains_inner_module | |||||
from fastNLP.core._parallel_utils import _model_contains_inner_module | |||||
from functools import partial | from functools import partial | ||||
__all__ = [ | __all__ = [ | ||||
@@ -352,7 +352,8 @@ from .utils import _move_dict_value_to_device | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from .utils import _model_contains_inner_module | |||||
from fastNLP.core._parallel_utils import _model_contains_inner_module | |||||
class Trainer(object): | class Trainer(object): | ||||
""" | """ | ||||
@@ -187,18 +187,6 @@ def _save_model(model, model_name, save_dir, only_param=False): | |||||
torch.save(model, model_path) | torch.save(model, model_path) | ||||
model.to(_model_device) | model.to(_model_device) | ||||
def _model_contains_inner_module(model): | |||||
""" | |||||
:param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, | |||||
nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 | |||||
:return: bool | |||||
""" | |||||
if isinstance(model, nn.Module): | |||||
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | |||||
return True | |||||
return False | |||||
def _move_model_to_device(model, device): | def _move_model_to_device(model, device): | ||||
""" | """ | ||||
将model移动到device | 将model移动到device | ||||
@@ -129,14 +129,14 @@ class BertWordPieceEncoder(nn.Module): | |||||
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', | def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', | ||||
pooled_cls: bool = False, requires_grad: bool=False): | pooled_cls: bool = False, requires_grad: bool=False): | ||||
super().__init__() | super().__init__() | ||||
PRETRAIN_URL = _get_base_url('bert') | |||||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | ||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | ||||
model_url = PRETRAIN_URL + model_name | model_url = PRETRAIN_URL + model_name | ||||
model_dir = cached_path(model_url) | model_dir = cached_path(model_url) | ||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isdir(model_dir_or_name): | |||||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||||
model_dir = model_dir_or_name | model_dir = model_dir_or_name | ||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
@@ -166,16 +166,25 @@ class BertWordPieceEncoder(nn.Module): | |||||
def embed_size(self): | def embed_size(self): | ||||
return self._embed_size | return self._embed_size | ||||
def index_datasets(self, *datasets, field_name): | |||||
@property | |||||
def embedding_dim(self): | |||||
return self._embed_size | |||||
@property | |||||
def num_embedding(self): | |||||
return self.model.encoder.config.vocab_size | |||||
def index_datasets(self, *datasets, field_name, add_cls_sep=True): | |||||
""" | """ | ||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | |||||
bert的pad value。 | |||||
:param datasets: DataSet对象 | |||||
:param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||||
:param DataSet datasets: DataSet对象 | |||||
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||||
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 | |||||
:return: | :return: | ||||
""" | """ | ||||
self.model.index_dataset(*datasets, field_name=field_name) | |||||
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) | |||||
def forward(self, word_pieces, token_type_ids=None): | def forward(self, word_pieces, token_type_ids=None): | ||||
""" | """ | ||||
@@ -92,7 +92,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
for i in range(len(kernel_sizes))]) | for i in range(len(kernel_sizes))]) | ||||
self._embed_size = embed_size | self._embed_size = embed_size | ||||
self.fc = nn.Linear(sum(filter_nums), embed_size) | self.fc = nn.Linear(sum(filter_nums), embed_size) | ||||
self.init_param() | |||||
self.reset_parameters() | |||||
def forward(self, words): | def forward(self, words): | ||||
""" | """ | ||||
@@ -149,7 +149,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
continue | continue | ||||
param.requires_grad = value | param.requires_grad = value | ||||
def init_param(self): | |||||
def reset_parameters(self): | |||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | ||||
continue | continue | ||||
@@ -41,7 +41,12 @@ class Embedding(nn.Module): | |||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
if not isinstance(self.embed, TokenEmbedding): | if not isinstance(self.embed, TokenEmbedding): | ||||
self._embed_size = self.embed.weight.size(1) | |||||
if hasattr(self, 'embed_size'): | |||||
self._embed_size = self.embed.embed_size | |||||
elif hasattr(self, 'embedding_dim'): | |||||
self._embed_size = self.embed.embedding_dim | |||||
else: | |||||
self._embed_size = self.embed.weight.size(1) | |||||
if word_dropout>0 and not isinstance(unk_index, int): | if word_dropout>0 and not isinstance(unk_index, int): | ||||
raise ValueError("When drop word is set, you need to pass in the unk_index.") | raise ValueError("When drop word is set, you need to pass in the unk_index.") | ||||
else: | else: | ||||
@@ -871,7 +871,7 @@ class _WordPieceBertModel(nn.Module): | |||||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | ||||
self.pooled_cls = pooled_cls | self.pooled_cls = pooled_cls | ||||
def index_dataset(self, *datasets, field_name): | |||||
def index_dataset(self, *datasets, field_name, add_cls_sep=True): | |||||
""" | """ | ||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | ||||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | ||||
@@ -887,10 +887,11 @@ class _WordPieceBertModel(nn.Module): | |||||
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word) | tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word) | ||||
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) | word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) | ||||
word_pieces.extend(word_piece_ids) | word_pieces.extend(word_piece_ids) | ||||
if word_pieces[0] != self._cls_index: | |||||
word_pieces.insert(0, self._cls_index) | |||||
if word_pieces[-1] != self._sep_index: | |||||
word_pieces.insert(-1, self._sep_index) | |||||
if add_cls_sep: | |||||
if word_pieces[0] != self._cls_index: | |||||
word_pieces.insert(0, self._cls_index) | |||||
if word_pieces[-1] != self._sep_index: | |||||
word_pieces.insert(-1, self._sep_index) | |||||
return word_pieces | return word_pieces | ||||
for index, dataset in enumerate(datasets): | for index, dataset in enumerate(datasets): | ||||