@@ -867,10 +867,6 @@ def _get_value_info(_dict): | |||||
return strs | return strs | ||||
from numbers import Number | |||||
from .batch import _to_tensor | |||||
def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, | def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, | ||||
dev_data=None, metric_key=None, check_level=0): | dev_data=None, metric_key=None, check_level=0): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
@@ -36,6 +36,7 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||||
return {'train': paths} | return {'train': paths} | ||||
elif os.path.isdir(paths): | elif os.path.isdir(paths): | ||||
filenames = os.listdir(paths) | filenames = os.listdir(paths) | ||||
filenames.sort() | |||||
files = {} | files = {} | ||||
for filename in filenames: | for filename in filenames: | ||||
path_pair = None | path_pair = None | ||||
@@ -1,3 +1,5 @@ | |||||
r"""undocumented""" | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from .seq2seq_model import Seq2SeqModel | from .seq2seq_model import Seq2SeqModel | ||||
@@ -1,4 +1,4 @@ | |||||
r"""undocumented""" | |||||
from typing import Union, Tuple | from typing import Union, Tuple | ||||
import math | import math | ||||
@@ -1,3 +1,4 @@ | |||||
r"""undocumented""" | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch | import torch | ||||
from torch.nn import LayerNorm | from torch.nn import LayerNorm | ||||
@@ -30,7 +30,7 @@ class TestLoad(unittest.TestCase): | |||||
'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), | 'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), | ||||
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), | 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), | ||||
'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), | 'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), | ||||
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False), | |||||
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 7, 6), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, loader, data_set, warns = v | path, loader, data_set, warns = v | ||||
@@ -31,8 +31,8 @@ class TestMatchingLoad(unittest.TestCase): | |||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | ||||
'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | ||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | ||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 6, 5), False), | |||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 6, 8), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 5, 6), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, loader, instance, warns = v | path, loader, instance, warns = v | ||||
@@ -38,8 +38,8 @@ class TestRunMatchingPipe(unittest.TestCase): | |||||
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | ||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | ||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | ||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 6, 5), (36, 2), False), | |||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, pipe1, pipe2, data_set, vocab, warns = v | path, pipe1, pipe2, data_set, vocab, warns = v | ||||