You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_with_model_converter.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import argparse
  2. import os.path as osp
  3. from torch import nn
  4. from torch.optim import RMSprop, lr_scheduler
  5. from lambdaLearn.Algorithm.AbductiveLearning.bridge import SimpleBridge
  6. from lambdaLearn.Algorithm.AbductiveLearning.data.evaluation import ReasoningMetric, SymbolAccuracy
  7. from lambdaLearn.Algorithm.AbductiveLearning.learning import ABLModel
  8. from lambdaLearn.Algorithm.AbductiveLearning.learning.model_converter import ModelConverter
  9. from lambdaLearn.Algorithm.AbductiveLearning.reasoning import GroundKB, KBBase, PrologKB, Reasoner
  10. from lambdaLearn.Algorithm.AbductiveLearning.utils import ABLLogger, print_log
  11. from lambdaLearn.Algorithm.SemiSupervised.Classification.FixMatch import FixMatch
  12. from datasets import get_dataset
  13. from models.nn import LeNet5
  14. class AddKB(KBBase):
  15. def __init__(self, pseudo_label_list=list(range(10))):
  16. super().__init__(pseudo_label_list)
  17. def logic_forward(self, nums):
  18. return sum(nums)
  19. class AddGroundKB(GroundKB):
  20. def __init__(self, pseudo_label_list=list(range(10)), GKB_len_list=[2]):
  21. super().__init__(pseudo_label_list, GKB_len_list)
  22. def logic_forward(self, nums):
  23. return sum(nums)
  24. def main():
  25. parser = argparse.ArgumentParser(description="MNIST Addition example")
  26. parser.add_argument(
  27. "--no-cuda", action="store_true", default=False, help="disables CUDA training"
  28. )
  29. parser.add_argument(
  30. "--epochs",
  31. type=int,
  32. default=1,
  33. help="number of epochs in each learning loop iteration (default : 1)",
  34. )
  35. parser.add_argument(
  36. "--lr", type=float, default=3e-4, help="base model learning rate (default : 0.0003)"
  37. )
  38. parser.add_argument("--alpha", type=float, default=0.9, help="alpha in RMSprop (default : 0.9)")
  39. parser.add_argument(
  40. "--batch-size", type=int, default=32, help="base model batch size (default : 32)"
  41. )
  42. parser.add_argument(
  43. "--loops", type=int, default=2, help="number of loop iterations (default : 2)"
  44. )
  45. parser.add_argument(
  46. "--segment_size", type=int, default=0.01, help="segment size (default : 0.01)"
  47. )
  48. parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
  49. parser.add_argument(
  50. "--max-revision",
  51. type=int,
  52. default=-1,
  53. help="maximum revision in reasoner (default : -1)",
  54. )
  55. parser.add_argument(
  56. "--require-more-revision",
  57. type=int,
  58. default=0,
  59. help="require more revision in reasoner (default : 0)",
  60. )
  61. kb_type = parser.add_mutually_exclusive_group()
  62. kb_type.add_argument(
  63. "--prolog", action="store_true", default=False, help="use PrologKB (default: False)"
  64. )
  65. kb_type.add_argument(
  66. "--ground", action="store_true", default=False, help="use GroundKB (default: False)"
  67. )
  68. args = parser.parse_args()
  69. # Build logger
  70. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  71. # -- Working with Data ------------------------------
  72. print_log("Working with Data.", logger="current")
  73. train_data = get_dataset(train=True, get_pseudo_label=True)
  74. test_data = get_dataset(train=False, get_pseudo_label=True)
  75. # -- Building the Learning Part ---------------------
  76. print_log("Building the Learning Part.", logger="current")
  77. # Build necessary components for BasicNN
  78. model = FixMatch(
  79. network=LeNet5(),
  80. threshold=0.95,
  81. lambda_u=1.0,
  82. mu=7,
  83. T=0.5,
  84. epoch=1,
  85. num_it_epoch=2**20,
  86. num_it_total=2**20,
  87. device="cuda",
  88. )
  89. loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)
  90. optimizer_dict = dict(optimizer=RMSprop, lr=0.0003, alpha=0.9)
  91. scheduler_dict = dict(
  92. scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200
  93. )
  94. converter = ModelConverter()
  95. base_model = converter.convert_lambdalearn_to_basicnn(
  96. model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict
  97. )
  98. # Build ABLModel
  99. model = ABLModel(base_model)
  100. # -- Building the Reasoning Part --------------------
  101. print_log("Building the Reasoning Part.", logger="current")
  102. # Build knowledge base
  103. if args.prolog:
  104. kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
  105. elif args.ground:
  106. kb = AddGroundKB()
  107. else:
  108. kb = AddKB()
  109. # Create reasoner
  110. reasoner = Reasoner(
  111. kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
  112. )
  113. # -- Building Evaluation Metrics --------------------
  114. print_log("Building Evaluation Metrics.", logger="current")
  115. metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  116. # -- Bridging Learning and Reasoning ----------------
  117. print_log("Bridge Learning and Reasoning.", logger="current")
  118. bridge = SimpleBridge(model, reasoner, metric_list)
  119. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  120. log_dir = ABLLogger.get_current_instance().log_dir
  121. weights_dir = osp.join(log_dir, "weights")
  122. # Train and Test
  123. bridge.train(
  124. train_data,
  125. loops=args.loops,
  126. segment_size=args.segment_size,
  127. save_interval=args.save_interval,
  128. save_dir=weights_dir,
  129. )
  130. bridge.test(test_data)
  131. if __name__ == "__main__":
  132. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.