From 48c1e8700b3736b0e7f55d4d343b5e40c7764590 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 9 Sep 2018 16:26:16 +0800 Subject: [PATCH] fix code style of config saver --- fastNLP/saver/config_saver.py | 189 ++++++++++++++++------------------ 1 file changed, 89 insertions(+), 100 deletions(-) diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py index bc682733..de06f8b0 100644 --- a/fastNLP/saver/config_saver.py +++ b/fastNLP/saver/config_saver.py @@ -13,103 +13,95 @@ class ConfigSaver(object): if not os.path.exists(self.file_path): raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) - """ - def save_section(self, section_name, section): - cfg = configparser.ConfigParser() - if not os.path.exists(self.file_path): - raise FileNotFoundError("config file {} not found. ".format(self.file_path)) - cfg.read(self.file_path) - if section_name not in cfg: - cfg.add_section(section_name) - gen_sec = cfg[section_name] - for key in section: - if key in gen_sec.keys(): - try: - val = json.load(gen_sec[key]) - except Exception as e: - print("cannot load attribute %s in section %s" - % (key, section_name)) - try: - assert section[key] == val - except Exception as e: - logger = create_logger(__name__, "./config_saver.log") - logger.warning("this is a warning #TODO") - cfg.set(section_name,key, section[key]) - cfg.write(open(self.file_path, 'w')) - """ + def _get_section(self, sect_name): + """ + :param sect_name: the name of section what wants to load + :return: the section + """ + sect = ConfigSection() + ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) + return sect + + def _read_section(self): + """ + :return: sect_list, sect_key_list + sect_list is a list of ConfigSection() + sect_key_list is a list of names in sect_list + """ + sect_name = None + + sect_list = {} + sect_key_list = [] + + single_section = {} + single_section_key = [] + + with open(self.file_path, 'r') as f: + lines = f.readlines() + + for line in lines: + if line.startswith('[') and line.endswith(']\n'): + if sect_name is None: + pass + else: + sect_list[sect_name] = single_section, single_section_key + single_section = {} + single_section_key = [] + sect_key_list.append(sect_name) + sect_name = line[1: -2] + continue + + if line.startswith('#'): + single_section[line] = '#' + single_section_key.append(line) + continue + + if line.startswith('\n'): + single_section_key.append('\n') + continue + + if '=' not in line: + log = create_logger(__name__, './config_saver.log') + log.error("can NOT load config file [%s]" % self.file_path) + raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) + + key = line.split('=', maxsplit=1)[0].strip() + value = line.split('=', maxsplit=1)[1].strip() + '\n' + single_section[key] = value + single_section_key.append(key) + + if sect_name is not None: + sect_list[sect_name] = single_section, single_section_key + sect_key_list.append(sect_name) + return sect_list, sect_key_list + + def _write_section(self, sect_list, sect_key_list): + """ + :param sect_list: a list of ConfigSection() need to be writen into file + :param sect_key_list: a list of name of sect_list + :return: + """ + with open(self.file_path, 'w') as f: + for sect_key in sect_key_list: + single_section, single_section_key = sect_list[sect_key] + f.write('[' + sect_key + ']\n') + for key in single_section_key: + if key == '\n': + f.write('\n') + continue + if single_section[key] == '#': + f.write(key) + continue + f.write(key + ' = ' + single_section[key]) + f.write('\n') def save_config_file(self, section_name, section): - - def get_section(file_path, sect_name): - sect = ConfigSection() - ConfigLoader(file_path).load_config(file_path, {sect_name: sect}) - return sect - - def read_section(file_path): - sect_name = None - - sect_list = {} - sect_key_list = [] - - single_section = {} - single_section_key = [] - - with open(file_path, 'r') as f: - lines = f.readlines() - - for line in lines: - if line.startswith('[') and line.endswith(']\n'): - if sect_name is None: - pass - else: - sect_list[sect_name] = single_section, single_section_key - single_section = {} - single_section_key = [] - sect_key_list.append(sect_name) - sect_name = line[1: -2] - continue - - if line.startswith('#'): - single_section[line] = '#' - single_section_key.append(line) - continue - - if line.startswith('\n'): - single_section_key.append('\n') - continue - - if '=' not in line: - log = create_logger(__name__, './config_saver.log') - log.error("can NOT load config file [%s]" % file_path) - raise RuntimeError("can NOT load config file {}".__format__(file_path)) - - key = line.split('=', maxsplit=1)[0].strip() - value = line.split('=', maxsplit=1)[1].strip() + '\n' - single_section[key] = value - single_section_key.append(key) - - if sect_name is not None: - sect_list[sect_name] = single_section, single_section_key - sect_key_list.append(sect_name) - - return sect_list, sect_key_list - - def write_section(file_path, sect_list, sect_key_list): - with open(file_path, 'w') as f: - for sect_key in sect_key_list: - single_section, single_section_key = sect_list[sect_key] - f.write('[' + sect_key + ']\n') - for key in single_section_key: - if key == '\n': - f.write('\n') - continue - if single_section[key] == '#': - f.write(key) - continue - f.write(key + ' = ' + single_section[key]) - f.write('\n') - - section_file = get_section(self.file_path, section_name) + """ + :param section_name: the name of section what needs to be changed and saved + :param section: the section with key and value what needs to be changed and saved + :return: + """ + section_file = self._get_section(section_name) if len(section_file.__dict__.keys()) == 0:#the section not in file before with open(self.file_path, 'a') as f: f.write('[' + section_name + ']\n') @@ -135,7 +127,7 @@ class ConfigSaver(object): if not change_file: return - sect_list, sect_key_list = read_section(self.file_path) + sect_list, sect_key_list = self._read_section() if section_name not in sect_key_list: raise AttributeError() @@ -149,7 +141,4 @@ class ConfigSaver(object): sect[k] = "\"" + sect[k] + "\"" sect[k] = sect[k] + "\n" sect_list[section_name] = sect, sect_key - write_section(self.file_path, sect_list, sect_key_list) - - - + self._write_section(sect_list, sect_key_list)