You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import argparse
  2. import os.path as osp
  3. import numpy as np
  4. import torch
  5. from torch import nn
  6. from ablkit.bridge import SimpleBridge
  7. from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy
  8. from ablkit.learning import ABLModel, BasicNN
  9. from ablkit.reasoning import GroundKB, KBBase, Reasoner
  10. from ablkit.utils import ABLLogger, print_log
  11. from datasets import get_dataset
  12. from models.nn import SymbolNet
  13. class HwfKB(KBBase):
  14. def __init__(
  15. self,
  16. pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
  17. max_err=1e-10,
  18. ):
  19. super().__init__(pseudo_label_list, max_err)
  20. def _valid_candidate(self, formula):
  21. if len(formula) % 2 == 0:
  22. return False
  23. for i in range(len(formula)):
  24. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  25. return False
  26. if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
  27. return False
  28. return True
  29. # Implement the deduction function
  30. def logic_forward(self, formula):
  31. if not self._valid_candidate(formula):
  32. return np.inf
  33. return eval("".join(formula))
  34. class HwfGroundKB(GroundKB):
  35. def __init__(
  36. self,
  37. pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
  38. GKB_len_list=[1, 3, 5, 7],
  39. max_err=1e-10,
  40. ):
  41. super().__init__(pseudo_label_list, GKB_len_list, max_err)
  42. def _valid_candidate(self, formula):
  43. if len(formula) % 2 == 0:
  44. return False
  45. for i in range(len(formula)):
  46. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  47. return False
  48. if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
  49. return False
  50. return True
  51. # Implement the deduction function
  52. def logic_forward(self, formula):
  53. if not self._valid_candidate(formula):
  54. return np.inf
  55. return eval("".join(formula))
  56. def main():
  57. parser = argparse.ArgumentParser(description="Handwritten Formula example")
  58. parser.add_argument(
  59. "--no-cuda", action="store_true", default=False, help="disables CUDA training"
  60. )
  61. parser.add_argument(
  62. "--epochs",
  63. type=int,
  64. default=3,
  65. help="number of epochs in each learning loop iteration (default : 3)",
  66. )
  67. parser.add_argument(
  68. "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
  69. )
  70. parser.add_argument(
  71. "--batch-size", type=int, default=128, help="base model batch size (default : 128)"
  72. )
  73. parser.add_argument(
  74. "--loops", type=int, default=5, help="number of loop iterations (default : 5)"
  75. )
  76. parser.add_argument(
  77. "--segment_size", type=int, default=1000, help="segment size (default : 1000)"
  78. )
  79. parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
  80. parser.add_argument(
  81. "--max-revision",
  82. type=int,
  83. default=-1,
  84. help="maximum revision in reasoner (default : -1)",
  85. )
  86. parser.add_argument(
  87. "--require-more-revision",
  88. type=int,
  89. default=0,
  90. help="require more revision in reasoner (default : 0)",
  91. )
  92. parser.add_argument(
  93. "--ground", action="store_true", default=False, help="use GroundKB (default: False)"
  94. )
  95. parser.add_argument(
  96. "--max-err",
  97. type=float,
  98. default=1e-10,
  99. help="max tolerance during abductive reasoning (default : 1e-10)",
  100. )
  101. args = parser.parse_args()
  102. # Build logger
  103. print_log("Abductive Learning on the HWF example.", logger="current")
  104. # -- Working with Data ------------------------------
  105. print_log("Working with Data.", logger="current")
  106. train_data = get_dataset(train=True, get_pseudo_label=True)
  107. test_data = get_dataset(train=False, get_pseudo_label=True)
  108. # -- Building the Learning Part ---------------------
  109. print_log("Building the Learning Part.", logger="current")
  110. # Build necessary components for BasicNN
  111. cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
  112. loss_fn = nn.CrossEntropyLoss()
  113. optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr)
  114. use_cuda = not args.no_cuda and torch.cuda.is_available()
  115. device = torch.device("cuda" if use_cuda else "cpu")
  116. # Build BasicNN
  117. base_model = BasicNN(
  118. cls,
  119. loss_fn,
  120. optimizer,
  121. device=device,
  122. batch_size=args.batch_size,
  123. num_epochs=args.epochs,
  124. )
  125. # Build ABLModel
  126. model = ABLModel(base_model)
  127. # -- Building the Reasoning Part --------------------
  128. print_log("Building the Reasoning Part.", logger="current")
  129. # Build knowledge base
  130. if args.ground:
  131. kb = HwfGroundKB()
  132. else:
  133. kb = HwfKB()
  134. # Create reasoner
  135. reasoner = Reasoner(
  136. kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
  137. )
  138. # -- Building Evaluation Metrics --------------------
  139. print_log("Building Evaluation Metrics.", logger="current")
  140. metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]
  141. # -- Bridging Learning and Reasoning ----------------
  142. print_log("Bridge Learning and Reasoning.", logger="current")
  143. bridge = SimpleBridge(model, reasoner, metric_list)
  144. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  145. log_dir = ABLLogger.get_current_instance().log_dir
  146. weights_dir = osp.join(log_dir, "weights")
  147. # Train and Test
  148. bridge.train(
  149. train_data,
  150. loops=args.loops,
  151. segment_size=args.segment_size,
  152. save_interval=args.save_interval,
  153. save_dir=weights_dir,
  154. )
  155. bridge.test(test_data)
  156. if __name__ == "__main__":
  157. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.