|
|
@@ -21,6 +21,7 @@ from abl.learning.basic_nn import BasicNN, BasicDataset |
|
|
|
|
|
|
|
from utils import gen_mappings, mapping_res, remapping_res |
|
|
|
from models.nn import SymbolNetAutoencoder |
|
|
|
from torch.utils.data import RandomSampler |
|
|
|
from datasets.get_hed import get_pretrain_data |
|
|
|
|
|
|
|
|
|
|
@@ -29,85 +30,170 @@ def hed_pretrain(kb, cls, recorder): |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
if not os.path.exists("./weights/pretrain_weights.pth"): |
|
|
|
INFO("Pretrain Start") |
|
|
|
pretrain_data_X, pretrain_data_Y = get_pretrain_data(['0', '1', '10', '11']) |
|
|
|
pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"]) |
|
|
|
pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y) |
|
|
|
pretrain_data_loader = torch.utils.data.DataLoader(pretrain_data, batch_size=64, shuffle=True) |
|
|
|
|
|
|
|
pretrain_data_loader = torch.utils.data.DataLoader( |
|
|
|
pretrain_data, batch_size=64, shuffle=True |
|
|
|
) |
|
|
|
|
|
|
|
criterion = nn.MSELoss() |
|
|
|
optimizer = torch.optim.RMSprop(cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6) |
|
|
|
optimizer = torch.optim.RMSprop( |
|
|
|
cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6 |
|
|
|
) |
|
|
|
|
|
|
|
pretrain_model = BasicNN(cls_autoencoder, criterion, optimizer, device, save_interval=1, save_dir=recorder.save_dir, num_epochs=10, recorder=recorder) |
|
|
|
pretrain_model = BasicNN( |
|
|
|
cls_autoencoder, |
|
|
|
criterion, |
|
|
|
optimizer, |
|
|
|
device, |
|
|
|
save_interval=1, |
|
|
|
save_dir=recorder.save_dir, |
|
|
|
num_epochs=10, |
|
|
|
recorder=recorder, |
|
|
|
) |
|
|
|
pretrain_model.fit(pretrain_data_loader) |
|
|
|
torch.save(cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth") |
|
|
|
torch.save( |
|
|
|
cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth" |
|
|
|
) |
|
|
|
cls.load_state_dict(cls_autoencoder.base_model.state_dict()) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
cls.load_state_dict(torch.load("./weights/pretrain_weights.pth")) |
|
|
|
|
|
|
|
|
|
|
|
def _get_char_acc(model, X, consistent_pred_res, mapping): |
|
|
|
original_pred_res = model.predict(X)['cls'] |
|
|
|
original_pred_res = model.predict(X)["label"] |
|
|
|
pred_res = flatten(mapping_res(original_pred_res, mapping)) |
|
|
|
INFO('Current model\'s output: ', pred_res) |
|
|
|
INFO('Abduced labels: ', flatten(consistent_pred_res)) |
|
|
|
INFO("Current model's output: ", pred_res) |
|
|
|
INFO("Abduced labels: ", flatten(consistent_pred_res)) |
|
|
|
assert len(pred_res) == len(flatten(consistent_pred_res)) |
|
|
|
return sum([pred_res[idx] == flatten(consistent_pred_res)[idx] for idx in range(len(pred_res))]) / len(pred_res) |
|
|
|
return sum( |
|
|
|
[ |
|
|
|
pred_res[idx] == flatten(consistent_pred_res)[idx] |
|
|
|
for idx in range(len(pred_res)) |
|
|
|
] |
|
|
|
) / len(pred_res) |
|
|
|
|
|
|
|
|
|
|
|
def abduce_and_train(model, abducer, mapping, train_X_true, select_num): |
|
|
|
select_idx = np.random.randint(len(train_X_true), size=select_num) |
|
|
|
X = [] |
|
|
|
for idx in select_idx: |
|
|
|
X.append(train_X_true[idx]) |
|
|
|
select_idx = RandomSampler(train_X_true, num_samples=select_num,replacement=False) |
|
|
|
X = [train_X_true[idx] for idx in select_idx] |
|
|
|
|
|
|
|
# original_pred_res = model.predict(X)['label'] |
|
|
|
pred_label = model.predict(X)["label"] |
|
|
|
|
|
|
|
original_pred_res = model.predict(X)['cls'] |
|
|
|
|
|
|
|
if mapping == None: |
|
|
|
mappings = gen_mappings(['+', '=', 0, 1],['+', '=', 0, 1]) |
|
|
|
mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) |
|
|
|
else: |
|
|
|
mappings = [mapping] |
|
|
|
|
|
|
|
|
|
|
|
consistent_idx = [] |
|
|
|
consistent_pred_res = [] |
|
|
|
|
|
|
|
|
|
|
|
for m in mappings: |
|
|
|
pred_res = mapping_res(original_pred_res, m) |
|
|
|
max_abduce_num = 20 |
|
|
|
solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num) |
|
|
|
all_address_flag = reform_idx(solution, pred_res) |
|
|
|
pred_pseudo_label = mapping_res(pred_label, m) |
|
|
|
max_revision_num = 20 |
|
|
|
solution = abducer.zoopt_get_solution( |
|
|
|
pred_label, |
|
|
|
pred_pseudo_label, |
|
|
|
[None] * len(pred_label), |
|
|
|
[None] * len(pred_label), |
|
|
|
max_revision_num, |
|
|
|
) |
|
|
|
all_address_flag = reform_idx(solution, pred_label) |
|
|
|
|
|
|
|
consistent_idx_tmp = [] |
|
|
|
consistent_pred_res_tmp = [] |
|
|
|
|
|
|
|
for idx in range(len(pred_res)): |
|
|
|
address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] |
|
|
|
candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx) |
|
|
|
|
|
|
|
for idx in range(len(pred_label)): |
|
|
|
address_idx = [ |
|
|
|
i for i, flag in enumerate(all_address_flag[idx]) if flag != 0 |
|
|
|
] |
|
|
|
candidate = abducer.revise_by_idx([pred_pseudo_label[idx]], None, address_idx) |
|
|
|
if len(candidate) > 0: |
|
|
|
consistent_idx_tmp.append(idx) |
|
|
|
consistent_pred_res_tmp.append(candidate[0][0]) |
|
|
|
|
|
|
|
|
|
|
|
if len(consistent_idx_tmp) > len(consistent_idx): |
|
|
|
consistent_idx = consistent_idx_tmp |
|
|
|
consistent_pred_res = consistent_pred_res_tmp |
|
|
|
if len(mappings) > 1: |
|
|
|
mapping = m |
|
|
|
|
|
|
|
|
|
|
|
if len(consistent_idx) == 0: |
|
|
|
return 0, 0, None |
|
|
|
|
|
|
|
INFO('Train pool size is:', len(flatten(consistent_pred_res))) |
|
|
|
|
|
|
|
INFO("Train pool size is:", len(flatten(consistent_pred_res))) |
|
|
|
INFO("Start to use abduced pseudo label to train model...") |
|
|
|
model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)) |
|
|
|
model.train( |
|
|
|
[X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping) |
|
|
|
) |
|
|
|
|
|
|
|
consistent_acc = len(consistent_idx) / select_num |
|
|
|
char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping) |
|
|
|
INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc)) |
|
|
|
char_acc = _get_char_acc( |
|
|
|
model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping |
|
|
|
) |
|
|
|
INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc)) |
|
|
|
return consistent_acc, char_acc, mapping |
|
|
|
|
|
|
|
|
|
|
|
# def abduce_and_train(model, abducer, mapping, train_X_true, select_num): |
|
|
|
# select_idx = np.random.randint(len(train_X_true), size=select_num) |
|
|
|
# X = [] |
|
|
|
# for idx in select_idx: |
|
|
|
# X.append(train_X_true[idx]) |
|
|
|
|
|
|
|
# original_pred_res = model.predict(X)['label'] |
|
|
|
|
|
|
|
# if mapping == None: |
|
|
|
# mappings = gen_mappings([0, 1, 2, 3],['+', '=', 0, 1]) |
|
|
|
# else: |
|
|
|
# mappings = [mapping] |
|
|
|
|
|
|
|
# consistent_idx = [] |
|
|
|
# consistent_pred_res = [] |
|
|
|
|
|
|
|
# for m in mappings: |
|
|
|
# pred_res = mapping_res(original_pred_res, m) |
|
|
|
# max_abduce_num = 20 |
|
|
|
# solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num) |
|
|
|
# all_address_flag = reform_idx(solution, pred_res) |
|
|
|
|
|
|
|
# consistent_idx_tmp = [] |
|
|
|
# consistent_pred_res_tmp = [] |
|
|
|
|
|
|
|
# for idx in range(len(pred_res)): |
|
|
|
# address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0] |
|
|
|
# candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx) |
|
|
|
# if len(candidate) > 0: |
|
|
|
# consistent_idx_tmp.append(idx) |
|
|
|
# consistent_pred_res_tmp.append(candidate[0][0]) |
|
|
|
|
|
|
|
# if len(consistent_idx_tmp) > len(consistent_idx): |
|
|
|
# consistent_idx = consistent_idx_tmp |
|
|
|
# consistent_pred_res = consistent_pred_res_tmp |
|
|
|
# if len(mappings) > 1: |
|
|
|
# mapping = m |
|
|
|
|
|
|
|
# if len(consistent_idx) == 0: |
|
|
|
# return 0, 0, None |
|
|
|
|
|
|
|
# INFO('Train pool size is:', len(flatten(consistent_pred_res))) |
|
|
|
# INFO("Start to use abduced pseudo label to train model...") |
|
|
|
# model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)) |
|
|
|
|
|
|
|
# consistent_acc = len(consistent_idx) / select_num |
|
|
|
# char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping) |
|
|
|
# INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc)) |
|
|
|
# return consistent_acc, char_acc, mapping |
|
|
|
|
|
|
|
|
|
|
|
def _remove_duplicate_rule(rule_dict): |
|
|
|
add_nums_dict = {} |
|
|
|
for r in list(rule_dict): |
|
|
|
add_nums = str(r.split(']')[0].split('[')[1]) + str(r.split(']')[1].split('[')[1]) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10' |
|
|
|
add_nums = str(r.split("]")[0].split("[")[1]) + str( |
|
|
|
r.split("]")[1].split("[")[1] |
|
|
|
) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10' |
|
|
|
if add_nums in add_nums_dict: |
|
|
|
old_r = add_nums_dict[add_nums] |
|
|
|
if rule_dict[r] >= rule_dict[old_r]: |
|
|
@@ -120,7 +206,9 @@ def _remove_duplicate_rule(rule_dict): |
|
|
|
return list(rule_dict) |
|
|
|
|
|
|
|
|
|
|
|
def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num): |
|
|
|
def get_rules_from_data( |
|
|
|
model, abducer, mapping, train_X_true, samples_per_rule, samples_num |
|
|
|
): |
|
|
|
rules = [] |
|
|
|
for _ in range(samples_num): |
|
|
|
while True: |
|
|
@@ -128,7 +216,7 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, |
|
|
|
X = [] |
|
|
|
for idx in select_idx: |
|
|
|
X.append(train_X_true[idx]) |
|
|
|
original_pred_res = model.predict(X)['cls'] |
|
|
|
original_pred_res = model.predict(X)["label"] |
|
|
|
pred_res = mapping_res(original_pred_res, mapping) |
|
|
|
|
|
|
|
consistent_idx = [] |
|
|
@@ -143,42 +231,47 @@ def get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, |
|
|
|
if rule != None: |
|
|
|
break |
|
|
|
rules.append(rule) |
|
|
|
|
|
|
|
|
|
|
|
all_rule_dict = {} |
|
|
|
for rule in rules: |
|
|
|
for r in rule: |
|
|
|
all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1 |
|
|
|
rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5} |
|
|
|
rules = _remove_duplicate_rule(rule_dict) |
|
|
|
|
|
|
|
|
|
|
|
return rules |
|
|
|
|
|
|
|
|
|
|
|
def _get_consist_rule_acc(model, abducer, mapping, rules, X): |
|
|
|
cnt = 0 |
|
|
|
for x in X: |
|
|
|
original_pred_res = model.predict([x])['cls'] |
|
|
|
original_pred_res = model.predict([x])["label"] |
|
|
|
pred_res = flatten(mapping_res(original_pred_res, mapping)) |
|
|
|
if abducer.kb.consist_rule(pred_res, rules): |
|
|
|
cnt += 1 |
|
|
|
return cnt / len(X) |
|
|
|
|
|
|
|
|
|
|
|
def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8): |
|
|
|
def train_with_rule( |
|
|
|
model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8 |
|
|
|
): |
|
|
|
train_X = train_data |
|
|
|
val_X = val_data |
|
|
|
|
|
|
|
|
|
|
|
samples_num = 50 |
|
|
|
samples_per_rule = 3 |
|
|
|
|
|
|
|
# Start training / for each length of equations |
|
|
|
for equation_len in range(min_len, max_len): |
|
|
|
INFO("============== equation_len: %d-%d ================" % (equation_len, equation_len + 1)) |
|
|
|
INFO( |
|
|
|
"============== equation_len: %d-%d ================" |
|
|
|
% (equation_len, equation_len + 1) |
|
|
|
) |
|
|
|
train_X_true = train_X[1][equation_len] |
|
|
|
train_X_false = train_X[0][equation_len] |
|
|
|
val_X_true = val_X[1][equation_len] |
|
|
|
val_X_false = val_X[0][equation_len] |
|
|
|
|
|
|
|
|
|
|
|
train_X_true.extend(train_X[1][equation_len + 1]) |
|
|
|
train_X_false.extend(train_X[0][equation_len + 1]) |
|
|
|
val_X_true.extend(val_X[1][equation_len + 1]) |
|
|
@@ -188,12 +281,14 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len |
|
|
|
while True: |
|
|
|
if equation_len == min_len: |
|
|
|
mapping = None |
|
|
|
|
|
|
|
|
|
|
|
# Abduce and train NN |
|
|
|
consistent_acc, char_acc, mapping = abduce_and_train(model, abducer, mapping, train_X_true, select_num) |
|
|
|
consistent_acc, char_acc, mapping = abduce_and_train( |
|
|
|
model, abducer, mapping, train_X_true, select_num |
|
|
|
) |
|
|
|
if consistent_acc == 0: |
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
# Test if we can use mlp to evaluate |
|
|
|
if consistent_acc >= 0.9 and char_acc >= 0.9: |
|
|
|
condition_cnt += 1 |
|
|
@@ -203,32 +298,49 @@ def train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len |
|
|
|
# The condition has been satisfied continuously five times |
|
|
|
if condition_cnt >= 5: |
|
|
|
INFO("Now checking if we can go to next course") |
|
|
|
rules = get_rules_from_data(model, abducer, mapping, train_X_true, samples_per_rule, samples_num) |
|
|
|
INFO('Learned rules from data:', rules) |
|
|
|
|
|
|
|
true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_true) |
|
|
|
false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, val_X_false) |
|
|
|
|
|
|
|
INFO('consist_rule_acc is %f, %f\n' %(true_consist_rule_acc, false_consist_rule_acc)) |
|
|
|
rules = get_rules_from_data( |
|
|
|
model, abducer, mapping, train_X_true, samples_per_rule, samples_num |
|
|
|
) |
|
|
|
INFO("Learned rules from data:", rules) |
|
|
|
|
|
|
|
true_consist_rule_acc = _get_consist_rule_acc( |
|
|
|
model, abducer, mapping, rules, val_X_true |
|
|
|
) |
|
|
|
false_consist_rule_acc = _get_consist_rule_acc( |
|
|
|
model, abducer, mapping, rules, val_X_false |
|
|
|
) |
|
|
|
|
|
|
|
INFO( |
|
|
|
"consist_rule_acc is %f, %f\n" |
|
|
|
% (true_consist_rule_acc, false_consist_rule_acc) |
|
|
|
) |
|
|
|
# decide next course or restart |
|
|
|
if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1: |
|
|
|
torch.save(model.cls_list[0].model.state_dict(), "./weights/weights_%d.pth" % equation_len) |
|
|
|
torch.save( |
|
|
|
model.classifier_list[0].model.state_dict(), |
|
|
|
"./weights/weights_%d.pth" % equation_len, |
|
|
|
) |
|
|
|
break |
|
|
|
else: |
|
|
|
if equation_len == min_len: |
|
|
|
INFO('Final mapping is: ', mapping) |
|
|
|
model.cls_list[0].model.load_state_dict(torch.load("./weights/pretrain_weights.pth")) |
|
|
|
INFO("Final mapping is: ", mapping) |
|
|
|
model.classifier_list[0].model.load_state_dict( |
|
|
|
torch.load("./weights/pretrain_weights.pth") |
|
|
|
) |
|
|
|
else: |
|
|
|
model.cls_list[0].model.load_state_dict(torch.load("./weights/weights_%d.pth" % (equation_len - 1))) |
|
|
|
model.classifier_list[0].model.load_state_dict( |
|
|
|
torch.load("./weights/weights_%d.pth" % (equation_len - 1)) |
|
|
|
) |
|
|
|
condition_cnt = 0 |
|
|
|
INFO('Reload Model and retrain') |
|
|
|
|
|
|
|
INFO("Reload Model and retrain") |
|
|
|
|
|
|
|
return model, mapping |
|
|
|
|
|
|
|
|
|
|
|
def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8): |
|
|
|
train_X = train_data |
|
|
|
test_X = test_data |
|
|
|
|
|
|
|
|
|
|
|
# Calcualte how many equations should be selected in each length |
|
|
|
# for each length, there are equation_samples_num[equation_len] rules |
|
|
|
print("Now begin to train final mlp model") |
|
|
@@ -247,16 +359,30 @@ def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len= |
|
|
|
rules = [] |
|
|
|
samples_per_rule = 3 |
|
|
|
for equation_len in range(min_len, max_len + 1): |
|
|
|
equation_rules = get_rules_from_data(model, abducer, mapping, train_X[1][equation_len], samples_per_rule, equation_samples_num[equation_len]) |
|
|
|
equation_rules = get_rules_from_data( |
|
|
|
model, |
|
|
|
abducer, |
|
|
|
mapping, |
|
|
|
train_X[1][equation_len], |
|
|
|
samples_per_rule, |
|
|
|
equation_samples_num[equation_len], |
|
|
|
) |
|
|
|
rules.extend(equation_rules) |
|
|
|
rules = list(set(rules)) |
|
|
|
INFO('Learned rules from data:', rules) |
|
|
|
|
|
|
|
|
|
|
|
INFO("Learned rules from data:", rules) |
|
|
|
|
|
|
|
for equation_len in range(5, 27): |
|
|
|
true_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[1][equation_len]) |
|
|
|
false_consist_rule_acc = _get_consist_rule_acc(model, abducer, mapping, rules, test_X[0][equation_len]) |
|
|
|
INFO('consist_rule_acc of testing length %d equations are %f, %f' %(equation_len, true_consist_rule_acc, false_consist_rule_acc)) |
|
|
|
true_consist_rule_acc = _get_consist_rule_acc( |
|
|
|
model, abducer, mapping, rules, test_X[1][equation_len] |
|
|
|
) |
|
|
|
false_consist_rule_acc = _get_consist_rule_acc( |
|
|
|
model, abducer, mapping, rules, test_X[0][equation_len] |
|
|
|
) |
|
|
|
INFO( |
|
|
|
"consist_rule_acc of testing length %d equations are %f, %f" |
|
|
|
% (equation_len, true_consist_rule_acc, false_consist_rule_acc) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
pass |