You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

zoo.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import numpy as np
  2. from sklearn.ensemble import RandomForestClassifier
  3. from sklearn.metrics import accuracy_score
  4. from z3 import Solver, Int, If, Not, Implies, Sum, sat
  5. import openml
  6. from abl.learning import ABLModel
  7. from abl.reasoning import KBBase, Reasoner
  8. from abl.evaluation import ReasoningMetric, SymbolMetric
  9. from abl.bridge import SimpleBridge
  10. from abl.utils.utils import confidence_dist
  11. class ZooKB(KBBase):
  12. def __init__(self):
  13. super().__init__(pseudo_label_list=list(range(7)), use_cache=False)
  14. # Use z3 solver
  15. self.solver = Solver()
  16. # Load information of Zoo dataset
  17. dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)
  18. X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
  19. self.attribute_names = attribute_names
  20. self.target_names = y.cat.categories.tolist()
  21. print("Attribute names are: ", self.attribute_names)
  22. print("Target names are: ", self.target_names)
  23. # self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"]
  24. # self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"]
  25. # Define variables
  26. for name in self.attribute_names+self.target_names:
  27. exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules
  28. # Define rules
  29. rules = [
  30. Implies(milk == 1, mammal == 1),
  31. Implies(mammal == 1, milk == 1),
  32. Implies(mammal == 1, backbone == 1),
  33. Implies(mammal == 1, breathes == 1),
  34. Implies(feathers == 1, bird == 1),
  35. Implies(bird == 1, feathers == 1),
  36. Implies(bird == 1, eggs == 1),
  37. Implies(bird == 1, backbone == 1),
  38. Implies(bird == 1, breathes == 1),
  39. Implies(bird == 1, legs == 2),
  40. Implies(bird == 1, tail == 1),
  41. Implies(reptile == 1, backbone == 1),
  42. Implies(reptile == 1, breathes == 1),
  43. Implies(reptile == 1, tail == 1),
  44. Implies(fish == 1, aquatic == 1),
  45. Implies(fish == 1, toothed == 1),
  46. Implies(fish == 1, backbone == 1),
  47. Implies(fish == 1, Not(breathes == 1)),
  48. Implies(fish == 1, fins == 1),
  49. Implies(fish == 1, legs == 0),
  50. Implies(fish == 1, tail == 1),
  51. Implies(amphibian == 1, eggs == 1),
  52. Implies(amphibian == 1, aquatic == 1),
  53. Implies(amphibian == 1, backbone == 1),
  54. Implies(amphibian == 1, breathes == 1),
  55. Implies(amphibian == 1, legs == 4),
  56. Implies(insect == 1, eggs == 1),
  57. Implies(insect == 1, Not(backbone == 1)),
  58. Implies(insect == 1, legs == 6),
  59. Implies(invertebrate == 1, Not(backbone == 1))
  60. ]
  61. # Define weights and sum of violated weights
  62. self.weights = {rule: 1 for rule in rules}
  63. self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])
  64. def logic_forward(self, pseudo_label, data_point):
  65. attribute_names, target_names = self.attribute_names, self.target_names
  66. solver = self.solver
  67. total_violation_weight = self.total_violation_weight
  68. pseudo_label, data_point = pseudo_label[0], data_point[0]
  69. self.solver.reset()
  70. for name, value in zip(attribute_names, data_point):
  71. solver.add(eval(f"{name} == {value}"))
  72. for cate, name in zip(self.pseudo_label_list,target_names):
  73. value = 1 if (cate == pseudo_label) else 0
  74. solver.add(eval(f"{name} == {value}"))
  75. if solver.check() == sat:
  76. model = solver.model()
  77. total_weight = model.evaluate(total_violation_weight)
  78. # violated_rules = [str(rule) for rule in self.weights if model.evaluate(Not(rule))]
  79. # print("Total violation weight for the given data point:", total_weight)
  80. # print("Violated rules:", violated_rules)
  81. return total_weight.as_long()
  82. else:
  83. # No solution found
  84. return 1e10
  85. def consitency(data_sample, candidates, candidate_idxs, reasoning_results):
  86. pred_prob = data_sample.pred_prob
  87. model_scores = confidence_dist(pred_prob, candidate_idxs)
  88. rule_scores = np.array(reasoning_results)
  89. scores = model_scores + rule_scores
  90. return scores
  91. # Function to load and preprocess the dataset
  92. def load_and_preprocess_dataset(dataset_id):
  93. dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)
  94. X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
  95. # Convert data types
  96. for col in X.select_dtypes(include='bool').columns:
  97. X[col] = X[col].astype(int)
  98. y = y.cat.codes.astype(int)
  99. X, y = X.to_numpy(), y.to_numpy()
  100. return X, y
  101. # Function to split data (one shot)
  102. def split_dataset(X, y, test_size = 0.3):
  103. # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)
  104. label_indices, unlabel_indices, test_indices = [], [], []
  105. for class_label in np.unique(y):
  106. idxs = np.where(y == class_label)[0]
  107. np.random.shuffle(idxs)
  108. n_train_unlabel = int((1-test_size)*(len(idxs)-1))
  109. label_indices.append(idxs[0])
  110. unlabel_indices.extend(idxs[1:1+n_train_unlabel])
  111. test_indices.extend(idxs[1+n_train_unlabel:])
  112. X_label, y_label = X[label_indices], y[label_indices]
  113. X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]
  114. X_test, y_test = X[test_indices], y[test_indices]
  115. return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test
  116. if __name__ == "__main__":
  117. '''
  118. Working with data
  119. '''
  120. # Load and preprocess the Zoo dataset
  121. X, y = load_and_preprocess_dataset(dataset_id=62)
  122. print("Shape of X and y:", X.shape, y.shape)
  123. print("First five elements of X:")
  124. print(X[:5])
  125. print("First five elements of y:")
  126. print(y[:5])
  127. # Split data into labeled/unlabeled/test data
  128. X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
  129. # Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)
  130. # For tabular data in abl, each sample contains a single instance (a row from the dataset).
  131. # For these tabular data samples, the reasoning results are expected to be 0, indicating no rules are violated.
  132. def transform_tab_data(X, y):
  133. return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))
  134. label_data = transform_tab_data(X_label, y_label)
  135. test_data = transform_tab_data(X_test, y_test)
  136. train_data = transform_tab_data(X_unlabel, y_unlabel)
  137. '''
  138. Building the learning part
  139. '''
  140. rf = RandomForestClassifier()
  141. # Pre-train the machine learning model
  142. rf.fit(X_label, y_label)
  143. # # Test the initial model
  144. # y_test_pred = rf.predict(X_test)
  145. # labeled_test_acc = accuracy_score(y_test, y_test_pred)
  146. # print(labeled_test_acc)
  147. model = ABLModel(rf)
  148. '''
  149. Building the reasoning part
  150. '''
  151. # Create the knowledge base for Zoo
  152. kb = ZooKB()
  153. # # Test ZooKB
  154. # pseudo_label = [0]
  155. # data_point = [np.array([1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,1])]
  156. # print(kb.logic_forward(pseudo_label, data_point))
  157. # for x,y_item in zip(X, y):
  158. # print(x,y_item)
  159. # print(kb.logic_forward([y_item], [x]))
  160. reasoner = Reasoner(kb, dist_func=consitency)
  161. '''
  162. Building evaluation metrics
  163. '''
  164. metric_list = [SymbolMetric(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]
  165. '''
  166. Bridging Learning and Reasoning
  167. '''
  168. bridge = SimpleBridge(model, reasoner, metric_list)
  169. # Test the initial model
  170. print("------- Test the initial model -----------")
  171. bridge.test(test_data)
  172. print("------- Use ABL to train the model -----------")
  173. # Use ABL to train the model
  174. bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel))
  175. print("------- Test the final model -----------")
  176. # Test the final model
  177. bridge.test(test_data)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.