From c957ed69c17c2d84252a7c6223f3a564143988c2 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 7 Jul 2020 22:53:05 +0800 Subject: [PATCH] =?UTF-8?q?loader=E4=B8=ADload=E5=85=88=E5=B0=86filename?= =?UTF-8?q?=E6=8E=92=E5=BA=8F=E4=BF=9D=E8=AF=81=E4=B8=80=E8=87=B4=E7=9A=84?= =?UTF-8?q?load=E9=A1=BA=E5=BA=8F,=20=E9=81=BF=E5=85=8D=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=B8=8D=E9=80=9A=E8=BF=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 4 ---- fastNLP/io/utils.py | 1 + fastNLP/models/seq2seq_generator.py | 2 ++ fastNLP/modules/decoder/seq2seq_decoder.py | 2 +- fastNLP/modules/encoder/seq2seq_encoder.py | 1 + test/io/loader/test_classification_loader.py | 2 +- test/io/loader/test_matching_loader.py | 4 ++-- test/io/pipe/test_matching.py | 4 ++-- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 5fa89e24..c9f16b9a 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -867,10 +867,6 @@ def _get_value_info(_dict): return strs -from numbers import Number -from .batch import _to_tensor - - def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py index 4a40b83a..c5dc7fd7 100644 --- a/fastNLP/io/utils.py +++ b/fastNLP/io/utils.py @@ -36,6 +36,7 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: return {'train': paths} elif os.path.isdir(paths): filenames = os.listdir(paths) + filenames.sort() files = {} for filename in filenames: path_pair = None diff --git a/fastNLP/models/seq2seq_generator.py b/fastNLP/models/seq2seq_generator.py index aa270b5f..81eb344d 100644 --- a/fastNLP/models/seq2seq_generator.py +++ b/fastNLP/models/seq2seq_generator.py @@ -1,3 +1,5 @@ +r"""undocumented""" + import torch from torch import nn from .seq2seq_model import Seq2SeqModel diff --git a/fastNLP/modules/decoder/seq2seq_decoder.py b/fastNLP/modules/decoder/seq2seq_decoder.py index 987679b3..41f255b6 100644 --- a/fastNLP/modules/decoder/seq2seq_decoder.py +++ b/fastNLP/modules/decoder/seq2seq_decoder.py @@ -1,4 +1,4 @@ - +r"""undocumented""" from typing import Union, Tuple import math diff --git a/fastNLP/modules/encoder/seq2seq_encoder.py b/fastNLP/modules/encoder/seq2seq_encoder.py index c38fa896..b35ca2d5 100644 --- a/fastNLP/modules/encoder/seq2seq_encoder.py +++ b/fastNLP/modules/encoder/seq2seq_encoder.py @@ -1,3 +1,4 @@ +r"""undocumented""" import torch.nn as nn import torch from torch.nn import LayerNorm diff --git a/test/io/loader/test_classification_loader.py b/test/io/loader/test_classification_loader.py index 72db136c..6ed8eb15 100644 --- a/test/io/loader/test_classification_loader.py +++ b/test/io/loader/test_classification_loader.py @@ -30,7 +30,7 @@ class TestLoad(unittest.TestCase): 'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), 'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), - 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False), + 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 7, 6), False), } for k, v in data_set_dict.items(): path, loader, data_set, warns = v diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py index d2b221c5..30ace410 100644 --- a/test/io/loader/test_matching_loader.py +++ b/test/io/loader/test_matching_loader.py @@ -31,8 +31,8 @@ class TestMatchingLoad(unittest.TestCase): 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), - 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), - 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 6, 5), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 6, 8), False), + 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 5, 6), False), } for k, v in data_set_dict.items(): path, loader, instance, warns = v diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py index ea687b2e..92993690 100644 --- a/test/io/pipe/test_matching.py +++ b/test/io/pipe/test_matching.py @@ -38,8 +38,8 @@ class TestRunMatchingPipe(unittest.TestCase): 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), - 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), - 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 6, 5), (36, 2), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), + 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), } for k, v in data_set_dict.items(): path, pipe1, pipe2, data_set, vocab, warns = v