You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import argparse
  2. import os.path as osp
  3. import numpy as np
  4. import torch
  5. from torch import nn
  6. from abl.bridge import SimpleBridge
  7. from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
  8. from abl.learning import ABLModel, BasicNN
  9. from abl.reasoning import GroundKB, KBBase, Reasoner
  10. from abl.utils import ABLLogger, print_log
  11. from datasets import get_dataset
  12. from models.nn import SymbolNet
  13. class HwfKB(KBBase):
  14. def __init__(
  15. self,
  16. pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
  17. max_err=1e-10,
  18. ):
  19. super().__init__(pseudo_label_list, max_err)
  20. def _valid_candidate(self, formula):
  21. if len(formula) % 2 == 0:
  22. return False
  23. for i in range(len(formula)):
  24. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  25. return False
  26. if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
  27. return False
  28. return True
  29. # Implement the deduction function
  30. def logic_forward(self, formula):
  31. if not self._valid_candidate(formula):
  32. return np.inf
  33. return eval("".join(formula))
  34. class HwfGroundKB(GroundKB):
  35. def __init__(
  36. self,
  37. pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
  38. GKB_len_list=[1, 3, 5, 7],
  39. max_err=1e-10,
  40. ):
  41. super().__init__(pseudo_label_list, GKB_len_list, max_err)
  42. def _valid_candidate(self, formula):
  43. if len(formula) % 2 == 0:
  44. return False
  45. for i in range(len(formula)):
  46. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  47. return False
  48. if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
  49. return False
  50. return True
  51. # Implement the deduction function
  52. def logic_forward(self, formula):
  53. if not self._valid_candidate(formula):
  54. return np.inf
  55. return eval("".join(formula))
  56. def main():
  57. parser = argparse.ArgumentParser(description="Handwritten Formula example")
  58. parser.add_argument(
  59. "--no-cuda", action="store_true", default=False, help="disables CUDA training"
  60. )
  61. parser.add_argument(
  62. "--epochs",
  63. type=int,
  64. default=3,
  65. help="number of epochs in each learning loop iteration (default : 3)",
  66. )
  67. parser.add_argument(
  68. "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
  69. )
  70. parser.add_argument(
  71. "--batch-size", type=int, default=128, help="base model batch size (default : 128)"
  72. )
  73. parser.add_argument(
  74. "--loops", type=int, default=5, help="number of loop iterations (default : 5)"
  75. )
  76. parser.add_argument(
  77. "--segment_size", type=int or float, default=1000, help="segment size (default : 1000)"
  78. )
  79. parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
  80. parser.add_argument(
  81. "--max-revision",
  82. type=int or float,
  83. default=-1,
  84. help="maximum revision in reasoner (default : -1)",
  85. )
  86. parser.add_argument(
  87. "--require-more-revision",
  88. type=int,
  89. default=5,
  90. help="require more revision in reasoner (default : 0)",
  91. )
  92. parser.add_argument(
  93. "--ground", action="store_true", default=False, help="use GroundKB (default: False)"
  94. )
  95. parser.add_argument(
  96. "--max-err",
  97. type=float,
  98. default=1e-10,
  99. help="max tolerance during abductive reasoning (default : 1e-10)",
  100. )
  101. args = parser.parse_args()
  102. ### Working with Data
  103. train_data = get_dataset(train=True, get_pseudo_label=True)
  104. test_data = get_dataset(train=False, get_pseudo_label=True)
  105. ### Building the Learning Part
  106. # Build necessary components for BasicNN
  107. cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
  108. loss_fn = nn.CrossEntropyLoss()
  109. optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr)
  110. use_cuda = not args.no_cuda and torch.cuda.is_available()
  111. device = torch.device("cuda" if use_cuda else "cpu")
  112. # Build BasicNN
  113. base_model = BasicNN(
  114. cls,
  115. loss_fn,
  116. optimizer,
  117. device=device,
  118. batch_size=args.batch_size,
  119. num_epochs=args.epochs,
  120. )
  121. # Build ABLModel
  122. model = ABLModel(base_model)
  123. ### Building the Reasoning Part
  124. # Build knowledge base
  125. if args.ground:
  126. kb = HwfGroundKB()
  127. else:
  128. kb = HwfKB()
  129. # Create reasoner
  130. reasoner = Reasoner(
  131. kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
  132. )
  133. ### Building Evaluation Metrics
  134. metric_list = [SymbolAccuracy(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]
  135. ### Bridge Learning and Reasoning
  136. bridge = SimpleBridge(model, reasoner, metric_list)
  137. # Build logger
  138. print_log("Abductive Learning on the HWF example.", logger="current")
  139. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  140. log_dir = ABLLogger.get_current_instance().log_dir
  141. weights_dir = osp.join(log_dir, "weights")
  142. # Train and Test
  143. bridge.train(
  144. train_data,
  145. loops=args.loops,
  146. segment_size=args.segment_size,
  147. save_interval=args.save_interval,
  148. save_dir=weights_dir,
  149. )
  150. bridge.test(test_data)
  151. if __name__ == "__main__":
  152. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.