You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hed_bridge.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import os
  2. from collections import defaultdict
  3. import torch
  4. from abl.bridge import SimpleBridge
  5. from abl.dataset import RegressionDataset
  6. from abl.evaluation import BaseMetric
  7. from abl.learning import ABLModel, BasicNN
  8. from abl.reasoning import Reasoner
  9. from abl.structures import ListData
  10. from abl.utils import print_log
  11. from examples.hed.datasets.get_hed import get_pretrain_data
  12. from examples.hed.utils import InfiniteSampler, gen_mappings
  13. from examples.models.nn import SymbolNetAutoencoder
  14. class HEDBridge(SimpleBridge):
  15. def __init__(
  16. self,
  17. model: ABLModel,
  18. reasoner: Reasoner,
  19. metric_list: BaseMetric,
  20. ) -> None:
  21. super().__init__(model, reasoner, metric_list)
  22. def pretrain(self, weights_dir):
  23. if not os.path.exists(os.path.join(weights_dir, "pretrain_weights.pth")):
  24. print_log("Pretrain Start", logger="current")
  25. cls_autoencoder = SymbolNetAutoencoder(
  26. num_classes=len(self.reasoner.kb.pseudo_label_list)
  27. )
  28. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  29. loss_fn = torch.nn.MSELoss()
  30. optimizer = torch.optim.RMSprop(
  31. cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6
  32. )
  33. pretrain_model = BasicNN(
  34. cls_autoencoder,
  35. loss_fn,
  36. optimizer,
  37. device,
  38. save_interval=1,
  39. save_dir=weights_dir,
  40. num_epochs=10,
  41. )
  42. pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"])
  43. pretrain_data = RegressionDataset(pretrain_data_X, pretrain_data_Y)
  44. pretrain_data_loader = torch.utils.data.DataLoader(
  45. pretrain_data, batch_size=64, shuffle=True
  46. )
  47. pretrain_model.fit(pretrain_data_loader)
  48. save_parma_dic = {
  49. "model": cls_autoencoder.base_model.state_dict(),
  50. }
  51. torch.save(save_parma_dic, os.path.join(weights_dir, "pretrain_weights.pth"))
  52. self.model.load(load_path=os.path.join(weights_dir, "pretrain_weights.pth"))
  53. def select_mapping_and_abduce(self, data_samples: ListData):
  54. candidate_mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1])
  55. mapping_score = []
  56. abduced_pseudo_label_list = []
  57. for _mapping in candidate_mappings:
  58. self.reasoner.mapping = _mapping
  59. self.reasoner.remapping = dict(zip(_mapping.values(), _mapping.keys()))
  60. self.idx_to_pseudo_label(data_samples)
  61. abduced_pseudo_label = self.reasoner.abduce(data_samples)
  62. mapping_score.append(len(abduced_pseudo_label) - abduced_pseudo_label.count([]))
  63. abduced_pseudo_label_list.append(abduced_pseudo_label)
  64. max_revisible_instances = max(mapping_score)
  65. return_idx = mapping_score.index(max_revisible_instances)
  66. self.reasoner.mapping = candidate_mappings[return_idx]
  67. self.reasoner.remapping = dict(
  68. zip(self.reasoner.mapping.values(), self.reasoner.mapping.keys())
  69. )
  70. self.idx_to_pseudo_label(data_samples)
  71. data_samples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx]
  72. return data_samples.abduced_pseudo_label
  73. def abduce_pseudo_label(self, data_samples: ListData):
  74. self.reasoner.abduce(data_samples)
  75. return data_samples.abduced_pseudo_label
  76. def check_training_impact(self, filtered_data_samples, data_samples):
  77. character_accuracy = self.model.valid(filtered_data_samples)
  78. revisible_ratio = len(filtered_data_samples.X) / len(data_samples.X)
  79. log_string = (
  80. f"Revisible ratio is {revisible_ratio:.3f}, Character "
  81. f"accuracy is {character_accuracy:.3f}"
  82. )
  83. print_log(log_string, logger="current")
  84. if character_accuracy >= 0.9 and revisible_ratio >= 0.9:
  85. return True
  86. return False
  87. def check_rule_quality(self, rule, val_data, equation_len):
  88. val_X_true = self.data_preprocess(val_data[1], equation_len)
  89. val_X_false = self.data_preprocess(val_data[0], equation_len)
  90. true_ratio = self.calc_consistent_ratio(val_X_true, rule)
  91. false_ratio = self.calc_consistent_ratio(val_X_false, rule)
  92. log_string = (
  93. f"True consistent ratio is {true_ratio:.3f}, False inconsistent ratio "
  94. f"is {1 - false_ratio:.3f}"
  95. )
  96. print_log(log_string, logger="current")
  97. if true_ratio > 0.95 and false_ratio < 0.1:
  98. return True
  99. return False
  100. def calc_consistent_ratio(self, data_samples, rule):
  101. self.predict(data_samples)
  102. pred_pseudo_label = self.idx_to_pseudo_label(data_samples)
  103. consistent_num = sum(
  104. [self.reasoner.kb.consist_rule(instance, rule) for instance in pred_pseudo_label]
  105. )
  106. return consistent_num / len(data_samples.X)
  107. def get_rules_from_data(self, data_samples, samples_per_rule, samples_num):
  108. rules = []
  109. sampler = InfiniteSampler(len(data_samples), batch_size=samples_per_rule)
  110. for _ in range(samples_num):
  111. for select_idx in sampler:
  112. sub_data_samples = data_samples[select_idx]
  113. self.predict(sub_data_samples)
  114. pred_pseudo_label = self.idx_to_pseudo_label(sub_data_samples)
  115. consistent_instance = []
  116. for instance in pred_pseudo_label:
  117. if self.reasoner.kb.logic_forward([instance]):
  118. consistent_instance.append(instance)
  119. if len(consistent_instance) != 0:
  120. rule = self.reasoner.abduce_rules(consistent_instance)
  121. if rule is not None:
  122. rules.append(rule)
  123. break
  124. all_rule_dict = defaultdict(int)
  125. for rule in rules:
  126. for r in rule:
  127. all_rule_dict[r] += 1
  128. rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5}
  129. rules = self.select_rules(rule_dict)
  130. return rules
  131. @staticmethod
  132. def filter_empty(data_samples: ListData):
  133. consistent_dix = [
  134. i
  135. for i in range(len(data_samples.abduced_pseudo_label))
  136. if len(data_samples.abduced_pseudo_label[i]) > 0
  137. ]
  138. return data_samples[consistent_dix]
  139. @staticmethod
  140. def select_rules(rule_dict):
  141. add_nums_dict = {}
  142. for r in list(rule_dict):
  143. add_nums = str(r.split("]")[0].split("[")[1]) + str(
  144. r.split("]")[1].split("[")[1]
  145. ) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10'
  146. if add_nums in add_nums_dict:
  147. old_r = add_nums_dict[add_nums]
  148. if rule_dict[r] >= rule_dict[old_r]:
  149. rule_dict.pop(old_r)
  150. add_nums_dict[add_nums] = r
  151. else:
  152. rule_dict.pop(r)
  153. else:
  154. add_nums_dict[add_nums] = r
  155. return list(rule_dict)
  156. def data_preprocess(self, data, equation_len) -> ListData:
  157. data_samples = ListData()
  158. data_samples.X = data[equation_len] + data[equation_len + 1]
  159. data_samples.gt_pseudo_label = None
  160. data_samples.Y = [None] * len(data_samples.X)
  161. return data_samples
  162. def train(self, train_data, val_data, segment_size=10, min_len=5, max_len=8):
  163. for equation_len in range(min_len, max_len):
  164. print_log(
  165. f"============== equation_len: {equation_len}-{equation_len + 1} ================",
  166. logger="current",
  167. )
  168. condition_num = 0
  169. data_samples = self.data_preprocess(train_data[1], equation_len)
  170. sampler = InfiniteSampler(len(data_samples), batch_size=segment_size)
  171. for seg_idx, select_idx in enumerate(sampler):
  172. print_log(
  173. f"Equation Len(train) [{equation_len}] Segment Index [{seg_idx + 1}]",
  174. logger="current",
  175. )
  176. sub_data_samples = data_samples[select_idx]
  177. self.predict(sub_data_samples)
  178. if equation_len == min_len:
  179. self.select_mapping_and_abduce(sub_data_samples)
  180. else:
  181. self.idx_to_pseudo_label(sub_data_samples)
  182. self.abduce_pseudo_label(sub_data_samples)
  183. filtered_sub_data_samples = self.filter_empty(sub_data_samples)
  184. self.pseudo_label_to_idx(filtered_sub_data_samples)
  185. loss = self.model.train(filtered_sub_data_samples)
  186. if self.check_training_impact(filtered_sub_data_samples, sub_data_samples):
  187. condition_num += 1
  188. else:
  189. condition_num = 0
  190. if condition_num >= 5:
  191. print_log("Now checking if we can go to next course", logger="current")
  192. rules = self.get_rules_from_data(
  193. data_samples, samples_per_rule=3, samples_num=50
  194. )
  195. print_log("Learned rules from data: " + str(rules), logger="current")
  196. seems_good = self.check_rule_quality(rules, val_data, equation_len)
  197. if seems_good:
  198. self.model.save(save_path=f"./weights/eq_len_{equation_len}.pth")
  199. break
  200. else:
  201. if equation_len == min_len:
  202. print_log(
  203. "Learned mapping is: " + str(self.reasoner.mapping),
  204. logger="current",
  205. )
  206. self.model.load(load_path="./weights/pretrain_weights.pth")
  207. else:
  208. self.model.load(load_path=f"./weights/eq_len_{equation_len - 1}.pth")
  209. condition_num = 0
  210. print_log("Reload Model and retrain", logger="current")

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.