Browse Source

Merge remote-tracking branch 'origin/master'

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
08b0edaebb
1 changed files with 35 additions and 6 deletions
  1. +35
    -6
      fastNLP/loader/config_loader.py

+ 35
- 6
fastNLP/loader/config_loader.py View File

@@ -50,14 +50,38 @@ class ConfigLoader(BaseLoader):
% (attr, s))
pass

if __name__ == "__name__":
config = ConfigLoader('configLoader', 'there is no data')
class ConfigSection(object):

def __init__(self):
pass

def __getitem__(self, key):
"""
if key not in self.__dict__.keys():
return self[key]
else:
raise AttributeError
"""
if key in self.__dict__.keys():
return getattr(self, key)
raise AttributeError('don\'t have attr %s' % (key))

def __setitem__(self, key, value):
"""
if key not in self.__dict__.keys():
self[key] will be added
else:
self[key] will be updated
"""
if key in self.__dict__.keys():
if not type(value) == type(getattr(self, key)):
raise AttributeError('attr %s except %s but got %s' % \
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)

class ConfigSection(object):
def __init__(self):
pass

if __name__ == "__name__":
config = ConfigLoader('configLoader', 'there is no data')

section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
"""
@@ -70,4 +94,9 @@ if __name__ == "__name__":
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr))
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))
se = section['General']
print(se["pre_trained"])
se["pre_trained"] = False
print(se["pre_trained"])
#se["pre_trained"] = 5 #this will raise AttributeError: attr pre_trained except <class 'bool'> but got <class 'int'>

Loading…
Cancel
Save