Browse Source

[ENH] run hed example successfully after change ABLModel output

pull/3/head
Gao Enhao 2 years ago
parent
commit
e389b6427e
3 changed files with 203 additions and 79 deletions
  1. +2
    -4
      abl/learning/abl_model.py
  2. +195
    -69
      examples/hed/framework_hed.py
  3. +6
    -6
      examples/hed/hed_example.ipynb

+ 2
- 4
abl/learning/abl_model.py View File

@@ -87,8 +87,7 @@ class ABLModel:
The accuracy score for the given data.
"""
data_X, _ = self.merge_data(X)
_data_Y, _ = self.merge_data(Y)
data_Y = list(map(lambda y: self.mapping[y], _data_Y))
data_Y, _ = self.merge_data(Y)
score = self.classifier_list[0].score(X=data_X, y=data_Y)
return score

@@ -104,8 +103,7 @@ class ABLModel:
The true labels for the given data.
"""
data_X, _ = self.merge_data(X)
_data_Y, _ = self.merge_data(Y)
data_Y = list(map(lambda y: self.mapping[y], _data_Y))
data_Y, _ = self.merge_data(Y)
self.classifier_list[0].fit(X=data_X, y=data_Y)

@staticmethod


+ 195
- 69
examples/hed/framework_hed.py View File

@@ -21,6 +21,7 @@ from abl.learning.basic_nn import BasicNN, BasicDataset

from utils import gen_mappings, mapping_res, remapping_res
from models.nn import SymbolNetAutoencoder
from torch.utils.data import RandomSampler
from datasets.get_hed import get_pretrain_data


@@ -29,85 +30,170 @@ def hed_pretrain(kb, cls, recorder):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if not os.path.exists("./weights/pretrain_weights.pth"):
INFO("Pretrain Start")
pretrain_data_X, pretrain_data_Y = get_pretrain_data(['0', '1', '10', '11'])
pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"])
pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y)
pretrain_data_loader = torch.utils.data.DataLoader(pretrain_data, batch_size=64, shuffle=True)
pretrain_data_loader = torch.utils.data.DataLoader(
pretrain_data, batch_size=64, shuffle=True
)

criterion = nn.MSELoss()
optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6)
optimizer = torch.optim.RMSprop(
cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6
)

pretrain_model = BasicNN(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder)
pretrain_model = BasicNN(
cls_autoencoder,
criterion,
optimizer,
device,
save_interval=1,
save_dir=recorder.save_dir,
num_epochs=10,
recorder=recorder,
)
pretrain_model.fit(pretrain_data_loader)
torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth")
torch.save(
cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth"
)
cls.load_state_dict(cls_autoencoder.base_model.state_dict())
else:
cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))


def _get_char_acc(model, X, consistent_pred_res, mapping):
original_pred_res = model.predict(X)['cls']
original_pred_res = model.predict(X)["label"]
pred_res = flatten(mapping_res(original_pred_res, mapping))
INFO('Current model\'s output: ', pred_res)
INFO('Abduced labels: ', flatten(consistent_pred_res))
INFO("Current model's output: ", pred_res)
INFO("Abduced labels: ", flatten(consistent_pred_res))
assert len(pred_res) == len(flatten(consistent_pred_res))
return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res)
return sum(
[
pred_res[idx] == flatten(consistent_pred_res)[idx]
for idx in range(len(pred_res))
]
) / len(pred_res)


def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
select_idx = np.random.randint(len(train_X_true), size=select_num)
X = []
for idx in select_idx:
X.append(train_X_true[idx])
select_idx = RandomSampler(train_X_true, num_samples=select_num,replacement=False)
X = [train_X_true[idx] for idx in select_idx]

# original_pred_res = model.predict(X)['label']
pred_label = model.predict(X)["label"]

original_pred_res = model.predict(X)['cls']
if mapping == None:
mappings = gen_mappings(['+', '=', 0, 1],['+', '=', 0, 1])
mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1])
else:
mappings = [mapping]
consistent_idx = []
consistent_pred_res = []
for m in mappings:
pred_res = mapping_res(original_pred_res, m)
max_abduce_num = 20
solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num)
all_address_flag = reform_idx(solution, pred_res)
pred_pseudo_label = mapping_res(pred_label, m)
max_revision_num = 20
solution = abducer.zoopt_get_solution(
pred_label,
pred_pseudo_label,
[None] * len(pred_label),
[None] * len(pred_label),
max_revision_num,
)
all_address_flag = reform_idx(solution, pred_label)

consistent_idx_tmp = []
consistent_pred_res_tmp = []
for idx in range(len(pred_res)):
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx)

