You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Tensor utils."""
  16. import numpy as np
  17. from mindinsight.utils.exceptions import ParamValueError
  18. from mindinsight.utils.exceptions import ParamTypeError
  19. from mindinsight.utils.log import setup_logger
  20. F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max
  21. MAX_DIMENSIONS_FOR_TENSOR = 2
  22. class Statistics:
  23. """Statistics data class.
  24. Args:
  25. stats (dict): Statistic info of tensor data.
  26. - is_bool (bool): If the tensor is bool type.
  27. - max_value (float): Max value of tensor data.
  28. - min_value (float): Min value of tensor data.
  29. - avg_value (float): Avg value of tensor data.
  30. - count (int): Total count of tensor data.
  31. - nan_count (int): Count of NAN.
  32. - neg_zero_count (int): Count of negative zero.
  33. - pos_zero_count (int): Count of positive zero.
  34. - zero_count (int): Count of zero.
  35. - neg_inf_count (int): Count of negative INF.
  36. - pos_inf_count (int): Count of positive INF.
  37. """
  38. def __init__(self, stats):
  39. self._stats = stats
  40. @property
  41. def max(self):
  42. """Get max value of tensor."""
  43. return float(self._stats.get('max_value', 0))
  44. @property
  45. def min(self):
  46. """Get min value of tensor."""
  47. return float(self._stats.get('min_value', 0))
  48. @property
  49. def avg(self):
  50. """Get avg value of tensor."""
  51. return float(self._stats.get('avg_value', 0))
  52. @property
  53. def count(self):
  54. """Get total count of tensor."""
  55. return int(self._stats.get('count', 0))
  56. @property
  57. def nan_count(self):
  58. """Get count of NAN."""
  59. return int(self._stats.get('nan_count', 0))
  60. @property
  61. def neg_inf_count(self):
  62. """Get count of negative INF."""
  63. return int(self._stats.get('neg_inf_count', 0))
  64. @property
  65. def pos_inf_count(self):
  66. """Get count of positive INF."""
  67. return int(self._stats.get('pos_inf_count', 0))
  68. @property
  69. def neg_zero_count(self):
  70. """Get count of negative zero."""
  71. return int(self._stats.get('neg_zero_count', 0))
  72. @property
  73. def pos_zero_count(self):
  74. """Get count of positive zero."""
  75. return int(self._stats.get('pos_zero_count', 0))
  76. @property
  77. def zero_count(self):
  78. """Get count of zero."""
  79. return int(self._stats.get('zero_count', 0))
  80. @property
  81. def true_count(self):
  82. """Get count of False."""
  83. return self.pos_zero_count if self.is_bool else 0
  84. @property
  85. def false_count(self):
  86. """Get count of True."""
  87. return self.zero_count if self.is_bool else 0
  88. @property
  89. def is_bool(self):
  90. """Whether the tensor is bool type."""
  91. return self._stats.get('is_bool', False)
  92. class TensorComparison:
  93. """TensorComparison class.
  94. Args:
  95. tolerance (float): Tolerance for calculating tensor diff.
  96. stats (float): Statistics of tensor diff.
  97. value (numpy.ndarray): Tensor diff.
  98. """
  99. def __init__(self, tolerance=0, stats=None, value=None):
  100. self._tolerance = tolerance
  101. self._stats = stats
  102. self._value = value
  103. @property
  104. def tolerance(self):
  105. """Get tolerance of TensorComparison."""
  106. return self._tolerance
  107. @property
  108. def stats(self):
  109. """Get stats of tensor diff."""
  110. return self._stats
  111. def update(self, tolerance, value):
  112. """update tensor comparisons."""
  113. self._tolerance = tolerance
  114. self._value = value
  115. @property
  116. def value(self):
  117. """Get value of tensor diff."""
  118. return self._value
  119. def str_to_slice_or_int(input_str):
  120. """
  121. Translate param from string to slice or int.
  122. Args:
  123. input_str (str): The string to be translated.
  124. Returns:
  125. Union[int, slice], the transformed param.
  126. """
  127. try:
  128. if ':' in input_str:
  129. ret = slice(*map(lambda x: int(x.strip()) if x.strip() else None, input_str.split(':')))
  130. else:
  131. ret = int(input_str)
  132. except ValueError:
  133. raise ParamValueError("Invalid shape. Convert int from str failed. input_str: {}".format(input_str))
  134. return ret
  135. class TensorUtils:
  136. """Tensor Utils class."""
  137. @staticmethod
  138. def parse_shape(shape, limit=0):
  139. """
  140. Parse shape from str.
  141. Args:
  142. shape (str): Specify shape of tensor.
  143. limit (int): The max dimensions specified. Default value is 0 which means that there is no limitation.
  144. Returns:
  145. Union[None, tuple], a string like this: "[0, 0, 1:10, :]" will convert to this value:
  146. (0, 0, slice(1, 10, None), slice(None, None, None)].
  147. Raises:
  148. ParamValueError, If type of shape is not str or format is not correct or exceed specified dimensions.
  149. """
  150. if shape is None:
  151. return shape
  152. if not (isinstance(shape, str) and shape.strip().startswith('[') and shape.strip().endswith(']')):
  153. raise ParamValueError("Invalid shape. The type of shape should be str and start with `[` and "
  154. "end with `]`. Received: {}.".format(shape))
  155. shape = shape.strip()[1:-1]
  156. dimension_size = sum(1 for dim in shape.split(',') if dim.count(':'))
  157. if limit and dimension_size > limit:
  158. raise ParamValueError("Invalid shape. At most {} dimensions are specified. Received: {}"
  159. .format(limit, shape))
  160. parsed_shape = tuple(
  161. str_to_slice_or_int(dim.strip()) for dim in shape.split(',')) if shape else tuple()
  162. return parsed_shape
  163. @staticmethod
  164. def get_specific_dims_data(ndarray, dims):
  165. """
  166. Get specific dims data.
  167. Args:
  168. ndarray (numpy.ndarray): An ndarray of numpy.
  169. dims (tuple): A tuple of specific dims.
  170. Returns:
  171. numpy.ndarray, an ndarray of specific dims tensor data.
  172. Raises:
  173. ParamValueError, If the length of param dims is not equal to the length of tensor dims.
  174. IndexError, If the param dims and tensor shape is unmatched.
  175. """
  176. if len(ndarray.shape) != len(dims):
  177. raise ParamValueError("Invalid dims. The length of param dims and tensor shape should be the same.")
  178. try:
  179. result = ndarray[dims]
  180. except IndexError:
  181. raise ParamValueError("Invalid shape. Shape unmatched. Received: {}, tensor shape: {}"
  182. .format(dims, ndarray.shape))
  183. # Make sure the return type is numpy.ndarray.
  184. if not isinstance(result, np.ndarray):
  185. result = np.array(result)
  186. return result
  187. @staticmethod
  188. def get_statistics_from_tensor(tensors):
  189. """
  190. Calculates statistics data of tensor.
  191. Args:
  192. tensors (numpy.ndarray): An numpy.ndarray of tensor data.
  193. Returns:
  194. Statistics, an instance of Statistics.
  195. """
  196. ma_value = np.ma.masked_invalid(tensors)
  197. total, valid = tensors.size, ma_value.count()
  198. invalids = []
  199. for isfn in np.isnan, np.isposinf, np.isneginf:
  200. if total - valid > sum(invalids):
  201. count = np.count_nonzero(isfn(tensors))
  202. invalids.append(count)
  203. else:
  204. invalids.append(0)
  205. nan_count, pos_inf_count, neg_inf_count = invalids
  206. logger = setup_logger("utils", "utils")
  207. if not valid:
  208. logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape)
  209. statistics = Statistics({'max_value': 0,
  210. 'min_value': 0,
  211. 'avg_value': 0,
  212. 'count': total,
  213. 'nan_count': nan_count,
  214. 'neg_inf_count': neg_inf_count,
  215. 'pos_inf_count': pos_inf_count})
  216. return statistics
  217. # BUG: max of a masked array with dtype np.float16 returns inf
  218. # See numpy issue#15077
  219. if issubclass(tensors.dtype.type, np.floating):
  220. tensor_min = ma_value.min(fill_value=np.PINF)
  221. tensor_max = ma_value.max(fill_value=np.NINF)
  222. if tensor_min < F32_MIN or tensor_max > F32_MAX:
  223. logger.warning('Values(%f, %f) are too large, you may encounter some undefined '
  224. 'behaviours hereafter.', tensor_min, tensor_max)
  225. else:
  226. tensor_min = ma_value.min()
  227. tensor_max = ma_value.max()
  228. tensor_sum = ma_value.sum(dtype=np.float64)
  229. with np.errstate(invalid='ignore'):
  230. neg_zero_count = np.sum(ma_value < 0)
  231. with np.errstate(invalid='ignore'):
  232. pos_zero_count = np.sum(ma_value > 0)
  233. with np.errstate(invalid='ignore'):
  234. zero_count = np.sum(ma_value == 0)
  235. statistics = Statistics({'is_bool': tensors.dtype == np.bool,
  236. 'max_value': tensor_max,
  237. 'min_value': tensor_min,
  238. 'avg_value': tensor_sum / valid,
  239. 'count': total,
  240. 'neg_zero_count': neg_zero_count,
  241. 'pos_zero_count': pos_zero_count,
  242. 'zero_count': zero_count,
  243. 'nan_count': nan_count,
  244. 'neg_inf_count': neg_inf_count,
  245. 'pos_inf_count': pos_inf_count})
  246. return statistics
  247. @staticmethod
  248. def get_statistics_dict(stats, overall_stats):
  249. """
  250. Get statistics dict according to statistics value.
  251. Args:
  252. stats (Statistics): An instance of Statistics for sliced tensor.
  253. overall_stats (Statistics): An instance of Statistics for whole tensor.
  254. Returns:
  255. dict, a dict including 'max', 'min', 'avg', 'count',
  256. 'nan_count', 'neg_inf_count', 'pos_inf_count', 'overall_max', 'overall_min'.
  257. """
  258. statistics = {
  259. "max": float(stats.max),
  260. "min": float(stats.min),
  261. "avg": float(stats.avg),
  262. "count": stats.count,
  263. "nan_count": stats.nan_count,
  264. "neg_inf_count": stats.neg_inf_count,
  265. "pos_inf_count": stats.pos_inf_count}
  266. overall_statistics = TensorUtils.get_overall_statistic_dict(overall_stats)
  267. statistics.update(overall_statistics)
  268. return statistics
  269. @staticmethod
  270. def get_overall_statistic_dict(overall_stats):
  271. """
  272. Get overall statistics dict according to statistics value.
  273. Args:
  274. overall_stats (Statistics): An instance of Statistics for whole tensor.
  275. Returns:
  276. dict, overall statistics.
  277. """
  278. if not overall_stats:
  279. return {}
  280. if overall_stats.is_bool:
  281. res = {
  282. 'overall_count': overall_stats.count,
  283. 'overall_true_count': overall_stats.true_count,
  284. 'overall_false_count': overall_stats.false_count
  285. }
  286. else:
  287. res = {
  288. "overall_max": float(overall_stats.max),
  289. "overall_min": float(overall_stats.min),
  290. "overall_avg": float(overall_stats.avg),
  291. "overall_count": overall_stats.count,
  292. "overall_nan_count": overall_stats.nan_count,
  293. "overall_neg_inf_count": overall_stats.neg_inf_count,
  294. "overall_pos_inf_count": overall_stats.pos_inf_count,
  295. "overall_zero_count": float(overall_stats.zero_count),
  296. "overall_neg_zero_count": float(overall_stats.neg_zero_count),
  297. "overall_pos_zero_count": float(overall_stats.pos_zero_count)
  298. }
  299. return res
  300. @staticmethod
  301. def calc_diff_between_two_tensor(first_tensor, second_tensor, tolerance):
  302. """
  303. Calculate the difference between the first tensor and the second tensor.
  304. Args:
  305. first_tensor (numpy.ndarray): Specify the first tensor.
  306. second_tensor (numpy.ndarray): Specify the second tensor.
  307. tolerance (float): The tolerance of difference between the first tensor and the second tensor.
  308. Its is a percentage. The boundary value is equal to max(abs(min),abs(max)) * tolerance.
  309. The function of min and max is being used to calculate the min value and max value of
  310. the result of the first tensor subtract the second tensor. If the absolute value of
  311. result is less than or equal to boundary value, the result will set to be zero.
  312. Returns:
  313. tuple[numpy.ndarray, OverallDiffMetric], numpy.ndarray indicates the value of the first tensor
  314. subtract the second tensor and set the value to be zero when its less than or equal to tolerance.
  315. Raises:
  316. ParamTypeError: If the type of these two tensors is not the numpy.ndarray.
  317. ParamValueError: If the shape or dtype is not the same of these two tensors or
  318. the tolerance should be between 0 and 1.
  319. """
  320. if not isinstance(first_tensor, np.ndarray):
  321. raise ParamTypeError('first_tensor', np.ndarray)
  322. if not isinstance(second_tensor, np.ndarray):
  323. raise ParamTypeError('second_tensor', np.ndarray)
  324. if first_tensor.shape != second_tensor.shape:
  325. raise ParamValueError("the shape: {} of first tensor is not equal to shape: {} of second tensor."
  326. .format(first_tensor.shape, second_tensor.shape))
  327. if first_tensor.dtype != second_tensor.dtype:
  328. raise ParamValueError("the dtype: {} of first tensor is not equal to dtype: {} of second tensor."
  329. .format(first_tensor.dtype, second_tensor.dtype))
  330. # Make sure tolerance is between 0 and 1.
  331. if tolerance < 0 or tolerance > 1:
  332. raise ParamValueError("the tolerance should be between 0 and 1, but got {}".format(tolerance))
  333. diff_tensor = np.subtract(first_tensor, second_tensor)
  334. stats = TensorUtils.get_statistics_from_tensor(diff_tensor)
  335. boundary_value = max(abs(stats.max), abs(stats.min)) * tolerance
  336. is_close = np.isclose(first_tensor, second_tensor, atol=boundary_value, rtol=0)
  337. result = np.multiply(diff_tensor, ~is_close)
  338. return result