You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import argparse
  2. import os.path as osp
  3. import torch
  4. from torch import nn
  5. from torch.optim import RMSprop, lr_scheduler
  6. from ablkit.bridge import SimpleBridge
  7. from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy
  8. from ablkit.learning import ABLModel, BasicNN
  9. from ablkit.reasoning import GroundKB, KBBase, PrologKB, Reasoner
  10. from ablkit.utils import ABLLogger, print_log
  11. from datasets import get_dataset
  12. from models.nn import LeNet5
  13. class AddKB(KBBase):
  14. def __init__(self, pseudo_label_list=list(range(10))):
  15. super().__init__(pseudo_label_list)
  16. def logic_forward(self, nums):
  17. return sum(nums)
  18. class AddGroundKB(GroundKB):
  19. def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
  20. super().__init__(pseudo_label_list, GKB_len_list)
  21. def logic_forward(self, nums):
  22. return sum(nums)
  23. def main():
  24. parser = argparse.ArgumentParser(description="MNIST Addition example")
  25. parser.add_argument(
  26. "--no-cuda", action="store_true", default=False, help="disables CUDA training"
  27. )
  28. parser.add_argument(
  29. "--epochs",
  30. type=int,
  31. default=1,
  32. help="number of epochs in each learning loop iteration (default : 1)",
  33. )
  34. parser.add_argument(
  35. "--label-smoothing",
  36. type=float,
  37. default=0.2,
  38. help="label smoothing in cross entropy loss (default : 0.2)",
  39. )
  40. parser.add_argument(
  41. "--lr", type=float, default=3e-4, help="base model learning rate (default : 0.0003)"
  42. )
  43. parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)")
  44. parser.add_argument(
  45. "--batch-size", type=int, default=32, help="base model batch size (default : 32)"
  46. )
  47. parser.add_argument(
  48. "--loops", type=int, default=2, help="number of loop iterations (default : 2)"
  49. )
  50. parser.add_argument(
  51. "--segment_size", type=int, default=0.01, help="segment size (default : 0.01)"
  52. )
  53. parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
  54. parser.add_argument(
  55. "--max-revision", type=int, default=-1, help="maximum revision in reasoner (default : -1)"
  56. )
  57. parser.add_argument(
  58. "--require-more-revision",
  59. type=int,
  60. default=0,
  61. help="require more revision in reasoner (default : 0)",
  62. )
  63. kb_type = parser.add_mutually_exclusive_group()
  64. kb_type.add_argument(
  65. "--prolog", action="store_true", default=False, help="use PrologKB (default: False)"
  66. )
  67. kb_type.add_argument(
  68. "--ground", action="store_true", default=False, help="use GroundKB (default: False)"
  69. )
  70. args = parser.parse_args()
  71. # Build logger
  72. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  73. # -- Working with Data ------------------------------
  74. print_log("Working with Data.", logger="current")
  75. train_data = get_dataset(train=True, get_pseudo_label=True)
  76. test_data = get_dataset(train=False, get_pseudo_label=True)
  77. # -- Building the Learning Part ---------------------
  78. print_log("Building the Learning Part.", logger="current")
  79. # Build necessary components for BasicNN
  80. cls = LeNet5(num_classes=10)
  81. loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
  82. optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha)
  83. use_cuda = not args.no_cuda and torch.cuda.is_available()
  84. device = torch.device("cuda" if use_cuda else "cpu")
  85. scheduler = lr_scheduler.OneCycleLR(
  86. optimizer,
  87. max_lr=args.lr,
  88. pct_start=0.15,
  89. epochs=args.loops,
  90. steps_per_epoch=int(1 / args.segment_size),
  91. )
  92. # Build BasicNN
  93. base_model = BasicNN(
  94. cls,
  95. loss_fn,
  96. optimizer,
  97. scheduler=scheduler,
  98. device=device,
  99. batch_size=args.batch_size,
  100. num_epochs=args.epochs,
  101. )
  102. # Build ABLModel
  103. model = ABLModel(base_model)
  104. # -- Building the Reasoning Part --------------------
  105. print_log("Building the Reasoning Part.", logger="current")
  106. # Build knowledge base
  107. if args.prolog:
  108. kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
  109. elif args.ground:
  110. kb = AddGroundKB()
  111. else:
  112. kb = AddKB()
  113. # Create reasoner
  114. reasoner = Reasoner(
  115. kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
  116. )
  117. # -- Building Evaluation Metrics --------------------
  118. print_log("Building Evaluation Metrics.", logger="current")
  119. metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  120. # -- Bridging Learning and Reasoning ----------------
  121. print_log("Bridge Learning and Reasoning.", logger="current")
  122. bridge = SimpleBridge(model, reasoner, metric_list)
  123. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  124. log_dir = ABLLogger.get_current_instance().log_dir
  125. weights_dir = osp.join(log_dir, "weights")
  126. # Train and Test
  127. bridge.train(
  128. train_data,
  129. loops=args.loops,
  130. segment_size=args.segment_size,
  131. save_interval=args.save_interval,
  132. save_dir=weights_dir,
  133. )
  134. bridge.test(test_data)
  135. if __name__ == "__main__":
  136. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.