You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Quick-Start.rst 5.7 kB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. `Learn the Basics <Basics.html>`_ ||
  2. **Quick Start** ||
  3. `Dataset & Data Structure <Datasets.html>`_ ||
  4. `Learning Part <Learning.html>`_ ||
  5. `Reasoning Part <Reasoning.html>`_ ||
  6. `Evaluation Metrics <Evaluation.html>`_ ||
  7. `Bridge <Bridge.html>`_
  8. Quick Start
  9. ===========
  10. We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contains information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. Refer to the links in each section to dive deeper.
  11. Working with Data
  12. -----------------
  13. ABL Kit requires data in the format of ``(X, gt_pseudo_label, Y)`` where ``X`` is a list of input examples containing instances,
  14. ``gt_pseudo_label`` is the ground-truth label of each example in ``X`` and ``Y`` is the ground-truth reasoning result of each example in ``X``. Note that ``gt_pseudo_label`` is only used to evaluate the machine learning model's performance but not to train it.
  15. In the MNIST Addition task, the data loading looks like
  16. .. code:: python
  17. # The 'datasets' module below is located in 'examples/mnist_add/'
  18. from datasets import get_dataset
  19. # train_data and test_data are tuples in the format of (X, gt_pseudo_label, Y)
  20. train_data = get_dataset(train=True)
  21. test_data = get_dataset(train=False)
  22. Read more about `preparing datasets <Datasets.html>`_.
  23. Building the Learning Part
  24. --------------------------
  25. Learning part is constructed by first defining a base model for machine learning. ABL Kit offers considerable flexibility, supporting any base model that conforms to the scikit-learn style (which requires the implementation of ``fit`` and ``predict`` methods), or a PyTorch-based neural network (which has defined the architecture and implemented ``forward`` method).
  26. In this example, we build a simple LeNet5 network as the base model.
  27. .. code:: python
  28. # The 'models' module below is located in 'examples/mnist_add/'
  29. from models.nn import LeNet5
  30. cls = LeNet5(num_classes=10)
  31. To facilitate uniform processing, ABL Kit provides the ``BasicNN`` class to convert a PyTorch-based neural network into a format compatible with scikit-learn models. To construct a ``BasicNN`` instance, aside from the network itself, we also need to define a loss function, an optimizer, and the computing device.
  32. .. code:: python
  33. import torch
  34. from ablkit.learning import BasicNN
  35. loss_fn = torch.nn.CrossEntropyLoss()
  36. optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001)
  37. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  38. base_model = BasicNN(model=cls, loss_fn=loss_fn, optimizer=optimizer, device=device)
  39. The base model built above is trained to make predictions on instance-level data (e.g., a single image), while ABL deals with example-level data. To bridge this gap, we wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise an equation).
  40. .. code:: python
  41. from ablkit.learning import ABLModel
  42. model = ABLModel(base_model)
  43. Read more about `building the learning part <Learning.html>`_.
  44. Building the Reasoning Part
  45. ---------------------------
  46. To build the reasoning part, we first define a knowledge base by creating a subclass of ``KBBase``. In the subclass, we initialize the ``pseudo_label_list`` parameter and override the ``logic_forward`` method, which specifies how to perform (deductive) reasoning that processes pseudo-labels of an example to the corresponding reasoning result. Specifically, for the MNIST Addition task, this ``logic_forward`` method is tailored to execute the sum operation.
  47. .. code:: python
  48. from ablkit.reasoning import KBBase
  49. class AddKB(KBBase):
  50. def __init__(self, pseudo_label_list=list(range(10))):
  51. super().__init__(pseudo_label_list)
  52. def logic_forward(self, nums):
  53. return sum(nums)
  54. kb = AddKB()
  55. Next, we create a reasoner by instantiating the class ``Reasoner``, passing the knowledge base as a parameter.
  56. Due to the indeterminism of abductive reasoning, there could be multiple candidate pseudo-labels compatible with the knowledge base.
  57. In such scenarios, the reasoner can minimize inconsistency and return the pseudo-label with the highest consistency.
  58. .. code:: python
  59. from ablkit.reasoning import Reasoner
  60. reasoner = Reasoner(kb)
  61. Read more about `building the reasoning part <Reasoning.html>`_.
  62. Building Evaluation Metrics
  63. ---------------------------
  64. ABL Kit provides two basic metrics, namely ``SymbolAccuracy`` and ``ReasoningMetric``, which are used to evaluate the accuracy of the machine learning model's predictions and the accuracy of the ``logic_forward`` results, respectively.
  65. .. code:: python
  66. from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy
  67. metric_list = [SymbolAccuracy(), ReasoningMetric(kb=kb)]
  68. Read more about `building evaluation metrics <Evaluation.html>`_
  69. Bridging Learning and Reasoning
  70. ---------------------------------------
  71. Now, we use ``SimpleBridge`` to combine learning and reasoning in a unified ABL framework.
  72. .. code:: python
  73. from ablkit.bridge import SimpleBridge
  74. bridge = SimpleBridge(model, reasoner, metric_list)
  75. Finally, we proceed with training and testing.
  76. .. code:: python
  77. bridge.train(train_data, loops=1, segment_size=0.01)
  78. bridge.test(test_data)
  79. Read more about `bridging machine learning and reasoning <Bridge.html>`_.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.