Browse Source

Update framework.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
b037edaf0b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 3 deletions
  1. +5
    -3
      framework.py

+ 5
- 3
framework.py View File

@@ -22,7 +22,7 @@ def block_sample(X, Z, Y, sample_num, epoch_idx):
if part_num == 0:
part_num = 1
seg_idx = epoch_idx % part_num
INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X_bak))
INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X))
X = X[sample_num * seg_idx: sample_num * (seg_idx + 1)]
Z = Z[sample_num * seg_idx: sample_num * (seg_idx + 1)]
Y = Y[sample_num * seg_idx: sample_num * (seg_idx + 1)]
@@ -150,10 +150,12 @@ def train(model, abducer, X, Z, Y, epochs = 10, sample_num = -1, verbose = -1):
abduced_Z = abduce_func(preds_res, Y)

abl_acc = get_abl_acc(Y, preds_res['cls'], abducer.kb.logic_forward)
if(not char_acc_flag):
if(char_acc_flag):
ori_char_acc = get_char_acc(Z, preds_res['cls'])
abd_char_acc = get_char_acc(abduced_Z, preds_res['cls'])
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc)
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc, ' ori_char_acc:', ori_char_acc, ' abd_char_acc:', abd_char_acc)
else:
print('epoch_idx:', epoch_idx, ' abl_acc:', abl_acc)
finetune_X, finetune_Z = filter_data(X, abduced_Z)
if len(finetune_X) > 0:


Loading…
Cancel
Save