You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_bridge.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. This module contains the base class for the Bridge part.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. from abc import ABCMeta, abstractmethod
  6. from typing import Any, List, Optional, Tuple, Union
  7. from ..data.structures import ListData
  8. from ..learning import ABLModel
  9. from ..reasoning import Reasoner
  10. class BaseBridge(metaclass=ABCMeta):
  11. """
  12. A base class for bridging learning and reasoning parts.
  13. This class provides necessary methods that need to be overridden in subclasses
  14. to construct a typical pipeline of Abductive Learning (corresponding to ``train``),
  15. which involves the following four methods:
  16. - predict: Predict class indices on the given data examples.
  17. - idx_to_pseudo_label: Map indices into pseudo-labels.
  18. - abduce_pseudo_label: Revise pseudo-labels based on abdutive reasoning.
  19. - pseudo_label_to_idx: Map revised pseudo-labels back into indices.
  20. Parameters
  21. ----------
  22. model : ABLModel
  23. The machine learning model wrapped in ``ABLModel``, which is mainly used for
  24. prediction and model training.
  25. reasoner : Reasoner
  26. The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision.
  27. """
  28. def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
  29. if not isinstance(model, ABLModel):
  30. raise TypeError(f"Expected an instance of ABLModel, but received type: {type(model)}")
  31. if not isinstance(reasoner, Reasoner):
  32. raise TypeError(
  33. f"Expected an instance of Reasoner, but received type: {type(reasoner)}"
  34. )
  35. self.model = model
  36. self.reasoner = reasoner
  37. @abstractmethod
  38. def predict(self, data_examples: ListData) -> Tuple[List[List[Any]], List[List[Any]]]:
  39. """Placeholder for predicting class indices from input."""
  40. @abstractmethod
  41. def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  42. """Placeholder for revising pseudo-labels based on abdutive reasoning."""
  43. @abstractmethod
  44. def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  45. """Placeholder for mapping indices to pseudo-labels."""
  46. @abstractmethod
  47. def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
  48. """Placeholder for mapping pseudo-labels to indices."""
  49. def filter_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  50. """Default filter function for pseudo-label."""
  51. non_empty_idx = [
  52. i
  53. for i in range(len(data_examples.abduced_pseudo_label))
  54. if data_examples.abduced_pseudo_label[i]
  55. ]
  56. data_examples.update(data_examples[non_empty_idx])
  57. return data_examples
  58. @abstractmethod
  59. def train(
  60. self,
  61. train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  62. ):
  63. """Placeholder for training loop of ABductive Learning."""
  64. @abstractmethod
  65. def valid(
  66. self,
  67. val_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  68. ) -> None:
  69. """Placeholder for model test."""
  70. @abstractmethod
  71. def test(
  72. self,
  73. test_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  74. ) -> None:
  75. """Placeholder for model validation."""

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.