You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

watchpoint_handler.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the watchpoint stream handler."""
  16. import numpy as np
  17. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  18. DebuggerParamTypeError
  19. from mindinsight.debugger.common.log import logger as log
  20. from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD
  21. from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \
  22. WATCHPOINT_CONDITION_MAPPING
  23. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  24. class WatchpointHandler(StreamHandlerBase):
  25. """watchpoint Handler."""
  26. def __init__(self):
  27. self._watchpoints = {}
  28. self._deleted_watchpoints = []
  29. self._updated_watchpoints = {}
  30. self._latest_id = 0
  31. def put(self, value):
  32. """
  33. Put Watchpoint into watchpoint handler.
  34. Args:
  35. value (Watchpoint): The name of nodes that have been chosen.
  36. """
  37. new_id = value.watchpoint_id
  38. self._watchpoints[new_id] = value
  39. self._updated_watchpoints[new_id] = value
  40. self._latest_id = new_id
  41. log.debug("Put watchpoint %d into cache.", new_id)
  42. def sync_set_cmd(self):
  43. """Clean temp watchpoints."""
  44. self._deleted_watchpoints = []
  45. self._updated_watchpoints = {}
  46. def get_watchpoint_by_id(self, watchpoint_id):
  47. """Get watchpoint by watchpoint id."""
  48. watchpoint = self._watchpoints.get(watchpoint_id)
  49. if not watchpoint:
  50. log.error("Invalid watchpoint id %d", watchpoint_id)
  51. raise DebuggerParamValueError("Invalid watchpoint id {}".format(watchpoint_id))
  52. return watchpoint
  53. def get(self, filter_condition=False):
  54. """
  55. Get the watchpoints.
  56. Args:
  57. filter_condition (bool): If True, get all watchpoints without nodes. If False,
  58. get updated watchpoints in SetCMD proto format. Default: False.
  59. Returns:
  60. dict, the watchpoints.
  61. """
  62. reply = []
  63. if not filter_condition:
  64. # get watch condition list
  65. for _, watchpoint in self._watchpoints.items():
  66. watchpoint_info = watchpoint.get_watch_condition_info()
  67. reply.append(watchpoint_info)
  68. else:
  69. # get updated watchpoint list
  70. for _, watchpoint in self._updated_watchpoints.items():
  71. set_cmd = watchpoint.get_set_cmd()
  72. reply.append(set_cmd)
  73. reply.extend(self._deleted_watchpoints)
  74. log.debug("get the watch points with filter_condition:%s", filter_condition)
  75. return {'watch_points': reply}
  76. def set_watch_nodes(self, graph, watch_point_id):
  77. """
  78. set watch nodes for graph.
  79. Args:
  80. graph (dict): The graph with list of nodes.
  81. watch_point_id (int): The id of watchpoint.
  82. """
  83. if not (watch_point_id and graph):
  84. return
  85. self._validate_watchpoint_id(watch_point_id)
  86. log.debug("add watch flags")
  87. watchpoint = self._watchpoints.get(watch_point_id)
  88. self._set_watch_status_recursively(graph, watchpoint)
  89. def _set_watch_status_recursively(self, graph, watchpoint):
  90. """Set watch status to graph."""
  91. if not isinstance(graph, dict):
  92. log.warning("The graph is not dict.")
  93. return
  94. if graph.get('children'):
  95. self._set_watch_status_recursively(graph.get('children'), watchpoint)
  96. for node in graph.get('nodes', []):
  97. if not isinstance(node, dict):
  98. log.warning("The node is not dict.")
  99. return
  100. node_name = node.get('name')
  101. if not node_name:
  102. continue
  103. flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name'))
  104. node['watched'] = flag
  105. if node.get('nodes'):
  106. self._set_watch_status_recursively(node, watchpoint)
  107. def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None):
  108. """
  109. Create watchpoint.
  110. Args:
  111. watch_condition (dict): The watch condition.
  112. - condition (str): Accept `INF` or `NAN`.
  113. - param (list[float]): Not defined yet.
  114. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
  115. watch_point_id (int): The id of watchpoint.
  116. Returns:
  117. int, the new id of watchpoint.
  118. """
  119. validate_watch_condition(watch_condition)
  120. new_id = self._latest_id + 1
  121. watchpoint = Watchpoint(new_id, watch_condition)
  122. if watch_nodes:
  123. watchpoint.add_nodes(watch_nodes)
  124. elif watch_point_id:
  125. self._validate_watchpoint_id(watch_point_id)
  126. watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
  127. self.put(watchpoint)
  128. return new_id
  129. def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
  130. """
  131. Update watchpoint.
  132. Args:
  133. watch_point_id (int): The id of watchpoint.
  134. watch_nodes (list[str]): The list of node names.
  135. watched (bool): The update operator on nodes. If False, remove nodes from watch nodes.
  136. If True, add nodes to watch nodes. Default: False.
  137. Returns:
  138. dict, empty response.
  139. """
  140. self._validate_watchpoint_id(watch_point_id)
  141. watchpoint = self._watchpoints.get(watch_point_id)
  142. if watched:
  143. watchpoint.add_nodes(watch_nodes)
  144. else:
  145. watchpoint.remove_nodes(watch_nodes)
  146. self._updated_watchpoints[watch_point_id] = watchpoint
  147. log.debug("Update watchpoint %d in cache.", watch_point_id)
  148. def delete_watchpoint(self, watch_point_id):
  149. """
  150. Delete watchpoint.
  151. Args:
  152. watch_point_id (int): The id of watchpoint.
  153. Returns:
  154. dict, empty response.
  155. """
  156. self._validate_watchpoint_id(watch_point_id)
  157. self._watchpoints.pop(watch_point_id)
  158. set_cmd = SetCMD()
  159. set_cmd.id = watch_point_id
  160. set_cmd.delete = True
  161. self._deleted_watchpoints.append(set_cmd)
  162. log.debug("Delete watchpoint %d in cache.", watch_point_id)
  163. def _validate_watchpoint_id(self, watch_point_id):
  164. """Validate watchpoint id."""
  165. if watch_point_id and watch_point_id not in self._watchpoints:
  166. log.error("Invalid watchpoint id: %d.", watch_point_id)
  167. raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))
  168. class WatchpointHitHandler(StreamHandlerBase):
  169. """Watchpoint hit handler."""
  170. def __init__(self):
  171. self._hits = {}
  172. def put(self, value):
  173. """
  174. Put value into watchpoint hit cache. Called by grpc server.
  175. Args:
  176. value (dict): The watchpoint hit info.
  177. - tensor_proto (TensorProto): The message about hit tensor.
  178. - watchpoint (Watchpoint): The Watchpoint that a node hit.
  179. """
  180. watchpoint_hit = WatchpointHit(
  181. tensor_proto=value.get('tensor_proto'),
  182. watchpoint=value.get('watchpoint'),
  183. node_name=value.get('node_name')
  184. )
  185. node_name = value.get('node_name')
  186. hit_tensors = self._hits.get(node_name)
  187. if hit_tensors is None:
  188. hit_tensors = []
  189. self._hits[node_name] = hit_tensors
  190. if watchpoint_hit not in hit_tensors:
  191. hit_tensors.append(watchpoint_hit)
  192. def get(self, filter_condition=None):
  193. """
  194. Get watchpoint hit list.
  195. Args:
  196. filter_condition (str): Get the watchpoint hit according to specifiled node name.
  197. If not given, get all watchpoint hits. Default: None.
  198. Returns:
  199. dict, the watchpoint hit list.
  200. """
  201. if filter_condition is None:
  202. log.debug("Get all watchpoint hit list.")
  203. reply = self.get_watchpoint_hits()
  204. else:
  205. log.debug("Get the watchpoint for node: <%s>.", filter_condition)
  206. reply = self._hits.get(filter_condition)
  207. return reply
  208. def get_watchpoint_hits(self):
  209. """Return the list of watchpoint hits."""
  210. watch_point_hits = []
  211. for node_name, watchpoint_hits in self._hits.items():
  212. watch_points = [watchpoint_hit.watchpoint for watchpoint_hit in watchpoint_hits]
  213. watch_point_hits.append({
  214. 'node_name': node_name,
  215. 'watch_points': watch_points
  216. })
  217. return {'watch_point_hits': watch_point_hits}
  218. def _is_tensor_hit(self, tensor_name):
  219. """Check if the tensor is record in hit cache."""
  220. node_name = tensor_name.split(':')[0]
  221. watchpoint_hits = self.get(node_name)
  222. if watchpoint_hits is None:
  223. return False
  224. for watchpoint_hit in watchpoint_hits:
  225. if tensor_name == watchpoint_hit.tensor_name:
  226. return True
  227. return False
  228. def update_tensor_history(self, tensor_history):
  229. """
  230. Add hit flag to tensor history.
  231. Args:
  232. tensor_history (dict): The tensor history.
  233. """
  234. if not self._hits:
  235. return
  236. # add hit tensor names to `tensor_names`
  237. for tensor_info in tensor_history.get('tensor_history'):
  238. tensor_name = tensor_info['full_name']
  239. hit_flag = self._is_tensor_hit(tensor_name)
  240. tensor_info['is_hit'] = hit_flag
  241. def validate_watch_condition(watch_condition):
  242. """Validate watch condition."""
  243. if not isinstance(watch_condition, dict):
  244. log.error("<watch_condition> should be dict. %s received.", watch_condition)
  245. raise DebuggerParamTypeError("<watch_condition> should be dict.")
  246. # validate condition
  247. condition = watch_condition.get('condition')
  248. if condition not in WATCHPOINT_CONDITION_MAPPING.keys():
  249. log.error("Invalid watch condition. Acceptable values are <%s>.",
  250. str(WATCHPOINT_CONDITION_MAPPING.keys()))
  251. raise DebuggerParamValueError("Invalid watch condition value.")
  252. # validate param
  253. validate_watch_condition_params(watch_condition)
  254. def validate_watch_condition_params(watch_condition):
  255. """
  256. Validate watch condition parameters.
  257. Args:
  258. watch_condition (dict): Watch condition.
  259. - condition (str): Condition type. Should be in WATCHPOINT_CONDITION_MAPPING.
  260. - param (list): Condition value. Should be given for comparison condition. The value will
  261. be translated to np.float32.
  262. """
  263. condition = watch_condition.get('condition')
  264. param = watch_condition.get('param')
  265. if condition in ['NAN', 'INF', 'OVERFLOW']:
  266. if param:
  267. log.error("No param is expected for %s condition.", condition)
  268. raise DebuggerParamValueError("No param is expected.")
  269. else:
  270. if not isinstance(param, (float, int)):
  271. log.error("Number param should be given for condition <%s>.",
  272. condition)
  273. raise DebuggerParamValueError("Number param should be given.")
  274. if np.isinf(np.float32(param)):
  275. log.error("Condition param should be float32.")
  276. raise DebuggerParamValueError("The value of condition param should be within float32.")