Browse Source

Merge pull request #65 from xuyige/test_code

add config saver
tags/v0.1.0
Coet GitHub 6 years ago
parent
commit
4bcfc5f930
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 278 additions and 8 deletions
  1. +32
    -0
      fastNLP/loader/config_loader.py
  2. +147
    -0
      fastNLP/saver/config_saver.py
  3. +11
    -0
      test/loader/config
  4. +6
    -8
      test/loader/test_loader.py
  5. +82
    -0
      test/saver/test_config_saver.py

+ 32
- 0
fastNLP/loader/config_loader.py View File

@@ -92,8 +92,40 @@ class ConfigSection(object):
setattr(self, key, value) setattr(self, key, value)


def __contains__(self, item): def __contains__(self, item):
"""
:param item: The key of item.
:return: True if the key in self.__dict__.keys() else False.
"""
return item in self.__dict__.keys() return item in self.__dict__.keys()


def __eq__(self, other):
"""Overwrite the == operator

:param other: Another ConfigSection() object which to be compared.
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
"""
for k in self.__dict__.keys():
if k not in other.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

for k in other.__dict__.keys():
if k not in self.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

return True

def __ne__(self, other):
"""Overwrite the != operator

:param other:
:return:
"""
return not self.__eq__(other)

@property @property
def data(self): def data(self):
return self.__dict__ return self.__dict__


+ 147
- 0
fastNLP/saver/config_saver.py View File

@@ -0,0 +1,147 @@
import os

from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.saver.logger import create_logger


class ConfigSaver(object):

def __init__(self, file_path):
self.file_path = file_path
if not os.path.exists(self.file_path):
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path))

def _get_section(self, sect_name):
"""This is the function to get the section with the section name.

:param sect_name: The name of section what wants to load.
:return: The section.
"""
sect = ConfigSection()
ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect})
return sect

def _read_section(self):
"""This is the function to read sections from the config file.

:return: sect_list, sect_key_list
sect_list: A list of ConfigSection().
sect_key_list: A list of names in sect_list.
"""
sect_name = None

sect_list = {}
sect_key_list = []

single_section = {}
single_section_key = []

with open(self.file_path, 'r') as f:
lines = f.readlines()

for line in lines:
if line.startswith('[') and line.endswith(']\n'):
if sect_name is None:
pass
else:
sect_list[sect_name] = single_section, single_section_key
single_section = {}
single_section_key = []
sect_key_list.append(sect_name)
sect_name = line[1: -2]
continue

if line.startswith('#'):
single_section[line] = '#'
single_section_key.append(line)
continue

if line.startswith('\n'):
single_section_key.append('\n')
continue

if '=' not in line:
log = create_logger(__name__, './config_saver.log')
log.error("can NOT load config file [%s]" % self.file_path)
raise RuntimeError("can NOT load config file {}".__format__(self.file_path))

key = line.split('=', maxsplit=1)[0].strip()
value = line.split('=', maxsplit=1)[1].strip() + '\n'
single_section[key] = value
single_section_key.append(key)

if sect_name is not None:
sect_list[sect_name] = single_section, single_section_key
sect_key_list.append(sect_name)
return sect_list, sect_key_list

def _write_section(self, sect_list, sect_key_list):
"""This is the function to write config file with section list and name list.

:param sect_list: A list of ConfigSection() need to be writen into file.
:param sect_key_list: A list of name of sect_list.
:return:
"""
with open(self.file_path, 'w') as f:
for sect_key in sect_key_list:
single_section, single_section_key = sect_list[sect_key]
f.write('[' + sect_key + ']\n')
for key in single_section_key:
if key == '\n':
f.write('\n')
continue
if single_section[key] == '#':
f.write(key)
continue
f.write(key + ' = ' + single_section[key])
f.write('\n')

def save_config_file(self, section_name, section):
"""This is the function to be called to change the config file with a single section and its name.

:param section_name: The name of section what needs to be changed and saved.
:param section: The section with key and value what needs to be changed and saved.
:return:
"""
section_file = self._get_section(section_name)
if len(section_file.__dict__.keys()) == 0:#the section not in file before
with open(self.file_path, 'a') as f:
f.write('[' + section_name + ']\n')
for k in section.__dict__.keys():
f.write(k + ' = ')
if isinstance(section[k], str):
f.write('\"' + str(section[k]) + '\"\n\n')
else:
f.write(str(section[k]) + '\n\n')
else:
change_file = False
for k in section.__dict__.keys():
if k not in section_file:
change_file = True
break
if section_file[k] != section[k]:
logger = create_logger(__name__, "./config_loader.log")
logger.warning("section [%s] in config file [%s] has been changed" % (
section_name, self.file_path
))
change_file = True
break
if not change_file:
return

