You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

simple_bridge.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. import os.path as osp
  2. from typing import Any, List, Optional, Tuple, Union
  3. from numpy import ndarray
  4. from ..data.evaluation import BaseMetric
  5. from ..data.structures import ListData
  6. from ..learning import ABLModel
  7. from ..reasoning import Reasoner
  8. from ..utils import print_log
  9. from .base_bridge import BaseBridge
  10. class SimpleBridge(BaseBridge):
  11. """
  12. A basic implementation for bridging machine learning and reasoning parts.
  13. This class implements the typical pipeline of Abductive Learning, which involves
  14. the following five steps:
  15. - Predict class probabilities and indices for the given data examples.
  16. - Map indices into pseudo-labels.
  17. - Revise pseudo-labels based on abdutive reasoning.
  18. - Map the revised pseudo-labels to indices.
  19. - Train the model.
  20. Parameters
  21. ----------
  22. model : ABLModel
  23. The machine learning model wrapped in ``ABLModel``, which is mainly used for
  24. prediction and model training.
  25. reasoner : Reasoner
  26. The reasoning part wrapped in ``Reasoner``, which is used for pseudo-label revision.
  27. metric_list : List[BaseMetric]
  28. A list of metrics used for evaluating the model's performance.
  29. """
  30. def __init__(
  31. self,
  32. model: ABLModel,
  33. reasoner: Reasoner,
  34. metric_list: List[BaseMetric],
  35. ) -> None:
  36. super().__init__(model, reasoner)
  37. self.metric_list = metric_list
  38. def predict(self, data_examples: ListData) -> Tuple[List[ndarray], List[ndarray]]:
  39. """
  40. Predict class indices and probabilities (if ``predict_proba`` is implemented in
  41. ``self.model.base_model``) on the given data examples.
  42. Parameters
  43. ----------
  44. data_examples : ListData
  45. Data examples on which predictions are to be made.
  46. Returns
  47. -------
  48. Tuple[List[ndarray], List[ndarray]]
  49. A tuple containing lists of predicted indices and probabilities.
  50. """
  51. self.model.predict(data_examples)
  52. return data_examples.pred_idx, data_examples.pred_prob
  53. def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  54. """
  55. Revise predicted pseudo-labels of the given data examples using abduction.
  56. Parameters
  57. ----------
  58. data_examples : ListData
  59. Data examples containing predicted pseudo-labels.
  60. Returns
  61. -------
  62. List[List[Any]]
  63. A list of abduced pseudo-labels for the given data examples.
  64. """
  65. self.reasoner.batch_abduce(data_examples)
  66. return data_examples.abduced_pseudo_label
  67. def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
  68. """
  69. Map indices of data examples into pseudo-labels.
  70. Parameters
  71. ----------
  72. data_examples : ListData
  73. Data examples containing the indices.
  74. Returns
  75. -------
  76. List[List[Any]]
  77. A list of pseudo-labels converted from indices.
  78. """
  79. pred_idx = data_examples.pred_idx
  80. data_examples.pred_pseudo_label = [
  81. [self.reasoner.idx_to_label[_idx] for _idx in sub_list] for sub_list in pred_idx
  82. ]
  83. return data_examples.pred_pseudo_label
  84. def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
  85. """
  86. Map pseudo-labels of data examples into indices.
  87. Parameters
  88. ----------
  89. data_examples : ListData
  90. Data examples containing pseudo-labels.
  91. Returns
  92. -------
  93. List[List[Any]]
  94. A list of indices converted from pseudo-labels.
  95. """
  96. abduced_idx = [
  97. [
  98. self.reasoner.label_to_idx[_abduced_pseudo_label]
  99. for _abduced_pseudo_label in sub_list
  100. ]
  101. for sub_list in data_examples.abduced_pseudo_label
  102. ]
  103. data_examples.abduced_idx = abduced_idx
  104. return data_examples.abduced_idx
  105. def data_preprocess(
  106. self,
  107. prefix: str,
  108. data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  109. ) -> ListData:
  110. """
  111. Transform data in the form of (X, gt_pseudo_label, Y) into ListData.
  112. Parameters
  113. ----------
  114. prefix : str
  115. A prefix indicating the type of data processing (e.g., 'train', 'test').
  116. data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
  117. Data to be preprocessed. Can be ListData or a tuple of lists.
  118. Returns
  119. -------
  120. ListData
  121. The preprocessed ListData object.
  122. """
  123. if isinstance(data, ListData):
  124. data_examples = data
  125. if not (
  126. hasattr(data_examples, "X")
  127. and hasattr(data_examples, "gt_pseudo_label")
  128. and hasattr(data_examples, "Y")
  129. ):
  130. raise ValueError(
  131. f"{prefix}data should have X, gt_pseudo_label and Y attribute but "
  132. f"only {data_examples.all_keys()} are provided."
  133. )
  134. else:
  135. X, gt_pseudo_label, Y = data
  136. data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y)
  137. return data_examples
  138. def concat_data_examples(
  139. self, unlabel_data_examples: ListData, label_data_examples: Optional[ListData]
  140. ) -> ListData:
  141. """
  142. Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data
  143. examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model.
  144. Parameters
  145. ----------
  146. unlabel_data_examples : ListData
  147. Unlabeled data examples to concatenate.
  148. label_data_examples : ListData, optional
  149. Labeled data examples to concatenate, if available.
  150. Returns
  151. -------
  152. ListData
  153. Concatenated data examples.
  154. """
  155. if label_data_examples is None:
  156. return unlabel_data_examples
  157. unlabel_data_examples.X = unlabel_data_examples.X + label_data_examples.X
  158. unlabel_data_examples.abduced_pseudo_label = (
  159. unlabel_data_examples.abduced_pseudo_label + label_data_examples.gt_pseudo_label
  160. )
  161. unlabel_data_examples.Y = unlabel_data_examples.Y + label_data_examples.Y
  162. return unlabel_data_examples
  163. def train(
  164. self,
  165. train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
  166. label_data: Optional[
  167. Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]
  168. ] = None,
  169. val_data: Optional[
  170. Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]]
  171. ] = None,
  172. loops: int = 50,
  173. segment_size: Union[int, float] = 1.0,
  174. eval_interval: int = 1,
  175. save_interval: Optional[int] = None,
  176. save_dir: Optional[str] = None,
  177. ):
  178. """
  179. A typical training pipeline of Abuductive Learning.
  180. Parameters
  181. ----------
  182. train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
  183. Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
  184. object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes.
  185. - ``X`` is a list of sublists representing the input data.
  186. - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but
  187. not to train. ``gt_pseudo_label`` can be ``None``.
  188. - ``Y`` is a list representing the ground truth reasoning result for each sublist
  189. in ``X``.
  190. label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional
  191. Labeled data should be in the same format as ``train_data``. The only difference is
  192. that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be
  193. utilized to train the model. Defaults to None.
  194. val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501
  195. Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label``
  196. and ``Y`` can be either None or not, which depends on the evaluation metircs in
  197. ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate
  198. the model during training time. Defaults to None.
  199. loops : int
  200. Machine Learning part and Reasoning part will be iteratively optimized
  201. for ``loops`` times, by default 50.
  202. segment_size : Union[int, float]
  203. Data will be split into segments of this size and data in each segment
  204. will be used together to train the model, by default 1.0.
  205. eval_interval : int
  206. The model will be evaluated every ``eval_interval`` loop during training,
  207. by default 1.
  208. save_interval : int, optional
  209. The model will be saved every ``eval_interval`` loop during training, by
  210. default None.
  211. save_dir : str, optional
  212. Directory to save the model, by default None.
  213. """
  214. data_examples = self.data_preprocess("train", train_data)
  215. if label_data is not None:
  216. label_data_examples = self.data_preprocess("label", label_data)
  217. else:
  218. label_data_examples = None
  219. if val_data is not None:
  220. val_data_examples = self.data_preprocess("val", val_data)
  221. else:
  222. val_data_examples = data_examples
  223. if isinstance(segment_size, int):
  224. if segment_size <= 0:
  225. raise ValueError("segment_size should be positive.")
  226. elif isinstance(segment_size, float):
  227. if 0 < segment_size <= 1:
  228. segment_size = int(segment_size * len(data_examples))
  229. else:
  230. raise ValueError("segment_size should be in (0, 1].")
  231. else:
  232. raise ValueError("segment_size should be int or float.")
  233. for loop in range(loops):
  234. for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
  235. print_log(
  236. f"loop(train) [{loop + 1}/{loops}] segment(train) "
  237. f"[{(seg_idx + 1)}/{(len(data_examples) - 1) // segment_size + 1}] ",
  238. logger="current",
  239. )
  240. sub_data_examples = data_examples[
  241. seg_idx * segment_size : (seg_idx + 1) * segment_size
  242. ]
  243. self.predict(sub_data_examples)
  244. self.idx_to_pseudo_label(sub_data_examples)
  245. self.abduce_pseudo_label(sub_data_examples)
  246. self.filter_pseudo_label(sub_data_examples)
  247. self.concat_data_examples(sub_data_examples, label_data_examples)
  248. self.pseudo_label_to_idx(sub_data_examples)
  249. self.model.train(sub_data_examples)
  250. if (loop + 1) % eval_interval == 0 or loop == loops - 1:
  251. print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
  252. self._valid(val_data_examples)
  253. if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
  254. print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
  255. self.model.save(
  256. save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
  257. )
  258. def _valid(self, data_examples: ListData) -> None:
  259. """
  260. Internal method for validating the model with given data examples.
  261. Parameters
  262. ----------
  263. data_examples : ListData
  264. Data examples to be used for validation.
  265. """
  266. self.predict(data_examples)
  267. self.idx_to_pseudo_label(data_examples)
  268. for metric in self.metric_list:
  269. metric.process(data_examples)
  270. res = dict()
  271. for metric in self.metric_list:
  272. res.update(metric.evaluate())
  273. msg = "Evaluation ended, "
  274. for k, v in res.items():
  275. msg += k + f": {v:.3f} "
  276. print_log(msg, logger="current")
  277. def valid(
  278. self,
  279. val_data: Union[
  280. ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
  281. ],
  282. ) -> None:
  283. """
  284. Validate the model with the given validation data.
  285. Parameters
  286. ----------
  287. val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501
  288. Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
  289. with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be
  290. either None or not, which depends on the evaluation metircs in ``self.metric_list``.
  291. """
  292. val_data_examples = self.data_preprocess("val", val_data)
  293. self._valid(val_data_examples)
  294. def test(
  295. self,
  296. test_data: Union[
  297. ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
  298. ],
  299. ) -> None:
  300. """
  301. Test the model with the given test data.
  302. Parameters
  303. ----------
  304. test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501
  305. Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
  306. with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y``
  307. can be either None or not, which depends on the evaluation metircs in ``self.metric_list``.
  308. """
  309. print_log("Test start:", logger="current")
  310. test_data_examples = self.data_preprocess("test", test_data)
  311. self._valid(test_data_examples)

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.