You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 8.1 kB

2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. from itertools import chain
  2. import numpy as np
  3. def flatten(nested_list):
  4. """
  5. Flattens a nested list.
  6. Parameters
  7. ----------
  8. nested_list : list
  9. A list which might contain sublists or tuples.
  10. Returns
  11. -------
  12. list
  13. A flattened version of the input list.
  14. Raises
  15. ------
  16. TypeError
  17. If the input object is not a list.
  18. """
  19. if not isinstance(nested_list, list):
  20. raise TypeError("Input must be of type list.")
  21. if not nested_list or not isinstance(nested_list[0], (list, tuple)):
  22. return nested_list
  23. return list(chain.from_iterable(nested_list))
  24. def reform_idx(flattened_list, structured_list):
  25. """
  26. Reform the index based on structured_list structure.
  27. Parameters
  28. ----------
  29. flattened_list : list
  30. A flattened list of predictions.
  31. structured_list : list
  32. A list containing saved predictions, which could be nested lists or tuples.
  33. Returns
  34. -------
  35. list
  36. A reformed list that mimics the structure of structured_list.
  37. """
  38. # if not isinstance(flattened_list, list):
  39. # raise TypeError("Input must be of type list.")
  40. if not isinstance(structured_list[0], (list, tuple)):
  41. return flattened_list
  42. reformed_list = []
  43. idx_start = 0
  44. for elem in structured_list:
  45. idx_end = idx_start + len(elem)
  46. reformed_list.append(flattened_list[idx_start:idx_end])
  47. idx_start = idx_end
  48. return reformed_list
  49. def hamming_dist(pred_pseudo_label, candidates):
  50. """
  51. Compute the Hamming distance between two arrays.
  52. Parameters
  53. ----------
  54. pred_pseudo_label : list
  55. First array to compare.
  56. candidates : list
  57. Second array to compare, expected to have shape (n, m)
  58. where n is the number of rows, m is the length of pred_pseudo_label.
  59. Returns
  60. -------
  61. numpy.ndarray
  62. Hamming distances.
  63. """
  64. pred_pseudo_label = np.array(pred_pseudo_label)
  65. candidates = np.array(candidates)
  66. # Ensuring that pred_pseudo_label is broadcastable to the shape of candidates
  67. pred_pseudo_label = np.expand_dims(pred_pseudo_label, 0)
  68. return np.sum(pred_pseudo_label != candidates, axis=1)
  69. def confidence_dist(pred_prob, candidates):
  70. """
  71. Compute the confidence distance between prediction probabilities and candidates.
  72. Parameters
  73. ----------
  74. pred_prob : list of numpy.ndarray
  75. Prediction probability distributions, each element is an ndarray
  76. representing the probability distribution of a particular prediction.
  77. candidates : list of list of int
  78. Index of candidate labels, each element is a list of indexes being considered
  79. as a candidate correction.
  80. Returns
  81. -------
  82. numpy.ndarray
  83. Confidence distances computed for each candidate.
  84. """
  85. pred_prob = np.clip(pred_prob, 1e-9, 1)
  86. _, cols = np.indices((len(candidates), len(candidates[0])))
  87. return 1 - np.prod(pred_prob[cols, candidates], axis=1)
  88. def block_sample(X, Z, Y, sample_num, seg_idx):
  89. """
  90. Extract a block of samples from lists X, Z, and Y.
  91. Parameters
  92. ----------
  93. X, Z, Y : list
  94. Input lists from which to extract the samples.
  95. sample_num : int
  96. The number of samples per block.
  97. seg_idx : int
  98. The block index to extract.
  99. Returns
  100. -------
  101. tuple of lists
  102. The extracted block samples from X, Z, and Y.
  103. Example
  104. -------
  105. >>> X = [1, 2, 3, 4, 5, 6]
  106. >>> Z = ['a', 'b', 'c', 'd', 'e', 'f']
  107. >>> Y = [10, 11, 12, 13, 14, 15]
  108. >>> block_sample(X, Z, Y, 2, 1)
  109. ([3, 4], ['c', 'd'], [12, 13])
  110. """
  111. start_idx = sample_num * seg_idx
  112. end_idx = sample_num * (seg_idx + 1)
  113. return (data[start_idx:end_idx] for data in (X, Z, Y))
  114. def check_equal(a, b, max_err=0):
  115. """
  116. Check whether two numbers a and b are equal within a maximum allowable error.
  117. Parameters
  118. ----------
  119. a, b : int or float
  120. The numbers to compare.
  121. max_err : int or float, optional
  122. The maximum allowable absolute difference between a and b for them to be considered equal.
  123. Default is 0, meaning the numbers must be exactly equal.
  124. Returns
  125. -------
  126. bool
  127. True if a and b are equal within the allowable error, False otherwise.
  128. Raises
  129. ------
  130. TypeError
  131. If a or b are not of type int or float.
  132. """
  133. if not (isinstance(a, (int, float)) and isinstance(b, (int, float))):
  134. raise TypeError("Input values must be int or float.")
  135. return abs(a - b) <= max_err
  136. def to_hashable(x):
  137. """
  138. Convert a nested list to a nested tuple so it is hashable.
  139. Parameters
  140. ----------
  141. x : list or other type
  142. A potentially nested list to convert to a tuple.
  143. Returns
  144. -------
  145. tuple or other type
  146. The input converted to a tuple if it was a list,
  147. otherwise the original input.
  148. """
  149. if isinstance(x, list):
  150. return tuple(to_hashable(item) for item in x)
  151. return x
  152. def hashable_to_list(x):
  153. """
  154. Convert a nested tuple back to a nested list.
  155. Parameters
  156. ----------
  157. x : tuple or other type
  158. A potentially nested tuple to convert to a list.
  159. Returns
  160. -------
  161. list or other type
  162. The input converted to a list if it was a tuple,
  163. otherwise the original input.
  164. """
  165. if isinstance(x, tuple):
  166. return [hashable_to_list(item) for item in x]
  167. return x
  168. def calculate_revision_num(parameter, total_length):
  169. """
  170. Convert a float parameter to an integer, based on a total length.
  171. Parameters
  172. ----------
  173. parameter : int or float
  174. The parameter to convert. If float, it should be between 0 and 1.
  175. If int, it should be non-negative. If -1, it will be replaced with total_length.
  176. total_length : int
  177. The total length to calculate the parameter from if it's a fraction.
  178. Returns
  179. -------
  180. int
  181. The calculated parameter.
  182. Raises
  183. ------
  184. TypeError
  185. If parameter is not an int or a float.
  186. ValueError
  187. If parameter is a float not in [0, 1] or an int below 0.
  188. """
  189. if not isinstance(parameter, (int, float)):
  190. raise TypeError("Parameter must be of type int or float.")
  191. if parameter == -1:
  192. return total_length
  193. elif isinstance(parameter, float):
  194. if not (0 <= parameter <= 1):
  195. raise ValueError("If parameter is a float, it must be between 0 and 1.")
  196. return round(total_length * parameter)
  197. else:
  198. if parameter < 0:
  199. raise ValueError("If parameter is an int, it must be non-negative.")
  200. return parameter
  201. if __name__ == "__main__":
  202. A = np.array(
  203. [
  204. [
  205. 0.18401675,
  206. 0.06797526,
  207. 0.06797541,
  208. 0.06801736,
  209. 0.06797528,
  210. 0.06797526,
  211. 0.06818808,
  212. 0.06797527,
  213. 0.06800033,
  214. 0.06797526,
  215. 0.06797526,
  216. 0.06797526,
  217. 0.06797526,
  218. ],
  219. [
  220. 0.07223161,
  221. 0.0685229,
  222. 0.06852708,
  223. 0.17227574,
  224. 0.06852163,
  225. 0.07018146,
  226. 0.06860291,
  227. 0.06852849,
  228. 0.06852163,
  229. 0.0685216,
  230. 0.0685216,
  231. 0.06852174,
  232. 0.0685216,
  233. ],
  234. [
  235. 0.06794382,
  236. 0.0679436,
  237. 0.06794395,
  238. 0.06794346,
  239. 0.06794346,
  240. 0.18467231,
  241. 0.06794345,
  242. 0.06794871,
  243. 0.06794345,
  244. 0.06794345,
  245. 0.06794345,
  246. 0.06794345,
  247. 0.06794345,
  248. ],
  249. ],
  250. dtype=np.float32,
  251. )
  252. B = [[0, 9, 3], [0, 11, 4]]
  253. print(ori_confidence_dist(A, B))
  254. print(confidence_dist(A, B))

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.