@@ -13,7 +13,7 @@ import hashlib | |||||
def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | ||||
""" | """ | ||||
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | ||||
将文件放入到 | |||||
将文件放入到cache_dir中 | |||||
""" | """ | ||||
if cache_dir is None: | if cache_dir is None: | ||||
dataset_cache = Path(get_defalt_path()) | dataset_cache = Path(get_defalt_path()) | ||||
@@ -88,7 +88,7 @@ def split_filename_suffix(filepath): | |||||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | def get_from_cache(url: str, cache_dir: Path = None) -> Path: | ||||
""" | """ | ||||
尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 | 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 | ||||
如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径 | |||||
如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。 | |||||
""" | """ | ||||
cache_dir.mkdir(parents=True, exist_ok=True) | cache_dir.mkdir(parents=True, exist_ok=True) | ||||
@@ -791,7 +791,7 @@ class _WordBertModel(nn.Module): | |||||
# +2是由于需要加入[CLS]与[SEP] | # +2是由于需要加入[CLS]与[SEP] | ||||
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) | word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) | ||||
word_pieces[:, 0].fill_(self._cls_index) | word_pieces[:, 0].fill_(self._cls_index) | ||||
word_pieces[:, word_pieces_lengths+1] = self._sep_index | |||||
word_pieces[torch.arange(batch_size).to(words), word_pieces_lengths+1] = self._sep_index | |||||
attn_masks = torch.zeros_like(word_pieces) | attn_masks = torch.zeros_like(word_pieces) | ||||
# 1. 获取words的word_pieces的id,以及对应的span范围 | # 1. 获取words的word_pieces的id,以及对应的span范围 | ||||
word_indexes = words.tolist() | word_indexes = words.tolist() | ||||
@@ -16,6 +16,7 @@ import json | |||||
from ..utils import get_dropout_mask | from ..utils import get_dropout_mask | ||||
import codecs | import codecs | ||||
from torch import autograd | |||||
class LstmCellWithProjection(torch.nn.Module): | class LstmCellWithProjection(torch.nn.Module): | ||||
""" | """ | ||||
@@ -760,7 +761,10 @@ class _ElmoModel(nn.Module): | |||||
token_embedding = self.token_embedder(expanded_words, chars) | token_embedding = self.token_embedder(expanded_words, chars) | ||||
if self.config['encoder']['name'] == 'elmo': | if self.config['encoder']['name'] == 'elmo': | ||||
encoder_output = self.encoder(token_embedding, seq_len) | encoder_output = self.encoder(token_embedding, seq_len) | ||||
sz = encoder_output.size() | |||||
if encoder_output.size(2) < max_len: | |||||
dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - encoder_output.size(2), encoder_output.size(-1))) | |||||
encoder_output = torch.cat([encoder_output, dummy_tensor], 1) | |||||
sz = encoder_output.size() # batch_size, max_len, hidden_size | |||||
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) | token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) | ||||
encoder_output = torch.cat([token_embedding, encoder_output], dim=0) | encoder_output = torch.cat([token_embedding, encoder_output], dim=0) | ||||
elif self.config['encoder']['name'] == 'lstm': | elif self.config['encoder']['name'] == 'lstm': | ||||
@@ -11,13 +11,15 @@ import torch.nn as nn | |||||
import torch.nn.utils.rnn as rnn | import torch.nn.utils.rnn as rnn | ||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
from torch import autograd | |||||
class LSTM(nn.Module): | class LSTM(nn.Module): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` | 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` | ||||
LSTM 模块, 轻量封装的Pytorch LSTM | |||||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | |||||
为1; 且可以应对DataParallel中LSTM的使用问题 | |||||
:param input_size: 输入 `x` 的特征维度 | :param input_size: 输入 `x` 的特征维度 | ||||
:param hidden_size: 隐状态 `h` 的特征维度. | :param hidden_size: 隐状态 `h` 的特征维度. | ||||
@@ -59,6 +61,7 @@ class LSTM(nn.Module): | |||||
:return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 | :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 | ||||
和 [batch, hidden_size*num_direction] 最后时刻隐状态. | 和 [batch, hidden_size*num_direction] 最后时刻隐状态. | ||||
""" | """ | ||||
batch_size, max_len, _ = x.size() | |||||
if h0 is not None and c0 is not None: | if h0 is not None and c0 is not None: | ||||
hx = (h0, c0) | hx = (h0, c0) | ||||
else: | else: | ||||
@@ -77,6 +80,10 @@ class LSTM(nn.Module): | |||||
output = output[unsort_idx] | output = output[unsort_idx] | ||||
else: | else: | ||||
output = output[:, unsort_idx] | output = output[:, unsort_idx] | ||||
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 | |||||
if output.size(1) < max_len: | |||||
dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - output.size(1), output.size(-1))) | |||||
output = torch.cat([output, dummy_tensor], 1) | |||||
else: | else: | ||||
output, hx = self.lstm(x, hx) | output, hx = self.lstm(x, hx) | ||||
return output, hx | return output, hx |
@@ -16,7 +16,7 @@ class Conll2003DataLoader(DataSetLoader): | |||||
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos | 加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos | ||||
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 | 时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 | ||||
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 | 的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 | ||||
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的中该值 | |||||
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行 | |||||
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 | ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 | ||||
:param task: 指定需要标注任务。可选ner, pos, chunk | :param task: 指定需要标注任务。可选ner, pos, chunk | ||||
@@ -64,8 +64,6 @@ class Conll2003DataLoader(DataSetLoader): | |||||
# 对construct vocab | # 对construct vocab | ||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | ||||
# word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT) | |||||
# TODO 这样感觉不规范呐 | |||||
word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT) | word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT) | ||||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | ||||
data.vocabs[Const.INPUT] = word_vocab | data.vocabs[Const.INPUT] = word_vocab | ||||
@@ -87,14 +87,13 @@ class OntoNoteNERDataLoader(DataSetLoader): | |||||
# 对construct vocab | # 对construct vocab | ||||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | ||||
# word_vocab.from_dataset(data.datasets['train'], field_name='raw_words') | |||||
word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT) | word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT) | ||||
word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name=Const.INPUT) | word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name=Const.INPUT) | ||||
data.vocabs[Const.INPUT] = word_vocab | data.vocabs[Const.INPUT] = word_vocab | ||||
# cap words | # cap words | ||||
cap_word_vocab = Vocabulary() | cap_word_vocab = Vocabulary() | ||||
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words') | |||||
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words') | |||||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | ||||
input_fields.append('cap_words') | input_fields.append('cap_words') | ||||
data.vocabs['cap_words'] = cap_word_vocab | data.vocabs['cap_words'] = cap_word_vocab | ||||