You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reasoning.py 3.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import math
  2. import os
  3. import numpy as np
  4. from abl.reasoning import PrologKB, Reasoner
  5. from abl.utils import reform_list
  6. CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
  7. class HedKB(PrologKB):
  8. def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")):
  9. super().__init__(pseudo_label_list, pl_file)
  10. self.learned_rules = {}
  11. def consist_rule(self, exs, rules):
  12. rules = str(rules).replace("'", "")
  13. return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0
  14. def abduce_rules(self, pred_res):
  15. prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
  16. if len(prolog_result) == 0:
  17. return None
  18. prolog_rules = prolog_result[0]["X"]
  19. rules = [rule.value for rule in prolog_rules]
  20. return rules
  21. class HedReasoner(Reasoner):
  22. def revise_at_idx(self, data_example):
  23. revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0]
  24. candidate = self.kb.revise_at_idx(
  25. data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
  26. )
  27. return candidate
  28. def zoopt_budget(self, symbol_num):
  29. return 200
  30. def zoopt_score(self, symbol_num, data_example, sol, get_score=True):
  31. revision_flag = reform_list(
  32. list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label
  33. )
  34. data_example.revision_flag = revision_flag
  35. lefted_idxs = [i for i in range(len(data_example.pred_idx))]
  36. candidate_size = []
  37. max_consistent_idxs = []
  38. while lefted_idxs:
  39. idxs = []
  40. idxs.append(lefted_idxs.pop(0))
  41. max_candidate_idxs = []
  42. found = False
  43. for idx in range(-1, len(data_example.pred_idx)):
  44. if (not idx in idxs) and (idx >= 0):
  45. idxs.append(idx)
  46. candidates, _ = self.revise_at_idx(data_example[idxs])
  47. if len(candidates) == 0:
  48. if len(idxs) > 1:
  49. idxs.pop()
  50. else:
  51. if len(idxs) > len(max_candidate_idxs):
  52. found = True
  53. max_candidate_idxs = idxs.copy()
  54. removed = [i for i in lefted_idxs if i in max_candidate_idxs]
  55. if found:
  56. removed.insert(0, idxs[0])
  57. candidate_size.append(len(removed))
  58. max_consistent_idxs = max_candidate_idxs.copy()
  59. lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
  60. candidate_size.sort()
  61. score = 0
  62. for i in range(0, len(candidate_size)):
  63. score -= math.exp(-i) * candidate_size[i]
  64. if get_score:
  65. return score
  66. else:
  67. return max_consistent_idxs
  68. def abduce(self, data_example):
  69. symbol_num = data_example.elements_num("pred_pseudo_label")
  70. max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
  71. solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
  72. max_candidate_idxs = self.zoopt_score(symbol_num, data_example, solution, get_score=False)
  73. abduced_pseudo_label = [[] for _ in range(len(data_example))]
  74. if len(max_candidate_idxs) > 0:
  75. candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])
  76. for i, idx in enumerate(max_candidate_idxs):
  77. abduced_pseudo_label[idx] = candidates[0][i]
  78. data_example.abduced_pseudo_label = abduced_pseudo_label
  79. return abduced_pseudo_label
  80. def abduce_rules(self, pred_res):
  81. return self.kb.abduce_rules(pred_res)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.