|
|
@@ -22,7 +22,7 @@ def block_sample(X, Z, Y, sample_num, epoch_idx): |
|
|
|
if part_num == 0: |
|
|
|
part_num = 1 |
|
|
|
seg_idx = epoch_idx % part_num |
|
|
|
INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak)) |
|
|
|
INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X)) |
|
|
|
X = X[sample_num * seg_idx: sample_num * (seg_idx + 1)] |
|
|
|
Z = Z[sample_num * seg_idx: sample_num * (seg_idx + 1)] |
|
|
|
Y = Y[sample_num * seg_idx: sample_num * (seg_idx + 1)] |
|
|
@@ -150,10 +150,12 @@ def train(model, abducer, X, Z, Y, epochs = 10, sample_num = -1, verbose = -1): |
|
|
|
abduced_Z = abduce_func(preds_res, Y) |
|
|
|
|
|
|
|
abl_acc = get_abl_acc(Y, preds_res['cls'], abducer.kb.logic_forward) |
|
|
|
if(not char_acc_flag): |
|
|
|
if(char_acc_flag): |
|
|
|
ori_char_acc = get_char_acc(Z, preds_res['cls']) |
|
|
|
abd_char_acc = get_char_acc(abduced_Z, preds_res['cls']) |
|
|
|
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc) |
|
|
|
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc) |
|
|
|
else: |
|
|
|
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc) |
|
|
|
|
|
|
|
finetune_X, finetune_Z = filter_data(X, abduced_Z) |
|
|
|
if len(finetune_X) > 0: |
|
|
|