|
|
@@ -1,5 +1,6 @@ |
|
|
import configparser |
|
|
import configparser |
|
|
import json |
|
|
import json |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
from fastNLP.loader.base_loader import BaseLoader |
|
|
from fastNLP.loader.base_loader import BaseLoader |
|
|
|
|
|
|
|
|
@@ -23,6 +24,8 @@ class ConfigLoader(BaseLoader): |
|
|
:return: |
|
|
:return: |
|
|
""" |
|
|
""" |
|
|
cfg = configparser.ConfigParser() |
|
|
cfg = configparser.ConfigParser() |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
|
|
|
raise FileNotFoundError("config file {} not found. ".format(file_path)) |
|
|
cfg.read(file_path) |
|
|
cfg.read(file_path) |
|
|
for s in sections: |
|
|
for s in sections: |
|
|
attr_list = [i for i in sections[s].__dict__.keys() if |
|
|
attr_list = [i for i in sections[s].__dict__.keys() if |
|
|
@@ -34,7 +37,7 @@ class ConfigLoader(BaseLoader): |
|
|
for attr in gen_sec.keys(): |
|
|
for attr in gen_sec.keys(): |
|
|
try: |
|
|
try: |
|
|
val = json.loads(gen_sec[attr]) |
|
|
val = json.loads(gen_sec[attr]) |
|
|
#print(s, attr, val, type(val)) |
|
|
|
|
|
|
|
|
# print(s, attr, val, type(val)) |
|
|
if attr in attr_list: |
|
|
if attr in attr_list: |
|
|
assert type(val) == type(getattr(sections[s], attr)), \ |
|
|
assert type(val) == type(getattr(sections[s], attr)), \ |
|
|
'type not match, except %s but got %s' % \ |
|
|
'type not match, except %s but got %s' % \ |
|
|
@@ -50,6 +53,7 @@ class ConfigLoader(BaseLoader): |
|
|
% (attr, s)) |
|
|
% (attr, s)) |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigSection(object): |
|
|
class ConfigSection(object): |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
@@ -57,6 +61,8 @@ class ConfigSection(object): |
|
|
|
|
|
|
|
|
def __getitem__(self, key): |
|
|
def __getitem__(self, key): |
|
|
""" |
|
|
""" |
|
|
|
|
|
:param key: str, the name of the attribute |
|
|
|
|
|
:return attr: the value of this attribute |
|
|
if key not in self.__dict__.keys(): |
|
|
if key not in self.__dict__.keys(): |
|
|
return self[key] |
|
|
return self[key] |
|
|
else: |
|
|
else: |
|
|
@@ -68,19 +74,21 @@ class ConfigSection(object): |
|
|
|
|
|
|
|
|
def __setitem__(self, key, value): |
|
|
def __setitem__(self, key, value): |
|
|
""" |
|
|
""" |
|
|
|
|
|
:param key: str, the name of the attribute |
|
|
|
|
|
:param value: the value of this attribute |
|
|
if key not in self.__dict__.keys(): |
|
|
if key not in self.__dict__.keys(): |
|
|
self[key] will be added |
|
|
self[key] will be added |
|
|
else: |
|
|
else: |
|
|
self[key] will be updated |
|
|
self[key] will be updated |
|
|
""" |
|
|
""" |
|
|
if key in self.__dict__.keys(): |
|
|
if key in self.__dict__.keys(): |
|
|
if not type(value) == type(getattr(self, key)): |
|
|
|
|
|
raise AttributeError('attr %s except %s but got %s' % \ |
|
|
|
|
|
|
|
|
if not isinstance(value, type(getattr(self, key))): |
|
|
|
|
|
raise AttributeError("attr %s except %s but got %s" % |
|
|
(key, str(type(getattr(self, key))), str(type(value)))) |
|
|
(key, str(type(getattr(self, key))), str(type(value)))) |
|
|
setattr(self, key, value) |
|
|
setattr(self, key, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__name__": |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
config = ConfigLoader('configLoader', 'there is no data') |
|
|
config = ConfigLoader('configLoader', 'there is no data') |
|
|
|
|
|
|
|
|
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} |
|
|
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} |
|
|
@@ -90,13 +98,8 @@ if __name__ == "__name__": |
|
|
A cannot be found in config file, so nothing will be done |
|
|
A cannot be found in config file, so nothing will be done |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
config.load_config("config", section) |
|
|
|
|
|
|
|
|
config.load_config("../../test/data_for_tests/config", section) |
|
|
for s in section: |
|
|
for s in section: |
|
|
print(s) |
|
|
print(s) |
|
|
for attr in section[s].__dict__.keys(): |
|
|
for attr in section[s].__dict__.keys(): |
|
|
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) |
|
|
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) |
|
|
se = section['General'] |
|
|
|
|
|
print(se["pre_trained"]) |
|
|
|
|
|
se["pre_trained"] = False |
|
|
|
|
|
print(se["pre_trained"]) |
|
|
|
|
|
#se["pre_trained"] = 5 #this will raise AttributeError: attr pre_trained except <class 'bool'> but got <class 'int'> |
|
|
|