You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Learning.rst 3.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. `Learn the Basics <Basics.html>`_ ||
  2. `Quick Start <Quick-Start.html>`_ ||
  3. `Dataset & Data Structure <Datasets.html>`_ ||
  4. **Learning Part** ||
  5. `Reasoning Part <Reasoning.html>`_ ||
  6. `Evaluation Metrics <Evaluation.html>`_ ||
  7. `Bridge <Bridge.html>`_
  8. Learning Part
  9. =============
  10. In this section, we will look at how to build the learning part.
  11. In ABLkit, building the learning part involves two steps:
  12. 1. Build a machine learning base model used to make predictions on instance-level data.
  13. 2. Instantiate an ``ABLModel`` with the base model, which enables the learning part to process example-level data.
  14. .. code:: python
  15. import sklearn
  16. import torchvision
  17. from ablkit.learning import BasicNN, ABLModel
  18. Building a base model
  19. ---------------------
  20. ABL toolkit allows the base model to be one of the following forms:
  21. 1. Any machine learning model conforming to the scikit-learn style, i.e., models which has implemented the ``fit`` and ``predict`` methods;
  22. 2. A PyTorch-based neural network, provided it has defined the architecture and implemented the ``forward`` method.
  23. For a scikit-learn model, we can directly use the model itself as a base model. For example, we can customize our base model by a KNN classfier:
  24. .. code:: python
  25. base_model = sklearn.neighbors.KNeighborsClassifier(n_neighbors=3)
  26. For a PyTorch-based neural network, we need to encapsulate it within a ``BasicNN`` object to create a base model. For example, we can customize our base model by a pre-trained ResNet-18:
  27. .. code:: python
  28. # Load a PyTorch-based neural network
  29. cls = torchvision.models.resnet18(pretrained=True)
  30. # loss function and optimizer are used for training
  31. loss_fn = torch.nn.CrossEntropyLoss()
  32. optimizer = torch.optim.Adam(cls.parameters())
  33. base_model = BasicNN(cls, loss_fn, optimizer)
  34. BasicNN
  35. ^^^^^^^
  36. ``BasicNN`` is a wrapper class for PyTorch-based neural networks, which enables them to work as scikit-learn models. It encapsulates the neural network, loss function, optimizer, and other elements into a single object, which can be used as a base model.
  37. Besides the necessary methods required to instantiate an ``ABLModel``, i.e., ``fit`` and ``predict``, ``BasicNN`` also implements the following methods:
  38. +-------------------------------+------------------------------------------+
  39. | Method | Function |
  40. +===============================+==========================================+
  41. | ``train_epoch(data_loader)`` | Train the neural network for one epoch. |
  42. +-------------------------------+------------------------------------------+
  43. | ``predict_proba(X)`` | Predict the class probabilities of ``X``.|
  44. +-------------------------------+------------------------------------------+
  45. | ``score(X, y)`` | Calculate the accuracy of the model on |
  46. | | test data. |
  47. +-------------------------------+------------------------------------------+
  48. | ``save(epoch_id, save_path)`` | Save the model. |
  49. +-------------------------------+------------------------------------------+
  50. | ``load(load_path)`` | Load the model. |
  51. +-------------------------------+------------------------------------------+
  52. Instantiating an ABLModel
  53. -------------------------
  54. Typically, base model is trained to make predictions on instance-level data, and can not directly process example-level data, which is not suitable for most neural-symbolic tasks. ABLkit provides the ``ABLModel`` to solve this problem. This class serves as a unified wrapper for all base models, which enables the learning part to train, test, and predict on example-level data.
  55. Generally, we can simply instantiate an ``ABLModel`` by:
  56. .. code:: python
  57. # Instantiate an ABLModel
  58. model = ABLModel(base_model)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.