You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

watchpoint_handler.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the watchpoint stream handler."""
  16. import numpy as np
  17. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  18. DebuggerParamTypeError
  19. from mindinsight.debugger.common.log import logger as log
  20. from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD
  21. from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \
  22. WATCHPOINT_CONDITION_MAPPING
  23. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  24. class WatchpointHandler(StreamHandlerBase):
  25. """watchpoint Handler."""
  26. def __init__(self):
  27. self._watchpoints = {}
  28. self._deleted_watchpoints = []
  29. self._updated_watchpoints = {}
  30. self._latest_id = 0
  31. def put(self, value):
  32. """
  33. Put Watchpoint into watchpoint handler.
  34. Args:
  35. value (Watchpoint): The name of nodes that have been chosen.
  36. """
  37. new_id = value.watchpoint_id
  38. self._watchpoints[new_id] = value
  39. self._updated_watchpoints[new_id] = value
  40. self._latest_id = new_id
  41. log.debug("Put watchpoint %d into cache.", new_id)
  42. def sync_set_cmd(self):
  43. """Clean temp watchpoints."""
  44. self._deleted_watchpoints = []
  45. self._updated_watchpoints = {}
  46. def get_watchpoint_by_id(self, watchpoint_id):
  47. """Get watchpoint by watchpoint id."""
  48. watchpoint = self._watchpoints.get(watchpoint_id)
  49. if not watchpoint:
  50. log.error("Invalid watchpoint id %d", watchpoint_id)
  51. raise DebuggerParamValueError("Invalid watchpoint id {}".format(watchpoint_id))
  52. return watchpoint
  53. def get(self, filter_condition=False):
  54. """
  55. Get the watchpoints.
  56. Args:
  57. filter_condition (bool): If True, get all watchpoints without nodes. If False,
  58. get updated watchpoints in SetCMD proto format. Default: False.
  59. Returns:
  60. dict, the watchpoints.
  61. """
  62. reply = []
  63. if not filter_condition:
  64. # get watch condition list
  65. for _, watchpoint in self._watchpoints.items():
  66. watchpoint_info = watchpoint.get_watch_condition_info()
  67. reply.append(watchpoint_info)
  68. else:
  69. # get updated watchpoint list
  70. for _, watchpoint in self._updated_watchpoints.items():
  71. set_cmd = watchpoint.get_set_cmd()
  72. reply.append(set_cmd)
  73. reply.extend(self._deleted_watchpoints)
  74. log.debug("get the watch points with filter_condition:%s", filter_condition)
  75. return {'watch_points': reply}
  76. def set_watch_nodes(self, graph, watch_point_id):
  77. """
  78. set watch nodes for graph.
  79. Args:
  80. graph (dict): The graph with list of nodes.
  81. watch_point_id (int): The id of watchpoint.
  82. """
  83. if not (watch_point_id and graph):
  84. return
  85. log.debug("add watch flags")
  86. watchpoint = self._watchpoints.get(watch_point_id)
  87. self._set_watch_status_recursively(graph, watchpoint)
  88. def _set_watch_status_recursively(self, graph, watchpoint):
  89. """Set watch status to graph."""
  90. if not isinstance(graph, dict):
  91. log.warning("The graph is not dict.")
  92. return
  93. if graph.get('children'):
  94. self._set_watch_status_recursively(graph.get('children'), watchpoint)
  95. for node in graph.get('nodes', []):
  96. if not isinstance(node, dict):
  97. log.warning("The node is not dict.")
  98. return
  99. node_name = node.get('name')
  100. if not node_name:
  101. continue
  102. flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name'))
  103. node['watched'] = flag
  104. if node.get('nodes'):
  105. self._set_watch_status_recursively(node, watchpoint)
  106. def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None):
  107. """
  108. Create watchpoint.
  109. Args:
  110. watch_condition (dict): The watch condition.
  111. - condition (str): Accept `INF` or `NAN`.
  112. - param (list[float]): Not defined yet.
  113. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  114. watch_point_id (int): The id of watchpoint.
  115. Returns:
  116. int, the new id of watchpoint.
  117. """
  118. validate_watch_condition(watch_condition)
  119. new_id = self._latest_id + 1
  120. watchpoint = Watchpoint(new_id, watch_condition)
  121. if watch_nodes:
  122. watchpoint.add_nodes(watch_nodes)
  123. elif watch_point_id:
  124. self.validate_watchpoint_id(watch_point_id)
  125. watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
  126. self.put(watchpoint)
  127. return new_id
  128. def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
  129. """
  130. Update watchpoint.
  131. Args:
  132. watch_point_id (int): The id of watchpoint.
  133. watch_nodes (list[str]): The list of node names.
  134. watched (bool): The update operator on nodes. If False, remove nodes from watch nodes.
  135. If True, add nodes to watch nodes. Default: False.
  136. Returns:
  137. dict, empty response.
  138. """
  139. self.validate_watchpoint_id(watch_point_id)
  140. watchpoint = self._watchpoints.get(watch_point_id)
  141. if watched:
  142. watchpoint.add_nodes(watch_nodes)
  143. else:
  144. watchpoint.remove_nodes(watch_nodes)
  145. self._updated_watchpoints[watch_point_id] = watchpoint
  146. log.debug("Update watchpoint %d in cache.", watch_point_id)
  147. def delete_watchpoint(self, watch_point_id):
  148. """
  149. Delete watchpoint.
  150. Args:
  151. watch_point_id (int): The id of watchpoint.
  152. Returns:
  153. dict, empty response.
  154. """
  155. self.validate_watchpoint_id(watch_point_id)
  156. self._watchpoints.pop(watch_point_id)
  157. set_cmd = SetCMD()
  158. set_cmd.id = watch_point_id
  159. set_cmd.delete = True
  160. self._deleted_watchpoints.append(set_cmd)
  161. log.debug("Delete watchpoint %d in cache.", watch_point_id)
  162. def validate_watchpoint_id(self, watch_point_id):
  163. """Validate watchpoint id."""
  164. if not isinstance(watch_point_id, int):
  165. log.error("Invalid watchpoint id %s. The watch point id should be int.", watch_point_id)
  166. raise DebuggerParamTypeError("Watchpoint id should be int type.")
  167. if watch_point_id and watch_point_id not in self._watchpoints:
  168. log.error("Invalid watchpoint id: %d.", watch_point_id)
  169. raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))
  170. class WatchpointHitHandler(StreamHandlerBase):
  171. """Watchpoint hit handler."""
  172. def __init__(self):
  173. self._hits = {}
  174. def put(self, value):
  175. """
  176. Put value into watchpoint hit cache. Called by grpc server.
  177. Args:
  178. value (dict): The watchpoint hit info.
  179. - tensor_proto (TensorProto): The message about hit tensor.
  180. - watchpoint (Watchpoint): The Watchpoint that a node hit.
  181. """
  182. watchpoint_hit = WatchpointHit(
  183. tensor_proto=value.get('tensor_proto'),
  184. watchpoint=value.get('watchpoint'),
  185. node_name=value.get('node_name')
  186. )
  187. node_name = value.get('node_name')
  188. hit_tensors = self._hits.get(node_name)
  189. if hit_tensors is None:
  190. hit_tensors = []
  191. self._hits[node_name] = hit_tensors
  192. if watchpoint_hit not in hit_tensors:
  193. hit_tensors.append(watchpoint_hit)
  194. def get(self, filter_condition=None):
  195. """
  196. Get watchpoint hit list.
  197. Args:
  198. filter_condition (str): Get the watchpoint hit according to specifiled node name.
  199. If not given, get all watchpoint hits. Default: None.
  200. Returns:
  201. dict, the watchpoint hit list.
  202. """
  203. if filter_condition is None:
  204. log.debug("Get all watchpoint hit list.")
  205. reply = self.get_watchpoint_hits()
  206. else:
  207. log.debug("Get the watchpoint for node: <%s>.", filter_condition)
  208. reply = self._hits.get(filter_condition)
  209. return reply
  210. def get_watchpoint_hits(self):
  211. """Return the list of watchpoint hits."""
  212. watch_point_hits = []
  213. for node_name, watchpoint_hits in self._hits.items():
  214. watch_points = [watchpoint_hit.watchpoint for watchpoint_hit in watchpoint_hits]
  215. watch_point_hits.append({
  216. 'node_name': node_name,
  217. 'watch_points': watch_points
  218. })
  219. return {'watch_point_hits': watch_point_hits}
  220. def _is_tensor_hit(self, tensor_name):
  221. """Check if the tensor is record in hit cache."""
  222. node_name = tensor_name.split(':')[0]
  223. watchpoint_hits = self.get(node_name)
  224. if watchpoint_hits is None:
  225. return False
  226. for watchpoint_hit in watchpoint_hits:
  227. if tensor_name == watchpoint_hit.tensor_name:
  228. return True
  229. return False
  230. def update_tensor_history(self, tensor_history):
  231. """
  232. Add hit flag to tensor history.
  233. Args:
  234. tensor_history (dict): The tensor history.
  235. """
  236. if not self._hits:
  237. return
  238. # add hit tensor names to `tensor_names`
  239. for tensor_info in tensor_history.get('tensor_history'):
  240. tensor_name = tensor_info['full_name']
  241. hit_flag = self._is_tensor_hit(tensor_name)
  242. tensor_info['is_hit'] = hit_flag
  243. def validate_watch_condition(watch_condition):
  244. """Validate watch condition."""
  245. if not isinstance(watch_condition, dict):
  246. log.error("<watch_condition> should be dict. %s received.", watch_condition)
  247. raise DebuggerParamTypeError("<watch_condition> should be dict.")
  248. # validate condition
  249. condition = watch_condition.get('condition')
  250. if condition not in WATCHPOINT_CONDITION_MAPPING.keys():
  251. log.error("Invalid watch condition. Acceptable values are <%s>.",
  252. str(WATCHPOINT_CONDITION_MAPPING.keys()))
  253. raise DebuggerParamValueError("Invalid watch condition value.")
  254. # validate param
  255. validate_watch_condition_params(watch_condition)
  256. def validate_watch_condition_params(watch_condition):
  257. """
  258. Validate watch condition parameters.
  259. Args:
  260. watch_condition (dict): Watch condition.
  261. - condition (str): Condition type. Should be in WATCHPOINT_CONDITION_MAPPING.
  262. - param (list): Condition value. Should be given for comparison condition. The value will
  263. be translated to np.float32.
  264. """
  265. condition = watch_condition.get('condition')
  266. param = watch_condition.get('param')
  267. if condition in ['NAN', 'INF', 'OVERFLOW']:
  268. if param:
  269. log.error("No param is expected for %s condition.", condition)
  270. raise DebuggerParamValueError("No param is expected.")
  271. else:
  272. if not isinstance(param, (float, int)):
  273. log.error("Number param should be given for condition <%s>.",
  274. condition)
  275. raise DebuggerParamValueError("Number param should be given.")
  276. if np.isinf(np.float32(param)):
  277. log.error("Condition param should be float32.")
  278. raise DebuggerParamValueError("The value of condition param should be within float32.")