You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

events_data.py 8.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Takes a generator of values, and collects them for a frontend."""
  16. import collections
  17. import threading
  18. from mindinsight.conf import settings
  19. from mindinsight.datavisual.common.enums import PluginNameEnum
  20. from mindinsight.datavisual.common.log import logger
  21. from mindinsight.datavisual.data_transform import reservoir
  22. # Type of the tensor event from external component
  23. _Tensor = collections.namedtuple('_Tensor', ['wall_time', 'step', 'value', 'filename'])
  24. TensorEvent = collections.namedtuple(
  25. 'TensorEvent', ['wall_time', 'step', 'tag', 'plugin_name', 'value', 'filename'])
  26. # config for `EventsData`
  27. _DEFAULT_STEP_SIZES_PER_TAG = settings.DEFAULT_STEP_SIZES_PER_TAG
  28. _MAX_DELETED_TAGS_SIZE = settings.MAX_TAG_SIZE_PER_EVENTS_DATA * 100
  29. CONFIG = {
  30. 'max_total_tag_sizes': settings.MAX_TAG_SIZE_PER_EVENTS_DATA,
  31. 'max_tag_sizes_per_plugin':
  32. {
  33. PluginNameEnum.GRAPH.value: settings.MAX_GRAPH_TAG_SIZE,
  34. PluginNameEnum.TENSOR.value: settings.MAX_TENSOR_TAG_SIZE
  35. },
  36. 'max_step_sizes_per_tag':
  37. {
  38. PluginNameEnum.SCALAR.value: settings.MAX_SCALAR_STEP_SIZE_PER_TAG,
  39. PluginNameEnum.IMAGE.value: settings.MAX_IMAGE_STEP_SIZE_PER_TAG,
  40. PluginNameEnum.GRAPH.value: settings.MAX_GRAPH_STEP_SIZE_PER_TAG,
  41. PluginNameEnum.HISTOGRAM.value: settings.MAX_HISTOGRAM_STEP_SIZE_PER_TAG,
  42. PluginNameEnum.TENSOR.value: settings.MAX_TENSOR_STEP_SIZE_PER_TAG
  43. }
  44. }
  45. class EventsData:
  46. """
  47. EventsData is an event data manager.
  48. It manages the log events generated during a training process.
  49. The log event records information such as graph, tag, and tensor.
  50. Data such as tensor can be retrieved based on its tag.
  51. """
  52. def __init__(self):
  53. self._config = CONFIG
  54. self._max_step_sizes_per_tag = self._config['max_step_sizes_per_tag']
  55. self._tags = list()
  56. self._deleted_tags = set()
  57. self._reservoir_by_tag = {}
  58. self._reservoir_mutex_lock = threading.Lock()
  59. self._tags_by_plugin = collections.defaultdict(list)
  60. self._tags_by_plugin_mutex_lock = collections.defaultdict(threading.Lock)
  61. def add_tensor_event(self, tensor_event):
  62. """
  63. Add a new tensor event to the tensors_data.
  64. Args:
  65. tensor_event (TensorEvent): Refer to `TensorEvent` object.
  66. """
  67. if not isinstance(tensor_event, TensorEvent):
  68. raise TypeError('Expect to get data of type `TensorEvent`.')
  69. tag = tensor_event.tag
  70. plugin_name = tensor_event.plugin_name
  71. if tag not in set(self._tags):
  72. deleted_tag = self._check_tag_out_of_spec(plugin_name)
  73. if deleted_tag is not None:
  74. if tag in self._deleted_tags:
  75. logger.debug("Tag is in deleted tags: %s.", tag)
  76. return
  77. self.delete_tensor_event(deleted_tag)
  78. self._tags.append(tag)
  79. with self._tags_by_plugin_mutex_lock[plugin_name]:
  80. if tag not in self._tags_by_plugin[plugin_name]:
  81. self._tags_by_plugin[plugin_name].append(tag)
  82. with self._reservoir_mutex_lock:
  83. if tag not in self._reservoir_by_tag:
  84. reservoir_size = self._get_reservoir_size(tensor_event.plugin_name)
  85. self._reservoir_by_tag[tag] = reservoir.ReservoirFactory().create_reservoir(
  86. plugin_name, reservoir_size
  87. )
  88. tensor = _Tensor(wall_time=tensor_event.wall_time,
  89. step=tensor_event.step,
  90. value=tensor_event.value,
  91. filename=tensor_event.filename)
  92. if self._is_out_of_order_step(tensor_event.step, tensor_event.tag):
  93. self.purge_reservoir_data(tensor_event.filename, tensor_event.step, self._reservoir_by_tag[tag])
  94. self._reservoir_by_tag[tag].add_sample(tensor)
  95. def delete_tensor_event(self, tag):
  96. """
  97. This function will delete tensor event by the given tag in memory record.
  98. Args:
  99. tag (str): The tag name.
  100. """
  101. if len(self._deleted_tags) < _MAX_DELETED_TAGS_SIZE:
  102. self._deleted_tags.add(tag)
  103. else:
  104. logger.warning(
  105. 'Too many deleted tags, %d upper limit reached, tags updating may not function hereafter',
  106. _MAX_DELETED_TAGS_SIZE)
  107. logger.info('%r and all related samples are going to be deleted', tag)
  108. self._tags.remove(tag)
  109. for plugin_name, lock in self._tags_by_plugin_mutex_lock.items():
  110. with lock:
  111. if tag in self._tags_by_plugin[plugin_name]:
  112. self._tags_by_plugin[plugin_name].remove(tag)
  113. break
  114. with self._reservoir_mutex_lock:
  115. if tag in self._reservoir_by_tag:
  116. self._reservoir_by_tag.pop(tag)
  117. def list_tags_by_plugin(self, plugin_name):
  118. """
  119. Return all the tag names of the plugin.
  120. Args:
  121. plugin_name (str): The Plugin name.
  122. Returns:
  123. list[str], tags of the plugin.
  124. Raises:
  125. KeyError: when plugin name could not be found.
  126. """
  127. if plugin_name not in self._tags_by_plugin:
  128. raise KeyError('Plugin %r could not be found.' % plugin_name)
  129. with self._tags_by_plugin_mutex_lock[plugin_name]:
  130. # Return a snapshot to avoid concurrent mutation and iteration issues.
  131. return sorted(list(self._tags_by_plugin[plugin_name]))
  132. def tensors(self, tag):
  133. """
  134. Return all tensors of the tag.
  135. Args:
  136. tag (str): The tag name.
  137. Returns:
  138. list[_Tensor], the list of tensors to the tag.
  139. """
  140. if tag not in self._reservoir_by_tag:
  141. raise KeyError('TAG %r could not be found.' % tag)
  142. return self._reservoir_by_tag[tag].samples()
  143. def _is_out_of_order_step(self, step, tag):
  144. """
  145. If the current step is smaller than the latest one, it is out-of-order step.
  146. Args:
  147. step (int): Check if the given step out of order.
  148. tag (str): The checked tensor of the given tag.
  149. Returns:
  150. bool, boolean value.
  151. """
  152. if self.tensors(tag):
  153. tensors = self.tensors(tag)
  154. last_step = tensors[-1].step
  155. if step <= last_step:
  156. return True
  157. return False
  158. @staticmethod
  159. def purge_reservoir_data(filename, start_step, tensor_reservoir):
  160. """
  161. Purge all tensor event that are out-of-order step after the given start step.
  162. Args:
  163. start_step (int): Urge start step. All previously seen events with
  164. a greater or equal to step will be purged.
  165. tensor_reservoir (Reservoir): A `Reservoir` object.
  166. Returns:
  167. int, the number of items removed.
  168. """
  169. cnt_out_of_order = tensor_reservoir.remove_sample(
  170. lambda x: x.step < start_step or (x.step > start_step and x.filename == filename))
  171. return cnt_out_of_order
  172. def _get_reservoir_size(self, plugin_name):
  173. max_step_sizes_per_tag = self._config['max_step_sizes_per_tag']
  174. return max_step_sizes_per_tag.get(plugin_name, _DEFAULT_STEP_SIZES_PER_TAG)
  175. def _check_tag_out_of_spec(self, plugin_name):
  176. """
  177. Check whether the tag is out of specification.
  178. Args:
  179. plugin_name (str): The given plugin name.
  180. Returns:
  181. Union[str, None], if out of specification, will return the first tag, else return None.
  182. """
  183. tag_specifications = self._config['max_tag_sizes_per_plugin'].get(plugin_name)
  184. if tag_specifications is not None and len(self._tags_by_plugin[plugin_name]) >= tag_specifications:
  185. deleted_tag = self._tags_by_plugin[plugin_name][0]
  186. return deleted_tag
  187. if len(self._tags) >= self._config['max_total_tag_sizes']:
  188. deleted_tag = self._tags[0]
  189. return deleted_tag
  190. return None