import math import os import numpy as np from abl.reasoning import PrologKB, Reasoner from abl.utils import reform_list CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) class HedKB(PrologKB): def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")): super().__init__(pseudo_label_list, pl_file) self.learned_rules = {} def consist_rule(self, exs, rules): rules = str(rules).replace("'", "") return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 def abduce_rules(self, pred_res): prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) if len(prolog_result) == 0: return None prolog_rules = prolog_result[0]["X"] rules = [rule.value for rule in prolog_rules] return rules class HedReasoner(Reasoner): def revise_at_idx(self, data_example): revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0] candidate = self.kb.revise_at_idx( data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx ) return candidate def zoopt_budget(self, symbol_num): return 200 def zoopt_score(self, symbol_num, data_example, sol, get_score=True): revision_flag = reform_list( list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label ) data_example.revision_flag = revision_flag lefted_idxs = [i for i in range(len(data_example.pred_idx))] candidate_size = [] max_consistent_idxs = [] while lefted_idxs: idxs = [] idxs.append(lefted_idxs.pop(0)) max_candidate_idxs = [] found = False for idx in range(-1, len(data_example.pred_idx)): if (not idx in idxs) and (idx >= 0): idxs.append(idx) candidates, _ = self.revise_at_idx(data_example[idxs]) if len(candidates) == 0: if len(idxs) > 1: idxs.pop() else: if len(idxs) > len(max_candidate_idxs): found = True max_candidate_idxs = idxs.copy() removed = [i for i in lefted_idxs if i in max_candidate_idxs] if found: removed.insert(0, idxs[0]) candidate_size.append(len(removed)) max_consistent_idxs = max_candidate_idxs.copy() lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs] candidate_size.sort() score = 0 for i in range(0, len(candidate_size)): score -= math.exp(-i) * candidate_size[i] if get_score: return score else: return max_consistent_idxs def abduce(self, data_example): symbol_num = data_example.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num) max_candidate_idxs = self.zoopt_score(symbol_num, data_example, solution, get_score=False) abduced_pseudo_label = [[] for _ in range(len(data_example))] if len(max_candidate_idxs) > 0: candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs]) for i, idx in enumerate(max_candidate_idxs): abduced_pseudo_label[idx] = candidates[0][i] data_example.abduced_pseudo_label = abduced_pseudo_label return abduced_pseudo_label def abduce_rules(self, pred_res): return self.kb.abduce_rules(pred_res)