You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import os.path as osp
  2. import numpy as np
  3. from sklearn.ensemble import RandomForestClassifier
  4. from sklearn.metrics import accuracy_score
  5. from z3 import Solver, Int, If, Not, Implies, Sum, sat
  6. import openml
  7. from abl.learning import ABLModel
  8. from abl.reasoning import KBBase, Reasoner
  9. from abl.data.evaluation import ReasoningMetric, SymbolMetric
  10. from abl.bridge import SimpleBridge
  11. from abl.utils.utils import confidence_dist
  12. from abl.utils import ABLLogger, print_log
  13. # Build logger
  14. print_log("Abductive Learning on the Zoo example.", logger="current")
  15. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  16. log_dir = ABLLogger.get_current_instance().log_dir
  17. weights_dir = osp.join(log_dir, "weights")
  18. # Learning Part
  19. rf = RandomForestClassifier()
  20. model = ABLModel(rf)
  21. # %% [markdown]
  22. # ### Logic Part
  23. # %%
  24. class ZooKB(KBBase):
  25. def __init__(self):
  26. super().__init__(pseudo_label_list=list(range(7)), use_cache=False)
  27. # Use z3 solver
  28. self.solver = Solver()
  29. # Load information of Zoo dataset
  30. dataset = openml.datasets.get_dataset(
  31. dataset_id=62,
  32. download_data=False,
  33. download_qualities=False,
  34. download_features_meta_data=False,
  35. )
  36. X, y, categorical_indicator, attribute_names = dataset.get_data(
  37. target=dataset.default_target_attribute
  38. )
  39. self.attribute_names = attribute_names
  40. self.target_names = y.cat.categories.tolist()
  41. # Define variables
  42. for name in self.attribute_names + self.target_names:
  43. exec(
  44. f"globals()['{name}'] = Int('{name}')"
  45. ) ## or use dict to create var and modify rules
  46. # Define rules
  47. rules = [
  48. Implies(milk == 1, mammal == 1),
  49. Implies(mammal == 1, milk == 1),
  50. Implies(mammal == 1, backbone == 1),
  51. Implies(mammal == 1, breathes == 1),
  52. Implies(feathers == 1, bird == 1),
  53. Implies(bird == 1, feathers == 1),
  54. Implies(bird == 1, eggs == 1),
  55. Implies(bird == 1, backbone == 1),
  56. Implies(bird == 1, breathes == 1),
  57. Implies(bird == 1, legs == 2),
  58. Implies(bird == 1, tail == 1),
  59. Implies(reptile == 1, backbone == 1),
  60. Implies(reptile == 1, breathes == 1),
  61. Implies(reptile == 1, tail == 1),
  62. Implies(fish == 1, aquatic == 1),
  63. Implies(fish == 1, toothed == 1),
  64. Implies(fish == 1, backbone == 1),
  65. Implies(fish == 1, Not(breathes == 1)),
  66. Implies(fish == 1, fins == 1),
  67. Implies(fish == 1, legs == 0),
  68. Implies(fish == 1, tail == 1),
  69. Implies(amphibian == 1, eggs == 1),
  70. Implies(amphibian == 1, aquatic == 1),
  71. Implies(amphibian == 1, backbone == 1),
  72. Implies(amphibian == 1, breathes == 1),
  73. Implies(amphibian == 1, legs == 4),
  74. Implies(insect == 1, eggs == 1),
  75. Implies(insect == 1, Not(backbone == 1)),
  76. Implies(insect == 1, legs == 6),
  77. Implies(invertebrate == 1, Not(backbone == 1)),
  78. ]
  79. # Define weights and sum of violated weights
  80. self.weights = {rule: 1 for rule in rules}
  81. self.total_violation_weight = Sum(
  82. [If(Not(rule), self.weights[rule], 0) for rule in self.weights]
  83. )
  84. def logic_forward(self, pseudo_label, data_point):
  85. attribute_names, target_names = self.attribute_names, self.target_names
  86. solver = self.solver
  87. total_violation_weight = self.total_violation_weight
  88. pseudo_label, data_point = pseudo_label[0], data_point[0]
  89. self.solver.reset()
  90. for name, value in zip(attribute_names, data_point):
  91. solver.add(eval(f"{name} == {value}"))
  92. for cate, name in zip(self.pseudo_label_list, target_names):
  93. value = 1 if (cate == pseudo_label) else 0
  94. solver.add(eval(f"{name} == {value}"))
  95. if solver.check() == sat:
  96. model = solver.model()
  97. total_weight = model.evaluate(total_violation_weight)
  98. return total_weight.as_long()
  99. else:
  100. # No solution found
  101. return 1e10
  102. def consitency(data_example, candidates, candidate_idxs, reasoning_results):
  103. pred_prob = data_example.pred_prob
  104. model_scores = confidence_dist(pred_prob, candidate_idxs)
  105. rule_scores = np.array(reasoning_results)
  106. scores = model_scores + rule_scores
  107. return scores
  108. kb = ZooKB()
  109. reasoner = Reasoner(kb, dist_func=consitency)
  110. # %% [markdown]
  111. # ### Datasets and Evaluation Metrics
  112. # %%
  113. # Function to load and preprocess the dataset
  114. def load_and_preprocess_dataset(dataset_id):
  115. dataset = openml.datasets.get_dataset(
  116. dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False
  117. )
  118. X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
  119. # Convert data types
  120. for col in X.select_dtypes(include="bool").columns:
  121. X[col] = X[col].astype(int)
  122. y = y.cat.codes.astype(int)
  123. X, y = X.to_numpy(), y.to_numpy()
  124. return X, y
  125. # Function to split data (one shot)
  126. def split_dataset(X, y, test_size=0.3):
  127. # For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)
  128. label_indices, unlabel_indices, test_indices = [], [], []
  129. for class_label in np.unique(y):
  130. idxs = np.where(y == class_label)[0]
  131. np.random.shuffle(idxs)
  132. n_train_unlabel = int((1 - test_size) * (len(idxs) - 1))
  133. label_indices.append(idxs[0])
  134. unlabel_indices.extend(idxs[1 : 1 + n_train_unlabel])
  135. test_indices.extend(idxs[1 + n_train_unlabel :])
  136. X_label, y_label = X[label_indices], y[label_indices]
  137. X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]
  138. X_test, y_test = X[test_indices], y[test_indices]
  139. return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test
  140. # Load and preprocess the Zoo dataset
  141. X, y = load_and_preprocess_dataset(dataset_id=62)
  142. # Split data into labeled/unlabeled/test data
  143. X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
  144. # Transform tabluar data to the format required by ABL, which is a tuple of (X, ground truth of X, reasoning results)
  145. # For tabular data in abl, each example contains a single instance (a row from the dataset).
  146. # For these tabular data examples, the reasoning results are expected to be 0, indicating no rules are violated.
  147. def transform_tab_data(X, y):
  148. return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))
  149. label_data = transform_tab_data(X_label, y_label)
  150. test_data = transform_tab_data(X_test, y_test)
  151. train_data = transform_tab_data(X_unlabel, y_unlabel)
  152. # %%
  153. # Set up metrics
  154. metric_list = [SymbolMetric(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]
  155. # %% [markdown]
  156. # ### Bridge Machine Learning and Logic Reasoning
  157. # %%
  158. bridge = SimpleBridge(model, reasoner, metric_list)
  159. # %% [markdown]
  160. # ### Train and Test
  161. # %%
  162. # Pre-train the machine learning model
  163. rf.fit(X_label, y_label)
  164. # %%
  165. # Test the initial model
  166. print("------- Test the initial model -----------")
  167. bridge.test(test_data)
  168. print("------- Use ABL to train the model -----------")
  169. # Use ABL to train the model
  170. bridge.train(
  171. train_data=train_data,
  172. label_data=label_data,
  173. loops=3,
  174. segment_size=len(X_unlabel),
  175. save_dir=weights_dir,
  176. )
  177. print("------- Test the final model -----------")
  178. # Test the final model
  179. bridge.test(test_data)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.