|
@@ -0,0 +1,75 @@ |
|
|
|
|
|
import os |
|
|
|
|
|
import configparser |
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
|
|
|
import unittest |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader |
|
|
|
|
|
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader |
|
|
|
|
|
|
|
|
|
|
|
class TestConfigLoader(unittest.TestCase): |
|
|
|
|
|
def test_case_ConfigLoader(self): |
|
|
|
|
|
|
|
|
|
|
|
def read_section_from_config(config_path, section_name): |
|
|
|
|
|
dict = {} |
|
|
|
|
|
if not os.path.exists(config_path): |
|
|
|
|
|
raise FileNotFoundError("config file {} NOT found.".format(config_path)) |
|
|
|
|
|
cfg = configparser.ConfigParser() |
|
|
|
|
|
cfg.read(config_path) |
|
|
|
|
|
if section_name not in cfg: |
|
|
|
|
|
raise AttributeError("config file {} do NOT have section {}".format( |
|
|
|
|
|
config_path, section_name |
|
|
|
|
|
)) |
|
|
|
|
|
gen_sec = cfg[section_name] |
|
|
|
|
|
for s in gen_sec.keys(): |
|
|
|
|
|
try: |
|
|
|
|
|
val = json.loads(gen_sec[s]) |
|
|
|
|
|
dict[s] = val |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
raise AttributeError("json can NOT load {} in section {}, config file {}".format( |
|
|
|
|
|
s, section_name, config_path |
|
|
|
|
|
)) |
|
|
|
|
|
return dict |
|
|
|
|
|
|
|
|
|
|
|
test_arg = ConfigSection() |
|
|
|
|
|
#ConfigLoader("config", "").load_config(os.path.join("./loader", "config"), {"test": test_arg}) |
|
|
|
|
|
ConfigLoader("config", "").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", |
|
|
|
|
|
{"test": test_arg}) |
|
|
|
|
|
|
|
|
|
|
|
dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") |
|
|
|
|
|
#dict = read_section_from_config(os.path.join("./loader", "config"), "test") |
|
|
|
|
|
|
|
|
|
|
|
for sec in dict: |
|
|
|
|
|
if (sec not in test_arg) or (dict[sec] != test_arg[sec]): |
|
|
|
|
|
raise AttributeError("ERROR") |
|
|
|
|
|
|
|
|
|
|
|
for sec in test_arg.__dict__.keys(): |
|
|
|
|
|
if (sec not in dict) or (dict[sec] != test_arg[sec]): |
|
|
|
|
|
raise AttributeError("ERROR") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
not_exist = test_arg["NOT EXIST"] |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
print("pass config test!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDatasetLoader(unittest.TestCase): |
|
|
|
|
|
def test_case_TokenizeDatasetLoader(self): |
|
|
|
|
|
loader = TokenizeDatasetLoader("cws_pku_utf_8", "./data_for_tests/cws_pku_utf_8") |
|
|
|
|
|
data = loader.load_pku(max_seq_len=32) |
|
|
|
|
|
print("pass TokenizeDatasetLoader test!") |
|
|
|
|
|
|
|
|
|
|
|
def test_case_POSDatasetLoader(self): |
|
|
|
|
|
loader = POSDatasetLoader("people", "./data_for_tests/people.txt") |
|
|
|
|
|
data = loader.load() |
|
|
|
|
|
datas = loader.load_lines() |
|
|
|
|
|
print("pass POSDatasetLoader test!") |
|
|
|
|
|
|
|
|
|
|
|
def test_case_LMDatasetLoader(self): |
|
|
|
|
|
loader = LMDatasetLoader("cws_pku_utf_8", "./data_for_tests/cws_pku_utf_8") |
|
|
|
|
|
data = loader.load() |
|
|
|
|
|
datas = loader.load_lines() |
|
|
|
|
|
print("pass TokenizeDatasetLoader test!") |