sect_list, sect_key_list = self._read_section()
if section_name not in sect_key_list:
raise AttributeError()

sect, sect_key = sect_list[section_name]
for k in section.__dict__.keys():
if k not in sect_key:
if sect_key[-1] != '\n':
sect_key.append('\n')
sect_key.append(k)
sect[k] = str(section[k])
if isinstance(section[k], str):
sect[k] = "\"" + sect[k] + "\""
sect[k] = sect[k] + "\n"
sect_list[section_name] = sect, sect_key
self._write_section(sect_list, sect_key_list)

+ 11
- 0
test/loader/config View File

@@ -1,7 +1,18 @@
[test] [test]
x = 1 x = 1

y = 2 y = 2

z = 3 z = 3
#this is an example
input = [1,2,3] input = [1,2,3]

text = "this is text" text = "this is text"

doubles = 0.5 doubles = 0.5

[t]
x = "this is an test section"

[test-case-2]
double = 0.5

+ 6
- 8
test/loader/test_loader.py View File

@@ -33,18 +33,16 @@ class TestConfigLoader(unittest.TestCase):


test_arg = ConfigSection() test_arg = ConfigSection()
ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg})
# ConfigLoader("config").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config",
# {"test": test_arg})


#dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test")
dict = read_section_from_config(os.path.join("./test/loader", "config"), "test")
section = read_section_from_config(os.path.join("./test/loader", "config"), "test")


for sec in dict:
if (sec not in test_arg) or (dict[sec] != test_arg[sec]):

for sec in section:
if (sec not in test_arg) or (section[sec] != test_arg[sec]):
raise AttributeError("ERROR") raise AttributeError("ERROR")


for sec in test_arg.__dict__.keys(): for sec in test_arg.__dict__.keys():
if (sec not in dict) or (dict[sec] != test_arg[sec]):
if (sec not in section) or (section[sec] != test_arg[sec]):
raise AttributeError("ERROR") raise AttributeError("ERROR")


try: try:
@@ -71,4 +69,4 @@ class TestDatasetLoader(unittest.TestCase):
loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8")
data = loader.load() data = loader.load()
datas = loader.load_lines() datas = loader.load_lines()
print("pass TokenizeDatasetLoader test!")
print("pass TokenizeDatasetLoader test!")

+ 82
- 0
test/saver/test_config_saver.py View File

@@ -0,0 +1,82 @@
import os

import unittest
import configparser
import json

from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.saver.config_saver import ConfigSaver


class TestConfigSaver(unittest.TestCase):
def test_case_1(self):
config_file_dir = "./test/loader/"
config_file_name = "config"
config_file_path = os.path.join(config_file_dir, config_file_name)

tmp_config_file_path = os.path.join(config_file_dir, "tmp_config")

with open(config_file_path, "r") as f:
lines = f.readlines()

standard_section = ConfigSection()
t_section = ConfigSection()
ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section})

config_saver = ConfigSaver(config_file_path)

section = ConfigSection()
section["doubles"] = 0.8
section["tt"] = 0.5
section["test"] = 105
section["str"] = "this is a str"

test_case_2_section = section
test_case_2_section["double"] = 0.5

for k in section.__dict__.keys():
standard_section[k] = section[k]

config_saver.save_config_file("test", section)
config_saver.save_config_file("another-test", section)
config_saver.save_config_file("one-another-test", section)
config_saver.save_config_file("test-case-2", section)

test_section = ConfigSection()
at_section = ConfigSection()
another_test_section = ConfigSection()
one_another_test_section = ConfigSection()
a_test_case_2_section = ConfigSection()

ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section,
"another-test": another_test_section,
"t": at_section,
"one-another-test": one_another_test_section,
"test-case-2": a_test_case_2_section})

assert test_section == standard_section
assert at_section == t_section
assert another_test_section == section
assert one_another_test_section == section
assert a_test_case_2_section == test_case_2_section

config_saver.save_config_file("test", section)

with open(config_file_path, "w") as f:
f.writelines(lines)

with open(tmp_config_file_path, "w") as f:
f.write('[test]\n')
f.write('this is an fault example\n')

tmp_config_saver = ConfigSaver(tmp_config_file_path)
try:
tmp_config_saver._read_section()
except Exception as e:
pass
os.remove(tmp_config_file_path)

try:
tmp_config_saver = ConfigSaver("file-NOT-exist")
except Exception as e:
pass

Loading…
Cancel
Save