diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 5fa89e24..c9f16b9a 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -867,10 +867,6 @@ def _get_value_info(_dict): return strs -from numbers import Number -from .batch import _to_tensor - - def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py index 4a40b83a..c5dc7fd7 100644 --- a/fastNLP/io/utils.py +++ b/fastNLP/io/utils.py @@ -36,6 +36,7 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: return {'train': paths} elif os.path.isdir(paths): filenames = os.listdir(paths) + filenames.sort() files = {} for filename in filenames: path_pair = None diff --git a/fastNLP/models/seq2seq_generator.py b/fastNLP/models/seq2seq_generator.py index aa270b5f..81eb344d 100644 --- a/fastNLP/models/seq2seq_generator.py +++ b/fastNLP/models/seq2seq_generator.py @@ -1,3 +1,5 @@ +r"""undocumented""" + import torch from torch import nn from .seq2seq_model import Seq2SeqModel diff --git a/fastNLP/modules/decoder/seq2seq_decoder.py b/fastNLP/modules/decoder/seq2seq_decoder.py index 987679b3..41f255b6 100644 --- a/fastNLP/modules/decoder/seq2seq_decoder.py +++ b/fastNLP/modules/decoder/seq2seq_decoder.py @@ -1,4 +1,4 @@ - +r"""undocumented""" from typing import Union, Tuple import math diff --git a/fastNLP/modules/encoder/seq2seq_encoder.py b/fastNLP/modules/encoder/seq2seq_encoder.py index c38fa896..b35ca2d5 100644 --- a/fastNLP/modules/encoder/seq2seq_encoder.py +++ b/fastNLP/modules/encoder/seq2seq_encoder.py @@ -1,3 +1,4 @@ +r"""undocumented""" import torch.nn as nn import torch from torch.nn import LayerNorm diff --git a/test/io/loader/test_classification_loader.py b/test/io/loader/test_classification_loader.py index 72db136c..6ed8eb15 100644 --- a/test/io/loader/test_classification_loader.py +++ b/test/io/loader/test_classification_loader.py @@ -30,7 +30,7 @@ class TestLoad(unittest.TestCase): 'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), 'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), - 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False), + 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 7, 6), False), } for k, v in data_set_dict.items(): path, loader, data_set, warns = v diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py index d2b221c5..30ace410 100644 --- a/test/io/loader/test_matching_loader.py +++ b/test/io/loader/test_matching_loader.py @@ -31,8 +31,8 @@ class TestMatchingLoad(unittest.TestCase): 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), - 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), - 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 6, 5), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 6, 8), False), + 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 5, 6), False), } for k, v in data_set_dict.items(): path, loader, instance, warns = v diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py index ea687b2e..92993690 100644 --- a/test/io/pipe/test_matching.py +++ b/test/io/pipe/test_matching.py @@ -38,8 +38,8 @@ class TestRunMatchingPipe(unittest.TestCase): 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), - 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), - 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 6, 5), (36, 2), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), + 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), } for k, v in data_set_dict.items(): path, pipe1, pipe2, data_set, vocab, warns = v