You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import argparse
  2. import os.path as osp
  3. import numpy as np
  4. from sklearn.ensemble import RandomForestClassifier
  5. from abl.bridge import SimpleBridge
  6. from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
  7. from abl.learning import ABLModel
  8. from abl.reasoning import Reasoner
  9. from abl.utils import ABLLogger, confidence_dist, print_log
  10. from get_dataset import load_and_preprocess_dataset, split_dataset
  11. from kb import ZooKB
  12. def transform_tab_data(X, y):
  13. return ([[x] for x in X], [[y_item] for y_item in y], [0] * len(y))
  14. def consitency(data_example, candidates, candidate_idxs, reasoning_results):
  15. pred_prob = data_example.pred_prob
  16. model_scores = confidence_dist(pred_prob, candidate_idxs)
  17. rule_scores = np.array(reasoning_results)
  18. scores = model_scores + rule_scores
  19. return scores
  20. def main():
  21. parser = argparse.ArgumentParser(description="Zoo example")
  22. parser.add_argument(
  23. "--loops", type=int, default=3, help="number of loop iterations (default : 3)"
  24. )
  25. args = parser.parse_args()
  26. ### Working with Data
  27. X, y = load_and_preprocess_dataset(dataset_id=62)
  28. X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
  29. label_data = transform_tab_data(X_label, y_label)
  30. test_data = transform_tab_data(X_test, y_test)
  31. train_data = transform_tab_data(X_unlabel, y_unlabel)
  32. ### Building the Learning Part
  33. base_model = RandomForestClassifier()
  34. # Build ABLModel
  35. model = ABLModel(base_model)
  36. ### Building the Reasoning Part
  37. # Build knowledge base
  38. kb = ZooKB()
  39. # Create reasoner
  40. reasoner = Reasoner(kb, dist_func=consitency)
  41. ### Building Evaluation Metrics
  42. metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]
  43. # Build logger
  44. print_log("Abductive Learning on the ZOO example.", logger="current")
  45. log_dir = ABLLogger.get_current_instance().log_dir
  46. weights_dir = osp.join(log_dir, "weights")
  47. ### Bridging learning and reasoning
  48. bridge = SimpleBridge(model, reasoner, metric_list)
  49. # Performing training and testing
  50. print_log("------- Use labeled data to pretrain the model -----------", logger="current")
  51. base_model.fit(X_label, y_label)
  52. print_log("------- Test the initial model -----------", logger="current")
  53. bridge.test(test_data)
  54. print_log("------- Use ABL to train the model -----------", logger="current")
  55. bridge.train(train_data=train_data, label_data=label_data, loops=args.loops, segment_size=len(X_unlabel), save_dir=weights_dir)
  56. print_log("------- Test the final model -----------", logger="current")
  57. bridge.test(test_data)
  58. if __name__ == "__main__":
  59. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.