From a137038eb2cc840581adacdcfb76e685a2eed63b Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 19 Jun 2019 19:43:53 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DELMO=E4=B8=8ELSTM=E6=97=A0?= =?UTF-8?q?=E6=B3=95=E4=BD=BF=E7=94=A8nn.DataParallel=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/file_utils.py | 4 ++-- fastNLP/modules/encoder/_bert.py | 2 +- fastNLP/modules/encoder/_elmo.py | 6 +++++- fastNLP/modules/encoder/lstm.py | 9 ++++++++- .../seqence_labelling/ner/data/Conll2003Loader.py | 4 +--- .../seqence_labelling/ner/data/OntoNoteLoader.py | 3 +-- 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 11c7ab64..d178626b 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -13,7 +13,7 @@ import hashlib def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: """ 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 - 将文件放入到 + 将文件放入到cache_dir中 """ if cache_dir is None: dataset_cache = Path(get_defalt_path()) @@ -88,7 +88,7 @@ def split_filename_suffix(filepath): def get_from_cache(url: str, cache_dir: Path = None) -> Path: """ 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 - 如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径 + 如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。 """ cache_dir.mkdir(parents=True, exist_ok=True) diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py index 317b78d8..a860054d 100644 --- a/fastNLP/modules/encoder/_bert.py +++ b/fastNLP/modules/encoder/_bert.py @@ -791,7 +791,7 @@ class _WordBertModel(nn.Module): # +2是由于需要加入[CLS]与[SEP] word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) word_pieces[:, 0].fill_(self._cls_index) - word_pieces[:, word_pieces_lengths+1] = self._sep_index + word_pieces[torch.arange(batch_size).to(words), word_pieces_lengths+1] = self._sep_index attn_masks = torch.zeros_like(word_pieces) # 1. 获取words的word_pieces的id,以及对应的span范围 word_indexes = words.tolist() diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index 1f400f1d..7fa29201 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -16,6 +16,7 @@ import json from ..utils import get_dropout_mask import codecs +from torch import autograd class LstmCellWithProjection(torch.nn.Module): """ @@ -760,7 +761,10 @@ class _ElmoModel(nn.Module): token_embedding = self.token_embedder(expanded_words, chars) if self.config['encoder']['name'] == 'elmo': encoder_output = self.encoder(token_embedding, seq_len) - sz = encoder_output.size() + if encoder_output.size(2) < max_len: + dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - encoder_output.size(2), encoder_output.size(-1))) + encoder_output = torch.cat([encoder_output, dummy_tensor], 1) + sz = encoder_output.size() # batch_size, max_len, hidden_size token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) encoder_output = torch.cat([token_embedding, encoder_output], dim=0) elif self.config['encoder']['name'] == 'lstm': diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 537a446d..0118d6d7 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -11,13 +11,15 @@ import torch.nn as nn import torch.nn.utils.rnn as rnn from ..utils import initial_parameter +from torch import autograd class LSTM(nn.Module): """ 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` - LSTM 模块, 轻量封装的Pytorch LSTM + LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 + 为1; 且可以应对DataParallel中LSTM的使用问题 :param input_size: 输入 `x` 的特征维度 :param hidden_size: 隐状态 `h` 的特征维度. @@ -59,6 +61,7 @@ class LSTM(nn.Module): :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 和 [batch, hidden_size*num_direction] 最后时刻隐状态. """ + batch_size, max_len, _ = x.size() if h0 is not None and c0 is not None: hx = (h0, c0) else: @@ -77,6 +80,10 @@ class LSTM(nn.Module): output = output[unsort_idx] else: output = output[:, unsort_idx] + # 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 + if output.size(1) < max_len: + dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - output.size(1), output.size(-1))) + output = torch.cat([output, dummy_tensor], 1) else: output, hx = self.lstm(x, hx) return output, hx diff --git a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py index 037d6081..3140af18 100644 --- a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py +++ b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py @@ -16,7 +16,7 @@ class Conll2003DataLoader(DataSetLoader): 加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos 时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 - 鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的中该值 + 鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行 ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 :param task: 指定需要标注任务。可选ner, pos, chunk @@ -64,8 +64,6 @@ class Conll2003DataLoader(DataSetLoader): # 对construct vocab word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) - # word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT) - # TODO 这样感觉不规范呐 word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT) word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) data.vocabs[Const.INPUT] = word_vocab diff --git a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py index 5abfe7c5..fe0236ad 100644 --- a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py +++ b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py @@ -87,14 +87,13 @@ class OntoNoteNERDataLoader(DataSetLoader): # 对construct vocab word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) - # word_vocab.from_dataset(data.datasets['train'], field_name='raw_words') word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT) word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name=Const.INPUT) data.vocabs[Const.INPUT] = word_vocab # cap words cap_word_vocab = Vocabulary() - cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words') + cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words') cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') input_fields.append('cap_words') data.vocabs['cap_words'] = cap_word_vocab