You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.0 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from itertools import chain
  2. import numpy as np
  3. def flatten(nested_list):
  4. """
  5. Flattens a nested list.
  6. Parameters
  7. ----------
  8. nested_list : list
  9. A list which might contain sublists or tuples.
  10. Returns
  11. -------
  12. list
  13. A flattened version of the input list.
  14. Raises
  15. ------
  16. TypeError
  17. If the input object is not a list.
  18. """
  19. # if not isinstance(nested_list, list):
  20. # raise TypeError("Input must be of type list.")
  21. if not isinstance(nested_list, list) or not isinstance(nested_list[0], (list, tuple)):
  22. return nested_list
  23. return list(chain.from_iterable(nested_list))
  24. def reform_list(flattened_list, structured_list):
  25. """
  26. Reform the index based on structured_list structure.
  27. Parameters
  28. ----------
  29. flattened_list : list
  30. A flattened list of predictions.
  31. structured_list : list
  32. A list containing saved predictions, which could be nested lists or tuples.
  33. Returns
  34. -------
  35. list
  36. A reformed list that mimics the structure of structured_list.
  37. """
  38. # if not isinstance(flattened_list, list):
  39. # raise TypeError("Input must be of type list.")
  40. if not isinstance(structured_list[0], (list, tuple)):
  41. return flattened_list
  42. reformed_list = []
  43. idx_start = 0
  44. for elem in structured_list:
  45. idx_end = idx_start + len(elem)
  46. reformed_list.append(flattened_list[idx_start:idx_end])
  47. idx_start = idx_end
  48. return reformed_list
  49. def hamming_dist(pred_pseudo_label, candidates):
  50. """
  51. Compute the Hamming distance between two arrays.
  52. Parameters
  53. ----------
  54. pred_pseudo_label : list
  55. First array to compare.
  56. candidates : list
  57. Second array to compare, expected to have shape (n, m)
  58. where n is the number of rows, m is the length of pred_pseudo_label.
  59. Returns
  60. -------
  61. numpy.ndarray
  62. Hamming distances.
  63. """
  64. pred_pseudo_label = np.array(pred_pseudo_label)
  65. candidates = np.array(candidates)
  66. # Ensuring that pred_pseudo_label is broadcastable to the shape of candidates
  67. pred_pseudo_label = np.expand_dims(pred_pseudo_label, 0)
  68. return np.sum(pred_pseudo_label != candidates, axis=1)
  69. def confidence_dist(pred_prob, candidates):
  70. """
  71. Compute the confidence distance between prediction probabilities and candidates.
  72. Parameters
  73. ----------
  74. pred_prob : list of numpy.ndarray
  75. Prediction probability distributions, each element is an ndarray
  76. representing the probability distribution of a particular prediction.
  77. candidates : list of list of int
  78. Index of candidate labels, each element is a list of indexes being considered
  79. as a candidate correction.
  80. Returns
  81. -------
  82. numpy.ndarray
  83. Confidence distances computed for each candidate.
  84. """
  85. pred_prob = np.clip(pred_prob, 1e-9, 1)
  86. _, cols = np.indices((len(candidates), len(candidates[0])))
  87. return 1 - np.prod(pred_prob[cols, candidates], axis=1)
  88. def block_sample(X, Z, Y, sample_num, seg_idx):
  89. """
  90. Extract a block of samples from lists X, Z, and Y.
  91. Parameters
  92. ----------
  93. X, Z, Y : list
  94. Input lists from which to extract the samples.
  95. sample_num : int
  96. The number of samples per block.
  97. seg_idx : int
  98. The block index to extract.
  99. Returns
  100. -------
  101. tuple of lists
  102. The extracted block samples from X, Z, and Y.
  103. Example
  104. -------
  105. >>> X = [1, 2, 3, 4, 5, 6]
  106. >>> Z = ['a', 'b', 'c', 'd', 'e', 'f']
  107. >>> Y = [10, 11, 12, 13, 14, 15]
  108. >>> block_sample(X, Z, Y, 2, 1)
  109. ([3, 4], ['c', 'd'], [12, 13])
  110. """
  111. start_idx = sample_num * seg_idx
  112. end_idx = sample_num * (seg_idx + 1)
  113. return (data[start_idx:end_idx] for data in (X, Z, Y))
  114. def to_hashable(x):
  115. """
  116. Convert a nested list to a nested tuple so it is hashable.
  117. Parameters
  118. ----------
  119. x : list or other type
  120. A potentially nested list to convert to a tuple.
  121. Returns
  122. -------
  123. tuple or other type
  124. The input converted to a tuple if it was a list,
  125. otherwise the original input.
  126. """
  127. if isinstance(x, list):
  128. return tuple(to_hashable(item) for item in x)
  129. return x
  130. def restore_from_hashable(x):
  131. """
  132. Convert a nested tuple back to a nested list.
  133. Parameters
  134. ----------
  135. x : tuple or other type
  136. A potentially nested tuple to convert to a list.
  137. Returns
  138. -------
  139. list or other type
  140. The input converted to a list if it was a tuple,
  141. otherwise the original input.
  142. """
  143. if isinstance(x, tuple):
  144. return [restore_from_hashable(item) for item in x]
  145. return x
  146. def calculate_revision_num(parameter, total_length):
  147. """
  148. Convert a float parameter to an integer, based on a total length.
  149. Parameters
  150. ----------
  151. parameter : int or float
  152. The parameter to convert. If float, it should be between 0 and 1.
  153. If int, it should be non-negative. If -1, it will be replaced with total_length.
  154. total_length : int
  155. The total length to calculate the parameter from if it's a fraction.
  156. Returns
  157. -------
  158. int
  159. The calculated parameter.
  160. Raises
  161. ------
  162. TypeError
  163. If parameter is not an int or a float.
  164. ValueError
  165. If parameter is a float not in [0, 1] or an int below 0.
  166. """
  167. if not isinstance(parameter, (int, float)):
  168. raise TypeError("Parameter must be of type int or float.")
  169. if parameter == -1:
  170. return total_length
  171. elif isinstance(parameter, float):
  172. if not (0 <= parameter <= 1):
  173. raise ValueError("If parameter is a float, it must be between 0 and 1.")
  174. return round(total_length * parameter)
  175. else:
  176. if parameter < 0:
  177. raise ValueError("If parameter is an int, it must be non-negative.")
  178. return parameter

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.