You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Bridge.rst 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. `Learn the Basics <Basics.html>`_ ||
  2. `Quick Start <Quick-Start.html>`_ ||
  3. `Dataset & Data Structure <Datasets.html>`_ ||
  4. `Learning Part <Learning.html>`_ ||
  5. `Reasoning Part <Reasoning.html>`_ ||
  6. `Evaluation Metrics <Evaluation.html>`_ ||
  7. **Bridge**
  8. Bridge
  9. ======
  10. In this section, we will look at how to bridge learning and reasoning parts to train the model, which is the fundamental idea of Abductive Learning. ABL-Package implements a set of bridge classes to achieve this.
  11. .. code:: python
  12. from abl.bridge import BaseBridge, SimpleBridge
  13. ``BaseBridge`` is an abstract class with the following initialization parameters:
  14. - ``model`` is an object of type ``ABLModel``. Learning part are wrapped in this object.
  15. - ``reasoner`` is a object of type ``Reasoner``. Reasoning part are wrapped in this object.
  16. ``BaseBridge`` has the following important methods that need to be overridden in subclasses:
  17. +---------------------------------------+----------------------------------------------------+
  18. | Method Signature | Description |
  19. +=======================================+====================================================+
  20. | ``predict(data_samples)`` | Predicts class probabilities and indices |
  21. | | for the given data samples. |
  22. +---------------------------------------+----------------------------------------------------+
  23. | ``abduce_pseudo_label(data_samples)`` | Abduces pseudo labels for the given data samples. |
  24. +---------------------------------------+----------------------------------------------------+
  25. | ``idx_to_pseudo_label(data_samples)`` | Converts indices to pseudo labels using |
  26. | | the provided or default mapping. |
  27. +---------------------------------------+----------------------------------------------------+
  28. | ``pseudo_label_to_idx(data_samples)`` | Converts pseudo labels to indices |
  29. | | using the provided or default remapping. |
  30. +---------------------------------------+----------------------------------------------------+
  31. | ``train(train_data)`` | Train the model. |
  32. +---------------------------------------+----------------------------------------------------+
  33. | ``test(test_data)`` | Test the model. |
  34. +---------------------------------------+----------------------------------------------------+
  35. where ``train_data`` and ``test_data`` are both in the form of ``(X, gt_pseudo_label, Y)``. They will be used to construct ``ListData`` instances which are referred to as ``data_samples`` in the ``train`` and ``test`` methods respectively. More details can be found in `preparing datasets <Datasets.html>`_.
  36. ``SimpleBridge`` inherits from ``BaseBridge`` and provides a basic implementation. Besides the ``model`` and ``reasoner``, ``SimpleBridge`` has an extra initialization arguments, ``metric_list``, which will be used to evaluate model performance. Its training process involves several Abductive Learning loops and each loop consists of the following five steps:
  37. 1. Predict class probabilities and indices for the given data samples.
  38. 2. Transform indices into pseudo labels.
  39. 3. Revise pseudo labels based on abdutive reasoning.
  40. 4. Transform the revised pseudo labels to indices.
  41. 5. Train the model.
  42. The fundamental part of the ``train`` method is as follows:
  43. .. code-block:: python
  44. def train(self, train_data, loops=50, segment_size=10000):
  45. """
  46. Parameters
  47. ----------
  48. train_data : Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
  49. Training data.
  50. loops : int
  51. Machine Learning part and Reasoning part will be iteratively optimized
  52. for ``loops`` times.
  53. segment_size : Union[int, float]
  54. Data will be split into segments of this size and data in each segment
  55. will be used together to train the model.
  56. """
  57. if isinstance(train_data, ListData):
  58. data_samples = train_data
  59. else:
  60. data_samples = self.data_preprocess(*train_data)
  61. if isinstance(segment_size, float):
  62. segment_size = int(segment_size * len(data_samples))
  63. for loop in range(loops):
  64. for seg_idx in range((len(data_samples) - 1) // segment_size + 1):
  65. sub_data_samples = data_samples[
  66. seg_idx * segment_size : (seg_idx + 1) * segment_size
  67. ]
  68. self.predict(sub_data_samples) # 1
  69. self.idx_to_pseudo_label(sub_data_samples) # 2
  70. self.abduce_pseudo_label(sub_data_samples) # 3
  71. self.pseudo_label_to_idx(sub_data_samples) # 4
  72. loss = self.model.train(sub_data_samples) # 5, self.model is an ABLModel object

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.