diff --git a/test/loader/config b/test/loader/config new file mode 100644 index 00000000..a1922d28 --- /dev/null +++ b/test/loader/config @@ -0,0 +1,7 @@ +[test] +x = 1 +y = 2 +z = 3 +input = [1,2,3] +text = "this is text" +doubles = 0.5 diff --git a/test/loader/test_loader.py b/test/loader/test_loader.py new file mode 100644 index 00000000..ba33801e --- /dev/null +++ b/test/loader/test_loader.py @@ -0,0 +1,75 @@ +import os +import configparser + +import json +import unittest + + +from fastNLP.loader.config_loader import ConfigSection, ConfigLoader +from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader + +class TestConfigLoader(unittest.TestCase): + def test_case_ConfigLoader(self): + + def read_section_from_config(config_path, section_name): + dict = {} + if not os.path.exists(config_path): + raise FileNotFoundError("config file {} NOT found.".format(config_path)) + cfg = configparser.ConfigParser() + cfg.read(config_path) + if section_name not in cfg: + raise AttributeError("config file {} do NOT have section {}".format( + config_path, section_name + )) + gen_sec = cfg[section_name] + for s in gen_sec.keys(): + try: + val = json.loads(gen_sec[s]) + dict[s] = val + except Exception as e: + raise AttributeError("json can NOT load {} in section {}, config file {}".format( + s, section_name, config_path + )) + return dict + + test_arg = ConfigSection() + #ConfigLoader("config", "").load_config(os.path.join("./loader", "config"), {"test": test_arg}) + ConfigLoader("config", "").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", + {"test": test_arg}) + + dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") + #dict = read_section_from_config(os.path.join("./loader", "config"), "test") + + for sec in dict: + if (sec not in test_arg) or (dict[sec] != test_arg[sec]): + raise AttributeError("ERROR") + + for sec in test_arg.__dict__.keys(): + if (sec not in dict) or (dict[sec] != test_arg[sec]): + raise AttributeError("ERROR") + + try: + not_exist = test_arg["NOT EXIST"] + except Exception as e: + pass + + print("pass config test!") + + +class TestDatasetLoader(unittest.TestCase): + def test_case_TokenizeDatasetLoader(self): + loader = TokenizeDatasetLoader("cws_pku_utf_8", "./data_for_tests/cws_pku_utf_8") + data = loader.load_pku(max_seq_len=32) + print("pass TokenizeDatasetLoader test!") + + def test_case_POSDatasetLoader(self): + loader = POSDatasetLoader("people", "./data_for_tests/people.txt") + data = loader.load() + datas = loader.load_lines() + print("pass POSDatasetLoader test!") + + def test_case_LMDatasetLoader(self): + loader = LMDatasetLoader("cws_pku_utf_8", "./data_for_tests/cws_pku_utf_8") + data = loader.load() + datas = loader.load_lines() + print("pass TokenizeDatasetLoader test!")