You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reservoir.py 8.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """A reservoir sampling on the values."""
  16. import random
  17. import threading
  18. from mindinsight.datavisual.common.log import logger
  19. from mindinsight.datavisual.common.enums import PluginNameEnum
  20. from mindinsight.utils.exceptions import ParamValueError
  21. from mindinsight.datavisual.utils.utils import calc_histogram_bins
  22. def binary_search(samples, target):
  23. """Binary search target in samples."""
  24. left = 0
  25. right = len(samples) - 1
  26. while left <= right:
  27. mid = (left + right) // 2
  28. if target < samples[mid].step:
  29. right = mid - 1
  30. elif target > samples[mid].step:
  31. left = mid + 1
  32. else:
  33. return mid
  34. # if right is -1, it is less than the first one.
  35. # if list is [1, 2, 4], target is 3, right will be 1, so wo will insert by 2.
  36. return right + 1
  37. class Reservoir:
  38. """
  39. A container based on Reservoir Sampling algorithm.
  40. The newly added sample will be preserved. If the container is full, an old
  41. sample will be replaced randomly. The probability of each sample being
  42. replaced is the same.
  43. """
  44. def __init__(self, size):
  45. """
  46. A Container constructor which create a new Reservoir.
  47. Args:
  48. size (int): Container Size. If the size is 0, the container is not limited.
  49. Raises:
  50. ValueError: If size is negative integer.
  51. """
  52. if not isinstance(size, (int,)) or size < 0:
  53. raise ParamValueError('size must be nonnegative integer, was %s' % size)
  54. self._samples_max_size = size
  55. self._samples = []
  56. self._sample_counter = 0
  57. self._sample_selector = random.Random(0)
  58. self._mutex = threading.Lock()
  59. def samples(self):
  60. """Return all stored samples."""
  61. with self._mutex:
  62. return list(self._samples)
  63. def add_sample(self, sample):
  64. """
  65. Add a sample to Reservoir.
  66. Replace the old sample when the capacity is full.
  67. New added samples are guaranteed to be added to the reservoir.
  68. Args:
  69. sample (Any): The sample to add to the Reservoir.
  70. """
  71. with self._mutex:
  72. if len(self._samples) < self._samples_max_size or self._samples_max_size == 0:
  73. self._add_sample(sample)
  74. else:
  75. # Use the Reservoir Sampling algorithm to replace the old sample.
  76. rand_int = self._sample_selector.randint(0, self._sample_counter)
  77. if rand_int < self._samples_max_size:
  78. self._samples.pop(rand_int)
  79. else:
  80. self._samples = self._samples[:-1]
  81. self._add_sample(sample)
  82. self._sample_counter += 1
  83. def _add_sample(self, sample):
  84. """Search the index and add sample."""
  85. if not self._samples or sample.step > self._samples[-1].step:
  86. self._samples.append(sample)
  87. return
  88. index = binary_search(self._samples, sample.step)
  89. if index == len(self._samples):
  90. self._samples.append(sample)
  91. else:
  92. self._samples.insert(index, sample)
  93. def remove_sample(self, filter_fun):
  94. """
  95. Remove the samples from Reservoir that do not meet the filter criteria.
  96. Args:
  97. filter_fun (Callable[..., Any]): Determines whether a sample meets
  98. the deletion condition.
  99. Returns:
  100. int, the number of samples removed.
  101. """
  102. remove_size = 0
  103. with self._mutex:
  104. before_remove_size = len(self._samples)
  105. if before_remove_size > 0:
  106. # remove samples that meet the filter criteria.
  107. self._samples = list(filter(filter_fun, self._samples))
  108. after_remove_size = len(self._samples)
  109. remove_size = before_remove_size - after_remove_size
  110. if remove_size > 0:
  111. # update _sample_counter when samples has been removed.
  112. sample_remaining_rate = float(
  113. after_remove_size) / before_remove_size
  114. self._sample_counter = int(
  115. round(self._sample_counter * sample_remaining_rate))
  116. return remove_size
  117. class _VisualRange:
  118. """Simple helper class to merge visual ranges."""
  119. def __init__(self):
  120. self._max = 0.0
  121. self._min = 0.0
  122. self._updated = False
  123. def update(self, max_val: float, min_val: float) -> None:
  124. """
  125. Merge visual range with given range.
  126. Args:
  127. max_val (float): Max value of given range.
  128. min_val (float): Min value of given range.
  129. """
  130. if not self._updated:
  131. self._max = max_val
  132. self._min = min_val
  133. self._updated = True
  134. return
  135. if max_val > self._max:
  136. self._max = max_val
  137. if min_val < self._min:
  138. self._min = min_val
  139. @property
  140. def max(self):
  141. """Gets max value of current range."""
  142. return self._max
  143. @property
  144. def min(self):
  145. """Gets min value of current range."""
  146. return self._min
  147. class HistogramReservoir(Reservoir):
  148. """
  149. Reservoir for histogram, which needs updating range over all steps.
  150. Args:
  151. size (int): Container Size. If the size is 0, the container is not limited.
  152. """
  153. def __init__(self, size):
  154. super().__init__(size)
  155. # Marker to avoid redundant calc for unchanged histograms.
  156. self._visual_range_up_to_date = False
  157. def add_sample(self, sample):
  158. """Adds sample, see parent class for details."""
  159. super().add_sample(sample)
  160. self._visual_range_up_to_date = False
  161. def samples(self):
  162. """Return all stored samples."""
  163. with self._mutex:
  164. if self._visual_range_up_to_date:
  165. return list(self._samples)
  166. # calc visual range
  167. visual_range = _VisualRange()
  168. max_count = 0
  169. for sample in self._samples:
  170. histogram_container = sample.value
  171. if histogram_container.count == 0:
  172. # ignore empty tensor
  173. continue
  174. max_count = max(histogram_container.count, max_count)
  175. visual_range.update(histogram_container.max, histogram_container.min)
  176. if visual_range.max == visual_range.min and not max_count:
  177. logger.debug("Max equals to min. Count is zero.")
  178. bins = calc_histogram_bins(max_count)
  179. # update visual range
  180. logger.debug(
  181. "Visual histogram: min %s, max %s, bins %s, max_count %s.",
  182. visual_range.min,
  183. visual_range.max,
  184. bins,
  185. max_count)
  186. for sample in self._samples:
  187. histogram = sample.value.histogram
  188. histogram.set_visual_range(visual_range.max, visual_range.min, bins)
  189. self._visual_range_up_to_date = True
  190. return list(self._samples)
  191. class ReservoirFactory:
  192. """Factory class to get reservoir instances."""
  193. def create_reservoir(self, plugin_name: str, size: int) -> Reservoir:
  194. """
  195. Creates reservoir for given plugin name.
  196. Args:
  197. plugin_name (str): Plugin name
  198. size (int): Container Size. If the size is 0, the container is not limited.
  199. Returns:
  200. Reservoir, reservoir instance for given plugin name.
  201. """
  202. if plugin_name in (PluginNameEnum.HISTOGRAM.value, PluginNameEnum.TENSOR.value):
  203. return HistogramReservoir(size)
  204. return Reservoir(size)