You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import os
  2. import os.path as osp
  3. import argparse
  4. import numpy as np
  5. import torch
  6. from torch import nn
  7. from examples.hwf.datasets import get_dataset
  8. from examples.models.nn import SymbolNet
  9. from abl.learning import ABLModel, BasicNN
  10. from abl.reasoning import KBBase, GroundKB, Reasoner
  11. from abl.data.evaluation import ReasoningMetric, SymbolMetric
  12. from abl.utils import ABLLogger, print_log
  13. from abl.bridge import SimpleBridge
  14. class HwfKB(KBBase):
  15. def __init__(
  16. self,
  17. pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
  18. max_err=1e-10,
  19. ):
  20. super().__init__(pseudo_label_list, max_err)
  21. def _valid_candidate(self, formula):
  22. if len(formula) % 2 == 0:
  23. return False
  24. for i in range(len(formula)):
  25. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  26. return False
  27. if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
  28. return False
  29. return True
  30. # Implement the deduction function
  31. def logic_forward(self, formula):
  32. if not self._valid_candidate(formula):
  33. return np.inf
  34. return eval("".join(formula))
  35. class HwfGroundKB(GroundKB):
  36. def __init__(
  37. self,
  38. pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
  39. GKB_len_list=[1, 3, 5, 7],
  40. max_err=1e-10,
  41. ):
  42. super().__init__(pseudo_label_list, GKB_len_list, max_err)
  43. def _valid_candidate(self, formula):
  44. if len(formula) % 2 == 0:
  45. return False
  46. for i in range(len(formula)):
  47. if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
  48. return False
  49. if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
  50. return False
  51. return True
  52. # Implement the deduction function
  53. def logic_forward(self, formula):
  54. if not self._valid_candidate(formula):
  55. return np.inf
  56. return eval("".join(formula))
  57. def main():
  58. parser = argparse.ArgumentParser(description="MNIST Addition example")
  59. parser.add_argument(
  60. "--no-cuda", action="store_true", default=False, help="disables CUDA training"
  61. )
  62. parser.add_argument(
  63. "--epochs",
  64. type=int,
  65. default=3,
  66. help="number of epochs in each learning loop iteration (default : 3)",
  67. )
  68. parser.add_argument(
  69. "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
  70. )
  71. parser.add_argument(
  72. "--batch-size", type=int, default=128, help="base model batch size (default : 128)"
  73. )
  74. parser.add_argument(
  75. "--loops", type=int, default=5, help="number of loop iterations (default : 5)"
  76. )
  77. parser.add_argument(
  78. "--segment_size", type=int or float, default=1000, help="segment size (default : 1000)"
  79. )
  80. parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
  81. parser.add_argument(
  82. "--max-revision",
  83. type=int or float,
  84. default=-1,
  85. help="maximum revision in reasoner (default : -1)",
  86. )
  87. parser.add_argument(
  88. "--require-more-revision",
  89. type=int,
  90. default=5,
  91. help="require more revision in reasoner (default : 0)",
  92. )
  93. parser.add_argument(
  94. "--ground", action="store_true", default=False, help="use GroundKB (default: False)"
  95. )
  96. parser.add_argument(
  97. "--max-err",
  98. type=float,
  99. default=1e-10,
  100. help="max tolerance during abductive reasoning (default : 1e-10)",
  101. )
  102. args = parser.parse_args()
  103. ### Working with Data
  104. train_data = get_dataset(train=True, get_pseudo_label=True)
  105. test_data = get_dataset(train=False, get_pseudo_label=True)
  106. ### Building the Learning Part
  107. # Build necessary components for BasicNN
  108. cls = SymbolNet(num_classes=14, image_size=(45, 45, 1))
  109. loss_fn = nn.CrossEntropyLoss()
  110. optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr)
  111. use_cuda = not args.no_cuda and torch.cuda.is_available()
  112. device = torch.device("cuda" if use_cuda else "cpu")
  113. # Build BasicNN
  114. base_model = BasicNN(
  115. cls,
  116. loss_fn,
  117. optimizer,
  118. device=device,
  119. batch_size=args.batch_size,
  120. num_epochs=args.epochs,
  121. )
  122. # Build ABLModel
  123. model = ABLModel(base_model)
  124. ### Building the Reasoning Part
  125. # Build knowledge base
  126. if args.ground:
  127. kb = HwfGroundKB()
  128. else:
  129. kb = HwfKB()
  130. # Create reasoner
  131. reasoner = Reasoner(
  132. kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
  133. )
  134. ### Building Evaluation Metrics
  135. metric_list = [SymbolMetric(prefix="hwf"), ReasoningMetric(kb=kb, prefix="hwf")]
  136. ### Bridge Learning and Reasoning
  137. bridge = SimpleBridge(model, reasoner, metric_list)
  138. # Build logger
  139. print_log("Abductive Learning on the HWF example.", logger="current")
  140. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  141. log_dir = ABLLogger.get_current_instance().log_dir
  142. weights_dir = osp.join(log_dir, "weights")
  143. # Train and Test
  144. bridge.train(
  145. train_data,
  146. loops=args.loops,
  147. segment_size=args.segment_size,
  148. save_interval=args.save_interval,
  149. save_dir=weights_dir,
  150. )
  151. bridge.test(test_data)
  152. if __name__ == "__main__":
  153. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.