for idx in range(len(pred_label)):
address_idx = [
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0
]
candidate = abducer.revise_by_idx([pred_pseudo_label[idx]], None, address_idx)
if len(candidate) > 0:
consistent_idx_tmp.append(idx)
consistent_pred_res_tmp.append(candidate[0][0])
if len(consistent_idx_tmp) > len(consistent_idx):
consistent_idx = consistent_idx_tmp
consistent_pred_res = consistent_pred_res_tmp
if len(mappings) > 1:
mapping = m
if len(consistent_idx) == 0:
return 0, 0, None
INFO('Train pool size is:', len(flatten(consistent_pred_res)))
INFO("Train pool size is:", len(flatten(consistent_pred_res)))
INFO("Start to use abduced pseudo label to train model...")
model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping))
model.train(
[X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)
)

consistent_acc = len(consistent_idx) / select_num
char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
char_acc = _get_char_acc(
model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping
)
INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc))
return consistent_acc, char_acc, mapping


# def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
# select_idx = np.random.randint(len(train_X_true), size=select_num)
# X = []
# for idx in select_idx:
# X.append(train_X_true[idx])

# original_pred_res = model.predict(X)['label']

# if mapping == None:
# mappings = gen_mappings([0, 1, 2, 3],['+', '=', 0, 1])
# else:
# mappings = [mapping]

# consistent_idx = []
# consistent_pred_res = []

# for m in mappings:
# pred_res = mapping_res(original_pred_res, m)
# max_abduce_num = 20
# solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num)
# all_address_flag = reform_idx(solution, pred_res)

# consistent_idx_tmp = []
# consistent_pred_res_tmp = []

# for idx in range(len(pred_res)):
# address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
# candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx)
# if len(candidate) > 0:
# consistent_idx_tmp.append(idx)
# consistent_pred_res_tmp.append(candidate[0][0])

# if len(consistent_idx_tmp) > len(consistent_idx):
# consistent_idx = consistent_idx_tmp
# consistent_pred_res = consistent_pred_res_tmp
# if len(mappings) > 1:
# mapping = m

# if len(consistent_idx) == 0:
# return 0, 0, None

# INFO('Train pool size is:', len(flatten(consistent_pred_res)))
# INFO("Start to use abduced pseudo label to train model...")
# model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping))

# consistent_acc = len(consistent_idx) / select_num
# char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
# INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
# return consistent_acc, char_acc, mapping


def _remove_duplicate_rule(rule_dict):
add_nums_dict = {}
for r in list(rule_dict):
add_nums = str(r.split(']')[0].split('[')[1]) + str(r.split(']')[1].split('[')[1]) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10'
add_nums = str(r.split("]")[0].split("[")[1]) + str(
r.split("]")[1].split("[")[1]
) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10'
if add_nums in add_nums_dict:
old_r = add_nums_dict[add_nums]
if rule_dict[r] >= rule_dict[old_r]:
@@ -120,7 +206,9 @@ def _remove_duplicate_rule(rule_dict):
return list(rule_dict)


def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num):
def get_rules_from_data(
model, abducer, mapping, train_X_true, samples_per_rule, samples_num
):
rules = []
for _ in range(samples_num):
while True:
@@ -128,7 +216,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule,
X = []
for idx in select_idx:
X.append(train_X_true[idx])
original_pred_res = model.predict(X)['cls']
original_pred_res = model.predict(X)["label"]
pred_res = mapping_res(original_pred_res, mapping)

consistent_idx = []
@@ -143,42 +231,47 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule,
if rule != None:
break
rules.append(rule)
all_rule_dict = {}
for rule in rules:
for r in rule:
all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1
rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5}
rules = _remove_duplicate_rule(rule_dict)
return rules


def _get_consist_rule_acc(model, abducer, mapping, rules, X):
cnt = 0
for x in X:
original_pred_res = model.predict([x])['cls']
original_pred_res = model.predict([x])["label"]
pred_res = flatten(mapping_res(original_pred_res, mapping))
if abducer.kb.consist_rule(pred_res, rules):
cnt += 1
return cnt / len(X)


def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8):
def train_with_rule(
model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8
):
train_X = train_data
val_X = val_data
samples_num = 50
samples_per_rule = 3

