@@ -92,8 +92,40 @@ class ConfigSection(object): | |||
setattr(self, key, value) | |||
def __contains__(self, item): | |||
""" | |||
:param item: The key of item. | |||
:return: True if the key in self.__dict__.keys() else False. | |||
""" | |||
return item in self.__dict__.keys() | |||
def __eq__(self, other): | |||
"""Overwrite the == operator | |||
:param other: Another ConfigSection() object which to be compared. | |||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||
""" | |||
for k in self.__dict__.keys(): | |||
if k not in other.__dict__.keys(): | |||
return False | |||
if getattr(self, k) != getattr(self, k): | |||
return False | |||
for k in other.__dict__.keys(): | |||
if k not in self.__dict__.keys(): | |||
return False | |||
if getattr(self, k) != getattr(self, k): | |||
return False | |||
return True | |||
def __ne__(self, other): | |||
"""Overwrite the != operator | |||
:param other: | |||
:return: | |||
""" | |||
return not self.__eq__(other) | |||
@property | |||
def data(self): | |||
return self.__dict__ | |||
@@ -0,0 +1,147 @@ | |||
import os | |||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
from fastNLP.saver.logger import create_logger | |||
class ConfigSaver(object): | |||
def __init__(self, file_path): | |||
self.file_path = file_path | |||
if not os.path.exists(self.file_path): | |||
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||
def _get_section(self, sect_name): | |||
"""This is the function to get the section with the section name. | |||
:param sect_name: The name of section what wants to load. | |||
:return: The section. | |||
""" | |||
sect = ConfigSection() | |||
ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) | |||
return sect | |||
def _read_section(self): | |||
"""This is the function to read sections from the config file. | |||
:return: sect_list, sect_key_list | |||
sect_list: A list of ConfigSection(). | |||
sect_key_list: A list of names in sect_list. | |||
""" | |||
sect_name = None | |||
sect_list = {} | |||
sect_key_list = [] | |||
single_section = {} | |||
single_section_key = [] | |||
with open(self.file_path, 'r') as f: | |||
lines = f.readlines() | |||
for line in lines: | |||
if line.startswith('[') and line.endswith(']\n'): | |||
if sect_name is None: | |||
pass | |||
else: | |||
sect_list[sect_name] = single_section, single_section_key | |||
single_section = {} | |||
single_section_key = [] | |||
sect_key_list.append(sect_name) | |||
sect_name = line[1: -2] | |||
continue | |||
if line.startswith('#'): | |||
single_section[line] = '#' | |||
single_section_key.append(line) | |||
continue | |||
if line.startswith('\n'): | |||
single_section_key.append('\n') | |||
continue | |||
if '=' not in line: | |||
log = create_logger(__name__, './config_saver.log') | |||
log.error("can NOT load config file [%s]" % self.file_path) | |||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | |||
key = line.split('=', maxsplit=1)[0].strip() | |||
value = line.split('=', maxsplit=1)[1].strip() + '\n' | |||
single_section[key] = value | |||
single_section_key.append(key) | |||
if sect_name is not None: | |||
sect_list[sect_name] = single_section, single_section_key | |||
sect_key_list.append(sect_name) | |||
return sect_list, sect_key_list | |||
def _write_section(self, sect_list, sect_key_list): | |||
"""This is the function to write config file with section list and name list. | |||
:param sect_list: A list of ConfigSection() need to be writen into file. | |||
:param sect_key_list: A list of name of sect_list. | |||
:return: | |||
""" | |||
with open(self.file_path, 'w') as f: | |||
for sect_key in sect_key_list: | |||
single_section, single_section_key = sect_list[sect_key] | |||
f.write('[' + sect_key + ']\n') | |||
for key in single_section_key: | |||
if key == '\n': | |||
f.write('\n') | |||
continue | |||
if single_section[key] == '#': | |||
f.write(key) | |||
continue | |||
f.write(key + ' = ' + single_section[key]) | |||
f.write('\n') | |||
def save_config_file(self, section_name, section): | |||
"""This is the function to be called to change the config file with a single section and its name. | |||
:param section_name: The name of section what needs to be changed and saved. | |||
:param section: The section with key and value what needs to be changed and saved. | |||
:return: | |||
""" | |||
section_file = self._get_section(section_name) | |||
if len(section_file.__dict__.keys()) == 0:#the section not in file before | |||
with open(self.file_path, 'a') as f: | |||
f.write('[' + section_name + ']\n') | |||
for k in section.__dict__.keys(): | |||
f.write(k + ' = ') | |||
if isinstance(section[k], str): | |||
f.write('\"' + str(section[k]) + '\"\n\n') | |||
else: | |||
f.write(str(section[k]) + '\n\n') | |||
else: | |||
change_file = False | |||
for k in section.__dict__.keys(): | |||
if k not in section_file: | |||
change_file = True | |||
break | |||
if section_file[k] != section[k]: | |||
logger = create_logger(__name__, "./config_loader.log") | |||
logger.warning("section [%s] in config file [%s] has been changed" % ( | |||
section_name, self.file_path | |||
)) | |||
change_file = True | |||
break | |||
if not change_file: | |||
return | |||
sect_list, sect_key_list = self._read_section() | |||
if section_name not in sect_key_list: | |||
raise AttributeError() | |||
sect, sect_key = sect_list[section_name] | |||
for k in section.__dict__.keys(): | |||
if k not in sect_key: | |||
if sect_key[-1] != '\n': | |||
sect_key.append('\n') | |||
sect_key.append(k) | |||
sect[k] = str(section[k]) | |||
if isinstance(section[k], str): | |||
sect[k] = "\"" + sect[k] + "\"" | |||
sect[k] = sect[k] + "\n" | |||
sect_list[section_name] = sect, sect_key | |||
self._write_section(sect_list, sect_key_list) |
@@ -1,7 +1,18 @@ | |||
[test] | |||
x = 1 | |||
y = 2 | |||
z = 3 | |||
#this is an example | |||
input = [1,2,3] | |||
text = "this is text" | |||
doubles = 0.5 | |||
[t] | |||
x = "this is an test section" | |||
[test-case-2] | |||
double = 0.5 |
@@ -33,18 +33,16 @@ class TestConfigLoader(unittest.TestCase): | |||
test_arg = ConfigSection() | |||
ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||
# ConfigLoader("config").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", | |||
# {"test": test_arg}) | |||
#dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") | |||
dict = read_section_from_config(os.path.join("./test/loader", "config"), "test") | |||
section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | |||
for sec in dict: | |||
if (sec not in test_arg) or (dict[sec] != test_arg[sec]): | |||
for sec in section: | |||
if (sec not in test_arg) or (section[sec] != test_arg[sec]): | |||
raise AttributeError("ERROR") | |||
for sec in test_arg.__dict__.keys(): | |||
if (sec not in dict) or (dict[sec] != test_arg[sec]): | |||
if (sec not in section) or (section[sec] != test_arg[sec]): | |||
raise AttributeError("ERROR") | |||
try: | |||
@@ -71,4 +69,4 @@ class TestDatasetLoader(unittest.TestCase): | |||
loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||
data = loader.load() | |||
datas = loader.load_lines() | |||
print("pass TokenizeDatasetLoader test!") | |||
print("pass TokenizeDatasetLoader test!") |
@@ -0,0 +1,82 @@ | |||
import os | |||
import unittest | |||
import configparser | |||
import json | |||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
from fastNLP.saver.config_saver import ConfigSaver | |||
class TestConfigSaver(unittest.TestCase): | |||
def test_case_1(self): | |||
config_file_dir = "./test/loader/" | |||
config_file_name = "config" | |||
config_file_path = os.path.join(config_file_dir, config_file_name) | |||
tmp_config_file_path = os.path.join(config_file_dir, "tmp_config") | |||
with open(config_file_path, "r") as f: | |||
lines = f.readlines() | |||
standard_section = ConfigSection() | |||
t_section = ConfigSection() | |||
ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section}) | |||
config_saver = ConfigSaver(config_file_path) | |||
section = ConfigSection() | |||
section["doubles"] = 0.8 | |||
section["tt"] = 0.5 | |||
section["test"] = 105 | |||
section["str"] = "this is a str" | |||
test_case_2_section = section | |||
test_case_2_section["double"] = 0.5 | |||
for k in section.__dict__.keys(): | |||
standard_section[k] = section[k] | |||
config_saver.save_config_file("test", section) | |||
config_saver.save_config_file("another-test", section) | |||
config_saver.save_config_file("one-another-test", section) | |||
config_saver.save_config_file("test-case-2", section) | |||
test_section = ConfigSection() | |||
at_section = ConfigSection() | |||
another_test_section = ConfigSection() | |||
one_another_test_section = ConfigSection() | |||
a_test_case_2_section = ConfigSection() | |||
ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section, | |||
"another-test": another_test_section, | |||
"t": at_section, | |||
"one-another-test": one_another_test_section, | |||
"test-case-2": a_test_case_2_section}) | |||
assert test_section == standard_section | |||
assert at_section == t_section | |||
assert another_test_section == section | |||
assert one_another_test_section == section | |||
assert a_test_case_2_section == test_case_2_section | |||
config_saver.save_config_file("test", section) | |||
with open(config_file_path, "w") as f: | |||
f.writelines(lines) | |||
with open(tmp_config_file_path, "w") as f: | |||
f.write('[test]\n') | |||
f.write('this is an fault example\n') | |||
tmp_config_saver = ConfigSaver(tmp_config_file_path) | |||
try: | |||
tmp_config_saver._read_section() | |||
except Exception as e: | |||
pass | |||
os.remove(tmp_config_file_path) | |||
try: | |||
tmp_config_saver = ConfigSaver("file-NOT-exist") | |||
except Exception as e: | |||
pass |