@@ -1,5 +1,6 @@ | |||||
import configparser | import configparser | ||||
import json | import json | ||||
import os | |||||
from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
@@ -23,6 +24,8 @@ class ConfigLoader(BaseLoader): | |||||
:return: | :return: | ||||
""" | """ | ||||
cfg = configparser.ConfigParser() | cfg = configparser.ConfigParser() | ||||
if not os.path.exists(file_path): | |||||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||||
cfg.read(file_path) | cfg.read(file_path) | ||||
for s in sections: | for s in sections: | ||||
attr_list = [i for i in sections[s].__dict__.keys() if | attr_list = [i for i in sections[s].__dict__.keys() if | ||||
@@ -34,7 +37,7 @@ class ConfigLoader(BaseLoader): | |||||
for attr in gen_sec.keys(): | for attr in gen_sec.keys(): | ||||
try: | try: | ||||
val = json.loads(gen_sec[attr]) | val = json.loads(gen_sec[attr]) | ||||
#print(s, attr, val, type(val)) | |||||
# print(s, attr, val, type(val)) | |||||
if attr in attr_list: | if attr in attr_list: | ||||
assert type(val) == type(getattr(sections[s], attr)), \ | assert type(val) == type(getattr(sections[s], attr)), \ | ||||
'type not match, except %s but got %s' % \ | 'type not match, except %s but got %s' % \ | ||||
@@ -50,6 +53,7 @@ class ConfigLoader(BaseLoader): | |||||
% (attr, s)) | % (attr, s)) | ||||
pass | pass | ||||
class ConfigSection(object): | class ConfigSection(object): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -57,6 +61,8 @@ class ConfigSection(object): | |||||
def __getitem__(self, key): | def __getitem__(self, key): | ||||
""" | """ | ||||
:param key: str, the name of the attribute | |||||
:return attr: the value of this attribute | |||||
if key not in self.__dict__.keys(): | if key not in self.__dict__.keys(): | ||||
return self[key] | return self[key] | ||||
else: | else: | ||||
@@ -68,19 +74,21 @@ class ConfigSection(object): | |||||
def __setitem__(self, key, value): | def __setitem__(self, key, value): | ||||
""" | """ | ||||
:param key: str, the name of the attribute | |||||
:param value: the value of this attribute | |||||
if key not in self.__dict__.keys(): | if key not in self.__dict__.keys(): | ||||
self[key] will be added | self[key] will be added | ||||
else: | else: | ||||
self[key] will be updated | self[key] will be updated | ||||
""" | """ | ||||
if key in self.__dict__.keys(): | if key in self.__dict__.keys(): | ||||
if not type(value) == type(getattr(self, key)): | |||||
raise AttributeError('attr %s except %s but got %s' % \ | |||||
if not isinstance(value, type(getattr(self, key))): | |||||
raise AttributeError("attr %s except %s but got %s" % | |||||
(key, str(type(getattr(self, key))), str(type(value)))) | (key, str(type(getattr(self, key))), str(type(value)))) | ||||
setattr(self, key, value) | setattr(self, key, value) | ||||
if __name__ == "__name__": | |||||
if __name__ == "__main__": | |||||
config = ConfigLoader('configLoader', 'there is no data') | config = ConfigLoader('configLoader', 'there is no data') | ||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | ||||
@@ -90,13 +98,8 @@ if __name__ == "__name__": | |||||
A cannot be found in config file, so nothing will be done | A cannot be found in config file, so nothing will be done | ||||
""" | """ | ||||
config.load_config("config", section) | |||||
config.load_config("../../test/data_for_tests/config", section) | |||||
for s in section: | for s in section: | ||||
print(s) | print(s) | ||||
for attr in section[s].__dict__.keys(): | for attr in section[s].__dict__.keys(): | ||||
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | ||||
se = section['General'] | |||||
print(se["pre_trained"]) | |||||
se["pre_trained"] = False | |||||
print(se["pre_trained"]) | |||||
#se["pre_trained"] = 5 #this will raise AttributeError: attr pre_trained except <class 'bool'> but got <class 'int'> |
@@ -1,54 +1,60 @@ | |||||
[General] | |||||
revision = "first" | |||||
datapath = "./data/smallset/imdb/" | |||||
embed_path = "./data/smallset/imdb/embedding.txt" | |||||
optimizer = "adam" | |||||
attn_mode = "rout" | |||||
seq_encoder = "bilstm" | |||||
out_caps_num = 5 | |||||
rout_iter = 3 | |||||
max_snt_num = 40 | |||||
max_wd_num = 40 | |||||
max_epochs = 50 | |||||
pre_trained = true | |||||
batch_sz = 32 | |||||
batch_sz_min = 32 | |||||
bucket_sz = 5000 | |||||
partial_update_until_epoch = 2 | |||||
embed_size = 300 | |||||
hidden_size = 200 | |||||
dense_hidden = [300, 10] | |||||
lr = 0.0002 | |||||
decay_steps = 1000 | |||||
decay_rate = 0.9 | |||||
dropout = 0.2 | |||||
early_stopping = 7 | |||||
reg = 1e-06 | |||||
[My] | |||||
datapath = "./data/smallset/imdb/" | |||||
embed_path = "./data/smallset/imdb/embedding.txt" | |||||
optimizer = "adam" | |||||
attn_mode = "rout" | |||||
seq_encoder = "bilstm" | |||||
out_caps_num = 5 | |||||
rout_iter = 3 | |||||
max_snt_num = 40 | |||||
max_wd_num = 40 | |||||
max_epochs = 50 | |||||
pre_trained = true | |||||
batch_sz = 32 | |||||
batch_sz_min = 32 | |||||
bucket_sz = 5000 | |||||
partial_update_until_epoch = 2 | |||||
embed_size = 300 | |||||
hidden_size = 200 | |||||
dense_hidden = [300, 10] | |||||
lr = 0.0002 | |||||
decay_steps = 1000 | |||||
decay_rate = 0.9 | |||||
dropout = 0.2 | |||||
early_stopping = 70 | |||||
reg = 1e-05 | |||||
test = 5 | |||||
new_attr = 40 | |||||
[General] | |||||
revision = "first" | |||||
datapath = "./data/smallset/imdb/" | |||||
embed_path = "./data/smallset/imdb/embedding.txt" | |||||
optimizer = "adam" | |||||
attn_mode = "rout" | |||||
seq_encoder = "bilstm" | |||||
out_caps_num = 5 | |||||
rout_iter = 3 | |||||
max_snt_num = 40 | |||||
max_wd_num = 40 | |||||
max_epochs = 50 | |||||
pre_trained = true | |||||
batch_sz = 32 | |||||
batch_sz_min = 32 | |||||
bucket_sz = 5000 | |||||
partial_update_until_epoch = 2 | |||||
embed_size = 300 | |||||
hidden_size = 200 | |||||
dense_hidden = [300, 10] | |||||
lr = 0.0002 | |||||
decay_steps = 1000 | |||||
decay_rate = 0.9 | |||||
dropout = 0.2 | |||||
early_stopping = 7 | |||||
reg = 1e-06 | |||||
[My] | |||||
datapath = "./data/smallset/imdb/" | |||||
embed_path = "./data/smallset/imdb/embedding.txt" | |||||
optimizer = "adam" | |||||
attn_mode = "rout" | |||||
seq_encoder = "bilstm" | |||||
out_caps_num = 5 | |||||
rout_iter = 3 | |||||
max_snt_num = 40 | |||||
max_wd_num = 40 | |||||
max_epochs = 50 | |||||
pre_trained = true | |||||
batch_sz = 32 | |||||
batch_sz_min = 32 | |||||
bucket_sz = 5000 | |||||
partial_update_until_epoch = 2 | |||||
embed_size = 300 | |||||
hidden_size = 200 | |||||
dense_hidden = [300, 10] | |||||
lr = 0.0002 | |||||
decay_steps = 1000 | |||||
decay_rate = 0.9 | |||||
dropout = 0.2 | |||||
early_stopping = 70 | |||||
reg = 1e-05 | |||||
test = 5 | |||||
new_attr = 40 | |||||
[POS] | |||||
epochs = 20 | |||||
batch_size = 1 | |||||
pickle_path = "./data_for_tests/" | |||||
validate = true |
@@ -2,6 +2,7 @@ import sys | |||||
sys.path.append("..") | sys.path.append("..") | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.action.trainer import POSTrainer | from fastNLP.action.trainer import POSTrainer | ||||
from fastNLP.loader.dataset_loader import POSDatasetLoader | from fastNLP.loader.dataset_loader import POSDatasetLoader | ||||
from fastNLP.loader.preprocess import POSPreprocess | from fastNLP.loader.preprocess import POSPreprocess | ||||
@@ -12,6 +13,9 @@ data_path = "data_for_tests/people.txt" | |||||
pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_args = ConfigSection() | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
# Data Loader | # Data Loader | ||||
pos = POSDatasetLoader(data_name, data_path) | pos = POSDatasetLoader(data_name, data_path) | ||||
train_data = pos.load_lines() | train_data = pos.load_lines() | ||||
@@ -21,9 +25,9 @@ if __name__ == "__main__": | |||||
vocab_size = p.vocab_size | vocab_size = p.vocab_size | ||||
num_classes = p.num_classes | num_classes = p.num_classes | ||||
# Trainer | |||||
train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes, | |||||
"vocab_size": vocab_size, "pickle_path": pickle_path, "validate": True} | |||||
train_args["vocab_size"] = vocab_size | |||||
train_args["num_classes"] = num_classes | |||||
trainer = POSTrainer(train_args) | trainer = POSTrainer(train_args) | ||||
# Model | # Model | ||||