You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from typing import List, Any, Union, Tuple, Optional
  2. import numpy as np
  3. def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[Any]:
  4. """
  5. Flattens a nested list at the first level.
  6. Parameters
  7. ----------
  8. nested_list : List[Union[Any, List[Any], Tuple[Any, ...]]]
  9. A list which might contain sublists or tuples at the first level.
  10. Returns
  11. -------
  12. List[Any]
  13. A flattened version of the input list, where only the first
  14. level of sublists and tuples are reduced.
  15. """
  16. if not isinstance(nested_list, list):
  17. return nested_list
  18. flattened_list = []
  19. for item in nested_list:
  20. if isinstance(item, (list, tuple)):
  21. flattened_list.extend(item)
  22. else:
  23. flattened_list.append(item)
  24. return flattened_list
  25. def reform_list(
  26. flattened_list: List[Any], structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]]
  27. ) -> List[List[Any]]:
  28. """
  29. Reform the list based on the structure of ``structured_list``.
  30. Parameters
  31. ----------
  32. flattened_list : List[Any]
  33. A flattened list of elements.
  34. structured_list : List[Union[Any, List[Any], Tuple[Any, ...]]]
  35. A list that reflects the desired structure, which may contain sublists or tuples.
  36. Returns
  37. -------
  38. List[List[Any]]
  39. A reformed list that mimics the structure of ``structured_list``.
  40. """
  41. if not isinstance(structured_list[0], (list, tuple)):
  42. return flattened_list
  43. reformed_list = []
  44. idx_start = 0
  45. for elem in structured_list:
  46. idx_end = idx_start + len(elem)
  47. reformed_list.append(flattened_list[idx_start:idx_end])
  48. idx_start = idx_end
  49. return reformed_list
  50. def hamming_dist(pred_pseudo_label: List[Any], candidates: List[List[Any]]) -> np.ndarray:
  51. """
  52. Compute the Hamming distance between two arrays.
  53. Parameters
  54. ----------
  55. pred_pseudo_label : List[Any]
  56. Pseudo-labels of an example.
  57. candidates : List[List[Any]]
  58. Multiple possible candidates.
  59. Returns
  60. -------
  61. np.ndarray
  62. Hamming distances computed for each candidate.
  63. """
  64. pred_pseudo_label = np.array(pred_pseudo_label)
  65. candidates = np.array(candidates)
  66. # Ensuring that pred_pseudo_label is broadcastable to the shape of candidates
  67. pred_pseudo_label = np.expand_dims(pred_pseudo_label, 0)
  68. return np.sum(pred_pseudo_label != candidates, axis=1)
  69. def confidence_dist(pred_prob: np.ndarray, candidates_idxs: List[List[Any]]) -> np.ndarray:
  70. """
  71. Compute the confidence distance between prediction probabilities and candidates,
  72. where the confidence distance is defined as 1 - the product of prediction probabilities.
  73. Parameters
  74. ----------
  75. pred_prob : np.ndarray
  76. Prediction probability distributions, each element is an array
  77. representing the probability distribution of a particular prediction.
  78. candidates_idxs : List[List[Any]]
  79. Multiple possible candidates' indices.
  80. Returns
  81. -------
  82. np.ndarray
  83. Confidence distances computed for each candidate.
  84. """
  85. pred_prob = np.clip(pred_prob, 1e-9, 1)
  86. cols = np.arange(len(candidates_idxs[0]))[None, :]
  87. return 1 - np.prod(pred_prob[cols, candidates_idxs], axis=1)
  88. def avg_confidence_dist(pred_prob: np.ndarray, candidates_idxs: List[List[Any]]) -> np.ndarray:
  89. """
  90. Compute the average confidence distance between prediction probabilities and candidates,
  91. where the confidence distance is defined as 1 - the average of prediction probabilities.
  92. Parameters
  93. ----------
  94. pred_prob : np.ndarray
  95. Prediction probability distributions, each element is an array
  96. representing the probability distribution of a particular prediction.
  97. candidates_idxs : List[List[Any]]
  98. Multiple possible candidates' indices.
  99. Returns
  100. -------
  101. np.ndarray
  102. Confidence distances computed for each candidate.
  103. """
  104. cols = np.arange(len(candidates_idxs[0]))[None, :]
  105. return 1 - np.average(pred_prob[cols, candidates_idxs], axis=1)
  106. def to_hashable(x: Union[List[Any], Any]) -> Union[Tuple[Any, ...], Any]:
  107. """
  108. Convert a nested list to a nested tuple so it is hashable.
  109. Parameters
  110. ----------
  111. x : Union[List[Any], Any]
  112. A potentially nested list to convert to a tuple.
  113. Returns
  114. -------
  115. Union[Tuple[Any, ...], Any]
  116. The input converted to a tuple if it was a list,
  117. otherwise the original input.
  118. """
  119. if isinstance(x, list):
  120. return tuple(to_hashable(item) for item in x)
  121. return x
  122. def restore_from_hashable(x):
  123. """
  124. Convert a nested tuple back to a nested list.
  125. Parameters
  126. ----------
  127. x : Union[Tuple[Any, ...], Any]
  128. A potentially nested tuple to convert to a list.
  129. Returns
  130. -------
  131. Union[List[Any], Any]
  132. The input converted to a list if it was a tuple,
  133. otherwise the original input.
  134. """
  135. if isinstance(x, tuple):
  136. return [restore_from_hashable(item) for item in x]
  137. return x
  138. def tab_data_to_tuple(
  139. X: Union[List[Any], Any], y: Union[List[Any], Any], reasoning_result: Optional[Any] = 0
  140. ) -> Tuple[List[List[Any]], List[List[Any]], List[Any]]:
  141. """
  142. Convert a tabular data to a tuple by adding a dimension to each element of
  143. X and y. The tuple contains three elements: data, label, and reasoning result.
  144. If X is None, return None.
  145. Parameters
  146. ----------
  147. X : Union[List[Any], Any]
  148. The data.
  149. y : Union[List[Any], Any]
  150. The label.
  151. reasoning_result : Any, optional
  152. The reasoning result, by default 0.
  153. Returns
  154. -------
  155. Tuple[List[List[Any]], List[List[Any]], List[Any]]
  156. A tuple of (data, label, reasoning_result).
  157. """
  158. if X is None:
  159. return None
  160. if len(X) != len(y):
  161. raise ValueError(
  162. "The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))
  163. )
  164. return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.