You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Quick-Start.rst 9.6 kB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. `Learn the Basics <Basics.html>`_ ||
  2. **Quick Start** ||
  3. `Dataset & Data Structure <Datasets.html>`_ ||
  4. `Learning Part <Learning.html>`_ ||
  5. `Reasoning Part <Reasoning.html>`_ ||
  6. `Evaluation Metrics <Evaluation.html>`_ ||
  7. `Bridge <Bridge.html>`_
  8. Quick Start
  9. ===========
  10. We use the MNIST Addition task as a quick start example. In this task, the inputs are pairs of MNIST handwritten images, and the outputs are their sums. Refer to the links in each section to dive deeper.
  11. Working with Data
  12. -----------------
  13. ABL-Package assumes data to be in the form of ``(X, gt_pseudo_label, Y)`` where ``X`` is the input of the machine learning model,
  14. ``gt_pseudo_label`` is the ground truth label of each element in ``X`` and ``Y`` is the ground truth reasoning result of each instance in ``X``. Note that ``gt_pseudo_label`` is only used to evaluate the performance of the machine learning part but not to train the model. If elements in ``X`` are unlabeled, ``gt_pseudo_label`` can be ``None``.
  15. In the MNIST Addition task, the data loading looks like
  16. .. code:: python
  17. from examples.mnist_add.datasets.get_mnist_add import get_mnist_add
  18. # train_data and test_data are all tuples consist of X, gt_pseudo_label and Y.
  19. # If get_pseudo_label is False, gt_pseudo_label will be None
  20. train_data = get_mnist_add(train=True, get_pseudo_label=True)
  21. test_data = get_mnist_add(train=False, get_pseudo_label=True)
  22. ABL-Package assumes ``X`` to be of type ``List[List[Any]]``, ``gt_pseudo_label`` can be ``None`` or of the type ``List[List[Any]]`` and ``Y`` should be of type ``List[Any]``. The following code shows the structure of the dataset used in MNIST Addition.
  23. .. code:: python
  24. def describe_structure(lst):
  25. if not isinstance(lst, list):
  26. return type(lst).__name__
  27. return [describe_structure(item) for item in lst]
  28. X, gt_pseudo_label, Y = train_data
  29. print(f"Length of X List[List[Any]]: {len(X)}")
  30. print(f"Length of gt_pseudo_label List[List[Any]]: {len(gt_pseudo_label)}")
  31. print(f"Length of Y List[Any]: {len(Y)}\n")
  32. structure_X = describe_structure(X[:3])
  33. print(f"Structure of X: {structure_X}")
  34. structure_gt_pseudo_label = describe_structure(gt_pseudo_label[:3])
  35. print(f"Structure of gt_pseudo_label: {structure_gt_pseudo_label}")
  36. structure_Y = describe_structure(Y[:3])
  37. print(f"Structure of Y: {structure_Y}\n")
  38. print(f"Shape of X [C, H, W]: {X[0][0].shape}")
  39. Out:
  40. .. code-block:: none
  41. :class: code-out
  42. Length of X List[List[Any]]: 30000
  43. Length of gt_pseudo_label List[List[Any]]: 30000
  44. Length of Y List[Any]: 30000
  45. Structure of X: [['Tensor', 'Tensor'], ['Tensor', 'Tensor'], ['Tensor', 'Tensor']]
  46. Structure of gt_pseudo_label: [['int', 'int'], ['int', 'int'], ['int', 'int']]
  47. Structure of Y: ['int', 'int', 'int']
  48. Shape of X [C, H, W]: torch.Size([1, 28, 28])
  49. ABL-Package provides several dataset classes for different purposes, including ``ClassificationDataset``, ``RegressionDataset``, and ``PredictionDataset``. However, it's not necessary to encapsulate data into these specific classes. Instead, we only need to structure our datasets in the aforementioned formats.
  50. Read more about `preparing datasets <Datasets.html>`_.
  51. Building the Learning Part
  52. --------------------------
  53. Learnig part is constructed by first defining a base machine learning model and then wrap it into an instance of ``ABLModel`` class.
  54. The flexibility of ABL package allows the base model to be any machine learning model conforming to the scikit-learn style, which requires implementing the ``fit`` and ``predict`` methods, or a PyTorch-based neural network, provided it has defined the architecture and implemented the ``forward`` method.
  55. In the MNIST Addition example, we build a simple LeNet5 network as the base model.
  56. .. code:: python
  57. from examples.models.nn import LeNet5
  58. # The number of pseudo labels is 10
  59. cls = LeNet5(num_classes=10)
  60. To facilitate uniform processing, ABL-Package provides the ``BasicNN`` class to convert PyTorch-based neural networks into a format similar to scikit-learn models. To construct a ``BasicNN`` instance, we need also define a loss function, an optimizer, and a device aside from the previous network.
  61. .. code:: python
  62. import torch
  63. from abl.learning import BasicNN
  64. loss_fn = torch.nn.CrossEntropyLoss()
  65. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  66. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  67. base_model = BasicNN(cls, loss_fn, optimizer, device)
  68. .. code:: python
  69. pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
  70. print(f"Shape of pred_idx : {pred_idx.shape}")
  71. pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
  72. print(f"Shape of pred_prob : {pred_prob.shape}")
  73. Out:
  74. .. code-block:: none
  75. :class: code-out
  76. Shape of pred_idx : (32,)
  77. Shape of pred_prob : (32, 10)
  78. Afterward, we wrap the scikit-learn style model, ``base_model``, into an instance of ``ABLModel``. This class serves as a unified wrapper for all base models, facilitating the learning part to train, test, and predict on sample-level data - such as equations in the MNIST Addition task.
  79. .. code:: python
  80. from abl.learning import ABLModel
  81. model = ABLModel(base_model)
  82. Read more about `building the learning part <Learning.html>`_.
  83. Building the Reasoning Part
  84. ---------------------------
  85. To build the reasoning part, we first define a knowledge base by
  86. creating a subclass of ``KBBase``, which specifies how to map a pseudo
  87. label sample to its reasoning result. In the subclass, we initialize the
  88. ``pseudo_label_list`` parameter and override the ``logic_forward``
  89. function specifying how to perform (deductive) reasoning.
  90. .. code:: python
  91. from abl.reasoning import KBBase
  92. class AddKB(KBBase):
  93. def __init__(self, pseudo_label_list=list(range(10))):
  94. super().__init__(pseudo_label_list)
  95. def logic_forward(self, nums):
  96. return sum(nums)
  97. kb = AddKB(pseudo_label_list=list(range(10)))
  98. Then, we create a reasoner by instantiating the class
  99. ``Reasoner`` and passing the knowledge base as an parameter.
  100. The reasoner can be used to minimize inconsistencies between the
  101. knowledge base and the prediction from the learning part.
  102. .. code:: python
  103. from abl.reasoning import Reasoner
  104. reasoner = Reasoner(kb)
  105. Read more about `building the reasoning part <Reasoning.html>`_.
  106. Building Evaluation Metrics
  107. ---------------------------
  108. ABL-Package provides two basic metrics, namely ``SymbolMetric`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively.
  109. .. code:: python
  110. from abl.evaluation import ReasoningMetric, SymbolMetric
  111. metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  112. Read more about `building evaluation metrics <Evaluation.html>`_
  113. Bridging Learning and Reasoning
  114. ---------------------------------------
  115. Now, we use ``SimpleBridge`` to combine learning and reasoning in a unified model.
  116. .. code:: python
  117. from abl.bridge import SimpleBridge
  118. bridge = SimpleBridge(model, reasoner, metric_list)
  119. Finally, we proceed with training and testing.
  120. .. code:: python
  121. bridge.train(train_data, loops=5, segment_size=1/3)
  122. bridge.test(test_data)
  123. Training log would be similar to this:
  124. .. code-block:: none
  125. :class: code-out
  126. abl - INFO - Abductive Learning on the MNIST Add example.
  127. abl - INFO - loop(train) [1/5] segment(train) [1/3]
  128. abl - INFO - model loss: 1.91761
  129. abl - INFO - loop(train) [1/5] segment(train) [2/3]
  130. abl - INFO - model loss: 1.59485
  131. abl - INFO - loop(train) [1/5] segment(train) [3/3]
  132. abl - INFO - model loss: 1.33183
  133. abl - INFO - Evaluation start: loop(val) [1]
  134. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.450 mnist_add/reasoning_accuracy: 0.237
  135. abl - INFO - Saving model: loop(save) [1]
  136. abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_1.pth
  137. abl - INFO - loop(train) [2/5] segment(train) [1/3]
  138. abl - INFO - model loss: 1.00664
  139. abl - INFO - loop(train) [2/5] segment(train) [2/3]
  140. abl - INFO - model loss: 0.52233
  141. abl - INFO - loop(train) [2/5] segment(train) [3/3]
  142. abl - INFO - model loss: 0.11282
  143. abl - INFO - Evaluation start: loop(val) [2]
  144. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.976 mnist_add/reasoning_accuracy: 0.954
  145. abl - INFO - Saving model: loop(save) [2]
  146. abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_2.pth
  147. ...
  148. abl - INFO - loop(train) [5/5] segment(train) [1/3]
  149. abl - INFO - model loss: 0.04030
  150. abl - INFO - loop(train) [5/5] segment(train) [2/3]
  151. abl - INFO - model loss: 0.03859
  152. abl - INFO - loop(train) [5/5] segment(train) [3/3]
  153. abl - INFO - model loss: 0.03423
  154. abl - INFO - Evaluation start: loop(val) [5]
  155. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.992 mnist_add/reasoning_accuracy: 0.984
  156. abl - INFO - Saving model: loop(save) [5]
  157. abl - INFO - Checkpoints will be saved to results/work_dir/weights/model_checkpoint_loop_5.pth
  158. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975
  159. Read more about `bridging machine learning and reasoning <Bridge.html>`_.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.