From b037edaf0b0eb82ad1d6c57e8ef50663474211ab Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:03:17 +0800 Subject: [PATCH] Update framework.py --- framework.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/framework.py b/framework.py index 5479316..a37483b 100644 --- a/framework.py +++ b/framework.py @@ -22,7 +22,7 @@ def block_sample(X, Z, Y, sample_num, epoch_idx): if part_num == 0: part_num = 1 seg_idx = epoch_idx % part_num - INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak)) + INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X)) X = X[sample_num * seg_idx: sample_num * (seg_idx + 1)] Z = Z[sample_num * seg_idx: sample_num * (seg_idx + 1)] Y = Y[sample_num * seg_idx: sample_num * (seg_idx + 1)] @@ -150,10 +150,12 @@ def train(model, abducer, X, Z, Y, epochs = 10, sample_num = -1, verbose = -1): abduced_Z = abduce_func(preds_res, Y) abl_acc = get_abl_acc(Y, preds_res['cls'], abducer.kb.logic_forward) - if(not char_acc_flag): + if(char_acc_flag): ori_char_acc = get_char_acc(Z, preds_res['cls']) abd_char_acc = get_char_acc(abduced_Z, preds_res['cls']) - print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc) + print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc) + else: + print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc) finetune_X, finetune_Z = filter_data(X, abduced_Z) if len(finetune_X) > 0: