You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

histogram_container.py 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Histogram data container."""
  16. import math
  17. from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary
  18. from mindinsight.utils.exceptions import ParamValueError
  19. from mindinsight.datavisual.utils.utils import calc_histogram_bins
  20. def _mask_invalid_number(num):
  21. """Mask invalid number to 0."""
  22. if math.isnan(num) or math.isinf(num):
  23. return type(num)(0)
  24. return num
  25. class Bucket:
  26. """
  27. Bucket data class.
  28. Args:
  29. left (double): Left edge of the histogram bucket.
  30. width (double): Width of the histogram bucket.
  31. count (int): Count of numbers fallen in the histogram bucket.
  32. """
  33. def __init__(self, left, width, count):
  34. self._left = left
  35. self._width = width
  36. self._count = count
  37. @property
  38. def left(self):
  39. """Gets left edge of the histogram bucket."""
  40. return self._left
  41. @property
  42. def count(self):
  43. """Gets count of numbers fallen in the histogram bucket."""
  44. return self._count
  45. @property
  46. def width(self):
  47. """Gets width of the histogram bucket."""
  48. return self._width
  49. @property
  50. def right(self):
  51. """Gets right edge of the histogram bucket."""
  52. return self._left + self._width
  53. def as_tuple(self):
  54. """Gets the bucket as tuple."""
  55. return self._left, self._width, self._count
  56. def __repr__(self):
  57. """Returns repr(self)."""
  58. return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count)
  59. class HistogramContainer:
  60. """
  61. Histogram data container.
  62. Args:
  63. histogram_message (Summary.Histogram): Histogram message in summary file.
  64. """
  65. # Max quantity of original buckets.
  66. MAX_ORIGINAL_BUCKETS_COUNT = 90
  67. def __init__(self, histogram_message: Summary.Histogram):
  68. self._msg = histogram_message
  69. original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets]
  70. # Ensure buckets are sorted from min to max.
  71. original_buckets.sort(key=lambda bucket: bucket.left)
  72. self._original_buckets = tuple(original_buckets)
  73. self._count = sum(bucket.count for bucket in self._original_buckets)
  74. self._max = _mask_invalid_number(histogram_message.max)
  75. self._min = _mask_invalid_number(histogram_message.min)
  76. self._visual_max = self._max
  77. self._visual_min = self._min
  78. # default bin number
  79. self._visual_bins = calc_histogram_bins(self._count)
  80. # Note that tuple is immutable, so sharing tuple is often safe.
  81. self._re_sampled_buckets = ()
  82. @property
  83. def max(self):
  84. """Gets max value of the tensor."""
  85. return self._max
  86. @property
  87. def min(self):
  88. """Gets min value of the tensor."""
  89. return self._min
  90. @property
  91. def count(self):
  92. """Gets valid number count of the tensor."""
  93. return self._count
  94. @property
  95. def original_msg(self):
  96. """Gets original proto message."""
  97. return self._msg
  98. @property
  99. def original_buckets_count(self):
  100. """Gets original buckets quantity."""
  101. return len(self._original_buckets)
  102. def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None:
  103. """
  104. Sets visual range for later re-sampling.
  105. It's caller's duty to ensure input is valid.
  106. Why we need visual range for histograms? Aligned buckets between steps can help users know about the trend of
  107. tensors. Miss aligned buckets between steps might miss-lead users about the trend of a tensor. Because for
  108. given tensor, if you have thinner buckets, count of every bucket will get lower, however, if you have
  109. thicker buckets, count of every bucket will get higher. When they are displayed together, user might think
  110. the histogram with thicker buckets has more values. This is miss-leading. So we need to unify buckets across
  111. steps. Visual range for histogram is a technology for unifying buckets.
  112. Args:
  113. max_val (float): Max value for visual histogram.
  114. min_val (float): Min value for visual histogram.
  115. bins (int): Bins number for visual histogram.
  116. """
  117. if max_val < min_val:
  118. raise ParamValueError(
  119. "Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val))
  120. if bins < 1:
  121. raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins))
  122. self._visual_max = max_val
  123. self._visual_min = min_val
  124. self._visual_bins = bins
  125. # mark _re_sampled_buckets to empty
  126. self._re_sampled_buckets = ()
  127. def _calc_intersection_len(self, max1, min1, max2, min2):
  128. """Calculates intersection length of [min1, max1] and [min2, max2]."""
  129. if max1 < min1:
  130. raise ParamValueError(
  131. "Invalid input. max1({}) is less than min1({}).".format(max1, min1))
  132. if max2 < min2:
  133. raise ParamValueError(
  134. "Invalid input. max2({}) is less than min2({}).".format(max2, min2))
  135. if min1 <= min2:
  136. if max1 <= min2:
  137. # return value must be calculated by max1.__sub__
  138. return max1 - max1
  139. if max1 <= max2:
  140. return max1 - min2
  141. # max1 > max2
  142. return max2 - min2
  143. # min1 > min2
  144. if max2 <= min1:
  145. return max2 - max2
  146. if max2 <= max1:
  147. return max2 - min1
  148. return max1 - min1
  149. def _re_sample_buckets(self):
  150. """Re-samples buckets according to visual_max, visual_min and visual_bins."""
  151. if self._visual_max == self._visual_min:
  152. # Adjust visual range if max equals min.
  153. self._visual_max += 0.5
  154. self._visual_min -= 0.5
  155. width = (self._visual_max - self._visual_min) / self._visual_bins
  156. if not self.count:
  157. self._re_sampled_buckets = tuple(
  158. Bucket(self._visual_min + width * i, width, 0)
  159. for i in range(self._visual_bins))
  160. return
  161. re_sampled = []
  162. original_pos = 0
  163. original_bucket = self._original_buckets[original_pos]
  164. for i in range(self._visual_bins):
  165. cur_left = self._visual_min + width * i
  166. cur_right = cur_left + width
  167. cur_estimated_count = 0.0
  168. # Skip no bucket range.
  169. if cur_right <= original_bucket.left:
  170. re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
  171. continue
  172. # Skip no intersect range.
  173. while cur_left >= original_bucket.right:
  174. original_pos += 1
  175. if original_pos >= len(self._original_buckets):
  176. break
  177. original_bucket = self._original_buckets[original_pos]
  178. # entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right
  179. while True:
  180. if original_pos >= len(self._original_buckets):
  181. break
  182. original_bucket = self._original_buckets[original_pos]
  183. intersection = self._calc_intersection_len(
  184. min1=cur_left, max1=cur_right,
  185. min2=original_bucket.left, max2=original_bucket.right)
  186. estimated_count = (intersection / original_bucket.width) * original_bucket.count
  187. cur_estimated_count += estimated_count
  188. if cur_right > original_bucket.right:
  189. # Need to sample next original bucket to this visual bucket.
  190. original_pos += 1
  191. else:
  192. # Current visual bucket has taken all intersect buckets into account.
  193. break
  194. re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
  195. self._re_sampled_buckets = tuple(re_sampled)
  196. def buckets(self, convert_to_tuple=True):
  197. """
  198. Get visual buckets instead of original buckets.
  199. Args:
  200. convert_to_tuple (bool): Whether convert bucket object to tuple.
  201. Returns:
  202. tuple, contains buckets.
  203. """
  204. if not self._re_sampled_buckets:
  205. self._re_sample_buckets()
  206. if not convert_to_tuple:
  207. return self._re_sampled_buckets
  208. return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets)

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。