You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework_hed.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # coding: utf-8
  2. # ================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :framework.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/07
  9. # Description :
  10. #
  11. # ================================================================#
  12. import torch
  13. import torch.nn as nn
  14. import numpy as np
  15. import os
  16. from abl.utils.plog import INFO
  17. from abl.utils.utils import flatten, reform_idx
  18. from abl.learning.basic_nn import BasicNN, BasicDataset
  19. from utils import gen_mappings, mapping_res, remapping_res
  20. from models.nn import SymbolNetAutoencoder
  21. from torch.utils.data import RandomSampler
  22. from datasets.get_hed import get_pretrain_data
  23. def hed_pretrain(kb, cls, recorder):
  24. cls_autoencoder = SymbolNetAutoencoder(num_classes=len(kb.pseudo_label_list))
  25. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  26. if not os.path.exists("./weights/pretrain_weights.pth"):
  27. INFO("Pretrain Start")
  28. pretrain_data_X, pretrain_data_Y = get_pretrain_data(["0", "1", "10", "11"])
  29. pretrain_data = BasicDataset(pretrain_data_X, pretrain_data_Y)
  30. pretrain_data_loader = torch.utils.data.DataLoader(
  31. pretrain_data, batch_size=64, shuffle=True
  32. )
  33. criterion = nn.MSELoss()
  34. optimizer = torch.optim.RMSprop(
  35. cls_autoencoder.parameters(), lr=0.001, alpha=0.9, weight_decay=1e-6
  36. )
  37. pretrain_model = BasicNN(
  38. cls_autoencoder,
  39. criterion,
  40. optimizer,
  41. device,
  42. save_interval=1,
  43. save_dir=recorder.save_dir,
  44. num_epochs=10,
  45. recorder=recorder,
  46. )
  47. pretrain_model.fit(pretrain_data_loader)
  48. torch.save(
  49. cls_autoencoder.base_model.state_dict(), "./weights/pretrain_weights.pth"
  50. )
  51. cls.load_state_dict(cls_autoencoder.base_model.state_dict())
  52. else:
  53. cls.load_state_dict(torch.load("./weights/pretrain_weights.pth"))
  54. def _get_char_acc(model, X, consistent_pred_res, mapping):
  55. original_pred_res = model.predict(X)["label"]
  56. pred_res = flatten(mapping_res(original_pred_res, mapping))
  57. INFO("Current model's output: ", pred_res)
  58. INFO("Abduced labels: ", flatten(consistent_pred_res))
  59. assert len(pred_res) == len(flatten(consistent_pred_res))
  60. return sum(
  61. [
  62. pred_res[idx] == flatten(consistent_pred_res)[idx]
  63. for idx in range(len(pred_res))
  64. ]
  65. ) / len(pred_res)
  66. def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
  67. select_idx = RandomSampler(train_X_true, num_samples=select_num,replacement=False)
  68. X = [train_X_true[idx] for idx in select_idx]
  69. # original_pred_res = model.predict(X)['label']
  70. pred_label = model.predict(X)["label"]
  71. if mapping == None:
  72. mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1])
  73. else:
  74. mappings = [mapping]
  75. consistent_idx = []
  76. consistent_pred_res = []
  77. for m in mappings:
  78. pred_pseudo_label = mapping_res(pred_label, m)
  79. max_revision_num = 20
  80. solution = abducer.zoopt_get_solution(
  81. pred_label,
  82. pred_pseudo_label,
  83. [None] * len(pred_label),
  84. [None] * len(pred_label),
  85. max_revision_num,
  86. )
  87. all_address_flag = reform_idx(solution, pred_label)
  88. consistent_idx_tmp = []
  89. consistent_pred_res_tmp = []
  90. for idx in range(len(pred_label)):
  91. address_idx = [
  92. i for i, flag in enumerate(all_address_flag[idx]) if flag != 0
  93. ]
  94. candidate = abducer.revise_by_idx([pred_pseudo_label[idx]], None, address_idx)
  95. if len(candidate) > 0:
  96. consistent_idx_tmp.append(idx)
  97. consistent_pred_res_tmp.append(candidate[0][0])
  98. if len(consistent_idx_tmp) > len(consistent_idx):
  99. consistent_idx = consistent_idx_tmp
  100. consistent_pred_res = consistent_pred_res_tmp
  101. if len(mappings) > 1:
  102. mapping = m
  103. if len(consistent_idx) == 0:
  104. return 0, 0, None
  105. INFO("Train pool size is:", len(flatten(consistent_pred_res)))
  106. INFO("Start to use abduced pseudo label to train model...")
  107. model.train(
  108. [X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping)
  109. )
  110. consistent_acc = len(consistent_idx) / select_num
  111. char_acc = _get_char_acc(
  112. model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping
  113. )
  114. INFO("consistent_acc is %s, char_acc is %s" % (consistent_acc, char_acc))
  115. return consistent_acc, char_acc, mapping
  116. # def abduce_and_train(model, abducer, mapping, train_X_true, select_num):
  117. # select_idx = np.random.randint(len(train_X_true), size=select_num)
  118. # X = []
  119. # for idx in select_idx:
  120. # X.append(train_X_true[idx])
  121. # original_pred_res = model.predict(X)['label']
  122. # if mapping == None:
  123. # mappings = gen_mappings([0, 1, 2, 3],['+', '=', 0, 1])
  124. # else:
  125. # mappings = [mapping]
  126. # consistent_idx = []
  127. # consistent_pred_res = []
  128. # for m in mappings:
  129. # pred_res = mapping_res(original_pred_res, m)
  130. # max_abduce_num = 20
  131. # solution = abducer.zoopt_get_solution(pred_res, [None] * len(pred_res), [None] * len(pred_res), max_abduce_num)
  132. # all_address_flag = reform_idx(solution, pred_res)
  133. # consistent_idx_tmp = []
  134. # consistent_pred_res_tmp = []
  135. # for idx in range(len(pred_res)):
  136. # address_idx = [i for i, flag in enumerate(all_address_flag[idx]) if flag != 0]
  137. # candidate = abducer.revise_by_idx([pred_res[idx]], None, address_idx)
  138. # if len(candidate) > 0:
  139. # consistent_idx_tmp.append(idx)
  140. # consistent_pred_res_tmp.append(candidate[0][0])
  141. # if len(consistent_idx_tmp) > len(consistent_idx):
  142. # consistent_idx = consistent_idx_tmp
  143. # consistent_pred_res = consistent_pred_res_tmp
  144. # if len(mappings) > 1:
  145. # mapping = m
  146. # if len(consistent_idx) == 0:
  147. # return 0, 0, None
  148. # INFO('Train pool size is:', len(flatten(consistent_pred_res)))
  149. # INFO("Start to use abduced pseudo label to train model...")
  150. # model.train([X[idx] for idx in consistent_idx], remapping_res(consistent_pred_res, mapping))
  151. # consistent_acc = len(consistent_idx) / select_num
  152. # char_acc = _get_char_acc(model, [X[idx] for idx in consistent_idx], consistent_pred_res, mapping)
  153. # INFO('consistent_acc is %s, char_acc is %s' % (consistent_acc, char_acc))
  154. # return consistent_acc, char_acc, mapping
  155. def _remove_duplicate_rule(rule_dict):
  156. add_nums_dict = {}
  157. for r in list(rule_dict):
  158. add_nums = str(r.split("]")[0].split("[")[1]) + str(
  159. r.split("]")[1].split("[")[1]
  160. ) # r = 'my_op([1], [0], [1, 0])' then add_nums = '10'
  161. if add_nums in add_nums_dict:
  162. old_r = add_nums_dict[add_nums]
  163. if rule_dict[r] >= rule_dict[old_r]:
  164. rule_dict.pop(old_r)
  165. add_nums_dict[add_nums] = r
  166. else:
  167. rule_dict.pop(r)
  168. else:
  169. add_nums_dict[add_nums] = r
  170. return list(rule_dict)
  171. def get_rules_from_data(
  172. model, abducer, mapping, train_X_true, samples_per_rule, samples_num
  173. ):
  174. rules = []
  175. for _ in range(samples_num):
  176. while True:
  177. select_idx = np.random.randint(len(train_X_true), size=samples_per_rule)
  178. X = []
  179. for idx in select_idx:
  180. X.append(train_X_true[idx])
  181. original_pred_res = model.predict(X)["label"]
  182. pred_res = mapping_res(original_pred_res, mapping)
  183. consistent_idx = []
  184. consistent_pred_res = []
  185. for idx in range(len(pred_res)):
  186. if abducer.kb.logic_forward([pred_res[idx]]):
  187. consistent_idx.append(idx)
  188. consistent_pred_res.append(pred_res[idx])
  189. if len(consistent_pred_res) != 0:
  190. rule = abducer.abduce_rules(consistent_pred_res)
  191. if rule != None:
  192. break
  193. rules.append(rule)
  194. all_rule_dict = {}
  195. for rule in rules:
  196. for r in rule:
  197. all_rule_dict[r] = 1 if r not in all_rule_dict else all_rule_dict[r] + 1
  198. rule_dict = {rule: cnt for rule, cnt in all_rule_dict.items() if cnt >= 5}
  199. rules = _remove_duplicate_rule(rule_dict)
  200. return rules
  201. def _get_consist_rule_acc(model, abducer, mapping, rules, X):
  202. cnt = 0
  203. for x in X:
  204. original_pred_res = model.predict([x])["label"]
  205. pred_res = flatten(mapping_res(original_pred_res, mapping))
  206. if abducer.kb.consist_rule(pred_res, rules):
  207. cnt += 1
  208. return cnt / len(X)
  209. def train_with_rule(
  210. model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8
  211. ):
  212. train_X = train_data
  213. val_X = val_data
  214. samples_num = 50
  215. samples_per_rule = 3
  216. # Start training / for each length of equations
  217. for equation_len in range(min_len, max_len):
  218. INFO(
  219. "============== equation_len: %d-%d ================"
  220. % (equation_len, equation_len + 1)
  221. )
  222. train_X_true = train_X[1][equation_len]
  223. train_X_false = train_X[0][equation_len]
  224. val_X_true = val_X[1][equation_len]
  225. val_X_false = val_X[0][equation_len]
  226. train_X_true.extend(train_X[1][equation_len + 1])
  227. train_X_false.extend(train_X[0][equation_len + 1])
  228. val_X_true.extend(val_X[1][equation_len + 1])
  229. val_X_false.extend(val_X[0][equation_len + 1])
  230. condition_cnt = 0
  231. while True:
  232. if equation_len == min_len:
  233. mapping = None
  234. # Abduce and train NN
  235. consistent_acc, char_acc, mapping = abduce_and_train(
  236. model, abducer, mapping, train_X_true, select_num
  237. )
  238. if consistent_acc == 0:
  239. continue
  240. # Test if we can use mlp to evaluate
  241. if consistent_acc >= 0.9 and char_acc >= 0.9:
  242. condition_cnt += 1
  243. else:
  244. condition_cnt = 0
  245. # The condition has been satisfied continuously five times
  246. if condition_cnt >= 5:
  247. INFO("Now checking if we can go to next course")
  248. rules = get_rules_from_data(
  249. model, abducer, mapping, train_X_true, samples_per_rule, samples_num
  250. )
  251. INFO("Learned rules from data:", rules)
  252. true_consist_rule_acc = _get_consist_rule_acc(
  253. model, abducer, mapping, rules, val_X_true
  254. )
  255. false_consist_rule_acc = _get_consist_rule_acc(
  256. model, abducer, mapping, rules, val_X_false
  257. )
  258. INFO(
  259. "consist_rule_acc is %f, %f\n"
  260. % (true_consist_rule_acc, false_consist_rule_acc)
  261. )
  262. # decide next course or restart
  263. if true_consist_rule_acc > 0.95 and false_consist_rule_acc < 0.1:
  264. torch.save(
  265. model.classifier_list[0].model.state_dict(),
  266. "./weights/weights_%d.pth" % equation_len,
  267. )
  268. break
  269. else:
  270. if equation_len == min_len:
  271. INFO("Final mapping is: ", mapping)
  272. model.classifier_list[0].model.load_state_dict(
  273. torch.load("./weights/pretrain_weights.pth")
  274. )
  275. else:
  276. model.classifier_list[0].model.load_state_dict(
  277. torch.load("./weights/weights_%d.pth" % (equation_len - 1))
  278. )
  279. condition_cnt = 0
  280. INFO("Reload Model and retrain")
  281. return model, mapping
  282. def hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8):
  283. train_X = train_data
  284. test_X = test_data
  285. # Calcualte how many equations should be selected in each length
  286. # for each length, there are equation_samples_num[equation_len] rules
  287. print("Now begin to train final mlp model")
  288. equation_samples_num = []
  289. len_cnt = max_len - min_len + 1
  290. samples_num = 50
  291. equation_samples_num += [0] * min_len
  292. if samples_num % len_cnt == 0:
  293. equation_samples_num += [samples_num // len_cnt] * len_cnt
  294. else:
  295. equation_samples_num += [samples_num // len_cnt] * len_cnt
  296. equation_samples_num[-1] += samples_num % len_cnt
  297. assert sum(equation_samples_num) == samples_num
  298. # Abduce rules
  299. rules = []
  300. samples_per_rule = 3
  301. for equation_len in range(min_len, max_len + 1):
  302. equation_rules = get_rules_from_data(
  303. model,
  304. abducer,
  305. mapping,
  306. train_X[1][equation_len],
  307. samples_per_rule,
  308. equation_samples_num[equation_len],
  309. )
  310. rules.extend(equation_rules)
  311. rules = list(set(rules))
  312. INFO("Learned rules from data:", rules)
  313. for equation_len in range(5, 27):
  314. true_consist_rule_acc = _get_consist_rule_acc(
  315. model, abducer, mapping, rules, test_X[1][equation_len]
  316. )
  317. false_consist_rule_acc = _get_consist_rule_acc(
  318. model, abducer, mapping, rules, test_X[0][equation_len]
  319. )
  320. INFO(
  321. "consist_rule_acc of testing length %d equations are %f, %f"
  322. % (equation_len, true_consist_rule_acc, false_consist_rule_acc)
  323. )
  324. if __name__ == "__main__":
  325. pass

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.