You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reservoir.py 7.1 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """A reservoir sampling on the values."""
  16. import random
  17. import threading
  18. from mindinsight.datavisual.common.log import logger
  19. from mindinsight.datavisual.common.enums import PluginNameEnum
  20. from mindinsight.utils.exceptions import ParamValueError
  21. from mindinsight.datavisual.utils.utils import calc_histogram_bins
  22. class Reservoir:
  23. """
  24. A container based on Reservoir Sampling algorithm.
  25. The newly added sample will be preserved. If the container is full, an old
  26. sample will be replaced randomly. The probability of each sample being
  27. replaced is the same.
  28. """
  29. def __init__(self, size):
  30. """
  31. A Container constructor which create a new Reservoir.
  32. Args:
  33. size (int): Container Size. If the size is 0, the container is not limited.
  34. Raises:
  35. ValueError: If size is negative integer.
  36. """
  37. if not isinstance(size, (int,)) or size < 0:
  38. raise ParamValueError('size must be nonnegative integer, was %s' % size)
  39. self._samples_max_size = size
  40. self._samples = []
  41. self._sample_counter = 0
  42. self._sample_selector = random.Random(0)
  43. self._mutex = threading.Lock()
  44. def samples(self):
  45. """Return all stored samples."""
  46. with self._mutex:
  47. return list(self._samples)
  48. def add_sample(self, sample):
  49. """
  50. Add a sample to Reservoir.
  51. Replace the old sample when the capacity is full.
  52. New added samples are guaranteed to be added to the reservoir.
  53. Args:
  54. sample (Any): The sample to add to the Reservoir.
  55. """
  56. with self._mutex:
  57. if len(self._samples) < self._samples_max_size or self._samples_max_size == 0:
  58. self._samples.append(sample)
  59. else:
  60. # Use the Reservoir Sampling algorithm to replace the old sample.
  61. rand_int = self._sample_selector.randint(
  62. 0, self._sample_counter)
  63. if rand_int < self._samples_max_size:
  64. self._samples.pop(rand_int)
  65. self._samples.append(sample)
  66. else:
  67. self._samples[-1] = sample
  68. self._sample_counter += 1
  69. def remove_sample(self, filter_fun):
  70. """
  71. Remove the samples from Reservoir that do not meet the filter criteria.
  72. Args:
  73. filter_fun (Callable[..., Any]): Determines whether a sample meets
  74. the deletion condition.
  75. Returns:
  76. int, the number of samples removed.
  77. """
  78. remove_size = 0
  79. with self._mutex:
  80. before_remove_size = len(self._samples)
  81. if before_remove_size > 0:
  82. # remove samples that meet the filter criteria.
  83. self._samples = list(filter(filter_fun, self._samples))
  84. after_remove_size = len(self._samples)
  85. remove_size = before_remove_size - after_remove_size
  86. if remove_size > 0:
  87. # update _sample_counter when samples has been removed.
  88. sample_remaining_rate = float(
  89. after_remove_size) / before_remove_size
  90. self._sample_counter = int(
  91. round(self._sample_counter * sample_remaining_rate))
  92. return remove_size
  93. class _VisualRange:
  94. """Simple helper class to merge visual ranges."""
  95. def __init__(self):
  96. self._max = 0.0
  97. self._min = 0.0
  98. self._updated = False
  99. def update(self, max_val: float, min_val: float) -> None:
  100. """
  101. Merge visual range with given range.
  102. Args:
  103. max_val (float): Max value of given range.
  104. min_val (float): Min value of given range.
  105. """
  106. if not self._updated:
  107. self._max = max_val
  108. self._min = min_val
  109. self._updated = True
  110. return
  111. if max_val > self._max:
  112. self._max = max_val
  113. if min_val < self._min:
  114. self._min = min_val
  115. @property
  116. def max(self):
  117. """Gets max value of current range."""
  118. return self._max
  119. @property
  120. def min(self):
  121. """Gets min value of current range."""
  122. return self._min
  123. class HistogramReservoir(Reservoir):
  124. """
  125. Reservoir for histogram, which needs updating range over all steps.
  126. Args:
  127. size (int): Container Size. If the size is 0, the container is not limited.
  128. """
  129. def samples(self):
  130. """Return all stored samples."""
  131. with self._mutex:
  132. # calc visual range
  133. visual_range = _VisualRange()
  134. max_count = 0
  135. for sample in self._samples:
  136. histogram = sample.value
  137. if histogram.count == 0:
  138. # ignore empty tensor
  139. continue
  140. max_count = max(histogram.count, max_count)
  141. visual_range.update(histogram.max, histogram.min)
  142. if visual_range.max == visual_range.min and not max_count:
  143. logger.warning("Max equals to min, however, count is zero. Please check mindspore "
  144. "does write max and min values to histogram summary file.")
  145. bins = calc_histogram_bins(max_count)
  146. # update visual range
  147. logger.info("Visual histogram: min %s, max %s, bins %s, max_count %s.",
  148. visual_range.min,
  149. visual_range.max,
  150. bins,
  151. max_count)
  152. for sample in self._samples:
  153. histogram = sample.value
  154. histogram.set_visual_range(visual_range.max, visual_range.min, bins)
  155. return list(self._samples)
  156. class ReservoirFactory:
  157. """Factory class to get reservoir instances."""
  158. def create_reservoir(self, plugin_name: str, size: int) -> Reservoir:
  159. """
  160. Creates reservoir for given plugin name.
  161. Args:
  162. plugin_name (str): Plugin name
  163. size (int): Container Size. If the size is 0, the container is not limited.
  164. Returns:
  165. Reservoir, reservoir instance for given plugin name.
  166. """
  167. if plugin_name == PluginNameEnum.HISTOGRAM.value:
  168. return HistogramReservoir(size)
  169. return Reservoir(size)

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。