You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_handler.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the tensor stream handler."""
  16. import numpy as np
  17. from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
  18. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  19. from mindinsight.debugger.common.log import logger as log
  20. from mindinsight.debugger.proto.ms_graph_pb2 import DataType
  21. from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor
  22. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  23. from mindinsight.utils.tensor import TensorUtils
  24. class TensorHandler(StreamHandlerBase):
  25. """Metadata Handler."""
  26. def __init__(self):
  27. self._const_vals = {}
  28. self._tensors = {}
  29. self._cur_step = 0
  30. def put(self, value):
  31. """
  32. Put value into tensor cache. Called by grpc server.
  33. Args:
  34. value (dict): The Tensor proto message.
  35. - step (int): The current step of tensor.
  36. - tensor_protos (list[TensorProto]): The tensor proto.
  37. """
  38. tensor_protos = value.get('tensor_protos')
  39. merged_tensor = self._get_merged_tensor(tensor_protos)
  40. step = value.get('step', 0)
  41. if merged_tensor.iter and step > 0:
  42. log.debug("Received previous tensor.")
  43. step -= 1
  44. tensor = OpTensor(merged_tensor, step)
  45. self._put_tensor_into_cache(tensor, step)
  46. log.info("Put tensor %s of step: %d, into cache", tensor.name, step)
  47. @staticmethod
  48. def _get_merged_tensor(tensor_protos):
  49. """
  50. Merged list of parsed tensor value into one.
  51. Args:
  52. tensor_protos (list[TensorProto]): List of tensor proto.
  53. Returns:
  54. TensorProto, merged tensor proto.
  55. """
  56. merged_tensor = tensor_protos[-1]
  57. if len(tensor_protos) > 1:
  58. tensor_value = bytes()
  59. for tensor_proto in tensor_protos:
  60. if not tensor_proto.tensor_content:
  61. log.warning("Doesn't find tensor value for %s:%s",
  62. tensor_proto.node_name, tensor_proto.slot)
  63. break
  64. tensor_value += tensor_proto.tensor_content
  65. merged_tensor.tensor_content = tensor_value
  66. log.debug("Merge multi tensor values into one.")
  67. return merged_tensor
  68. def _put_tensor_into_cache(self, tensor, step):
  69. """
  70. Put tensor into cache.
  71. Args:
  72. tensor (OpTensor): The tensor value.
  73. """
  74. cache_tensor = self._tensors.get(tensor.name)
  75. if cache_tensor is None:
  76. cache_tensor = {}
  77. self._tensors[tensor.name] = cache_tensor
  78. cache_tensor[step] = tensor
  79. def put_const_vals(self, const_vals):
  80. """
  81. Put const value into tensor cache.
  82. Args:
  83. const_vals (list[NamedValueProto]): List of const values.
  84. """
  85. for const_val in const_vals:
  86. if not (const_val.value and const_val.key):
  87. continue
  88. if DataType.Name(const_val.value.dtype) == "DT_TENSOR":
  89. tensor_proto = const_val.value.tensor_val
  90. tensor_proto.node_name = const_val.key
  91. tensor_proto.slot = '0'
  92. const_tensor = OpTensor(tensor_proto)
  93. else:
  94. const_tensor = ConstTensor(const_val)
  95. self._const_vals[const_tensor.name] = const_tensor
  96. def get(self, filter_condition=None):
  97. """
  98. Get full tensor value.
  99. Args:
  100. filter_condition (dict): Filter condition.
  101. - name (str): The name of tensor.
  102. - node_type (str): The type of the node.
  103. Returns:
  104. dict, the tensor_value.
  105. """
  106. name = filter_condition.get('name')
  107. node_type = filter_condition.get('node_type')
  108. shape = filter_condition.get('shape')
  109. tensor = self._get_tensor(name, node_type)
  110. if not tensor:
  111. log.error("No tensor named %s", name)
  112. raise DebuggerParamValueError("No tensor named {}".format(name))
  113. tensor_info = tensor.get_full_info(shape)
  114. self._update_has_prev_step_field(tensor_info, name, node_type)
  115. return {'tensor_value': tensor_info}
  116. def _get_tensor(self, tensor_name, node_type=None, step=None):
  117. """
  118. Get tensor according to tensor name and node_type.
  119. Args:
  120. tensor_name (str): Tensor name, format like `node_name:slot`.
  121. node_type (str): Node type.
  122. step (int): The step of tensor info. Default: None. Noe
  123. Returns:
  124. Union[OPTensor, ConstTensor], the tensor object.
  125. """
  126. if step is None:
  127. step = self._cur_step
  128. tensor = self._tensors.get(tensor_name, {}).get(step)
  129. if not tensor and node_type == NodeTypeEnum.CONST.value:
  130. const_name = tensor_name.rsplit('/', 1)[-1]
  131. tensor = self._const_vals.get(const_name)
  132. self._tensors[tensor_name] = {step: tensor}
  133. return tensor
  134. def _get_basic_info(self, tensor_name, node_type=None):
  135. """Get the latest basic tensor info by tensor name."""
  136. tensor = self._get_tensor(tensor_name, node_type)
  137. if tensor:
  138. return tensor.get_basic_info()
  139. return None
  140. def update_tensor_history(self, tensor_history):
  141. """
  142. Add tensor basic info in tensor_history.
  143. Args:
  144. tensor_history (dict): Tensor history, including a list of tensor name and type.
  145. Returns:
  146. list[dict], the list of tensor basic info cache.
  147. """
  148. missed_tensors = []
  149. for tensor_info in tensor_history.get('tensor_history'):
  150. tensor_name = tensor_info.get('full_name')
  151. node_type = tensor_info.get('node_type')
  152. basic_info = self._get_basic_info(tensor_name, node_type)
  153. flag = self._update_has_prev_step_field(basic_info, tensor_name, node_type)
  154. if flag is False:
  155. missed_tensor = tensor_info.copy()
  156. missed_tensor['iter'] = 'prev'
  157. missed_tensors.append(missed_tensor)
  158. log.debug("Add previous view cmd for %s", tensor_name)
  159. # add `has_prev_step` field to tensor basic info.
  160. if basic_info:
  161. tensor_info.update(basic_info)
  162. if not basic_info.get('value'):
  163. missed_tensors.append(tensor_info)
  164. log.debug("Add view cmd for %s", tensor_name)
  165. else:
  166. missed_tensors.append(tensor_info)
  167. log.debug("Add view cmd for %s", tensor_name)
  168. return missed_tensors
  169. def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type):
  170. """Update has_prev_step field in tensor info."""
  171. flag = None
  172. if node_type == NodeTypeEnum.PARAMETER.value:
  173. flag = self._get_prev_tensor_value_status(tensor_name)
  174. if flag and tensor_info:
  175. tensor_info['has_prev_step'] = True
  176. return flag
  177. def _get_prev_tensor_value_status(self, tensor_name):
  178. """
  179. Get the status of tensor value of previous step.
  180. Args:
  181. tensor_name (str): Tensor name.
  182. Returns:
  183. Union[None, bool], the status of previous tensor value. If True, there is valid previous
  184. tensor value. If False, the tensor value should be queried from client.
  185. If None, ignore.
  186. """
  187. flag = None
  188. # check if the tensor has previous step value.
  189. prev_step = self._cur_step - 1
  190. if prev_step < 0:
  191. return flag
  192. tensor = self._get_tensor(tensor_name, step=prev_step)
  193. return bool(tensor and tensor.value)
  194. def get_tensor_value_by_name(self, tensor_name, prev=False):
  195. """Get tensor value by name in numpy type."""
  196. cur_step = self._cur_step
  197. step = cur_step - 1 if prev else cur_step
  198. if step < 0:
  199. log.warning("%d step has no previous value for tensor: %s", cur_step, tensor_name)
  200. return None
  201. tensor = self._get_tensor(tensor_name, step=step)
  202. return tensor
  203. def clean_tensors(self, cur_step):
  204. """Clean the tensor cache."""
  205. self._cur_step = cur_step
  206. expired_tensor = []
  207. for tensor_name, tensor in self._tensors.items():
  208. expired_step = [step for step in tensor.keys() if step <= cur_step - 2]
  209. for step in expired_step:
  210. tensor.pop(step)
  211. if not tensor:
  212. expired_tensor.append(tensor_name)
  213. for tensor_name in expired_tensor:
  214. self._tensors.pop(tensor_name)
  215. self._tensors = {}
  216. def get_tensors_diff(self, tensor_name, shape, tolerance=0):
  217. """
  218. Get tensor comparisons data for given name, detail, shape and tolerance.
  219. Args:
  220. tensor_name (str): The name of tensor for cache.
  221. shape (tuple): Specify concrete dimensions of shape.
  222. tolerance (str): Specify tolerance of difference between current step tensor and previous
  223. step tensor. Default value is 0. Its is a percentage. The boundary value is equal to
  224. max(abs(min),abs(max)) * tolerance. The function of min and max is being used to
  225. calculate the min value and max value of the result of the current step tensor subtract
  226. the previous step tensor. If the absolute value of result is less than or equal to
  227. boundary value, the result will set to be zero.
  228. Raises:
  229. DebuggerParamValueError, If get current step node and previous step node failed.
  230. Returns:
  231. dict, the retrieved data.
  232. """
  233. curr_tensor = self.get_tensor_value_by_name(tensor_name)
  234. prev_tensor = self.get_tensor_value_by_name(tensor_name, prev=True)
  235. if not (curr_tensor and prev_tensor):
  236. log.error("Get current step and previous step for this tensor name %s failed.", tensor_name)
  237. raise DebuggerParamValueError(f"Get current step and previous step for this tensor name "
  238. f"{tensor_name} failed.")
  239. curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape)
  240. prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape)
  241. tensor_info = curr_tensor.get_basic_info()
  242. if isinstance(tensor_info, dict):
  243. del tensor_info['has_prev_step']
  244. del tensor_info['value']
  245. # the type of curr_tensor_slice is one of None, np.ndarray or str
  246. if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray):
  247. diff_tensor = TensorUtils.calc_diff_between_two_tensor(curr_tensor_slice, prev_tensor_slice, tolerance)
  248. result = np.stack([prev_tensor_slice, curr_tensor_slice, diff_tensor], axis=-1)
  249. tensor_info['diff'] = result.tolist()
  250. stats = TensorUtils.get_statistics_from_tensor(diff_tensor)
  251. tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats)
  252. elif isinstance(curr_tensor_slice, str):
  253. tensor_info['diff'] = curr_tensor_slice
  254. reply = {'tensor_value': tensor_info}
  255. return reply