Browse Source

update configLoader to load hyper-parameters from file

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
621b79ee19
3 changed files with 80 additions and 67 deletions
  1. +13
    -10
      fastNLP/loader/config_loader.py
  2. +60
    -54
      test/data_for_tests/config
  3. +7
    -3
      test/test_POS_pipeline.py

+ 13
- 10
fastNLP/loader/config_loader.py View File

@@ -1,5 +1,6 @@
import configparser
import json
import os

from fastNLP.loader.base_loader import BaseLoader

@@ -23,6 +24,8 @@ class ConfigLoader(BaseLoader):
:return:
"""
cfg = configparser.ConfigParser()
if not os.path.exists(file_path):
raise FileNotFoundError("config file {} not found. ".format(file_path))
cfg.read(file_path)
for s in sections:
attr_list = [i for i in sections[s].__dict__.keys() if
@@ -34,7 +37,7 @@ class ConfigLoader(BaseLoader):
for attr in gen_sec.keys():
try:
val = json.loads(gen_sec[attr])
#print(s, attr, val, type(val))
# print(s, attr, val, type(val))
if attr in attr_list:
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
@@ -50,6 +53,7 @@ class ConfigLoader(BaseLoader):
% (attr, s))
pass


class ConfigSection(object):

def __init__(self):
@@ -57,6 +61,8 @@ class ConfigSection(object):

def __getitem__(self, key):
"""
:param key: str, the name of the attribute
:return attr: the value of this attribute
if key not in self.__dict__.keys():
return self[key]
else:
@@ -68,19 +74,21 @@ class ConfigSection(object):

def __setitem__(self, key, value):
"""
:param key: str, the name of the attribute
:param value: the value of this attribute
if key not in self.__dict__.keys():
self[key] will be added
else:
self[key] will be updated
"""
if key in self.__dict__.keys():
if not type(value) == type(getattr(self, key)):
raise AttributeError('attr %s except %s but got %s' % \
if not isinstance(value, type(getattr(self, key))):
raise AttributeError("attr %s except %s but got %s" %
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)


if __name__ == "__name__":
if __name__ == "__main__":
config = ConfigLoader('configLoader', 'there is no data')

section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
@@ -90,13 +98,8 @@ if __name__ == "__name__":
A cannot be found in config file, so nothing will be done
"""

config.load_config("config", section)
config.load_config("../../test/data_for_tests/config", section)
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))
se = section['General']
print(se["pre_trained"])
se["pre_trained"] = False
print(se["pre_trained"])
#se["pre_trained"] = 5 #this will raise AttributeError: attr pre_trained except <class 'bool'> but got <class 'int'>

fastNLP/loader/config → test/data_for_tests/config View File

@@ -1,54 +1,60 @@
[General]
revision = "first"
datapath = "./data/smallset/imdb/"
embed_path = "./data/smallset/imdb/embedding.txt"
optimizer = "adam"
attn_mode = "rout"
seq_encoder = "bilstm"
out_caps_num = 5
rout_iter = 3
max_snt_num = 40
max_wd_num = 40
max_epochs = 50
pre_trained = true
batch_sz = 32
batch_sz_min = 32
bucket_sz = 5000
partial_update_until_epoch = 2
embed_size = 300
hidden_size = 200
dense_hidden = [300, 10]
lr = 0.0002
decay_steps = 1000
decay_rate = 0.9
dropout = 0.2
early_stopping = 7
reg = 1e-06
[My]
datapath = "./data/smallset/imdb/"
embed_path = "./data/smallset/imdb/embedding.txt"
optimizer = "adam"
attn_mode = "rout"
seq_encoder = "bilstm"
out_caps_num = 5
rout_iter = 3
max_snt_num = 40
max_wd_num = 40
max_epochs = 50
pre_trained = true
batch_sz = 32
batch_sz_min = 32
bucket_sz = 5000
partial_update_until_epoch = 2
embed_size = 300
hidden_size = 200
dense_hidden = [300, 10]
lr = 0.0002
decay_steps = 1000
decay_rate = 0.9
dropout = 0.2
early_stopping = 70
reg = 1e-05
test = 5
new_attr = 40
[General]
revision = "first"
datapath = "./data/smallset/imdb/"
embed_path = "./data/smallset/imdb/embedding.txt"
optimizer = "adam"
attn_mode = "rout"
seq_encoder = "bilstm"
out_caps_num = 5
rout_iter = 3
max_snt_num = 40
max_wd_num = 40
max_epochs = 50
pre_trained = true
batch_sz = 32
batch_sz_min = 32
bucket_sz = 5000
partial_update_until_epoch = 2
embed_size = 300
hidden_size = 200
dense_hidden = [300, 10]
lr = 0.0002
decay_steps = 1000
decay_rate = 0.9
dropout = 0.2
early_stopping = 7
reg = 1e-06

[My]
datapath = "./data/smallset/imdb/"
embed_path = "./data/smallset/imdb/embedding.txt"
optimizer = "adam"
attn_mode = "rout"
seq_encoder = "bilstm"
out_caps_num = 5
rout_iter = 3
max_snt_num = 40
max_wd_num = 40
max_epochs = 50
pre_trained = true
batch_sz = 32
batch_sz_min = 32
bucket_sz = 5000
partial_update_until_epoch = 2
embed_size = 300
hidden_size = 200
dense_hidden = [300, 10]
lr = 0.0002
decay_steps = 1000
decay_rate = 0.9
dropout = 0.2
early_stopping = 70
reg = 1e-05
test = 5
new_attr = 40

[POS]
epochs = 20
batch_size = 1
pickle_path = "./data_for_tests/"
validate = true

+ 7
- 3
test/test_POS_pipeline.py View File

@@ -2,6 +2,7 @@ import sys

sys.path.append("..")

from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.action.trainer import POSTrainer
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.preprocess import POSPreprocess
@@ -12,6 +13,9 @@ data_path = "data_for_tests/people.txt"
pickle_path = "data_for_tests"

if __name__ == "__main__":
train_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})

# Data Loader
pos = POSDatasetLoader(data_name, data_path)
train_data = pos.load_lines()
@@ -21,9 +25,9 @@ if __name__ == "__main__":
vocab_size = p.vocab_size
num_classes = p.num_classes

# Trainer
train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes,
"vocab_size": vocab_size, "pickle_path": pickle_path, "validate": True}
train_args["vocab_size"] = vocab_size
train_args["num_classes"] = num_classes
trainer = POSTrainer(train_args)

# Model


Loading…
Cancel
Save