You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

wabl_models.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # coding: utf-8
  2. #================================================================#
  3. # Copyright (C) 2020 Freecss All rights reserved.
  4. #
  5. # File Name :models.py
  6. # Author :freecss
  7. # Email :karlfreecss@gmail.com
  8. # Created Date :2020/04/02
  9. # Description :
  10. #
  11. #================================================================#
  12. from itertools import chain
  13. from sklearn.tree import DecisionTreeClassifier
  14. from sklearn.model_selection import cross_val_score
  15. from sklearn.svm import LinearSVC
  16. from sklearn.pipeline import make_pipeline
  17. from sklearn.preprocessing import StandardScaler
  18. from sklearn.svm import SVC
  19. from sklearn.gaussian_process import GaussianProcessClassifier
  20. from sklearn.gaussian_process.kernels import RBF
  21. import pickle as pk
  22. import random
  23. from sklearn.neighbors import KNeighborsClassifier
  24. import numpy as np
  25. def get_part_data(X, i):
  26. return list(map(lambda x : x[i], X))
  27. def merge_data(X):
  28. ret_mark = list(map(lambda x : len(x), X))
  29. ret_X = list(chain(*X))
  30. return ret_X, ret_mark
  31. def reshape_data(Y, marks):
  32. begin_mark = 0
  33. ret_Y = []
  34. for mark in marks:
  35. end_mark = begin_mark + mark
  36. ret_Y.append(Y[begin_mark:end_mark])
  37. begin_mark = end_mark
  38. return ret_Y
  39. class WABLBasicModel:
  40. """
  41. label_lists 的目标在于为各个符号设置编号,无论方法是给出字典形式的概率还是给出list形式的,都可以通过这种方式解决.
  42. 后续可能会考虑更加完善的措施,降低这部分的复杂度
  43. 当模型共享的时候,label_lists 之间的元素也是共享的
  44. """
  45. def __init__(self):
  46. pass
  47. def predict(self, X):
  48. if self.share:
  49. data_X, marks = merge_data(X)
  50. prob = self.cls_list[0].predict_proba(X = data_X)
  51. cls = np.array(prob).argmax(axis = 1)
  52. prob = reshape_data(prob, marks)
  53. cls = reshape_data(cls, marks)
  54. else:
  55. cls_result = []
  56. prob_result = []
  57. for i in range(self.code_len):
  58. data_X = get_part_data(X, i)
  59. tmp_prob = self.cls_list[i].predict_proba(X = data_X)
  60. cls_result.append(np.array(tmp_prob).argmax(axis = 1))
  61. prob_result.append(tmp_prob)
  62. cls = list(zip(*cls_result))
  63. prob = list(zip(*prob_result))
  64. return {"cls" : cls, "prob" : prob}
  65. def valid(self, X, Y):
  66. if self.share:
  67. data_X, _ = merge_data(X)
  68. data_Y, _ = merge_data(Y)
  69. score = self.cls_list[0].score(X = data_X, y = data_Y)
  70. return score, [score]
  71. else:
  72. score_list = []
  73. for i in range(self.code_len):
  74. data_X = get_part_data(X, i)
  75. data_Y = get_part_data(Y, i)
  76. score_list.append(self.cls_list[i].score(data_X, data_Y))
  77. return sum(score_list) / len(score_list), score_list
  78. def train(self, X, Y):
  79. #self.label_lists = []
  80. if self.share:
  81. data_X, _ = merge_data(X)
  82. data_Y, _ = merge_data(Y)
  83. self.cls_list[0].fit(X = data_X, y = data_Y)
  84. else:
  85. for i in range(self.code_len):
  86. data_X = get_part_data(X, i)
  87. data_Y = get_part_data(Y, i)
  88. self.cls_list[i].fit(data_X, data_Y)
  89. def _set_label_lists(self, label_lists):
  90. label_lists = [sorted(list(set(label_list))) for label_list in label_lists]
  91. self.label_lists = label_lists
  92. class DecisionTree(WABLBasicModel):
  93. def __init__(self, code_len, label_lists, share = False):
  94. self.code_len = code_len
  95. self._set_label_lists(label_lists)
  96. self.cls_list = []
  97. self.share = share
  98. if share:
  99. # 本质上是同一个分类器
  100. self.cls_list.append(DecisionTreeClassifier(random_state = 0, min_samples_leaf = 3))
  101. self.cls_list = self.cls_list * self.code_len
  102. else:
  103. for _ in range(code_len):
  104. self.cls_list.append(DecisionTreeClassifier(random_state = 0, min_samples_leaf = 3))
  105. class KNN(WABLBasicModel):
  106. def __init__(self, code_len, label_lists, share = False, k = 3):
  107. self.code_len = code_len
  108. self._set_label_lists(label_lists)
  109. self.cls_list = []
  110. self.share = share
  111. if share:
  112. # 本质上是同一个分类器
  113. self.cls_list.append(KNeighborsClassifier(n_neighbors = k))
  114. self.cls_list = self.cls_list * self.code_len
  115. else:
  116. for _ in range(code_len):
  117. self.cls_list.append(KNeighborsClassifier(n_neighbors = k))
  118. class CNN(WABLBasicModel):
  119. def __init__(self, base_model, code_len, label_lists, share = True):
  120. assert share == True, "Not implemented"
  121. label_lists = [sorted(list(set(label_list))) for label_list in label_lists]
  122. self.label_lists = label_lists
  123. self.code_len = code_len
  124. self.cls_list = []
  125. self.share = share
  126. if share:
  127. self.cls_list.append(base_model)
  128. def train(self, X, Y, n_epoch = 100):
  129. #self.label_lists = []
  130. if self.share:
  131. # 因为是同一个分类器,所以只需要把数据放在一起,然后训练其中任意一个即可
  132. data_X, _ = merge_data(X)
  133. data_Y, _ = merge_data(Y)
  134. self.cls_list[0].fit(X = data_X, y = data_Y, n_epoch = n_epoch)
  135. #self.label_lists = [sorted(list(set(data_Y)))] * self.code_len
  136. else:
  137. for i in range(self.code_len):
  138. data_X = get_part_data(X, i)
  139. data_Y = get_part_data(Y, i)
  140. self.cls_list[i].fit(data_X, data_Y)
  141. #self.label_lists.append(sorted(list(set(data_Y))))
  142. if __name__ == "__main__":
  143. #data_path = "utils/hamming_data/generated_data/hamming_7_3_0.20.pk"
  144. data_path = "datasets/generated_data/0_code_7_2_0.00.pk"
  145. codes, data, labels = pk.load(open(data_path, "rb"))
  146. cls = KNN(7, False, k = 3)
  147. cls.train(data, labels)
  148. print(cls.valid(data, labels))
  149. for res in cls.predict_proba(data):
  150. print(res)
  151. break
  152. for res in cls.predict(data):
  153. print(res)
  154. break
  155. print("Trained")

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.