You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

watchpoint_handler.py 12 kB


  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the watchpoint stream handler."""
  16. import numpy as np
  17. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  18. DebuggerParamTypeError
  19. from mindinsight.debugger.common.log import logger as log
  20. from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD
  21. from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \
  22. WATCHPOINT_CONDITION_MAPPING
  23. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  24. class WatchpointHandler(StreamHandlerBase):
  25. """watchpoint Handler."""
  26. def __init__(self):
  27. self._watchpoints = {}
  28. self._deleted_watchpoints = []
  29. self._updated_watchpoints = {}
  30. self._latest_id = 0
  31. def put(self, value):
  32. """
  33. Put Watchpoint into watchpoint handler.
  34. Args:
  35. value (Watchpoint): The name of nodes that have been chosen.
  36. """
  37. new_id = value.watchpoint_id
  38. self._watchpoints[new_id] = value
  39. self._updated_watchpoints[new_id] = value
  40. self._latest_id = new_id
  41. log.debug("Put watchpoint %d into cache.", new_id)
  42. def sync_set_cmd(self):
  43. """Clean temp watchpoints."""
  44. self._deleted_watchpoints = []
  45. self._updated_watchpoints = {}
  46. def get_watchpoint_by_id(self, watchpoint_id):
  47. """Get watchpoint by watchpoint id."""
  48. watchpoint = self._watchpoints.get(watchpoint_id)
  49. if not watchpoint:
  50. log.error("Invalid watchpoint id %d", watchpoint_id)
  51. raise DebuggerParamValueError("Invalid watchpoint id {}".format(watchpoint_id))
  52. return watchpoint
  53. def get(self, filter_condition=False):
  54. """
  55. Get the watchpoints.
  56. Args:
  57. filter_condition (bool): If True, get all watchpoints without nodes. If False,
  58. get updated watchpoints in SetCMD proto format. Default: False.
  59. Returns:
  60. dict, the watchpoints.
  61. """
  62. reply = []
  63. if not filter_condition:
  64. # get watch condition list
  65. for _, watchpoint in self._watchpoints.items():
  66. watchpoint_info = watchpoint.get_watch_condition_info()
  67. reply.append(watchpoint_info)
  68. else:
  69. # get updated watchpoint list
  70. for _, watchpoint in self._updated_watchpoints.items():
  71. set_cmd = watchpoint.get_set_cmd()
  72. reply.append(set_cmd)
  73. reply.extend(self._deleted_watchpoints)
  74. log.debug("get the watch points with filter_condition:%s", filter_condition)
  75. return {'watch_points': reply}
  76. def set_watch_nodes(self, graph, watch_point_id):
  77. """
  78. set watch nodes for graph.
  79. Args:
  80. graph (dict): The graph with list of nodes.
  81. watch_point_id (int): The id of watchpoint.
  82. """
  83. if not (watch_point_id and graph):
  84. return
  85. log.debug("add watch flags")
  86. watchpoint = self._watchpoints.get(watch_point_id)
  87. self._set_watch_status_recursively(graph, watchpoint)
  88. def _set_watch_status_recursively(self, graph, watchpoint):
  89. """Set watch status to graph."""
  90. if not isinstance(graph, dict):
  91. log.warning("The graph is not dict.")
  92. return
  93. if graph.get('children'):
  94. self._set_watch_status_recursively(graph.get('children'), watchpoint)
  95. for node in graph.get('nodes', []):
  96. if not isinstance(node, dict):
  97. log.warning("The node is not dict.")
  98. return
  99. node_name = node.get('name')
  100. if not node_name:
  101. continue
  102. flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name'))
  103. node['watched'] = flag
  104. if node.get('nodes'):
  105. self._set_watch_status_recursively(node, watchpoint)
  106. def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None):
  107. """
  108. Create watchpoint.
  109. Args:
  110. watch_condition (dict): The watch condition.
  111. - condition (str): Accept `INF` or `NAN`.
  112. - param (list[float]): Not defined yet.
  113. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  114. watch_point_id (int): The id of watchpoint.
  115. Returns:
  116. int, the new id of watchpoint.
  117. """
  118. validate_watch_condition(watch_condition)
  119. new_id = self._latest_id + 1
  120. watchpoint = Watchpoint(new_id, watch_condition)
  121. if watch_nodes:
  122. watchpoint.add_nodes(watch_nodes)
  123. elif watch_point_id:
  124. self.validate_watchpoint_id(watch_point_id)
  125. watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
  126. self.put(watchpoint)
  127. return new_id
  128. def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
  129. """
  130. Update watchpoint.
  131. Args:
  132. watch_point_id (int): The id of watchpoint.
  133. watch_nodes (list[str]): The list of node names.
  134. watched (bool): The update operator on nodes. If False, remove nodes from watch nodes.
  135. If True, add nodes to watch nodes. Default: False.
  136. Returns:
  137. dict, empty response.
  138. """
  139. self.validate_watchpoint_id(watch_point_id)
  140. watchpoint = self._watchpoints.get(watch_point_id)
  141. if watched:
  142. watchpoint.add_nodes(watch_nodes)
  143. else:
  144. watchpoint.remove_nodes(watch_nodes)
  145. self._updated_watchpoints[watch_point_id] = watchpoint
  146. log.debug("Update watchpoint %d in cache.", watch_point_id)
  147. def delete_watchpoint(self, watch_point_id):
  148. """
  149. Delete watchpoint.
  150. Args:
  151. watch_point_id (int): The id of watchpoint.
  152. Returns:
  153. dict, empty response.
  154. """
  155. self.validate_watchpoint_id(watch_point_id)
  156. self._watchpoints.pop(watch_point_id)
  157. set_cmd = SetCMD()
  158. set_cmd.id = watch_point_id
  159. set_cmd.delete = True
  160. self._deleted_watchpoints.append(set_cmd)
  161. log.debug("Delete watchpoint %d in cache.", watch_point_id)
  162. def validate_watchpoint_id(self, watch_point_id):
  163. """Validate watchpoint id."""
  164. if not isinstance(watch_point_id, int):
  165. log.error("Invalid watchpoint id %s. The watch point id should be int.", watch_point_id)
  166. raise DebuggerParamTypeError("Watchpoint id should be int type.")
  167. if watch_point_id and watch_point_id not in self._watchpoints:
  168. log.error("Invalid watchpoint id: %d.", watch_point_id)
  169. raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))
  170. class WatchpointHitHandler(StreamHandlerBase):
  171. """Watchpoint hit handler."""
  172. def __init__(self):
  173. self._hits = {}
  174. @property
  175. def empty(self):
  176. """Whether the watchpoint hit is empty."""
  177. return not self._hits
  178. def put(self, value):
  179. """
  180. Put value into watchpoint hit cache. Called by grpc server.
  181. Args:
  182. value (dict): The watchpoint hit info.
  183. - tensor_proto (TensorProto): The message about hit tensor.
  184. - watchpoint (Watchpoint): The Watchpoint that a node hit.
  185. """
  186. watchpoint_hit = WatchpointHit(
  187. tensor_proto=value.get('tensor_proto'),
  188. watchpoint=value.get('watchpoint'),
  189. node_name=value.get('node_name')
  190. )
  191. node_name = value.get('node_name')
  192. hit_tensors = self._hits.get(node_name)
  193. if hit_tensors is None:
  194. hit_tensors = []
  195. self._hits[node_name] = hit_tensors
  196. if watchpoint_hit not in hit_tensors:
  197. hit_tensors.append(watchpoint_hit)
  198. def get(self, filter_condition=None):
  199. """
  200. Get watchpoint hit list.
  201. Args:
  202. filter_condition (str): Get the watchpoint hit according to specified node name.
  203. If not given, get all watchpoint hits. Default: None.
  204. Returns:
  205. dict, the watchpoint hit list.
  206. """
  207. if filter_condition is None:
  208. log.debug("Get all watchpoint hit list.")
  209. reply = self.get_watchpoint_hits()
  210. else:
  211. log.debug("Get the watchpoint for node: <%s>.", filter_condition)
  212. reply = self._hits.get(filter_condition)
  213. return reply
  214. def get_watchpoint_hits(self):
  215. """Return the list of watchpoint hits."""
  216. watch_point_hits = []
  217. for node_name, watchpoint_hits in self._hits.items():
  218. watch_points = [watchpoint_hit.watchpoint for watchpoint_hit in watchpoint_hits]
  219. watch_point_hits.append({
  220. 'node_name': node_name,
  221. 'watch_points': watch_points
  222. })
  223. return {'watch_point_hits': watch_point_hits}
  224. def _is_tensor_hit(self, tensor_name):
  225. """Check if the tensor is record in hit cache."""
  226. node_name = tensor_name.split(':')[0]
  227. watchpoint_hits = self.get(node_name)
  228. if watchpoint_hits is None:
  229. return False
  230. for watchpoint_hit in watchpoint_hits:
  231. if tensor_name == watchpoint_hit.tensor_name:
  232. return True
  233. return False
  234. def update_tensor_history(self, tensor_history):
  235. """
  236. Add hit flag to tensor history.
  237. Args:
  238. tensor_history (dict): The tensor history.
  239. """
  240. if not self._hits:
  241. return
  242. # add hit tensor names to `tensor_names`
  243. for tensor_info in tensor_history.get('tensor_history'):
  244. tensor_name = tensor_info['full_name']
  245. hit_flag = self._is_tensor_hit(tensor_name)
  246. tensor_info['is_hit'] = hit_flag
  247. def validate_watch_condition(watch_condition):
  248. """Validate watch condition."""
  249. if not isinstance(watch_condition, dict):
  250. log.error("<watch_condition> should be dict. %s received.", watch_condition)
  251. raise DebuggerParamTypeError("<watch_condition> should be dict.")
  252. # validate condition
  253. condition = watch_condition.get('condition')
  254. if condition not in WATCHPOINT_CONDITION_MAPPING.keys():
  255. log.error("Invalid watch condition. Acceptable values are <%s>.",
  256. str(WATCHPOINT_CONDITION_MAPPING.keys()))
  257. raise DebuggerParamValueError("Invalid watch condition value.")
  258. # validate param
  259. validate_watch_condition_params(watch_condition)
  260. def validate_watch_condition_params(watch_condition):
  261. """
  262. Validate watch condition parameters.
  263. Args:
  264. watch_condition (dict): Watch condition.
  265. - condition (str): Condition type. Should be in WATCHPOINT_CONDITION_MAPPING.
  266. - param (list): Condition value. Should be given for comparison condition. The value will
  267. be translated to np.float32.
  268. """
  269. condition = watch_condition.get('condition')
  270. param = watch_condition.get('param')
  271. if condition in ['NAN', 'INF', 'OVERFLOW']:
  272. if param:
  273. log.error("No param is expected for %s condition.", condition)
  274. raise DebuggerParamValueError("No param is expected.")
  275. else:
  276. if not isinstance(param, (float, int)):
  277. log.error("Number param should be given for condition <%s>.",
  278. condition)
  279. raise DebuggerParamValueError("Number param should be given.")
  280. if np.isinf(np.float32(param)):
  281. log.error("Condition param should be float32.")
  282. raise DebuggerParamValueError("The value of condition param should be within float32.")