You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_bridge.py 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from abc import ABCMeta, abstractmethod
  2. from typing import Any, List, Optional, Tuple, Union
  3. from ..data.structures import ListData
  4. from ..learning import ABLModel
  5. from ..reasoning import Reasoner
  6. class BaseBridge(metaclass=ABCMeta):
  7. """
  8. A base class for bridging learning and reasoning parts.
  9. This class provides necessary methods that need to be overridden in subclasses
  10. to construct a typical pipeline of Abductive Learning (corresponding to ``train``),
  11. which involves the following four methods:
  12. - predict: Predict class indices on the given data examples.
  13. - idx_to_pseudo_label: Map indices into pseudo-labels.
  14. - abduce_pseudo_label: Revise pseudo-labels based on abdutive reasoning.
  15. - pseudo_label_to_idx: Map revised pseudo-labels back into indices.
  16. Parameters
  17. ----------
  18. model : ABLModel
  19. The machine learning model wrapped in ``ABLModel``, which is mainly used for
  20. prediction and model training.
  21. reasoner : Reasoner
  22. The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision.
  23. """
  24. def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
  25. if not isinstance(model, ABLModel):
  26. raise TypeError(
  27. "Expected an instance of ABLModel, but received type: {}".format(type(model))
  28. )
  29. if not isinstance(reasoner, Reasoner):
  30. raise TypeError(
  31. "Expected an instance of Reasoner, but received type: {}".format(type(reasoner))
  32. )
  33. self.model = model
  34. self.reasoner = reasoner
  35. @abstractmethod
  36. def predict(self, data_examples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
  37. """Placeholder for predicting class indices from input."""
  38. @abstractmethod
  39. def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  40. """Placeholder for revising pseudo-labels based on abdutive reasoning."""
  41. @abstractmethod
  42. def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  43. """Placeholder for mapping indices to pseudo-labels."""
  44. @abstractmethod
  45. def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
  46. """Placeholder for mapping pseudo-labels to indices."""
  47. def filter_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  48. """Default filter function for pseudo-label."""
  49. non_empty_idx = [
  50. i
  51. for i in range(len(data_examples.abduced_pseudo_label))
  52. if data_examples.abduced_pseudo_label[i]
  53. ]
  54. data_examples.update(data_examples[non_empty_idx])
  55. return data_examples
  56. @abstractmethod
  57. def train(
  58. self,
  59. train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  60. ):
  61. """Placeholder for training loop of ABductive Learning."""
  62. @abstractmethod
  63. def valid(
  64. self,
  65. val_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  66. ) -> None:
  67. """Placeholder for model test."""
  68. @abstractmethod
  69. def test(
  70. self,
  71. test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  72. ) -> None:
  73. """Placeholder for model validation."""

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.