You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import argparse
  2. import os.path as osp
  3. import numpy as np
  4. from sklearn.ensemble import RandomForestClassifier
  5. from ablkit.bridge import SimpleBridge
  6. from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy
  7. from ablkit.learning import ABLModel
  8. from ablkit.reasoning import Reasoner
  9. from ablkit.utils import ABLLogger, avg_confidence_dist, print_log, tab_data_to_tuple
  10. from get_dataset import load_and_preprocess_dataset, split_dataset
  11. from kb import ZooKB
  12. def consitency(data_example, candidates, candidate_idxs, reasoning_results):
  13. pred_prob = data_example.pred_prob
  14. model_scores = avg_confidence_dist(pred_prob, candidate_idxs)
  15. rule_scores = np.array(reasoning_results)
  16. scores = model_scores + rule_scores
  17. return scores
  18. def main():
  19. parser = argparse.ArgumentParser(description="Zoo example")
  20. parser.add_argument(
  21. "--loops", type=int, default=3, help="number of loop iterations (default : 3)"
  22. )
  23. args = parser.parse_args()
  24. # Build logger
  25. print_log("Abductive Learning on the ZOO example.", logger="current")
  26. # -- Working with Data ------------------------------
  27. print_log("Working with Data.", logger="current")
  28. X, y = load_and_preprocess_dataset(dataset_id=62)
  29. X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
  30. label_data = tab_data_to_tuple(X_label, y_label)
  31. test_data = tab_data_to_tuple(X_test, y_test)
  32. train_data = tab_data_to_tuple(X_unlabel, y_unlabel)
  33. # -- Building the Learning Part ---------------------
  34. print_log("Building the Learning Part.", logger="current")
  35. # Build base model
  36. base_model = RandomForestClassifier()
  37. # Build ABLModel
  38. model = ABLModel(base_model)
  39. # -- Building the Reasoning Part --------------------
  40. print_log("Building the Reasoning Part.", logger="current")
  41. # Build knowledge base
  42. kb = ZooKB()
  43. # Create reasoner
  44. reasoner = Reasoner(kb, dist_func=consitency)
  45. # -- Building Evaluation Metrics --------------------
  46. print_log("Building Evaluation Metrics.", logger="current")
  47. metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]
  48. # -- Bridging Learning and Reasoning ----------------
  49. print_log("Bridge Learning and Reasoning.", logger="current")
  50. bridge = SimpleBridge(model, reasoner, metric_list)
  51. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  52. log_dir = ABLLogger.get_current_instance().log_dir
  53. weights_dir = osp.join(log_dir, "weights")
  54. # Performing training and testing
  55. print_log("------- Use labeled data to pretrain the model -----------", logger="current")
  56. base_model.fit(X_label, y_label)
  57. print_log("------- Test the initial model -----------", logger="current")
  58. bridge.test(test_data)
  59. print_log("------- Use ABL to train the model -----------", logger="current")
  60. bridge.train(
  61. train_data=train_data,
  62. label_data=label_data,
  63. loops=args.loops,
  64. segment_size=len(X_unlabel),
  65. save_dir=weights_dir,
  66. )
  67. print_log("------- Test the final model -----------", logger="current")
  68. bridge.test(test_data)
  69. if __name__ == "__main__":
  70. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.