You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. """
  2. Implementation of utilities used in ablkit.
  3. Copyright (c) 2024 LAMDA. All rights reserved.
  4. """
  5. from typing import List, Any, Union, Tuple, Optional
  6. import numpy as np
  7. def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[Any]:
  8. """
  9. Flattens a nested list at the first level.
  10. Parameters
  11. ----------
  12. nested_list : List[Union[Any, List[Any], Tuple[Any, ...]]]
  13. A list which might contain sublists or tuples at the first level.
  14. Returns
  15. -------
  16. List[Any]
  17. A flattened version of the input list, where only the first
  18. level of sublists and tuples are reduced.
  19. """
  20. if not isinstance(nested_list, list):
  21. return nested_list
  22. flattened_list = []
  23. for item in nested_list:
  24. if isinstance(item, (list, tuple)):
  25. flattened_list.extend(item)
  26. else:
  27. flattened_list.append(item)
  28. return flattened_list
  29. def reform_list(
  30. flattened_list: List[Any], structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]]
  31. ) -> List[List[Any]]:
  32. """
  33. Reform the list based on the structure of ``structured_list``.
  34. Parameters
  35. ----------
  36. flattened_list : List[Any]
  37. A flattened list of elements.
  38. structured_list : List[Union[Any, List[Any], Tuple[Any, ...]]]
  39. A list that reflects the desired structure, which may contain sublists or tuples.
  40. Returns
  41. -------
  42. List[List[Any]]
  43. A reformed list that mimics the structure of ``structured_list``.
  44. """
  45. if not isinstance(structured_list[0], (list, tuple)):
  46. return flattened_list
  47. reformed_list = []
  48. idx_start = 0
  49. for elem in structured_list:
  50. idx_end = idx_start + len(elem)
  51. reformed_list.append(flattened_list[idx_start:idx_end])
  52. idx_start = idx_end
  53. return reformed_list
  54. def hamming_dist(pred_pseudo_label: List[Any], candidates: List[List[Any]]) -> np.ndarray:
  55. """
  56. Compute the Hamming distance between two arrays.
  57. Parameters
  58. ----------
  59. pred_pseudo_label : List[Any]
  60. Pseudo-labels of an example.
  61. candidates : List[List[Any]]
  62. Multiple possible candidates.
  63. Returns
  64. -------
  65. np.ndarray
  66. Hamming distances computed for each candidate.
  67. """
  68. pred_pseudo_label = np.array(pred_pseudo_label)
  69. candidates = np.array(candidates)
  70. # Ensuring that pred_pseudo_label is broadcastable to the shape of candidates
  71. pred_pseudo_label = np.expand_dims(pred_pseudo_label, 0)
  72. return np.sum(pred_pseudo_label != candidates, axis=1)
  73. def confidence_dist(pred_prob: np.ndarray, candidates_idxs: List[List[Any]]) -> np.ndarray:
  74. """
  75. Compute the confidence distance between prediction probabilities and candidates,
  76. where the confidence distance is defined as 1 - the product of prediction probabilities.
  77. Parameters
  78. ----------
  79. pred_prob : np.ndarray
  80. Prediction probability distributions, each element is an array
  81. representing the probability distribution of a particular prediction.
  82. candidates_idxs : List[List[Any]]
  83. Multiple possible candidates' indices.
  84. Returns
  85. -------
  86. np.ndarray
  87. Confidence distances computed for each candidate.
  88. """
  89. pred_prob = np.clip(pred_prob, 1e-9, 1)
  90. cols = np.arange(len(candidates_idxs[0]))[None, :]
  91. return 1 - np.prod(pred_prob[cols, candidates_idxs], axis=1)
  92. def avg_confidence_dist(pred_prob: np.ndarray, candidates_idxs: List[List[Any]]) -> np.ndarray:
  93. """
  94. Compute the average confidence distance between prediction probabilities and candidates,
  95. where the confidence distance is defined as 1 - the average of prediction probabilities.
  96. Parameters
  97. ----------
  98. pred_prob : np.ndarray
  99. Prediction probability distributions, each element is an array
  100. representing the probability distribution of a particular prediction.
  101. candidates_idxs : List[List[Any]]
  102. Multiple possible candidates' indices.
  103. Returns
  104. -------
  105. np.ndarray
  106. Confidence distances computed for each candidate.
  107. """
  108. cols = np.arange(len(candidates_idxs[0]))[None, :]
  109. return 1 - np.average(pred_prob[cols, candidates_idxs], axis=1)
  110. def to_hashable(x: Union[List[Any], Any]) -> Union[Tuple[Any, ...], Any]:
  111. """
  112. Convert a nested list to a nested tuple so it is hashable.
  113. Parameters
  114. ----------
  115. x : Union[List[Any], Any]
  116. A potentially nested list to convert to a tuple.
  117. Returns
  118. -------
  119. Union[Tuple[Any, ...], Any]
  120. The input converted to a tuple if it was a list,
  121. otherwise the original input.
  122. """
  123. if isinstance(x, list):
  124. return tuple(to_hashable(item) for item in x)
  125. return x
  126. def restore_from_hashable(x):
  127. """
  128. Convert a nested tuple back to a nested list.
  129. Parameters
  130. ----------
  131. x : Union[Tuple[Any, ...], Any]
  132. A potentially nested tuple to convert to a list.
  133. Returns
  134. -------
  135. Union[List[Any], Any]
  136. The input converted to a list if it was a tuple,
  137. otherwise the original input.
  138. """
  139. if isinstance(x, tuple):
  140. return [restore_from_hashable(item) for item in x]
  141. return x
  142. def tab_data_to_tuple(
  143. X: Union[List[Any], Any], y: Union[List[Any], Any], reasoning_result: Optional[Any] = 0
  144. ) -> Tuple[List[List[Any]], List[List[Any]], List[Any]]:
  145. """
  146. Convert a tabular data to a tuple by adding a dimension to each element of
  147. X and y. The tuple contains three elements: data, label, and reasoning result.
  148. If X is None, return None.
  149. Parameters
  150. ----------
  151. X : Union[List[Any], Any]
  152. The data.
  153. y : Union[List[Any], Any]
  154. The label.
  155. reasoning_result : Any, optional
  156. The reasoning result, by default 0.
  157. Returns
  158. -------
  159. Tuple[List[List[Any]], List[List[Any]], List[Any]]
  160. A tuple of (data, label, reasoning_result).
  161. """
  162. if X is None:
  163. return None
  164. if len(X) != len(y):
  165. raise ValueError(
  166. f"The length of X and y should be the same, but got {len(X)} and {len(y)}."
  167. )
  168. return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.