diff --git a/docs/requirements.txt b/docs/requirements.txt index 2809876b..294a44d0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ numpy>=1.14.2 http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl torchvision>=0.1.8 -sphinx-rtd-theme==0.4.1 \ No newline at end of file +sphinx-rtd-theme==0.4.1 +tensorboardX>=1.4 \ No newline at end of file diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index 20d791c4..94871222 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -92,8 +92,40 @@ class ConfigSection(object): setattr(self, key, value) def __contains__(self, item): + """ + :param item: The key of item. + :return: True if the key in self.__dict__.keys() else False. + """ return item in self.__dict__.keys() + def __eq__(self, other): + """Overwrite the == operator + + :param other: Another ConfigSection() object which to be compared. + :return: True if value of each key in each ConfigSection() object are equal to the other, else False. + """ + for k in self.__dict__.keys(): + if k not in other.__dict__.keys(): + return False + if getattr(self, k) != getattr(self, k): + return False + + for k in other.__dict__.keys(): + if k not in self.__dict__.keys(): + return False + if getattr(self, k) != getattr(self, k): + return False + + return True + + def __ne__(self, other): + """Overwrite the != operator + + :param other: + :return: + """ + return not self.__eq__(other) + @property def data(self): return self.__dict__ diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py new file mode 100644 index 00000000..e05e865e --- /dev/null +++ b/fastNLP/saver/config_saver.py @@ -0,0 +1,147 @@ +import os + +from fastNLP.loader.config_loader import ConfigSection, ConfigLoader +from fastNLP.saver.logger import create_logger + + +class ConfigSaver(object): + + def __init__(self, file_path): + self.file_path = file_path + if not os.path.exists(self.file_path): + raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) + + def _get_section(self, sect_name): + """This is the function to get the section with the section name. + + :param sect_name: The name of section what wants to load. + :return: The section. + """ + sect = ConfigSection() + ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) + return sect + + def _read_section(self): + """This is the function to read sections from the config file. + + :return: sect_list, sect_key_list + sect_list: A list of ConfigSection(). + sect_key_list: A list of names in sect_list. + """ + sect_name = None + + sect_list = {} + sect_key_list = [] + + single_section = {} + single_section_key = [] + + with open(self.file_path, 'r') as f: + lines = f.readlines() + + for line in lines: + if line.startswith('[') and line.endswith(']\n'): + if sect_name is None: + pass + else: + sect_list[sect_name] = single_section, single_section_key + single_section = {} + single_section_key = [] + sect_key_list.append(sect_name) + sect_name = line[1: -2] + continue + + if line.startswith('#'): + single_section[line] = '#' + single_section_key.append(line) + continue + + if line.startswith('\n'): + single_section_key.append('\n') + continue + + if '=' not in line: + log = create_logger(__name__, './config_saver.log') + log.error("can NOT load config file [%s]" % self.file_path) + raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) + + key = line.split('=', maxsplit=1)[0].strip() + value = line.split('=', maxsplit=1)[1].strip() + '\n' + single_section[key] = value + single_section_key.append(key) + + if sect_name is not None: + sect_list[sect_name] = single_section, single_section_key + sect_key_list.append(sect_name) + return sect_list, sect_key_list + + def _write_section(self, sect_list, sect_key_list): + """This is the function to write config file with section list and name list. + + :param sect_list: A list of ConfigSection() need to be writen into file. + :param sect_key_list: A list of name of sect_list. + :return: + """ + with open(self.file_path, 'w') as f: + for sect_key in sect_key_list: + single_section, single_section_key = sect_list[sect_key] + f.write('[' + sect_key + ']\n') + for key in single_section_key: + if key == '\n': + f.write('\n') + continue + if single_section[key] == '#': + f.write(key) + continue + f.write(key + ' = ' + single_section[key]) + f.write('\n') + + def save_config_file(self, section_name, section): + """This is the function to be called to change the config file with a single section and its name. + + :param section_name: The name of section what needs to be changed and saved. + :param section: The section with key and value what needs to be changed and saved. + :return: + """ + section_file = self._get_section(section_name) + if len(section_file.__dict__.keys()) == 0:#the section not in file before + with open(self.file_path, 'a') as f: + f.write('[' + section_name + ']\n') + for k in section.__dict__.keys(): + f.write(k + ' = ') + if isinstance(section[k], str): + f.write('\"' + str(section[k]) + '\"\n\n') + else: + f.write(str(section[k]) + '\n\n') + else: + change_file = False + for k in section.__dict__.keys(): + if k not in section_file: + change_file = True + break + if section_file[k] != section[k]: + logger = create_logger(__name__, "./config_loader.log") + logger.warning("section [%s] in config file [%s] has been changed" % ( + section_name, self.file_path + )) + change_file = True + break + if not change_file: + return + + sect_list, sect_key_list = self._read_section() + if section_name not in sect_key_list: + raise AttributeError() + + sect, sect_key = sect_list[section_name] + for k in section.__dict__.keys(): + if k not in sect_key: + if sect_key[-1] != '\n': + sect_key.append('\n') + sect_key.append(k) + sect[k] = str(section[k]) + if isinstance(section[k], str): + sect[k] = "\"" + sect[k] + "\"" + sect[k] = sect[k] + "\n" + sect_list[section_name] = sect, sect_key + self._write_section(sect_list, sect_key_list) diff --git a/test/loader/config b/test/loader/config index a1922d28..b91e750d 100644 --- a/test/loader/config +++ b/test/loader/config @@ -1,7 +1,18 @@ [test] x = 1 + y = 2 + z = 3 +#this is an example input = [1,2,3] + text = "this is text" + doubles = 0.5 + +[t] +x = "this is an test section" + +[test-case-2] +double = 0.5 diff --git a/test/loader/test_loader.py b/test/loader/test_loader.py index d2f22166..740ff952 100644 --- a/test/loader/test_loader.py +++ b/test/loader/test_loader.py @@ -33,18 +33,16 @@ class TestConfigLoader(unittest.TestCase): test_arg = ConfigSection() ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) - # ConfigLoader("config").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", - # {"test": test_arg}) - #dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") - dict = read_section_from_config(os.path.join("./test/loader", "config"), "test") + section = read_section_from_config(os.path.join("./test/loader", "config"), "test") - for sec in dict: - if (sec not in test_arg) or (dict[sec] != test_arg[sec]): + + for sec in section: + if (sec not in test_arg) or (section[sec] != test_arg[sec]): raise AttributeError("ERROR") for sec in test_arg.__dict__.keys(): - if (sec not in dict) or (dict[sec] != test_arg[sec]): + if (sec not in section) or (section[sec] != test_arg[sec]): raise AttributeError("ERROR") try: @@ -71,4 +69,4 @@ class TestDatasetLoader(unittest.TestCase): loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") data = loader.load() datas = loader.load_lines() - print("pass TokenizeDatasetLoader test!") + print("pass TokenizeDatasetLoader test!") \ No newline at end of file diff --git a/test/saver/test_config_saver.py b/test/saver/test_config_saver.py new file mode 100644 index 00000000..45daf0c6 --- /dev/null +++ b/test/saver/test_config_saver.py @@ -0,0 +1,82 @@ +import os + +import unittest +import configparser +import json + +from fastNLP.loader.config_loader import ConfigSection, ConfigLoader +from fastNLP.saver.config_saver import ConfigSaver + + +class TestConfigSaver(unittest.TestCase): + def test_case_1(self): + config_file_dir = "./test/loader/" + config_file_name = "config" + config_file_path = os.path.join(config_file_dir, config_file_name) + + tmp_config_file_path = os.path.join(config_file_dir, "tmp_config") + + with open(config_file_path, "r") as f: + lines = f.readlines() + + standard_section = ConfigSection() + t_section = ConfigSection() + ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section}) + + config_saver = ConfigSaver(config_file_path) + + section = ConfigSection() + section["doubles"] = 0.8 + section["tt"] = 0.5 + section["test"] = 105 + section["str"] = "this is a str" + + test_case_2_section = section + test_case_2_section["double"] = 0.5 + + for k in section.__dict__.keys(): + standard_section[k] = section[k] + + config_saver.save_config_file("test", section) + config_saver.save_config_file("another-test", section) + config_saver.save_config_file("one-another-test", section) + config_saver.save_config_file("test-case-2", section) + + test_section = ConfigSection() + at_section = ConfigSection() + another_test_section = ConfigSection() + one_another_test_section = ConfigSection() + a_test_case_2_section = ConfigSection() + + ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section, + "another-test": another_test_section, + "t": at_section, + "one-another-test": one_another_test_section, + "test-case-2": a_test_case_2_section}) + + assert test_section == standard_section + assert at_section == t_section + assert another_test_section == section + assert one_another_test_section == section + assert a_test_case_2_section == test_case_2_section + + config_saver.save_config_file("test", section) + + with open(config_file_path, "w") as f: + f.writelines(lines) + + with open(tmp_config_file_path, "w") as f: + f.write('[test]\n') + f.write('this is an fault example\n') + + tmp_config_saver = ConfigSaver(tmp_config_file_path) + try: + tmp_config_saver._read_section() + except Exception as e: + pass + os.remove(tmp_config_file_path) + + try: + tmp_config_saver = ConfigSaver("file-NOT-exist") + except Exception as e: + pass