You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import argparse
  2. import os.path as osp
  3. import torch
  4. import torch.nn as nn
  5. from ablkit.learning import ABLModel, BasicNN
  6. from ablkit.utils import ABLLogger, print_log
  7. from bridge import HedBridge
  8. from consistency_metric import ConsistencyMetric
  9. from datasets import get_dataset, split_equation
  10. from models.nn import SymbolNet
  11. from reasoning import HedKB, HedReasoner
  12. def main():
  13. parser = argparse.ArgumentParser(description="Handwritten Equation Decipherment example")
  14. parser.add_argument(
  15. "--no-cuda", action="store_true", default=False, help="disables CUDA training"
  16. )
  17. parser.add_argument(
  18. "--epochs",
  19. type=int,
  20. default=1,
  21. help="number of epochs in each learning loop iteration (default : 1)",
  22. )
  23. parser.add_argument(
  24. "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
  25. )
  26. parser.add_argument(
  27. "--weight-decay", type=float, default=1e-4, help="weight decay (default : 0.0001)"
  28. )
  29. parser.add_argument(
  30. "--batch-size", type=int, default=32, help="base model batch size (default : 32)"
  31. )
  32. parser.add_argument(
  33. "--segment_size", type=int, default=1000, help="segment size (default : 1000)"
  34. )
  35. parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
  36. parser.add_argument(
  37. "--max-revision",
  38. type=int,
  39. default=10,
  40. help="maximum revision in reasoner (default : 10)",
  41. )
  42. args = parser.parse_args()
  43. # Build logger
  44. print_log("Abductive Learning on the HED example.", logger="current")
  45. # -- Working with Data ------------------------------
  46. print_log("Working with Data.", logger="current")
  47. total_train_data = get_dataset(train=True)
  48. train_data, val_data = split_equation(total_train_data, 3, 1)
  49. test_data = get_dataset(train=False)
  50. # -- Building the Learning Part ---------------------
  51. print_log("Building the Learning Part.", logger="current")
  52. # Build necessary components for BasicNN
  53. cls = SymbolNet(num_classes=4)
  54. loss_fn = nn.CrossEntropyLoss()
  55. optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  56. use_cuda = not args.no_cuda and torch.cuda.is_available()
  57. device = torch.device("cuda" if use_cuda else "cpu")
  58. # Build BasicNN
  59. base_model = BasicNN(
  60. cls,
  61. loss_fn,
  62. optimizer,
  63. device=device,
  64. batch_size=args.batch_size,
  65. num_epochs=args.epochs,
  66. stop_loss=None,
  67. )
  68. # Build ABLModel
  69. model = ABLModel(base_model)
  70. # -- Building the Reasoning Part --------------------
  71. print_log("Building the Reasoning Part.", logger="current")
  72. # Build knowledge base
  73. kb = HedKB()
  74. # Create reasoner
  75. reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=args.max_revision)
  76. # -- Building Evaluation Metrics --------------------
  77. print_log("Building Evaluation Metrics.", logger="current")
  78. metric_list = [ConsistencyMetric(kb=kb)]
  79. # -- Bridging Learning and Reasoning ----------------
  80. print_log("Bridge Learning and Reasoning.", logger="current")
  81. bridge = HedBridge(model, reasoner, metric_list)
  82. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  83. log_dir = ABLLogger.get_current_instance().log_dir
  84. weights_dir = osp.join(log_dir, "weights")
  85. bridge.pretrain(weights_dir)
  86. bridge.train(train_data, val_data, save_dir=weights_dir)
  87. bridge.test(test_data)
  88. if __name__ == "__main__":
  89. main()

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.