You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

abducer_base.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2021 Freecss All rights reserved.
  4. #
  5. # File Name :abducer_base.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2021/06/03
  9. # Description :
  10. #
  11. #================================================================#
  12. import abc
  13. from abducer.kb import ClsKB, RegKB
  14. #from kb import ClsKB, RegKB
  15. import numpy as np
  16. def hamming_dist(A, B):
  17. B = np.array(B)
  18. A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
  19. return np.sum(A != B, axis = 1)
  20. def confidence_dist(A, B):
  21. B = np.array(B)
  22. #print(A)
  23. A = np.clip(A, 1e-9, 1)
  24. A = np.expand_dims(A, axis=0)
  25. A = A.repeat(axis=0, repeats=(len(B)))
  26. rows = np.array(range(len(B)))
  27. rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
  28. cols = np.array(range(len(B[0])))
  29. cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
  30. return 1 - np.prod(A[rows, cols, B], axis = 1)
  31. class AbducerBase(abc.ABC):
  32. def __init__(self, kb, dist_func = "hamming", pred_res_parse = None):
  33. self.kb = kb
  34. if dist_func == "hamming":
  35. dist_func = hamming_dist
  36. elif dist_func == "confidence":
  37. dist_func = confidence_dist
  38. self.dist_func = dist_func
  39. if pred_res_parse is None:
  40. pred_res_parse = lambda x : x["cls"]
  41. self.pred_res_parse = pred_res_parse
  42. def abduce(self, data, max_address_num, require_more_address, length = -1):
  43. pred_res, ans = data
  44. if length == -1:
  45. length = len(pred_res)
  46. candidates = self.kb.get_candidates(ans, length)
  47. pred_res = np.array(pred_res)
  48. cost_list = self.dist_func(pred_res, candidates)
  49. address_num = np.min(cost_list)
  50. threshold = min(address_num + require_more_address, max_address_num)
  51. idxs = np.where(cost_list <= address_num+require_more_address)[0]
  52. #return [candidates[idx] for idx in idxs], address_num
  53. if len(idxs) > 1:
  54. return None
  55. return [candidates[idx] for idx in idxs][0]
  56. def batch_abduce(self, Y, C, max_address_num = 3, require_more_address = 0):
  57. return [
  58. self.abduce((y, c), max_address_num, require_more_address)\
  59. for y, c in zip(self.pred_res_parse(Y), C)
  60. ]
  61. def __call__(self, Y, C, max_address_num = 3, require_more_address = 0):
  62. return batch_abduce(Y, C, max_address_num, require_more_address)
  63. if __name__ == "__main__":
  64. #["1+1", "0+1", "1+0", "2+0"]
  65. X = [[1,3,1], [0,3,1], [1,2,0], [3,2,0]]
  66. Y = [2, 1, 1, 2]
  67. kb = RegKB(X, Y)
  68. abd = AbducerBase(kb)
  69. res = abd.abduce(([0,2,0], None), 1, 0)
  70. print(res)
  71. res = abd.abduce(([0, 2, 0], 0.99), 1, 0)
  72. print(res)
  73. A = np.array([[0.5, 0.25, 0.25, 0], [0.3, 0.3, 0.3, 0.1], [0.1, 0.2, 0.3, 0.4]])
  74. B = [[1, 2, 3], [0, 1, 3]]
  75. res = confidence_dist(A, B)
  76. print(res)
  77. A = np.array([[0.5, 0.25, 0.25, 0], [0.3, 1.0, 0.3, 0.1], [0.1, 0.2, 0.3, 1.0]])
  78. B = [[0, 1, 3]]
  79. res = confidence_dist(A, B)
  80. print(res)
  81. kb_str = ['10010001011', '00010001100', '00111101011', '11101000011', '11110011001', '11111010001', '10001010010', '11100100001', '10001001100', '11011010001', '00110000100', '11000000111', '01110111111', '11000101100', '10101011010', '00000110110', '11111110010', '11100101100', '10111001111', '10000101100', '01001011101', '01001110000', '01110001110', '01010010001', '10000100010', '01001011011', '11111111100', '01011101101', '00101110101', '11101001101', '10010110000', '10000000011']
  82. X = [[int(c) for c in s] for s in kb_str]
  83. kb = RegKB(X, len(X) * [None])
  84. abd = AbducerBase(kb)
  85. res = abd.abduce(((1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1), None), 1, 0)
  86. print(res)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.