Browse Source

loader中load先将filename排序保证一致的load顺序, 避免测试不通过

tags/v0.6.0
yh_cc 5 years ago
parent
commit
c957ed69c1
8 changed files with 10 additions and 10 deletions
  1. +0
    -4
      fastNLP/core/trainer.py
  2. +1
    -0
      fastNLP/io/utils.py
  3. +2
    -0
      fastNLP/models/seq2seq_generator.py
  4. +1
    -1
      fastNLP/modules/decoder/seq2seq_decoder.py
  5. +1
    -0
      fastNLP/modules/encoder/seq2seq_encoder.py
  6. +1
    -1
      test/io/loader/test_classification_loader.py
  7. +2
    -2
      test/io/loader/test_matching_loader.py
  8. +2
    -2
      test/io/pipe/test_matching.py

+ 0
- 4
fastNLP/core/trainer.py View File

@@ -867,10 +867,6 @@ def _get_value_info(_dict):
return strs


from numbers import Number
from .batch import _to_tensor


def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, metric_key=None, check_level=0):
# check get_loss 方法


+ 1
- 0
fastNLP/io/utils.py View File

@@ -36,6 +36,7 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
return {'train': paths}
elif os.path.isdir(paths):
filenames = os.listdir(paths)
filenames.sort()
files = {}
for filename in filenames:
path_pair = None


+ 2
- 0
fastNLP/models/seq2seq_generator.py View File

@@ -1,3 +1,5 @@
r"""undocumented"""

import torch
from torch import nn
from .seq2seq_model import Seq2SeqModel


+ 1
- 1
fastNLP/modules/decoder/seq2seq_decoder.py View File

@@ -1,4 +1,4 @@
r"""undocumented"""
from typing import Union, Tuple
import math



+ 1
- 0
fastNLP/modules/encoder/seq2seq_encoder.py View File

@@ -1,3 +1,4 @@
r"""undocumented"""
import torch.nn as nn
import torch
from torch.nn import LayerNorm


+ 1
- 1
test/io/loader/test_classification_loader.py View File

@@ -30,7 +30,7 @@ class TestLoad(unittest.TestCase):
'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False),
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False),
'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False),
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False),
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 7, 6), False),
}
for k, v in data_set_dict.items():
path, loader, data_set, warns = v


+ 2
- 2
test/io/loader/test_matching_loader.py View File

@@ -31,8 +31,8 @@ class TestMatchingLoad(unittest.TestCase):
'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True),
'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False),
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False),
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False),
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 6, 5), False),
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 6, 8), False),
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 5, 6), False),
}
for k, v in data_set_dict.items():
path, loader, instance, warns = v


+ 2
- 2
test/io/pipe/test_matching.py View File

@@ -38,8 +38,8 @@ class TestRunMatchingPipe(unittest.TestCase):
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True),
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True),
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False),
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False),
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 6, 5), (36, 2), False),
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False),
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False),
}
for k, v in data_set_dict.items():
path, pipe1, pipe2, data_set, vocab, warns = v


Loading…
Cancel
Save