# Start training / for each length of equations
for equation_len in range(min_len, max_len):
INFO("============== equation_len: %d-%d ================" % (equation_len, equation_len + 1))
INFO(
"============== equation_len: %d-%d ================"
% (equation_len, equation_len + 1)
)
train_X_true = train_X[1][equation_len]
train_X_false = train_X[0][equation_len]
val_X_true = val_X[1][equation_len]
val_X_false = val_X[0][equation_len]
train_X_true.extend(train_X[1][equation_len + 1])
train_X_false.extend(train_X[0][equation_len + 1])
val_X_true.extend(val_X[1][equation_len + 1])
@@ -188,12 +281,14 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len
while True:
if equation_len == min_len:
mapping = None
# Abduce and train NN
consistent_acc, char_acc, mapping = abduce_and_train(model, abducer, mapping, train_X_true, select_num)
consistent_acc, char_acc, mapping = abduce_and_train(
model, abducer, mapping, train_X_true, select_num
)
if consistent_acc == 0:
continue
# Test if we can use mlp to evaluate
if consistent_acc >= 0.9 and char_acc >= 0.9:
condition_cnt += 1
@@ -203,32 +298,49 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len
# The condition has been satisfied continuously five times
if condition_cnt >= 5:
INFO("Now checking if we can go to next course")
rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num)
INFO('Learned rules from data:', rules)
true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_true)
false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_false)
INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc))
rules = get_rules_from_data(
model, abducer, mapping, train_X_true, samples_per_rule, samples_num
)
INFO("Learned rules from data:", rules)

true_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, val_X_true
)
false_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, val_X_false
)

INFO(
"consist_rule_acc is %f, %f\n"
% (true_consist_rule_acc, false_consist_rule_acc)
)
# decide next course or restart
if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1:
torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len)
torch.save(
model.classifier_list[0].model.state_dict(),
"./weights/weights_%d.pth" % equation_len,
)
break
else:
if equation_len == min_len:
INFO('Final mapping is: ', mapping)
model.cls_list[0].model.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
INFO("Final mapping is: ", mapping)
model.classifier_list[0].model.load_state_dict(
torch.load("./weights/pretrain_weights.pth")
)
else:
model.cls_list[0].model.load_state_dict(torch.load("./weights/weights_%d.pth" % (equation_len - 1)))
model.classifier_list[0].model.load_state_dict(
torch.load("./weights/weights_%d.pth" % (equation_len - 1))
)
condition_cnt = 0
INFO('Reload Model and retrain')
INFO("Reload Model and retrain")
return model, mapping


def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8):
train_X = train_data
test_X = test_data
# Calcualte how many equations should be selected in each length
# for each length, there are equation_samples_num[equation_len] rules
print("Now begin to train final mlp model")
@@ -247,16 +359,30 @@ def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=
rules = []
samples_per_rule = 3
for equation_len in range(min_len, max_len + 1):
equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, equation_samples_num[equation_len])
equation_rules = get_rules_from_data(
model,
abducer,
mapping,
train_X[1][equation_len],
samples_per_rule,
equation_samples_num[equation_len],
)
rules.extend(equation_rules)
rules = list(set(rules))
INFO('Learned rules from data:', rules)
INFO("Learned rules from data:", rules)

for equation_len in range(5, 27):
true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[1][equation_len])
false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[0][equation_len])
INFO('consist_rule_acc of testing length %d equations are %f, %f' %(equation_len, true_consist_rule_acc, false_consist_rule_acc))
true_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, test_X[1][equation_len]
)
false_consist_rule_acc = _get_consist_rule_acc(
model, abducer, mapping, rules, test_X[0][equation_len]
)
INFO(
"consist_rule_acc of testing length %d equations are %f, %f"
% (equation_len, true_consist_rule_acc, false_consist_rule_acc)
)


if __name__ == "__main__":
pass

+ 6
- 6
examples/hed/hed_example.ipynb View File

@@ -83,8 +83,8 @@
" candidate = self.revise_by_idx(pred, k, address_idx)\n",
" return candidate\n",
" \n",
" def zoopt_revision_score(self, pred_res, pred_res_prob, key, sol): \n",
" all_address_flag = reform_idx(sol.get_x(), pred_res)\n",
" def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, key, sol): \n",
" all_address_flag = reform_idx(sol.get_x(), pseudo_label)\n",
" lefted_idxs = [i for i in range(len(pred_res))]\n",
" candidate_size = [] \n",
" while lefted_idxs:\n",
@@ -95,7 +95,7 @@
" for idx in range(-1, len(pred_res)):\n",
" if (not idx in idxs) and (idx >= 0):\n",
" idxs.append(idx)\n",
" candidate = self._revise_by_idxs(pred_res, key, all_address_flag, idxs)\n",
" candidate = self._revise_by_idxs(pseudo_label, key, all_address_flag, idxs)\n",
" if len(candidate) == 0:\n",
" if len(idxs) > 1:\n",
" idxs.pop()\n",
@@ -106,7 +106,7 @@
" removed = [i for i in lefted_idxs if i in max_candidate_idxs]\n",
" if found:\n",
" candidate_size.append(len(removed) + 1)\n",
" lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] \n",
" lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]\n",
" candidate_size.sort()\n",
" score = 0\n",
" import math\n",
@@ -189,7 +189,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = ABLModel(base_model, kb.pseudo_label_list)"
"model = ABLModel(base_model)"
]
},
{
@@ -221,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [


Loading…
Cancel
Save