From 621b79ee1962bb0c1a3350e7f6f09dbbd5ffdfd7 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 16 Jul 2018 19:53:36 +0800 Subject: [PATCH] update configLoader to load hyper-parameters from file --- fastNLP/loader/config_loader.py | 23 ++-- .../loader => test/data_for_tests}/config | 114 +++++++++--------- test/test_POS_pipeline.py | 10 +- 3 files changed, 80 insertions(+), 67 deletions(-) rename {fastNLP/loader => test/data_for_tests}/config (92%) diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index 29264fb4..9f252821 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -1,5 +1,6 @@ import configparser import json +import os from fastNLP.loader.base_loader import BaseLoader @@ -23,6 +24,8 @@ class ConfigLoader(BaseLoader): :return: """ cfg = configparser.ConfigParser() + if not os.path.exists(file_path): + raise FileNotFoundError("config file {} not found. ".format(file_path)) cfg.read(file_path) for s in sections: attr_list = [i for i in sections[s].__dict__.keys() if @@ -34,7 +37,7 @@ class ConfigLoader(BaseLoader): for attr in gen_sec.keys(): try: val = json.loads(gen_sec[attr]) - #print(s, attr, val, type(val)) + # print(s, attr, val, type(val)) if attr in attr_list: assert type(val) == type(getattr(sections[s], attr)), \ 'type not match, except %s but got %s' % \ @@ -50,6 +53,7 @@ class ConfigLoader(BaseLoader): % (attr, s)) pass + class ConfigSection(object): def __init__(self): @@ -57,6 +61,8 @@ class ConfigSection(object): def __getitem__(self, key): """ + :param key: str, the name of the attribute + :return attr: the value of this attribute if key not in self.__dict__.keys(): return self[key] else: @@ -68,19 +74,21 @@ class ConfigSection(object): def __setitem__(self, key, value): """ + :param key: str, the name of the attribute + :param value: the value of this attribute if key not in self.__dict__.keys(): self[key] will be added else: self[key] will be updated """ if key in self.__dict__.keys(): - if not type(value) == type(getattr(self, key)): - raise AttributeError('attr %s except %s but got %s' % \ + if not isinstance(value, type(getattr(self, key))): + raise AttributeError("attr %s except %s but got %s" % (key, str(type(getattr(self, key))), str(type(value)))) setattr(self, key, value) -if __name__ == "__name__": +if __name__ == "__main__": config = ConfigLoader('configLoader', 'there is no data') section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} @@ -90,13 +98,8 @@ if __name__ == "__name__": A cannot be found in config file, so nothing will be done """ - config.load_config("config", section) + config.load_config("../../test/data_for_tests/config", section) for s in section: print(s) for attr in section[s].__dict__.keys(): print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr))) - se = section['General'] - print(se["pre_trained"]) - se["pre_trained"] = False - print(se["pre_trained"]) - #se["pre_trained"] = 5 #this will raise AttributeError: attr pre_trained except but got \ No newline at end of file diff --git a/fastNLP/loader/config b/test/data_for_tests/config similarity index 92% rename from fastNLP/loader/config rename to test/data_for_tests/config index 5eb57db5..b2d000a6 100644 --- a/fastNLP/loader/config +++ b/test/data_for_tests/config @@ -1,54 +1,60 @@ -[General] -revision = "first" -datapath = "./data/smallset/imdb/" -embed_path = "./data/smallset/imdb/embedding.txt" -optimizer = "adam" -attn_mode = "rout" -seq_encoder = "bilstm" -out_caps_num = 5 -rout_iter = 3 -max_snt_num = 40 -max_wd_num = 40 -max_epochs = 50 -pre_trained = true -batch_sz = 32 -batch_sz_min = 32 -bucket_sz = 5000 -partial_update_until_epoch = 2 -embed_size = 300 -hidden_size = 200 -dense_hidden = [300, 10] -lr = 0.0002 -decay_steps = 1000 -decay_rate = 0.9 -dropout = 0.2 -early_stopping = 7 -reg = 1e-06 - -[My] -datapath = "./data/smallset/imdb/" -embed_path = "./data/smallset/imdb/embedding.txt" -optimizer = "adam" -attn_mode = "rout" -seq_encoder = "bilstm" -out_caps_num = 5 -rout_iter = 3 -max_snt_num = 40 -max_wd_num = 40 -max_epochs = 50 -pre_trained = true -batch_sz = 32 -batch_sz_min = 32 -bucket_sz = 5000 -partial_update_until_epoch = 2 -embed_size = 300 -hidden_size = 200 -dense_hidden = [300, 10] -lr = 0.0002 -decay_steps = 1000 -decay_rate = 0.9 -dropout = 0.2 -early_stopping = 70 -reg = 1e-05 -test = 5 -new_attr = 40 +[General] +revision = "first" +datapath = "./data/smallset/imdb/" +embed_path = "./data/smallset/imdb/embedding.txt" +optimizer = "adam" +attn_mode = "rout" +seq_encoder = "bilstm" +out_caps_num = 5 +rout_iter = 3 +max_snt_num = 40 +max_wd_num = 40 +max_epochs = 50 +pre_trained = true +batch_sz = 32 +batch_sz_min = 32 +bucket_sz = 5000 +partial_update_until_epoch = 2 +embed_size = 300 +hidden_size = 200 +dense_hidden = [300, 10] +lr = 0.0002 +decay_steps = 1000 +decay_rate = 0.9 +dropout = 0.2 +early_stopping = 7 +reg = 1e-06 + +[My] +datapath = "./data/smallset/imdb/" +embed_path = "./data/smallset/imdb/embedding.txt" +optimizer = "adam" +attn_mode = "rout" +seq_encoder = "bilstm" +out_caps_num = 5 +rout_iter = 3 +max_snt_num = 40 +max_wd_num = 40 +max_epochs = 50 +pre_trained = true +batch_sz = 32 +batch_sz_min = 32 +bucket_sz = 5000 +partial_update_until_epoch = 2 +embed_size = 300 +hidden_size = 200 +dense_hidden = [300, 10] +lr = 0.0002 +decay_steps = 1000 +decay_rate = 0.9 +dropout = 0.2 +early_stopping = 70 +reg = 1e-05 +test = 5 +new_attr = 40 + +[POS] +epochs = 20 +batch_size = 1 +pickle_path = "./data_for_tests/" +validate = true \ No newline at end of file diff --git a/test/test_POS_pipeline.py b/test/test_POS_pipeline.py index af22e3b9..46a80170 100644 --- a/test/test_POS_pipeline.py +++ b/test/test_POS_pipeline.py @@ -2,6 +2,7 @@ import sys sys.path.append("..") +from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.action.trainer import POSTrainer from fastNLP.loader.dataset_loader import POSDatasetLoader from fastNLP.loader.preprocess import POSPreprocess @@ -12,6 +13,9 @@ data_path = "data_for_tests/people.txt" pickle_path = "data_for_tests" if __name__ == "__main__": + train_args = ConfigSection() + ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) + # Data Loader pos = POSDatasetLoader(data_name, data_path) train_data = pos.load_lines() @@ -21,9 +25,9 @@ if __name__ == "__main__": vocab_size = p.vocab_size num_classes = p.num_classes - # Trainer - train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes, - "vocab_size": vocab_size, "pickle_path": pickle_path, "validate": True} + train_args["vocab_size"] = vocab_size + train_args["num_classes"] = num_classes + trainer = POSTrainer(train_args) # Model