您最多选择25个标签 标签必须以中文、字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

tensor_handler.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the tensor stream handler."""
  16. from collections import namedtuple
  17. import numpy as np
  18. from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
  19. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  20. from mindinsight.debugger.common.log import LOGGER as log
  21. from mindinsight.debugger.proto.ms_graph_pb2 import DataType
  22. from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor
  23. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  24. from mindinsight.utils.tensor import TensorUtils, TensorComparison
  25. TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter'])
  26. class TensorHandler(StreamHandlerBase):
  27. """Metadata Handler."""
  28. def __init__(self):
  29. # the collection of parameter full names
  30. self._param_names = set()
  31. # const value objects, the format is like: dict[<const name>, <OpTensor object>]
  32. self._const_vals = {}
  33. # tensor values, the format is like:
  34. # dict[<tensor full name>, dict[<step_num>, <OpTensor object>]]
  35. self._tensors = {}
  36. self._cur_step = 0
  37. @property
  38. def cur_step(self):
  39. """The property of current step."""
  40. return self._cur_step
  41. @property
  42. def prev_step(self):
  43. """The property of previous step."""
  44. return self._cur_step - 1
  45. def put(self, value):
  46. """
  47. Put value into tensor cache. Called by grpc server.
  48. Args:
  49. value (dict): The Tensor proto message.
  50. - step (int): The current step of tensor.
  51. - tensor_protos (list[TensorProto]): The tensor proto.
  52. Returns:
  53. bool, the tensor has updated successfully.
  54. """
  55. tensor_protos = value.get('tensor_protos')
  56. merged_tensor = self._get_merged_tensor(tensor_protos)
  57. step = value.get('step', 0)
  58. if merged_tensor.iter and step > 0:
  59. log.debug("Received previous tensor.")
  60. step -= 1
  61. tensor = OpTensor(merged_tensor, step)
  62. flag = self._put_tensor_into_cache(tensor, step)
  63. log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag)
  64. return flag
  65. @staticmethod
  66. def _get_merged_tensor(tensor_protos):
  67. """
  68. Merged list of parsed tensor value into one.
  69. Args:
  70. tensor_protos (list[TensorProto]): List of tensor proto.
  71. Returns:
  72. TensorProto, merged tensor proto.
  73. """
  74. merged_tensor = tensor_protos[-1]
  75. if len(tensor_protos) > 1:
  76. tensor_value = bytes()
  77. for tensor_proto in tensor_protos:
  78. if not tensor_proto.tensor_content:
  79. log.warning("Doesn't find tensor value for %s:%s",
  80. tensor_proto.node_name, tensor_proto.slot)
  81. break
  82. tensor_value += tensor_proto.tensor_content
  83. merged_tensor.tensor_content = tensor_value
  84. log.debug("Merge multi tensor values into one.")
  85. return merged_tensor
  86. def _put_tensor_into_cache(self, tensor, step):
  87. """
  88. Put tensor into cache.
  89. Args:
  90. tensor (OpTensor): The tensor value.
  91. step (int): The step of tensor.
  92. Returns:
  93. bool, the tensor has updated successfully.
  94. """
  95. cache_tensor = self._tensors.get(tensor.name)
  96. if cache_tensor is None:
  97. cache_tensor = {}
  98. self._tensors[tensor.name] = cache_tensor
  99. old_tensor = cache_tensor.get(step)
  100. if old_tensor and not self._is_value_diff(old_tensor.value, tensor.value):
  101. log.debug("Tensor %s of step %s has no change. Ignore it.", tensor.name, step)
  102. return False
  103. cache_tensor[step] = tensor
  104. log.debug("Put updated tensor value for %s of step %s.", tensor.name, step)
  105. return True
  106. @staticmethod
  107. def _is_value_diff(old_value, new_value):
  108. """Check tensor value if there are equal."""
  109. log.debug("old value type: %s, new_value type: %s", type(old_value), type(new_value))
  110. if old_value is None and new_value is None:
  111. return False
  112. flag = old_value != new_value
  113. if isinstance(flag, np.ndarray):
  114. return flag.any()
  115. return flag
  116. def put_const_vals(self, const_vals):
  117. """
  118. Put const value into tensor cache.
  119. Args:
  120. const_vals (list[NamedValueProto]): List of const values.
  121. """
  122. for const_val in const_vals:
  123. if not (const_val.value and const_val.key):
  124. continue
  125. if DataType.Name(const_val.value.dtype) == "DT_TENSOR":
  126. tensor_proto = const_val.value.tensor_val
  127. tensor_proto.node_name = const_val.key
  128. tensor_proto.slot = '0'
  129. const_tensor = OpTensor(tensor_proto)
  130. else:
  131. const_tensor = ConstTensor(const_val)
  132. self._const_vals[const_tensor.name] = const_tensor
  133. def record_parameter_names(self, names):
  134. """
  135. Record parameter names.
  136. Note:
  137. Parameter values could be changed during an iteration step. It must be cleaned after each node step.
  138. Args:
  139. names (list[str]): List of tensor full names.
  140. """
  141. self._param_names.update(names)
  142. log.debug("Record %d parameters in cache. Total parameter number: %d", len(names), len(self._param_names))
  143. def get(self, filter_condition=None):
  144. """
  145. Get full tensor value.
  146. Args:
  147. filter_condition (dict): Filter condition.
  148. - name (str): The full name of tensor.
  149. - node_type (str): The type of the node.
  150. - prev (bool): Whether to get previous tensor.
  151. Returns:
  152. dict, the tensor_value.
  153. """
  154. name = filter_condition.get('name')
  155. node_type = filter_condition.get('node_type')
  156. shape = filter_condition.get('shape')
  157. if filter_condition.get('prev'):
  158. step = self.prev_step
  159. else:
  160. step = self.cur_step
  161. tensor = self._get_tensor(name, node_type, step)
  162. if not tensor:
  163. log.error("No tensor named %s at the step %s", name, step)
  164. raise DebuggerParamValueError("No tensor named {}".format(name))
  165. tensor_info = tensor.get_full_info(shape)
  166. self._update_has_prev_step_field(tensor_info, name, node_type)
  167. return {'tensor_value': tensor_info}
  168. def _get_tensor(self, tensor_name, node_type=None, step=None):
  169. """
  170. Get tensor according to tensor name and node_type.
  171. Args:
  172. tensor_name (str): Tensor name, format like `node_name:slot`.
  173. node_type (str): Node type.
  174. step (int): The step of tensor info. Default: None.
  175. Returns:
  176. Union[OPTensor, ConstTensor], the tensor object.
  177. """
  178. if step is None:
  179. step = self._cur_step
  180. tensor = self._tensors.get(tensor_name, {}).get(step)
  181. if not tensor and node_type == NodeTypeEnum.CONST.value:
  182. const_name = tensor_name.rsplit('/', 1)[-1]
  183. tensor = self._const_vals.get(const_name)
  184. if tensor:
  185. self._tensors[tensor_name] = {step: tensor}
  186. return tensor
  187. def _get_basic_info(self, tensor_name, node_type=None):
  188. """Get the latest basic tensor info by tensor name."""
  189. tensor = self._get_tensor(tensor_name, node_type)
  190. if tensor:
  191. return tensor.get_basic_info()
  192. return None
  193. def update_tensor_history(self, tensor_history):
  194. """
  195. Add tensor basic info in tensor_history.
  196. Args:
  197. tensor_history (dict): Tensor history, including a list of tensor name and type.
  198. Returns:
  199. list[dict], the list of tensor basic info cache.
  200. """
  201. missed_tensors = []
  202. for tensor_info in tensor_history.get('tensor_history'):
  203. tensor_name = tensor_info.get('full_name')
  204. node_type = tensor_info.get('node_type')
  205. basic_info = self._get_basic_info(tensor_name, node_type)
  206. # add `has_prev_step` field to tensor basic info.
  207. missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type)
  208. if basic_info:
  209. tensor_info.update(basic_info)
  210. if missing_tensors_info:
  211. missed_tensors.extend(missing_tensors_info)
  212. return missed_tensors
  213. def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type):
  214. """Update has_prev_step field in tensor info."""
  215. missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type)
  216. if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0:
  217. tensor_info['has_prev_step'] = True
  218. return missing_tensors_info
  219. def _get_missing_tensor_info(self, tensor_name, node_type):
  220. """
  221. Get missing tensor infos.
  222. Args:
  223. tensor_name (str): The full name of Tensor.
  224. node_type (str): The type of the relative node.
  225. Returns:
  226. list, list of missing tensor basic information.
  227. """
  228. step = self.cur_step
  229. missing_tensors_info = []
  230. # check the current step value is missing
  231. if self._is_tensor_value_missing(tensor_name, step):
  232. missing_tensors_info.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter=''))
  233. log.debug("Add current step view cmd for %s", tensor_name)
  234. # check the previous step value is missing
  235. if node_type == NodeTypeEnum.PARAMETER.value and self._is_tensor_value_missing(tensor_name, step - 1):
  236. missing_tensors_info.append(TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='prev'))
  237. log.debug("Add previous view cmd for %s", tensor_name)
  238. return missing_tensors_info
  239. def _is_tensor_value_missing(self, tensor_name, step):
  240. """
  241. Get the status of tensor value of previous step.
  242. Args:
  243. tensor_name (str): Tensor name.
  244. step (int): The step of the tensor.
  245. Returns:
  246. Union[None, bool], the status of tensor value. If False, there is valid
  247. tensor value. If True, the tensor value should be queried from client.
  248. If None, ignore.
  249. """
  250. if step < 0:
  251. return None
  252. tensor = self._get_tensor(tensor_name, step=step)
  253. return bool(not tensor or tensor.empty)
  254. def get_valid_tensor_by_name(self, tensor_name, prev=False):
  255. """Get tensor value by name in numpy type."""
  256. step = self.prev_step if prev else self.cur_step
  257. if step < 0:
  258. log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name)
  259. return None
  260. tensor = self._get_tensor(tensor_name, step=step)
  261. if tensor and tensor.empty:
  262. log.warning("%s has empty value.", tensor_name)
  263. return None
  264. return tensor
  265. def clean_tensors(self, cur_step):
  266. """Clean the tensor cache."""
  267. if cur_step != self._cur_step:
  268. self._cur_step = cur_step
  269. self._clean_expired_tensors(cur_step)
  270. self._clean_parameters()
  271. def _clean_expired_tensors(self, cur_step):
  272. """Clean expired tensors less than current steps."""
  273. expired_tensor = []
  274. for tensor_name, tensor in self._tensors.items():
  275. expired_step = [step for step in tensor.keys() if step <= cur_step - 2]
  276. for step in expired_step:
  277. tensor.pop(step)
  278. if not tensor:
  279. expired_tensor.append(tensor_name)
  280. for tensor_name in expired_tensor:
  281. self._tensors.pop(tensor_name)
  282. def _clean_parameters(self):
  283. """Clean parameter cache."""
  284. for param in self._param_names:
  285. if param in self._tensors:
  286. self._tensors.pop(param)
  287. log.debug("Clean param %s in cache.", param)
  288. def get_tensors_diff(self, tensor_name, shape, tolerance=0):
  289. """
  290. Get tensor comparisons data for given name, detail, shape and tolerance.
  291. Args:
  292. tensor_name (str): The name of tensor for cache.
  293. shape (tuple): Specify concrete dimensions of shape.
  294. tolerance (str): Specify tolerance of difference between current step tensor and previous
  295. step tensor. Default value is 0. Its is a percentage. The boundary value is equal to
  296. max(abs(min),abs(max)) * tolerance. The function of min and max is being used to
  297. calculate the min value and max value of the result of the current step tensor subtract
  298. the previous step tensor. If the absolute value of result is less than or equal to
  299. boundary value, the result will set to be zero.
  300. Raises:
  301. DebuggerParamValueError, If get current step node and previous step node failed or
  302. the type of tensor value is not numpy.ndarray."
  303. Returns:
  304. dict, the retrieved data.
  305. """
  306. curr_tensor = self.get_valid_tensor_by_name(tensor_name)
  307. prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True)
  308. if not (curr_tensor and prev_tensor):
  309. log.error("Get current step and previous step for this tensor name %s failed.", tensor_name)
  310. raise DebuggerParamValueError(f"Get current step and previous step for this tensor name "
  311. f"{tensor_name} failed.")
  312. curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape)
  313. prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape)
  314. # get tensor comparison basic info
  315. tensor_info = curr_tensor.get_basic_info()
  316. tensor_info.pop('has_prev_step')
  317. tensor_info.pop('value')
  318. # calculate tensor comparision object
  319. tensor_comparison = curr_tensor.tensor_comparison
  320. if not tensor_comparison or tensor_comparison.tolerance != tolerance:
  321. if curr_tensor.value.shape != prev_tensor.value.shape:
  322. raise DebuggerParamValueError("The shape of these two step tensors is not the same.")
  323. tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance)
  324. stats = TensorUtils.get_statistics_from_tensor(tensor_diff)
  325. tensor_comparison = TensorComparison(tolerance, stats, tensor_diff)
  326. curr_tensor.update_tensor_comparisons(tensor_comparison)
  327. # calculate diff value
  328. # the type of curr_tensor_slice is one of np.ndarray or str
  329. if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray):
  330. if not shape:
  331. tensor_diff_slice = tensor_comparison.value
  332. else:
  333. tensor_diff_slice = tensor_comparison.value[shape]
  334. result = np.stack([prev_tensor_slice, curr_tensor_slice, tensor_diff_slice], axis=-1)
  335. tensor_info['diff'] = result.tolist()
  336. elif isinstance(curr_tensor_slice, str):
  337. tensor_info['diff'] = curr_tensor_slice
  338. # add comparision statistics
  339. tensor_info.update(self._get_comparison_statistics(curr_tensor, prev_tensor))
  340. reply = {'tensor_value': tensor_info}
  341. return reply
  342. @staticmethod
  343. def _get_comparison_statistics(curr_tensor, prev_tensor):
  344. """Get comparison statistics."""
  345. stats_info = {}
  346. diff_tensor_stats = curr_tensor.tensor_comparison.stats
  347. curr_tensor_stats = TensorUtils.get_statistics_from_tensor(curr_tensor.value)
  348. prev_tensor_stats = TensorUtils.get_statistics_from_tensor(prev_tensor.value)
  349. stats_info['curr_step_statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=curr_tensor_stats)
  350. stats_info['prev_step_statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=prev_tensor_stats)
  351. stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats)
  352. return stats_info
  353. def get_tensor_info_for_tensor_graph(self, tensor_name, node_type):
  354. """
  355. Get Tensor info for tensor graphs.
  356. Args:
  357. tensor_name (str): Tensor name, format like `node_name:slot`.
  358. node_type (str): Node type.
  359. Returns:
  360. dict, tensor infos, including overall statistics, tensor shape and has_prev_step info.
  361. list, list of missing tensor basic information.
  362. """
  363. res = {}
  364. tensor = self._get_tensor(tensor_name, node_type)
  365. if tensor and not tensor.empty:
  366. res['statistics'] = tensor.get_tensor_statistics()
  367. res['shape'] = tensor.shape
  368. missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type)
  369. return res, missing_tensors