| @@ -1,5 +1,6 @@ | |||
| import configparser | |||
| import json | |||
| import os | |||
| from fastNLP.loader.base_loader import BaseLoader | |||
| @@ -23,6 +24,8 @@ class ConfigLoader(BaseLoader): | |||
| :return: | |||
| """ | |||
| cfg = configparser.ConfigParser() | |||
| if not os.path.exists(file_path): | |||
| raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||
| cfg.read(file_path) | |||
| for s in sections: | |||
| attr_list = [i for i in sections[s].__dict__.keys() if | |||
| @@ -34,7 +37,7 @@ class ConfigLoader(BaseLoader): | |||
| for attr in gen_sec.keys(): | |||
| try: | |||
| val = json.loads(gen_sec[attr]) | |||
| #print(s, attr, val, type(val)) | |||
| # print(s, attr, val, type(val)) | |||
| if attr in attr_list: | |||
| assert type(val) == type(getattr(sections[s], attr)), \ | |||
| 'type not match, except %s but got %s' % \ | |||
| @@ -50,6 +53,7 @@ class ConfigLoader(BaseLoader): | |||
| % (attr, s)) | |||
| pass | |||
| class ConfigSection(object): | |||
| def __init__(self): | |||
| @@ -57,6 +61,8 @@ class ConfigSection(object): | |||
| def __getitem__(self, key): | |||
| """ | |||
| :param key: str, the name of the attribute | |||
| :return attr: the value of this attribute | |||
| if key not in self.__dict__.keys(): | |||
| return self[key] | |||
| else: | |||
| @@ -68,19 +74,21 @@ class ConfigSection(object): | |||
| def __setitem__(self, key, value): | |||
| """ | |||
| :param key: str, the name of the attribute | |||
| :param value: the value of this attribute | |||
| if key not in self.__dict__.keys(): | |||
| self[key] will be added | |||
| else: | |||
| self[key] will be updated | |||
| """ | |||
| if key in self.__dict__.keys(): | |||
| if not type(value) == type(getattr(self, key)): | |||
| raise AttributeError('attr %s except %s but got %s' % \ | |||
| if not isinstance(value, type(getattr(self, key))): | |||
| raise AttributeError("attr %s except %s but got %s" % | |||
| (key, str(type(getattr(self, key))), str(type(value)))) | |||
| setattr(self, key, value) | |||
| if __name__ == "__name__": | |||
| if __name__ == "__main__": | |||
| config = ConfigLoader('configLoader', 'there is no data') | |||
| section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | |||
| @@ -90,13 +98,8 @@ if __name__ == "__name__": | |||
| A cannot be found in config file, so nothing will be done | |||
| """ | |||
| config.load_config("config", section) | |||
| config.load_config("../../test/data_for_tests/config", section) | |||
| for s in section: | |||
| print(s) | |||
| for attr in section[s].__dict__.keys(): | |||
| print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) | |||
| se = section['General'] | |||
| print(se["pre_trained"]) | |||
| se["pre_trained"] = False | |||
| print(se["pre_trained"]) | |||
| #se["pre_trained"] = 5 #this will raise AttributeError: attr pre_trained except <class 'bool'> but got <class 'int'> | |||
| @@ -1,54 +1,60 @@ | |||
| [General] | |||
| revision = "first" | |||
| datapath = "./data/smallset/imdb/" | |||
| embed_path = "./data/smallset/imdb/embedding.txt" | |||
| optimizer = "adam" | |||
| attn_mode = "rout" | |||
| seq_encoder = "bilstm" | |||
| out_caps_num = 5 | |||
| rout_iter = 3 | |||
| max_snt_num = 40 | |||
| max_wd_num = 40 | |||
| max_epochs = 50 | |||
| pre_trained = true | |||
| batch_sz = 32 | |||
| batch_sz_min = 32 | |||
| bucket_sz = 5000 | |||
| partial_update_until_epoch = 2 | |||
| embed_size = 300 | |||
| hidden_size = 200 | |||
| dense_hidden = [300, 10] | |||
| lr = 0.0002 | |||
| decay_steps = 1000 | |||
| decay_rate = 0.9 | |||
| dropout = 0.2 | |||
| early_stopping = 7 | |||
| reg = 1e-06 | |||
| [My] | |||
| datapath = "./data/smallset/imdb/" | |||
| embed_path = "./data/smallset/imdb/embedding.txt" | |||
| optimizer = "adam" | |||
| attn_mode = "rout" | |||
| seq_encoder = "bilstm" | |||
| out_caps_num = 5 | |||
| rout_iter = 3 | |||
| max_snt_num = 40 | |||
| max_wd_num = 40 | |||
| max_epochs = 50 | |||
| pre_trained = true | |||
| batch_sz = 32 | |||
| batch_sz_min = 32 | |||
| bucket_sz = 5000 | |||
| partial_update_until_epoch = 2 | |||
| embed_size = 300 | |||
| hidden_size = 200 | |||
| dense_hidden = [300, 10] | |||
| lr = 0.0002 | |||
| decay_steps = 1000 | |||
| decay_rate = 0.9 | |||
| dropout = 0.2 | |||
| early_stopping = 70 | |||
| reg = 1e-05 | |||
| test = 5 | |||
| new_attr = 40 | |||
| [General] | |||
| revision = "first" | |||
| datapath = "./data/smallset/imdb/" | |||
| embed_path = "./data/smallset/imdb/embedding.txt" | |||
| optimizer = "adam" | |||
| attn_mode = "rout" | |||
| seq_encoder = "bilstm" | |||
| out_caps_num = 5 | |||
| rout_iter = 3 | |||
| max_snt_num = 40 | |||
| max_wd_num = 40 | |||
| max_epochs = 50 | |||
| pre_trained = true | |||
| batch_sz = 32 | |||
| batch_sz_min = 32 | |||
| bucket_sz = 5000 | |||
| partial_update_until_epoch = 2 | |||
| embed_size = 300 | |||
| hidden_size = 200 | |||
| dense_hidden = [300, 10] | |||
| lr = 0.0002 | |||
| decay_steps = 1000 | |||
| decay_rate = 0.9 | |||
| dropout = 0.2 | |||
| early_stopping = 7 | |||
| reg = 1e-06 | |||
| [My] | |||
| datapath = "./data/smallset/imdb/" | |||
| embed_path = "./data/smallset/imdb/embedding.txt" | |||
| optimizer = "adam" | |||
| attn_mode = "rout" | |||
| seq_encoder = "bilstm" | |||
| out_caps_num = 5 | |||
| rout_iter = 3 | |||
| max_snt_num = 40 | |||
| max_wd_num = 40 | |||
| max_epochs = 50 | |||
| pre_trained = true | |||
| batch_sz = 32 | |||
| batch_sz_min = 32 | |||
| bucket_sz = 5000 | |||
| partial_update_until_epoch = 2 | |||
| embed_size = 300 | |||
| hidden_size = 200 | |||
| dense_hidden = [300, 10] | |||
| lr = 0.0002 | |||
| decay_steps = 1000 | |||
| decay_rate = 0.9 | |||
| dropout = 0.2 | |||
| early_stopping = 70 | |||
| reg = 1e-05 | |||
| test = 5 | |||
| new_attr = 40 | |||
| [POS] | |||
| epochs = 20 | |||
| batch_size = 1 | |||
| pickle_path = "./data_for_tests/" | |||
| validate = true | |||
| @@ -2,6 +2,7 @@ import sys | |||
| sys.path.append("..") | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.action.trainer import POSTrainer | |||
| from fastNLP.loader.dataset_loader import POSDatasetLoader | |||
| from fastNLP.loader.preprocess import POSPreprocess | |||
| @@ -12,6 +13,9 @@ data_path = "data_for_tests/people.txt" | |||
| pickle_path = "data_for_tests" | |||
| if __name__ == "__main__": | |||
| train_args = ConfigSection() | |||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||
| # Data Loader | |||
| pos = POSDatasetLoader(data_name, data_path) | |||
| train_data = pos.load_lines() | |||
| @@ -21,9 +25,9 @@ if __name__ == "__main__": | |||
| vocab_size = p.vocab_size | |||
| num_classes = p.num_classes | |||
| # Trainer | |||
| train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes, | |||
| "vocab_size": vocab_size, "pickle_path": pickle_path, "validate": True} | |||
| train_args["vocab_size"] = vocab_size | |||
| train_args["num_classes"] = num_classes | |||
| trainer = POSTrainer(train_args) | |||
| # Model | |||