You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_handler.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the graph stream handler."""
  16. from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo
  17. from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum as CategoryTypeEnum
  18. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  19. DebuggerNodeNotInGraphError, DebuggerGraphNotExistError
  20. from mindinsight.debugger.common.log import LOGGER as log
  21. from mindinsight.debugger.common.utils import is_scope_type
  22. from mindinsight.debugger.stream_cache.debugger_graph import DebuggerGraph
  23. from mindinsight.debugger.stream_cache.debugger_multigraph import DebuggerMultiGraph
  24. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  25. class MultiCardGraphHandler:
  26. """Multi-card Graph Handler."""
  27. def __init__(self):
  28. self._graph_handlers = {0: GraphHandler()}
  29. @property
  30. def graph_handlers(self):
  31. """The property of whole_graph."""
  32. return self._graph_handlers
  33. def get_graph_handler_by_rank_id(self, rank_id=0):
  34. """Get handler by rank id"""
  35. if rank_id in self._graph_handlers:
  36. return self._graph_handlers.get(rank_id)
  37. log.error("There is no rank id %d.", rank_id)
  38. raise ValueError
  39. def put(self, value):
  40. """put graphs into graph_handlers"""
  41. for rank_id, graph in value.items():
  42. if rank_id not in self._graph_handlers:
  43. self._graph_handlers[rank_id] = GraphHandler()
  44. self._graph_handlers[rank_id].put(graph)
  45. def get(self, filter_condition=None, rank_id=0):
  46. """Get the graph of specific node for specific device."""
  47. if rank_id in self._graph_handlers:
  48. return self._graph_handlers.get(rank_id).get(filter_condition)
  49. log.error("There is no rank id %d.", rank_id)
  50. raise ValueError
  51. def has_graph(self):
  52. """check if has graph"""
  53. res = False
  54. for graph_handler in self._graph_handlers:
  55. res = res or graph_handler.graph
  56. return res
  57. def register_graph_handler(self, rank_id, graph_handler):
  58. """Register graph handler."""
  59. self._graph_handlers[rank_id] = graph_handler
  60. def clean(self):
  61. """Clean cache."""
  62. self.__init__()
  63. class GraphHandler(StreamHandlerBase):
  64. """Metadata Handler."""
  65. def __init__(self):
  66. # dict of <graph_name, GraphProto object>
  67. self._graph_proto = {}
  68. # dict of <graph_name, DebuggerGraph object>
  69. self._graph = {}
  70. self._searched_node_list = {}
  71. # list of node names in bfs order
  72. self.bfs_order = []
  73. # dict of <node full name, graph_name>
  74. self.graph_node_map = {}
  75. # dict of <node ui name, Node object> for all graphs
  76. self._all_leaf_nodes = {}
  77. # the whole graph
  78. self._whole_graph = None
  79. @property
  80. def whole_graph(self):
  81. """The property of whole_graph."""
  82. return self._whole_graph
  83. @property
  84. def graph(self):
  85. """The property of graph."""
  86. return self._graph_proto
  87. @property
  88. def graph_names(self):
  89. """The property of graph names."""
  90. return list(self._graph)
  91. @property
  92. def debugger_graph_obj(self):
  93. """The property of graph object."""
  94. return self._graph
  95. def put(self, value):
  96. """
  97. Put value into graph cache. Called by grpc server.
  98. Args:
  99. value (dict): The Graph proto message. Each item is format like (<graph_name>, GraphProto).
  100. """
  101. log.info("Put graph into cache.")
  102. sorted_value_list = self._sort_graph(value)
  103. for graph_name, graph_value in sorted_value_list:
  104. self._graph_proto[graph_name] = graph_value
  105. # build sub graph
  106. graph = DebuggerGraph()
  107. graph.build_graph(graph_value)
  108. self._graph[graph_name] = graph
  109. self.bfs_order.extend(graph.get_bfs_order())
  110. leaf_nodes = graph.leaf_nodes
  111. self._all_leaf_nodes.update(leaf_nodes)
  112. for _, node in leaf_nodes.items():
  113. self.graph_node_map[node.full_name] = graph_name
  114. # build whole graph
  115. graph = DebuggerMultiGraph()
  116. graph.add_graph(self._graph)
  117. self._whole_graph = graph
  118. def get(self, filter_condition=None):
  119. """
  120. Get the graph of specific node.
  121. Args:
  122. filter_condition (dict):
  123. - name (str): The full debug node name.
  124. - graph_name (str): The relative graph_name of the node.
  125. - single_node (bool): If True, return the graph from root
  126. to the specific node; else, return the sublayer of the
  127. graph. Default: False.
  128. Returns:
  129. dict, the metadata.
  130. """
  131. try:
  132. self._graph_exists()
  133. except DebuggerGraphNotExistError:
  134. log.warning('The graph is empty. To view a graph, '
  135. 'please start the training script first.')
  136. return {'graph': {}}
  137. graph = {}
  138. if filter_condition is None:
  139. filter_condition = {}
  140. graph = {'graph_names': self.graph_names}
  141. single_node = filter_condition.get('single_node', False)
  142. name = filter_condition.get('name')
  143. graph_name = filter_condition.get('graph_name')
  144. if single_node is True:
  145. nodes = self._get_single_node(name, graph_name)
  146. else:
  147. nodes = self._list_nodes(name, graph_name)
  148. graph.update(nodes)
  149. return {'graph': graph}
  150. def _get_single_node(self, name, graph_name=None):
  151. """
  152. Search node, and return every layer nodes until this node.
  153. Args:
  154. graph_name(str): The graph_name.
  155. name (str): The name of node.
  156. Returns:
  157. dict, every layer nodes until this node.
  158. """
  159. if graph_name:
  160. graph = self._get_graph(graph_name=graph_name)
  161. searched_graph = graph.search_single_node(name)
  162. else:
  163. searched_graph = self._whole_graph.search_single_node(name)
  164. return searched_graph
  165. def _list_nodes(self, scope, graph_name):
  166. """
  167. Get the nodes of every layer in graph.
  168. Args:
  169. scope (str): The name of a scope.
  170. graph_name(str): The graph name.
  171. Returns:
  172. TypedDict{'nodes': ['Node_1', ...], 'graph_names': ['graph_name_1', ...]},
  173. format is {'nodes': [<NodeObject>], 'graph_names': [<str>]}.
  174. example:
  175. {
  176. "nodes" : [
  177. {
  178. "attr" :
  179. {
  180. "index" : "i: 0\n"
  181. },
  182. "input" : {},
  183. "name" : "input_tensor",
  184. "output" :
  185. {
  186. "Default/TensorAdd-op17" :
  187. {
  188. "edge_type" : "data",
  189. "scope" : "name_scope",
  190. "shape" : [1, 16, 128, 128]
  191. }
  192. },
  193. "output_i" : -1,
  194. "proxy_input" : {},
  195. "proxy_output" : {},
  196. "independent_layout" : False,
  197. "subnode_count" : 0,
  198. "type" : "Data"
  199. }
  200. ]
  201. }
  202. """
  203. if graph_name:
  204. graph = self._get_graph(graph_name, scope)
  205. nodes = graph.list_node_by_scope(scope=scope)
  206. res = {'nodes': nodes}
  207. else:
  208. nodes = self._whole_graph.list_node_by_scope(scope=scope)
  209. res = {'nodes': nodes}
  210. return res
  211. def get_tensor_history(self, node_name, graph_name=None, depth=0):
  212. """
  213. Get the tensor history of a specified node.
  214. Args:
  215. node_name (str): The debug name of the node.
  216. graph_name (str): The graph_name. Default: None.
  217. depth (int): The number of layers the user
  218. wants to trace. Default is 0.
  219. Returns:
  220. dict, basic tensor history, only including tensor name and tensor type and node type.
  221. """
  222. graph_name, node_name = self._parse_node_name(node_name, graph_name)
  223. graph = self._get_graph(graph_name=graph_name, node_name=node_name)
  224. # validate node type, scope node has no tensor history
  225. node_type = graph.get_node_type(node_name)
  226. if is_scope_type(node_type):
  227. log.error("Scope type node has no tensor history.")
  228. raise DebuggerParamValueError("Invalid leaf node name.")
  229. # get tensor history
  230. tensor_history, cur_outputs_nums = graph.get_tensor_history(node_name, depth)
  231. # add the tensor type for tensor history
  232. self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output', graph_name)
  233. self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input', graph_name)
  234. log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name)
  235. return {'tensor_history': tensor_history}
  236. @staticmethod
  237. def _update_tensor_history(tensor_history, tensor_type, graph_name):
  238. """
  239. Add tensor source type for tensor history.
  240. Args:
  241. tensor_history (list[dict]): Tensor history from Graph stream. Each element has two
  242. keys: `node_type` and `name`. `node_type` refers to the type of the node which
  243. the tensor come from. `name` refers to the tensor name.
  244. tensor_type (str): The source type of the tensor. `input` or `output`.
  245. graph_name (str): The graph name.
  246. """
  247. for single_tensor_info in tensor_history:
  248. single_tensor_info['type'] = tensor_type
  249. single_tensor_info['graph_name'] = graph_name
  250. def search_nodes(self, pattern):
  251. """
  252. Search nodes by given pattern.
  253. Args:
  254. pattern (dict): Filter condition.
  255. - name (str): The name pattern.
  256. - graph_name (str): The graph name.
  257. - node_category (str): The node_category. Default: None
  258. - condition (dict): The additional filter condition.
  259. Returns:
  260. dict, the searched node.
  261. """
  262. graph_name = pattern.pop('graph_name', None)
  263. search_nodes = self.search_in_graph(pattern, graph_name)
  264. # construct to search tree
  265. graph = self._get_graph(graph_name=graph_name)
  266. format_nodes = graph.get_nodes(search_nodes)
  267. return {'nodes': format_nodes}
  268. def search_in_graph(self, pattern, graph_name=None):
  269. """
  270. Search nodes by given pattern.
  271. Args:
  272. pattern (dict): Filter condition.
  273. - name (str): The name pattern.
  274. - node_category (str): The node_category. Default: None.
  275. - condition (dict): The additional filter condition.
  276. graph_name (str): The graph name.
  277. Returns:
  278. list, the searched node list.
  279. """
  280. temp_node_list = []
  281. node_category = pattern.get('node_category')
  282. graph = self._get_graph(graph_name=graph_name)
  283. # filter nodes by name
  284. if pattern.get('name'):
  285. if node_category:
  286. # get leaf nodes for forward filter
  287. temp_node_list = graph.search_leaf_nodes_by_pattern(pattern.get('name'))
  288. else:
  289. # optimize search nodes
  290. temp_node_list = graph.search_nodes_by_pattern(pattern.get('name'))
  291. if not temp_node_list:
  292. log.debug("No node named %s", pattern.get('name'))
  293. return []
  294. # filter nodes by category
  295. if node_category:
  296. node_category = self._get_inner_node_category(node_category)
  297. condition = pattern['condition'].copy() if pattern.get('condition') else {}
  298. condition['search_range'] = temp_node_list
  299. temp_node_list = graph.search_nodes_by_category(node_category, condition=condition)
  300. return temp_node_list
  301. @staticmethod
  302. def _get_inner_node_category(node_category):
  303. """
  304. Get inner node category.
  305. Args:
  306. node_category (str): The node category supported in
  307. mindinsight.conditionmgr.condition.TargetTypeEnum.
  308. Returns:
  309. CategoryTypeEnum, the translated value.
  310. """
  311. try:
  312. res = CategoryTypeEnum(node_category)
  313. except ValueError as err:
  314. log.error("Invalid node category. %s", err)
  315. raise DebuggerParamValueError("Invalid node_category.")
  316. return res
  317. def get_graph_id_by_name(self, node_name):
  318. """
  319. Get graph id by full name.
  320. Args:
  321. node_name (str): The name of the node.
  322. Returns:
  323. str, the graph name of the node.
  324. Raises:
  325. DebuggerNodeNotInGraphError: If can not find the node in all graphs.
  326. """
  327. if node_name:
  328. for graph_name, sub_graph in self._graph.items():
  329. if sub_graph.exist_node(name=node_name):
  330. return graph_name
  331. log.error('Failed to find node %s in graph. Please make sure the graph has been sent and '
  332. 'the node name is correct, and try again.', node_name)
  333. raise DebuggerGraphNotExistError
  334. def get_graph_id_by_full_name(self, node_name):
  335. """
  336. Get graph id by full name.
  337. Args:
  338. node_name (str): The full name of the node.
  339. Returns:
  340. str, the graph name of the node.
  341. Raises:
  342. DebuggerNodeNotInGraphError: If can not find the node in all graphs.
  343. """
  344. graph_id = self.graph_node_map.get(node_name) if node_name else None
  345. if not graph_id:
  346. log.warning("Failed to get graph id by full name: %s", node_name)
  347. return graph_id
  348. def get_node_type(self, node_name, graph_name=None):
  349. """
  350. Get the type of the specified node.
  351. Args:
  352. node_name (str): The debug name of the node.
  353. graph_name (str): The relative graph_name of the node. Default: None.
  354. Returns:
  355. A string of the node type, name_scope or leaf.
  356. """
  357. if graph_name:
  358. graph = self._get_graph(node_name=node_name, graph_name=graph_name)
  359. else:
  360. graph = self._whole_graph
  361. node_type = graph.get_node_type(node_name)
  362. return node_type
  363. def get_full_name(self, node_name, graph_name=None):
  364. """Get full name according to ui node name."""
  365. full_name = ''
  366. if node_name:
  367. graph = self._get_graph(node_name=node_name, graph_name=graph_name)
  368. full_name = graph.get_full_name_by_node_name(node_name)
  369. return full_name
  370. def get_node_basic_info(self, node_name, graph_name):
  371. """Get node basic info with graph scope."""
  372. graph_name, node_name = self._parse_node_name(node_name=node_name, graph_name=graph_name)
  373. graph = self._get_graph(graph_name, node_name)
  374. full_name = graph.get_full_name_by_node_name(node_name)
  375. node_type = graph.get_node_type(node_name)
  376. return self.construct_node_basic_info(full_name, graph_name, node_name, node_type)
  377. def get_tensor_graph(self, tensor_name, graph_name):
  378. """
  379. Get tensor graph according to node name.
  380. Args:
  381. tensor_name (str): Tensor name from UI, format is "node_name:slot".
  382. graph_name (str): The relative graph_name of the node. Default: None.
  383. Returns:
  384. dict, relative node.
  385. """
  386. node_name, _ = tensor_name.rsplit(':', 1)
  387. graph = self._get_graph(graph_name=graph_name, node_name=node_name)
  388. tensor_graph = graph.get_tensor_graph(node_name)
  389. return {'graph': tensor_graph}
  390. @staticmethod
  391. def construct_node_basic_info(full_name, graph_name, node_name, node_type):
  392. """Construct node basic info."""
  393. node_name_with_graph_scope = '/'.join([graph_name, node_name]) if node_name else graph_name
  394. return NodeBasicInfo(name=node_name_with_graph_scope, full_name=full_name, type=node_type)
  395. def get_node_basic_info_by_scope(self, scope_name, graph_name):
  396. """
  397. Get node by a given scope name.
  398. Args:
  399. scope_name (str): The name of scope.
  400. graph_name (str): The relative graph_name of the watched node. Default: None.
  401. Returns:
  402. list[NodeBasicInfo], a list of node.
  403. """
  404. graph_name, node_name = self._parse_node_name(scope_name, graph_name)
  405. graph = self._get_graph(graph_name)
  406. # to make sure fully match the scope name
  407. node_name = node_name + '/' if node_name and not node_name.endswith('/') else node_name
  408. nodes = graph.search_leaf_nodes_by_pattern(node_name, True)
  409. res = [self.construct_node_basic_info(full_name=node.full_name,
  410. graph_name=graph_name,
  411. node_name=node.name,
  412. node_type=node.type) for node in nodes]
  413. return res
  414. def get_node_name_by_full_name(self, full_name, graph_name):
  415. """Get UI node name by full name and graph name."""
  416. if graph_name and full_name:
  417. graph = self._get_graph(graph_name)
  418. node_name = graph.get_node_name_by_full_name(full_name)
  419. else:
  420. node_name = ''
  421. log.debug("Get empty full name.")
  422. return node_name
  423. def _get_next_node_in_bfs(self, index, length, ascend):
  424. """
  425. Get the next node in bfs order.
  426. Args:
  427. index (int): The current index.
  428. length (int): The number of all leaf nodes.
  429. ascend (bool): Whether get the node in ascend order or not.
  430. Returns:
  431. Union[None, dict], the next node object in dict type or None.
  432. """
  433. next_node = None
  434. if 0 <= index < length:
  435. if ascend is True and index < length - 1:
  436. next_node = self.bfs_order[index + 1]
  437. elif ascend is False and index > 0:
  438. next_node = self.bfs_order[index - 1]
  439. return next_node
  440. def _graph_exists(self):
  441. """
  442. Check if the graph has been loaded in the debugger cache.
  443. Raises:
  444. DebuggerGraphNotExistError: If the graph does not exist.
  445. """
  446. if not self._graph:
  447. log.error('The graph does not exist. Please start the '
  448. 'training script and try again.')
  449. raise DebuggerGraphNotExistError
  450. def _get_graph(self, graph_name=None, node_name=None):
  451. """
  452. Get the graph object according to graph name and node name.
  453. Args:
  454. graph_name (str): The graph name.
  455. node_name (str): The node name.
  456. Returns:
  457. DebuggerGraph, the graph object.
  458. Raises:
  459. DebuggerGraphNotExistError: If the graph does not exist.
  460. """
  461. graph = self._graph.get(graph_name) if graph_name else self._whole_graph
  462. # get graph according to graph name and check the node
  463. if graph and (not node_name or graph.exist_node(name=node_name)):
  464. return graph
  465. log.error('The graph %s does not exist node %s.', graph_name, node_name)
  466. raise DebuggerGraphNotExistError
  467. def _has_graph_scope(self, graph_name):
  468. """Check if query with graph_scope."""
  469. return bool(graph_name is None and len(self._graph) > 1)
  470. def validate_graph_name(self, graph_name):
  471. """Validate graph_name."""
  472. if graph_name and self._graph.get(graph_name) is None:
  473. log.error("No graph named %s in debugger cache.", graph_name)
  474. raise DebuggerGraphNotExistError
  475. if not graph_name and len(self._graph) == 1:
  476. graph_name = self.graph_names[0]
  477. return graph_name
  478. def _add_graph_scope_for_nodes(self, nodes, graph_name):
  479. """
  480. Add graph scope for nodes.
  481. Args:
  482. nodes (list[Node]): List of nodes object.
  483. graph_name (str): The graph name.
  484. """
  485. def _get_updated_node_info(cur_node, node_type):
  486. """Add graph scope in key."""
  487. old_node = cur_node.get(node_type)
  488. if not old_node:
  489. return
  490. new_values = {}
  491. for old_name, node_info in old_node.items():
  492. new_name = '/'.join([graph_name, old_name]) if old_name else graph_name
  493. new_values[new_name] = node_info
  494. cur_node[node_type] = new_values
  495. for node in nodes:
  496. node['name'] = '/'.join([graph_name, node['name']]) if node['name'] else graph_name
  497. _get_updated_node_info(node, 'input')
  498. _get_updated_node_info(node, 'output')
  499. if node.get('nodes'):
  500. self._add_graph_scope_for_nodes(node.get('nodes'), graph_name)
  501. def _parse_node_name(self, node_name, graph_name):
  502. """
  503. Check if the node name should have graph scope.
  504. Args:
  505. node_name (str): The ui node name.
  506. graph_name (str): The graph name.
  507. Returns:
  508. str, parsed graph name.
  509. str, parsed node name.
  510. """
  511. node_name = '' if node_name is None else node_name
  512. if self._has_graph_scope(graph_name):
  513. names = node_name.split("/", 1)
  514. graph_name = names[0]
  515. node_name = names[1] if len(names) == 2 else ''
  516. if graph_name is None and len(self._graph) == 1:
  517. graph_name = self.graph_names[0]
  518. return graph_name, node_name
  519. def validate_node_name(self, node_name, graph_name):
  520. """
  521. Validate the graph exist the specified node.
  522. Args:
  523. node_name (str): The ui node name.
  524. graph_name (str): The graph name.
  525. Raises:
  526. DebuggerNodeNotInGraphError: If can not find the node in all graphs.
  527. """
  528. graph = self._get_graph(graph_name=graph_name)
  529. if not graph.exist_node(name=node_name):
  530. log.error("graph %s doesn't find node: %s.", graph_name, node_name)
  531. raise DebuggerNodeNotInGraphError(node_name)
  532. @staticmethod
  533. def _sort_graph(graphs):
  534. """
  535. Sort graph by graph_name.
  536. Args:
  537. graphs(dict): <graph_name, GraphProto object>.
  538. """
  539. if len(graphs) == 1:
  540. return graphs.items()
  541. sorted_graphs = sorted(graphs.items(), key=lambda x: get_graph_number(x[0]))
  542. return sorted_graphs
  543. def get_graph_number(graph_name):
  544. number = graph_name.split("_")[-1]
  545. return int(number)