You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hed_tmp.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os.path as osp
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. from abl.evaluation import SemanticsMetric, SymbolMetric
  6. from abl.learning import ABLModel, BasicNN
  7. from abl.reasoning import PrologKB, ReasonerBase
  8. from abl.utils import ABLLogger, print_log, reform_list
  9. from examples.hed.datasets.get_hed import get_hed, split_equation
  10. from examples.hed.hed_bridge import HEDBridge
  11. from examples.models.nn import SymbolNet
  12. # Build logger
  13. print_log("Abductive Learning on the HED example.", logger="current")
  14. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  15. log_dir = ABLLogger.get_current_instance().log_dir
  16. weights_dir = osp.join(log_dir, "weights")
  17. ### Logic Part
  18. # Initialize knowledge base and abducer
  19. class HedKB(PrologKB):
  20. def __init__(self, pseudo_label_list, pl_file):
  21. super().__init__(pseudo_label_list, pl_file)
  22. def consist_rule(self, exs, rules):
  23. rules = str(rules).replace("'", "")
  24. return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0
  25. def abduce_rules(self, pred_res):
  26. prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
  27. if len(prolog_result) == 0:
  28. return None
  29. prolog_rules = prolog_result[0]["X"]
  30. rules = [rule.value for rule in prolog_rules]
  31. return rules
  32. class HedReasoner(ReasonerBase):
  33. def revise_at_idx(self, data_sample):
  34. revision_idx = np.where(np.array(data_sample.flatten("revision_flag")) != 0)[0]
  35. candidate = self.kb.revise_at_idx(
  36. data_sample.pred_pseudo_label, data_sample.Y, revision_idx
  37. )
  38. return candidate
  39. def zoopt_revision_score(self, symbol_num, data_sample, sol):
  40. revision_flag = reform_list(list(sol.get_x().astype(np.int32)), data_sample.pred_pseudo_label)
  41. data_sample.revision_flag = revision_flag
  42. lefted_idxs = [i for i in range(len(data_sample.pred_idx))]
  43. candidate_size = []
  44. while lefted_idxs:
  45. idxs = []
  46. idxs.append(lefted_idxs.pop(0))
  47. max_candidate_idxs = []
  48. found = False
  49. for idx in range(-1, len(data_sample.pred_idx)):
  50. if (not idx in idxs) and (idx >= 0):
  51. idxs.append(idx)
  52. candidate = self.revise_at_idx(data_sample[idxs])
  53. if len(candidate) == 0:
  54. if len(idxs) > 1:
  55. idxs.pop()
  56. else:
  57. if len(idxs) > len(max_candidate_idxs):
  58. found = True
  59. max_candidate_idxs = idxs.copy()
  60. removed = [i for i in lefted_idxs if i in max_candidate_idxs]
  61. if found:
  62. candidate_size.append(len(removed) + 1)
  63. lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
  64. candidate_size.sort()
  65. score = 0
  66. import math
  67. for i in range(0, len(candidate_size)):
  68. score -= math.exp(-i) * candidate_size[i]
  69. return score
  70. def abduce(self, data_sample):
  71. symbol_num = data_sample.elements_num("pred_pseudo_label")
  72. max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
  73. solution = self.zoopt_get_solution(symbol_num, data_sample, max_revision_num)
  74. data_sample.revision_flag = reform_list(
  75. solution.astype(np.int32), data_sample.pred_pseudo_label
  76. )
  77. abduced_pseudo_label = []
  78. for single_instance in data_sample:
  79. single_instance.pred_pseudo_label = [single_instance.pred_pseudo_label]
  80. candidates = self.revise_at_idx(single_instance)
  81. if len(candidates) == 0:
  82. abduced_pseudo_label.append([])
  83. else:
  84. abduced_pseudo_label.append(candidates[0][0])
  85. data_sample.abduced_pseudo_label = abduced_pseudo_label
  86. return abduced_pseudo_label
  87. def abduce_rules(self, pred_res):
  88. return self.kb.abduce_rules(pred_res)
  89. import os
  90. CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
  91. kb = HedKB(
  92. pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "./datasets/learn_add.pl")
  93. )
  94. reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=20)
  95. ### Machine Learning Part
  96. # Build necessary components for BasicNN
  97. cls = SymbolNet(num_classes=4)
  98. criterion = nn.CrossEntropyLoss()
  99. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)
  100. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  101. # Build BasicNN
  102. # The function of BasicNN is to wrap NN models into the form of an sklearn estimator
  103. base_model = BasicNN(
  104. cls,
  105. criterion,
  106. optimizer,
  107. device,
  108. batch_size=32,
  109. num_epochs=1,
  110. save_interval=1,
  111. save_dir=weights_dir,
  112. )
  113. # Build ABLModel
  114. # The main function of the ABL model is to serialize data and
  115. # provide a unified interface for different machine learning models
  116. model = ABLModel(base_model)
  117. ### Metric
  118. # Set up metrics
  119. metric_list = [SymbolMetric(prefix="hed"), SemanticsMetric(prefix="hed")]
  120. ### Bridge Machine Learning and Logic Reasoning
  121. bridge = HEDBridge(model, reasoner, metric_list)
  122. ### Dataset
  123. total_train_data = get_hed(train=True)
  124. train_data, val_data = split_equation(total_train_data, 3, 1)
  125. test_data = get_hed(train=False)
  126. ### Train and Test
  127. bridge.pretrain("examples/hed/weights")
  128. bridge.train(train_data, val_data)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.