From dbf1e492fd2eb27e10df046224859ddc0735c291 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 2 Sep 2018 17:40:43 +0800 Subject: [PATCH 01/16] add config saver --- fastNLP/saver/config_saver.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 fastNLP/saver/config_saver.py diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py new file mode 100644 index 00000000..434a0de4 --- /dev/null +++ b/fastNLP/saver/config_saver.py @@ -0,0 +1,35 @@ +import os + +import json +import configparser + +from fastNLP.loader.config_loader import ConfigSection, ConfigLoader +from fastNLP.saver.logger import create_logger + +class ConfigSaver(object): + + def __init__(self, file_path): + self.file_path = file_path + + def save_section(self, section_name, section): + cfg = configparser.ConfigParser() + if not os.path.exists(self.file_path): + raise FileNotFoundError("config file {} not found. ".format(self.file_path)) + cfg.read(self.file_path) + if section_name not in cfg: + cfg.add_section(section_name) + gen_sec = cfg[section_name] + for key in section: + if key in gen_sec.keys(): + try: + val = json.load(gen_sec[key]) + except Exception as e: + print("cannot load attribute %s in section %s" + % (key, section_name)) + try: + assert section[key] == val + except Exception as e: + logger = create_logger(__name__, "./config_saver.log") + logger.warning("this is a warning #TODO") + cfg.set(section_name,key, section[key]) + cfg.write(open(self.file_path, 'w')) From edd9dedb5d365dad41fad293b884e0baf62972cf Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 12:11:38 +0800 Subject: [PATCH 02/16] add config saver --- fastNLP/saver/config_saver.py | 118 ++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py index 434a0de4..06522465 100644 --- a/fastNLP/saver/config_saver.py +++ b/fastNLP/saver/config_saver.py @@ -10,6 +10,8 @@ class ConfigSaver(object): def __init__(self, file_path): self.file_path = file_path + if not os.path.exists(self.file_path): + raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) def save_section(self, section_name, section): cfg = configparser.ConfigParser() @@ -33,3 +35,119 @@ class ConfigSaver(object): logger.warning("this is a warning #TODO") cfg.set(section_name,key, section[key]) cfg.write(open(self.file_path, 'w')) + + def save_config_file(self, section_name, section): + + def get_section(file_path, sect_name): + sect = ConfigSection() + ConfigLoader("", "").load_config(file_path, {sect_name: sect}) + return sect + + def read_section(file_path): + sect_name = None + + sect_list = {} + sect_key_list = [] + + single_section = {} + single_section_key = [] + + with open(file_path, 'r') as f: + lines = f.readlines() + + for line in lines: + if line.startswith('[') and line.endswith(']\n'): + if sect_name is None: + pass + else: + sect_list[sect_name] = single_section, single_section_key + single_section = {} + single_section_key = [] + sect_key_list.append(sect_name) + sect_name = line[1: -2] + continue + + if line.startswith('#'): + single_section[line] = '#' + single_section_key.append(line) + continue + + if line.startswith('\n'): + single_section_key.append('\n') + continue + + if '=' not in line: + log = create_logger(__name__, './config_saver.log') + log.error("can NOT load config file [%s]" % file_path) + raise RuntimeError("can NOT load config file {}".__format__(file_path)) + + key = line.split('=', maxsplit=1)[0].strip() + value = line.split('=', maxsplit=1)[1].strip() + '\n' + single_section[key] = value + single_section_key.append(key) + + if sect_name is not None: + sect_list[sect_name] = single_section, single_section_key + sect_key_list.append(sect_name) + + return sect_list, sect_key_list + + def write_section(file_path, sect_list, sect_key_list): + with open(file_path, 'w') as f: + for sect_key in sect_key_list: + single_section, single_section_key = sect_list[sect_key] + f.write('[' + sect_key + ']\n') + for key in single_section_key: + if key == '\n': + f.write('\n') + continue + if single_section[key] == '#': + f.write(key) + continue + f.write(key + ' = ' + single_section[key]) + f.write('\n') + + section_file = get_section(self.file_path, section_name) + if len(section_file.__dict__.keys()) == 0:#the section not in file before + with open(self.file_path, 'a') as f: + f.write('[' + section_name + ']\n') + for k in section.__dict__.keys(): + f.write(k + ' = ') + if isinstance(section[k], str): + f.write('\"' + str(section[k]) + '\"\n\n') + else: + f.write(str(section[k]) + '\n\n') + else: + change_file = False + for k in section.__dict__.keys(): + if k not in section_file: + change_file = True + break + if section_file[k] != section[k]: + logger = create_logger(__name__, "./config_loader.log") + logger.warning("section [%s] in config file [%s] has been changed" % ( + section_name, self.file_path + )) + change_file = True + break + if not change_file: + return + + sect_list, sect_key_list = read_section(self.file_path) + if section_name not in sect_key_list: + raise AttributeError() + + sect, sect_key = sect_list[section_name] + for k in section.__dict__.keys(): + if k not in sect_key: + sect_key.append('\n') + sect_key.append(k) + sect[k] = str(section[k]) + if isinstance(section[k], str): + sect[k] = "\"" + sect[k] + "\"" + sect[k] = sect[k] + "\n" + sect_list[section_name] = sect, sect_key + write_section(self.file_path, sect_list, sect_key_list) + + + From 2dd2f0c8f4784c38013a788f4ca1fcca9dc86f54 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 12:12:16 +0800 Subject: [PATCH 03/16] update config file for test --- test/loader/config | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/loader/config b/test/loader/config index a1922d28..675180e4 100644 --- a/test/loader/config +++ b/test/loader/config @@ -1,7 +1,12 @@ [test] x = 1 + y = 2 + z = 3 +#this is an example input = [1,2,3] + text = "this is text" + doubles = 0.5 From 476988573b2177ed195b612693ece3df5eb470f8 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 12:13:42 +0800 Subject: [PATCH 04/16] add test code for testing config saver --- test/saver/test_config_saver.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 test/saver/test_config_saver.py diff --git a/test/saver/test_config_saver.py b/test/saver/test_config_saver.py new file mode 100644 index 00000000..064b66fa --- /dev/null +++ b/test/saver/test_config_saver.py @@ -0,0 +1,21 @@ +import os + +import unittest + +from fastNLP.loader.config_loader import ConfigSection, ConfigLoader +from fastNLP.saver.config_saver import ConfigSaver + + +class TestConfigSaver(unittest.TestCase): + def test_case_1(self): + config_saver = ConfigSaver("./test/loader/config") + #config_saver = ConfigSaver("./../loader/config") + + section = ConfigSection() + section["test"] = 105 + section["tt"] = 0.5 + section["str"] = "this is a str" + config_saver.save_config_file("test", section) + config_saver.save_config_file("another-test", section) + config_saver.save_config_file("one-another-test", section) + From 58ccb6576fd498b61f2999d855bb19c59498bcd9 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 12:30:33 +0800 Subject: [PATCH 05/16] clean up codes --- test/loader/test_loader.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/loader/test_loader.py b/test/loader/test_loader.py index fe826a6f..e5ddc3a6 100644 --- a/test/loader/test_loader.py +++ b/test/loader/test_loader.py @@ -34,18 +34,14 @@ class TestConfigLoader(unittest.TestCase): test_arg = ConfigSection() ConfigLoader("config", "").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) - #ConfigLoader("config", "").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", - # {"test": test_arg}) + section = read_section_from_config(os.path.join("./test/loader", "config"), "test") - #dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") - dict = read_section_from_config(os.path.join("./test/loader", "config"), "test") - - for sec in dict: - if (sec not in test_arg) or (dict[sec] != test_arg[sec]): + for sec in section: + if (sec not in test_arg) or (section[sec] != test_arg[sec]): raise AttributeError("ERROR") for sec in test_arg.__dict__.keys(): - if (sec not in dict) or (dict[sec] != test_arg[sec]): + if (sec not in section) or (section[sec] != test_arg[sec]): raise AttributeError("ERROR") try: From 7c57bc6fc924ee8ed2e88fbc08a4397948cd0da8 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 12:32:39 +0800 Subject: [PATCH 06/16] fix a bug for config saver --- fastNLP/saver/config_saver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py index 06522465..bc682733 100644 --- a/fastNLP/saver/config_saver.py +++ b/fastNLP/saver/config_saver.py @@ -13,6 +13,7 @@ class ConfigSaver(object): if not os.path.exists(self.file_path): raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) + """ def save_section(self, section_name, section): cfg = configparser.ConfigParser() if not os.path.exists(self.file_path): @@ -35,12 +36,13 @@ class ConfigSaver(object): logger.warning("this is a warning #TODO") cfg.set(section_name,key, section[key]) cfg.write(open(self.file_path, 'w')) + """ def save_config_file(self, section_name, section): def get_section(file_path, sect_name): sect = ConfigSection() - ConfigLoader("", "").load_config(file_path, {sect_name: sect}) + ConfigLoader(file_path).load_config(file_path, {sect_name: sect}) return sect def read_section(file_path): From 7fb2bcc78c89092845a05c88ae49dffab8c3e42d Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 12:36:56 +0800 Subject: [PATCH 07/16] update config loader --- fastNLP/loader/config_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index 9e3ebc1c..20d791c4 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -9,7 +9,7 @@ class ConfigLoader(BaseLoader): """loader for configuration files""" def __int__(self, data_name, data_path): - super(ConfigLoader, self).__init__(data_name, data_path) + super(ConfigLoader, self).__init__(data_path) self.config = self.parse(super(ConfigLoader, self).load()) @staticmethod @@ -100,7 +100,7 @@ class ConfigSection(object): if __name__ == "__main__": - config = ConfigLoader('configLoader', 'there is no data') + config = ConfigLoader('there is no data') section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} """ From 12f06d09d20c00f081dc6ead3add9392a3318dd7 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 12:39:47 +0800 Subject: [PATCH 08/16] clean up code in test loader --- test/loader/test_loader.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/test/loader/test_loader.py b/test/loader/test_loader.py index e5ddc3a6..85f6232d 100644 --- a/test/loader/test_loader.py +++ b/test/loader/test_loader.py @@ -1,13 +1,12 @@ -import os import configparser - import json +import os import unittest - from fastNLP.loader.config_loader import ConfigSection, ConfigLoader from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader + class TestConfigLoader(unittest.TestCase): def test_case_ConfigLoader(self): @@ -33,7 +32,7 @@ class TestConfigLoader(unittest.TestCase): return dict test_arg = ConfigSection() - ConfigLoader("config", "").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) + ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) section = read_section_from_config(os.path.join("./test/loader", "config"), "test") for sec in section: @@ -54,18 +53,18 @@ class TestConfigLoader(unittest.TestCase): class TestDatasetLoader(unittest.TestCase): def test_case_TokenizeDatasetLoader(self): - loader = TokenizeDatasetLoader("cws_pku_utf_8", "./test/data_for_tests/cws_pku_utf_8") + loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8") data = loader.load_pku(max_seq_len=32) print("pass TokenizeDatasetLoader test!") def test_case_POSDatasetLoader(self): - loader = POSDatasetLoader("people", "./test/data_for_tests/people.txt") + loader = POSDatasetLoader("./test/data_for_tests/people.txt") data = loader.load() datas = loader.load_lines() print("pass POSDatasetLoader test!") def test_case_LMDatasetLoader(self): - loader = LMDatasetLoader("cws_pku_utf_8", "./test/data_for_tests/cws_pku_utf_8") + loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") data = loader.load() datas = loader.load_lines() - print("pass TokenizeDatasetLoader test!") + print("pass TokenizeDatasetLoader test!") \ No newline at end of file From 6a77731d86ba19411e81556ea01b4845283594e4 Mon Sep 17 00:00:00 2001 From: lyhuang Date: Sun, 9 Sep 2018 14:08:05 +0800 Subject: [PATCH 09/16] add tensorboardX --- docs/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 2809876b..b4dd10cc 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ numpy>=1.14.2 http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl torchvision>=0.1.8 -sphinx-rtd-theme==0.4.1 \ No newline at end of file +sphinx-rtd-theme==0.4.1 +tensorboardX \ No newline at end of file From a7fa63a0dbb1e6dd552bd23f6fa0a80854410eb6 Mon Sep 17 00:00:00 2001 From: lyhuang Date: Sun, 9 Sep 2018 14:11:47 +0800 Subject: [PATCH 10/16] add tensorboardX --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index b4dd10cc..294a44d0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,4 +2,4 @@ numpy>=1.14.2 http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl torchvision>=0.1.8 sphinx-rtd-theme==0.4.1 -tensorboardX \ No newline at end of file +tensorboardX>=1.4 \ No newline at end of file From 48c1e8700b3736b0e7f55d4d343b5e40c7764590 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 16:26:16 +0800 Subject: [PATCH 11/16] fix code style of config saver --- fastNLP/saver/config_saver.py | 189 ++++++++++++++++------------------ 1 file changed, 89 insertions(+), 100 deletions(-) diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py index bc682733..de06f8b0 100644 --- a/fastNLP/saver/config_saver.py +++ b/fastNLP/saver/config_saver.py @@ -13,103 +13,95 @@ class ConfigSaver(object): if not os.path.exists(self.file_path): raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) - """ - def save_section(self, section_name, section): - cfg = configparser.ConfigParser() - if not os.path.exists(self.file_path): - raise FileNotFoundError("config file {} not found. ".format(self.file_path)) - cfg.read(self.file_path) - if section_name not in cfg: - cfg.add_section(section_name) - gen_sec = cfg[section_name] - for key in section: - if key in gen_sec.keys(): - try: - val = json.load(gen_sec[key]) - except Exception as e: - print("cannot load attribute %s in section %s" - % (key, section_name)) - try: - assert section[key] == val - except Exception as e: - logger = create_logger(__name__, "./config_saver.log") - logger.warning("this is a warning #TODO") - cfg.set(section_name,key, section[key]) - cfg.write(open(self.file_path, 'w')) - """ + def _get_section(self, sect_name): + """ + :param sect_name: the name of section what wants to load + :return: the section + """ + sect = ConfigSection() + ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) + return sect + + def _read_section(self): + """ + :return: sect_list, sect_key_list + sect_list is a list of ConfigSection() + sect_key_list is a list of names in sect_list + """ + sect_name = None + + sect_list = {} + sect_key_list = [] + + single_section = {} + single_section_key = [] + + with open(self.file_path, 'r') as f: + lines = f.readlines() + + for line in lines: + if line.startswith('[') and line.endswith(']\n'): + if sect_name is None: + pass + else: + sect_list[sect_name] = single_section, single_section_key + single_section = {} + single_section_key = [] + sect_key_list.append(sect_name) + sect_name = line[1: -2] + continue + + if line.startswith('#'): + single_section[line] = '#' + single_section_key.append(line) + continue + + if line.startswith('\n'): + single_section_key.append('\n') + continue + + if '=' not in line: + log = create_logger(__name__, './config_saver.log') + log.error("can NOT load config file [%s]" % self.file_path) + raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) + + key = line.split('=', maxsplit=1)[0].strip() + value = line.split('=', maxsplit=1)[1].strip() + '\n' + single_section[key] = value + single_section_key.append(key) + + if sect_name is not None: + sect_list[sect_name] = single_section, single_section_key + sect_key_list.append(sect_name) + return sect_list, sect_key_list + + def _write_section(self, sect_list, sect_key_list): + """ + :param sect_list: a list of ConfigSection() need to be writen into file + :param sect_key_list: a list of name of sect_list + :return: + """ + with open(self.file_path, 'w') as f: + for sect_key in sect_key_list: + single_section, single_section_key = sect_list[sect_key] + f.write('[' + sect_key + ']\n') + for key in single_section_key: + if key == '\n': + f.write('\n') + continue + if single_section[key] == '#': + f.write(key) + continue + f.write(key + ' = ' + single_section[key]) + f.write('\n') def save_config_file(self, section_name, section): - - def get_section(file_path, sect_name): - sect = ConfigSection() - ConfigLoader(file_path).load_config(file_path, {sect_name: sect}) - return sect - - def read_section(file_path): - sect_name = None - - sect_list = {} - sect_key_list = [] - - single_section = {} - single_section_key = [] - - with open(file_path, 'r') as f: - lines = f.readlines() - - for line in lines: - if line.startswith('[') and line.endswith(']\n'): - if sect_name is None: - pass - else: - sect_list[sect_name] = single_section, single_section_key - single_section = {} - single_section_key = [] - sect_key_list.append(sect_name) - sect_name = line[1: -2] - continue - - if line.startswith('#'): - single_section[line] = '#' - single_section_key.append(line) - continue - - if line.startswith('\n'): - single_section_key.append('\n') - continue - - if '=' not in line: - log = create_logger(__name__, './config_saver.log') - log.error("can NOT load config file [%s]" % file_path) - raise RuntimeError("can NOT load config file {}".__format__(file_path)) - - key = line.split('=', maxsplit=1)[0].strip() - value = line.split('=', maxsplit=1)[1].strip() + '\n' - single_section[key] = value - single_section_key.append(key) - - if sect_name is not None: - sect_list[sect_name] = single_section, single_section_key - sect_key_list.append(sect_name) - - return sect_list, sect_key_list - - def write_section(file_path, sect_list, sect_key_list): - with open(file_path, 'w') as f: - for sect_key in sect_key_list: - single_section, single_section_key = sect_list[sect_key] - f.write('[' + sect_key + ']\n') - for key in single_section_key: - if key == '\n': - f.write('\n') - continue - if single_section[key] == '#': - f.write(key) - continue - f.write(key + ' = ' + single_section[key]) - f.write('\n') - - section_file = get_section(self.file_path, section_name) + """ + :param section_name: the name of section what needs to be changed and saved + :param section: the section with key and value what needs to be changed and saved + :return: + """ + section_file = self._get_section(section_name) if len(section_file.__dict__.keys()) == 0:#the section not in file before with open(self.file_path, 'a') as f: f.write('[' + section_name + ']\n') @@ -135,7 +127,7 @@ class ConfigSaver(object): if not change_file: return - sect_list, sect_key_list = read_section(self.file_path) + sect_list, sect_key_list = self._read_section() if section_name not in sect_key_list: raise AttributeError() @@ -149,7 +141,4 @@ class ConfigSaver(object): sect[k] = "\"" + sect[k] + "\"" sect[k] = sect[k] + "\n" sect_list[section_name] = sect, sect_key - write_section(self.file_path, sect_list, sect_key_list) - - - + self._write_section(sect_list, sect_key_list) From 534bc675210486f509e51aaa2ea8ac4472525860 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 17:53:32 +0800 Subject: [PATCH 12/16] overwrite '==' operator and '!=' operator in ConfigSection class --- fastNLP/loader/config_loader.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index 20d791c4..94871222 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -92,8 +92,40 @@ class ConfigSection(object): setattr(self, key, value) def __contains__(self, item): + """ + :param item: The key of item. + :return: True if the key in self.__dict__.keys() else False. + """ return item in self.__dict__.keys() + def __eq__(self, other): + """Overwrite the == operator + + :param other: Another ConfigSection() object which to be compared. + :return: True if value of each key in each ConfigSection() object are equal to the other, else False. + """ + for k in self.__dict__.keys(): + if k not in other.__dict__.keys(): + return False + if getattr(self, k) != getattr(self, k): + return False + + for k in other.__dict__.keys(): + if k not in self.__dict__.keys(): + return False + if getattr(self, k) != getattr(self, k): + return False + + return True + + def __ne__(self, other): + """Overwrite the != operator + + :param other: + :return: + """ + return not self.__eq__(other) + @property def data(self): return self.__dict__ From bbb02d0c1f4b35b4d35f46076d449916e5748638 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 17:54:49 +0800 Subject: [PATCH 13/16] clean up the code in config saver --- fastNLP/saver/config_saver.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py index de06f8b0..e05e865e 100644 --- a/fastNLP/saver/config_saver.py +++ b/fastNLP/saver/config_saver.py @@ -1,11 +1,9 @@ import os -import json -import configparser - from fastNLP.loader.config_loader import ConfigSection, ConfigLoader from fastNLP.saver.logger import create_logger + class ConfigSaver(object): def __init__(self, file_path): @@ -14,19 +12,21 @@ class ConfigSaver(object): raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) def _get_section(self, sect_name): - """ - :param sect_name: the name of section what wants to load - :return: the section + """This is the function to get the section with the section name. + + :param sect_name: The name of section what wants to load. + :return: The section. """ sect = ConfigSection() ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) return sect def _read_section(self): - """ + """This is the function to read sections from the config file. + :return: sect_list, sect_key_list - sect_list is a list of ConfigSection() - sect_key_list is a list of names in sect_list + sect_list: A list of ConfigSection(). + sect_key_list: A list of names in sect_list. """ sect_name = None @@ -76,9 +76,10 @@ class ConfigSaver(object): return sect_list, sect_key_list def _write_section(self, sect_list, sect_key_list): - """ - :param sect_list: a list of ConfigSection() need to be writen into file - :param sect_key_list: a list of name of sect_list + """This is the function to write config file with section list and name list. + + :param sect_list: A list of ConfigSection() need to be writen into file. + :param sect_key_list: A list of name of sect_list. :return: """ with open(self.file_path, 'w') as f: @@ -96,9 +97,10 @@ class ConfigSaver(object): f.write('\n') def save_config_file(self, section_name, section): - """ - :param section_name: the name of section what needs to be changed and saved - :param section: the section with key and value what needs to be changed and saved + """This is the function to be called to change the config file with a single section and its name. + + :param section_name: The name of section what needs to be changed and saved. + :param section: The section with key and value what needs to be changed and saved. :return: """ section_file = self._get_section(section_name) @@ -134,7 +136,8 @@ class ConfigSaver(object): sect, sect_key = sect_list[section_name] for k in section.__dict__.keys(): if k not in sect_key: - sect_key.append('\n') + if sect_key[-1] != '\n': + sect_key.append('\n') sect_key.append(k) sect[k] = str(section[k]) if isinstance(section[k], str): From 7138ff210f483690b36ef7c1de6eab56d2238a67 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 17:55:30 +0800 Subject: [PATCH 14/16] update config file for testing code, add more sections for testing. --- test/loader/config | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/loader/config b/test/loader/config index 675180e4..b91e750d 100644 --- a/test/loader/config +++ b/test/loader/config @@ -10,3 +10,9 @@ input = [1,2,3] text = "this is text" doubles = 0.5 + +[t] +x = "this is an test section" + +[test-case-2] +double = 0.5 From 6ddf5fcdcd27552198aaccef82566c0911c2f2c2 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 17:55:54 +0800 Subject: [PATCH 15/16] update test code for testing config saver --- test/saver/test_config_saver.py | 67 +++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/test/saver/test_config_saver.py b/test/saver/test_config_saver.py index 064b66fa..18447c90 100644 --- a/test/saver/test_config_saver.py +++ b/test/saver/test_config_saver.py @@ -1,6 +1,8 @@ import os import unittest +import configparser +import json from fastNLP.loader.config_loader import ConfigSection, ConfigLoader from fastNLP.saver.config_saver import ConfigSaver @@ -8,14 +10,73 @@ from fastNLP.saver.config_saver import ConfigSaver class TestConfigSaver(unittest.TestCase): def test_case_1(self): - config_saver = ConfigSaver("./test/loader/config") - #config_saver = ConfigSaver("./../loader/config") + config_file_dir = "./../loader/" + config_file_name = "config" + config_file_path = os.path.join(config_file_dir, config_file_name) + + tmp_config_file_path = os.path.join(config_file_dir, "tmp_config") + + with open(config_file_path, "r") as f: + lines = f.readlines() + + standard_section = ConfigSection() + t_section = ConfigSection() + ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section}) + + config_saver = ConfigSaver(config_file_path) section = ConfigSection() - section["test"] = 105 + section["doubles"] = 0.8 section["tt"] = 0.5 + section["test"] = 105 section["str"] = "this is a str" + + test_case_2_section = section + test_case_2_section["double"] = 0.5 + + for k in section.__dict__.keys(): + standard_section[k] = section[k] + config_saver.save_config_file("test", section) config_saver.save_config_file("another-test", section) config_saver.save_config_file("one-another-test", section) + config_saver.save_config_file("test-case-2", section) + + test_section = ConfigSection() + at_section = ConfigSection() + another_test_section = ConfigSection() + one_another_test_section = ConfigSection() + a_test_case_2_section = ConfigSection() + + ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section, + "another-test": another_test_section, + "t": at_section, + "one-another-test": one_another_test_section, + "test-case-2": a_test_case_2_section}) + + assert test_section == standard_section + assert at_section == t_section + assert another_test_section == section + assert one_another_test_section == section + assert a_test_case_2_section == test_case_2_section + + config_saver.save_config_file("test", section) + + with open(config_file_path, "w") as f: + f.writelines(lines) + + with open(tmp_config_file_path, "w") as f: + f.write('[test]\n') + f.write('this is an fault example\n') + + tmp_config_saver = ConfigSaver(tmp_config_file_path) + try: + tmp_config_saver._read_section() + except Exception as e: + pass + os.remove(tmp_config_file_path) + try: + tmp_config_saver = ConfigSaver("file-NOT-exist") + except Exception as e: + pass From aac7982e93c1d55dc96ee3093c09ca58ecd2f0ce Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 17:57:09 +0800 Subject: [PATCH 16/16] fix a bug in config saver testing code --- test/saver/test_config_saver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/saver/test_config_saver.py b/test/saver/test_config_saver.py index 18447c90..45daf0c6 100644 --- a/test/saver/test_config_saver.py +++ b/test/saver/test_config_saver.py @@ -10,7 +10,7 @@ from fastNLP.saver.config_saver import ConfigSaver class TestConfigSaver(unittest.TestCase): def test_case_1(self): - config_file_dir = "./../loader/" + config_file_dir = "./test/loader/" config_file_name = "config" config_file_path = os.path.join(config_file_dir, config_file_name)