Browse Source

几个小修复:1. 修改了部分文档; 2. 修改了IMDBLoader自动下载的一个bug; 3. 修改ESIM模型内的lstm相关代码以修改torch版本为1.2.0时不停报warning的问题

tags/v0.5.5
Yige Xu 5 years ago
parent
commit
b3c7ead5e5
5 changed files with 5 additions and 4 deletions
  1. +2
    -1
      fastNLP/core/dataset.py
  2. +1
    -1
      fastNLP/io/file_utils.py
  3. +1
    -1
      fastNLP/io/loader/classification.py
  4. +0
    -1
      fastNLP/models/cnn_text_classification.py
  5. +1
    -0
      fastNLP/models/snli.py

+ 2
- 1
fastNLP/core/dataset.py View File

@@ -864,9 +864,10 @@ class DataSet(object):

def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN):
"""
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。
将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。

:param field_name: str.
:param new_field_name: str. 新的field_name
:return:
"""
if self.has_field(field_name=field_name):


+ 1
- 1
fastNLP/io/file_utils.py View File

@@ -120,7 +120,7 @@ DATASET_DIR = {
# Summarization, English
"ext-cnndm": "ext-cnndm.zip",

# Question & answer
# Question & answer, Chinese
"cmrc2018": "cmrc2018.zip"

}


+ 1
- 1
fastNLP/io/loader/classification.py View File

@@ -225,7 +225,7 @@ class IMDBLoader(Loader):
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if not os.path.exists(os.path.join(data_dir, 'dev.txt')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
try:


+ 0
- 1
fastNLP/models/cnn_text_classification.py View File

@@ -33,7 +33,6 @@ class CNNText(torch.nn.Module):
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
:param int num_classes: 一共有多少类
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param float dropout: Dropout的大小
"""


+ 1
- 0
fastNLP/models/snli.py View File

@@ -164,6 +164,7 @@ class BiRNN(nn.Module):
if self.dropout_rate > 0:
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training)
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes)
self.rnn.flatten_parameters()
output = self.rnn(rnn_input)[0]
# Unpack everything
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0]


Loading…
Cancel
Save