You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import argparse
  2. import os.path as osp
  3. import torch
  4. from torch import nn
  5. from torch.optim import RMSprop, lr_scheduler
  6. from datasets import get_dataset
  7. from models.nn import LeNet5
  8. from abl.learning import ABLModel, BasicNN
  9. from abl.reasoning import GroundKB, KBBase, PrologKB, Reasoner
  10. from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
  11. from abl.utils import ABLLogger, print_log
  12. from abl.bridge import SimpleBridge
  13. class AddKB(KBBase):
  14. def __init__(self, pseudo_label_list=list(range(10))):
  15. super().__init__(pseudo_label_list)
  16. def logic_forward(self, nums):
  17. return sum(nums)
  18. class AddGroundKB(GroundKB):
  19. def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
  20. super().__init__(pseudo_label_list, GKB_len_list)
  21. def logic_forward(self, nums):
  22. return sum(nums)
  23. def main():
  24. parser = argparse.ArgumentParser(description="MNIST Addition example")
  25. parser.add_argument(
  26. "--no-cuda", action="store_true", default=False, help="disables CUDA training"
  27. )
  28. parser.add_argument(
  29. "--epochs",
  30. type=int,
  31. default=1,
  32. help="number of epochs in each learning loop iteration (default : 1)",
  33. )
  34. parser.add_argument(
  35. "--lr", type=float, default=3e-4, help="base model learning rate (default : 0.0003)"
  36. )
  37. parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)")
  38. parser.add_argument(
  39. "--batch-size", type=int, default=32, help="base model batch size (default : 32)"
  40. )
  41. parser.add_argument(
  42. "--loops", type=int, default=2, help="number of loop iterations (default : 2)"
  43. )
  44. parser.add_argument(
  45. "--segment_size", type=int or float, default=0.01, help="segment size (default : 0.01)"
  46. )
  47. parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
  48. parser.add_argument(
  49. "--max-revision",
  50. type=int or float,
  51. default=-1,
  52. help="maximum revision in reasoner (default : -1)",
  53. )
  54. parser.add_argument(
  55. "--require-more-revision",
  56. type=int,
  57. default=0,
  58. help="require more revision in reasoner (default : 0)",
  59. )
  60. kb_type = parser.add_mutually_exclusive_group()
  61. kb_type.add_argument(
  62. "--prolog", action="store_true", default=False, help="use PrologKB (default: False)"
  63. )
  64. kb_type.add_argument(
  65. "--ground", action="store_true", default=False, help="use GroundKB (default: False)"
  66. )
  67. args = parser.parse_args()
  68. ### Working with Data
  69. train_data = get_dataset(train=True, get_pseudo_label=True)
  70. test_data = get_dataset(train=False, get_pseudo_label=True)
  71. ### Building the Learning Part
  72. # Build necessary components for BasicNN
  73. cls = LeNet5(num_classes=10)
  74. loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)
  75. optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha)
  76. use_cuda = not args.no_cuda and torch.cuda.is_available()
  77. device = torch.device("cuda" if use_cuda else "cpu")
  78. scheduler = lr_scheduler.OneCycleLR(
  79. optimizer,
  80. max_lr=args.lr,
  81. pct_start=0.15,
  82. epochs=args.loops,
  83. steps_per_epoch=int(1 / args.segment_size),
  84. )
  85. # Build BasicNN
  86. base_model = BasicNN(
  87. cls,
  88. loss_fn,
  89. optimizer,
  90. scheduler=scheduler,
  91. device=device,
  92. batch_size=args.batch_size,
  93. num_epochs=args.epochs,
  94. )
  95. # Build ABLModel
  96. model = ABLModel(base_model)
  97. ### Building the Reasoning Part
  98. # Build knowledge base
  99. if args.prolog:
  100. kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
  101. elif args.ground:
  102. kb = AddGroundKB()
  103. else:
  104. kb = AddKB()
  105. # Create reasoner
  106. reasoner = Reasoner(
  107. kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
  108. )
  109. ### Building Evaluation Metrics
  110. metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  111. ### Bridge Learning and Reasoning
  112. bridge = SimpleBridge(model, reasoner, metric_list)
  113. # Build logger
  114. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  115. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  116. log_dir = ABLLogger.get_current_instance().log_dir
  117. weights_dir = osp.join(log_dir, "weights")
  118. # Train and Test
  119. bridge.train(
  120. train_data,
  121. loops=args.loops,
  122. segment_size=args.segment_size,
  123. save_interval=args.save_interval,
  124. save_dir=weights_dir,
  125. )
  126. bridge.test(test_data)
  127. if __name__ == "__main__":
  128. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.