@@ -1,5 +1,6 @@ | |||
import configparser | |||
import json | |||
import os | |||
from fastNLP.loader.base_loader import BaseLoader | |||
@@ -23,6 +24,8 @@ class ConfigLoader(BaseLoader): | |||
:return: | |||
""" | |||
cfg = configparser.ConfigParser() | |||
if not os.path.exists(file_path): | |||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||
cfg.read(file_path) | |||
for s in sections: | |||
attr_list = [i for i in sections[s].__dict__.keys() if | |||
@@ -34,7 +37,7 @@ class ConfigLoader(BaseLoader): | |||
for attr in gen_sec.keys(): | |||
try: | |||
val = json.loads(gen_sec[attr]) | |||
#print(s, attr, val, type(val)) | |||
# print(s, attr, val, type(val)) | |||
if attr in attr_list: | |||
assert type(val) == type(getattr(sections[s], attr)), \ | |||
'type not match, except %s but got %s' % \ | |||
@@ -50,6 +53,7 @@ class ConfigLoader(BaseLoader): | |||
% (attr, s)) | |||
pass | |||
class ConfigSection(object): | |||
def __init__(self): | |||
@@ -57,6 +61,8 @@ class ConfigSection(object): | |||
def __getitem__(self, key): | |||
""" | |||
:param key: str, the name of the attribute | |||
:return attr: the value of this attribute | |||
if key not in self.__dict__.keys(): | |||
return self[key] | |||
else: | |||
@@ -68,19 +74,21 @@ class ConfigSection(object): | |||
def __setitem__(self, key, value): | |||
""" | |||
:param key: str, the name of the attribute | |||
:param value: the value of this attribute | |||
if key not in self.__dict__.keys(): | |||
self[key] will be added | |||
else: | |||
self[key] will be updated | |||
""" | |||
if key in self.__dict__.keys(): | |||
if not type(value) == type(getattr(self, key)): | |||
raise AttributeError('attr %s except %s but got %s' % \ | |||
if not isinstance(value, type(getattr(self, key))): | |||
raise AttributeError("attr %s except %s but got %s" % | |||
(key, str(type(getattr(self, key))), str(type(value)))) | |||
setattr(self, key, value) | |||
if __name__ == "__name__": | |||
if __name__ == "__main__": | |||
config = ConfigLoader('configLoader', 'there is no data') | |||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||
@@ -90,13 +98,8 @@ if __name__ == "__name__": | |||
A cannot be found in config file, so nothing will be done | |||
""" | |||
config.load_config("config", section) | |||
config.load_config("../../test/data_for_tests/config", section) | |||
for s in section: | |||
print(s) | |||
for attr in section[s].__dict__.keys(): | |||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | |||
se = section['General'] | |||
print(se["pre_trained"]) | |||
se["pre_trained"] = False | |||
print(se["pre_trained"]) | |||
#se["pre_trained"] = 5 #this will raise AttributeError: attr pre_trained except <class 'bool'> but got <class 'int'> |
@@ -1,54 +1,60 @@ | |||
[General] | |||
revision = "first" | |||
datapath = "./data/smallset/imdb/" | |||
embed_path = "./data/smallset/imdb/embedding.txt" | |||
optimizer = "adam" | |||
attn_mode = "rout" | |||
seq_encoder = "bilstm" | |||
out_caps_num = 5 | |||
rout_iter = 3 | |||
max_snt_num = 40 | |||
max_wd_num = 40 | |||
max_epochs = 50 | |||
pre_trained = true | |||
batch_sz = 32 | |||
batch_sz_min = 32 | |||
bucket_sz = 5000 | |||
partial_update_until_epoch = 2 | |||
embed_size = 300 | |||
hidden_size = 200 | |||
dense_hidden = [300, 10] | |||
lr = 0.0002 | |||
decay_steps = 1000 | |||
decay_rate = 0.9 | |||
dropout = 0.2 | |||
early_stopping = 7 | |||
reg = 1e-06 | |||
[My] | |||
datapath = "./data/smallset/imdb/" | |||
embed_path = "./data/smallset/imdb/embedding.txt" | |||
optimizer = "adam" | |||
attn_mode = "rout" | |||
seq_encoder = "bilstm" | |||
out_caps_num = 5 | |||
rout_iter = 3 | |||
max_snt_num = 40 | |||
max_wd_num = 40 | |||
max_epochs = 50 | |||
pre_trained = true | |||
batch_sz = 32 | |||
batch_sz_min = 32 | |||
bucket_sz = 5000 | |||
partial_update_until_epoch = 2 | |||
embed_size = 300 | |||
hidden_size = 200 | |||
dense_hidden = [300, 10] | |||
lr = 0.0002 | |||
decay_steps = 1000 | |||
decay_rate = 0.9 | |||
dropout = 0.2 | |||
early_stopping = 70 | |||
reg = 1e-05 | |||
test = 5 | |||
new_attr = 40 | |||
[General] | |||
revision = "first" | |||
datapath = "./data/smallset/imdb/" | |||
embed_path = "./data/smallset/imdb/embedding.txt" | |||
optimizer = "adam" | |||
attn_mode = "rout" | |||
seq_encoder = "bilstm" | |||
out_caps_num = 5 | |||
rout_iter = 3 | |||
max_snt_num = 40 | |||
max_wd_num = 40 | |||
max_epochs = 50 | |||
pre_trained = true | |||
batch_sz = 32 | |||
batch_sz_min = 32 | |||
bucket_sz = 5000 | |||
partial_update_until_epoch = 2 | |||
embed_size = 300 | |||
hidden_size = 200 | |||
dense_hidden = [300, 10] | |||
lr = 0.0002 | |||
decay_steps = 1000 | |||
decay_rate = 0.9 | |||
dropout = 0.2 | |||
early_stopping = 7 | |||
reg = 1e-06 | |||
[My] | |||
datapath = "./data/smallset/imdb/" | |||
embed_path = "./data/smallset/imdb/embedding.txt" | |||
optimizer = "adam" | |||
attn_mode = "rout" | |||
seq_encoder = "bilstm" | |||
out_caps_num = 5 | |||
rout_iter = 3 | |||
max_snt_num = 40 | |||
max_wd_num = 40 | |||
max_epochs = 50 | |||
pre_trained = true | |||
batch_sz = 32 | |||
batch_sz_min = 32 | |||
bucket_sz = 5000 | |||
partial_update_until_epoch = 2 | |||
embed_size = 300 | |||
hidden_size = 200 | |||
dense_hidden = [300, 10] | |||
lr = 0.0002 | |||
decay_steps = 1000 | |||
decay_rate = 0.9 | |||
dropout = 0.2 | |||
early_stopping = 70 | |||
reg = 1e-05 | |||
test = 5 | |||
new_attr = 40 | |||
[POS] | |||
epochs = 20 | |||
batch_size = 1 | |||
pickle_path = "./data_for_tests/" | |||
validate = true |
@@ -2,6 +2,7 @@ import sys | |||
sys.path.append("..") | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.action.trainer import POSTrainer | |||
from fastNLP.loader.dataset_loader import POSDatasetLoader | |||
from fastNLP.loader.preprocess import POSPreprocess | |||
@@ -12,6 +13,9 @@ data_path = "data_for_tests/people.txt" | |||
pickle_path = "data_for_tests" | |||
if __name__ == "__main__": | |||
train_args = ConfigSection() | |||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||
# Data Loader | |||
pos = POSDatasetLoader(data_name, data_path) | |||
train_data = pos.load_lines() | |||
@@ -21,9 +25,9 @@ if __name__ == "__main__": | |||
vocab_size = p.vocab_size | |||
num_classes = p.num_classes | |||
# Trainer | |||
train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes, | |||
"vocab_size": vocab_size, "pickle_path": pickle_path, "validate": True} | |||
train_args["vocab_size"] = vocab_size | |||
train_args["num_classes"] = num_classes | |||
trainer = POSTrainer(train_args) | |||
# Model | |||