You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_server.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Implement the debugger server."""
  16. import signal
  17. from concurrent import futures
  18. from threading import Thread
  19. import grpc
  20. from mindinsight.conf import settings
  21. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  22. from mindinsight.datavisual.utils.tools import to_float
  23. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  24. DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \
  25. DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError
  26. from mindinsight.debugger.common.log import logger as log
  27. from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
  28. create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo, \
  29. str_to_slice_or_int
  30. from mindinsight.debugger.debugger_cache import DebuggerCache
  31. from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
  32. from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
  33. from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
  34. from mindinsight.utils.exceptions import MindInsightException
  35. class DebuggerServer:
  36. """The server manager of debugger."""
  37. def __init__(self, grpc_port=None):
  38. self.grpc_port = grpc_port
  39. self.cache_store = DebuggerCache()
  40. self.grpc_server = DebuggerGrpcServer(self.cache_store)
  41. self.grpc_server_manager = None
  42. self.back_server = None
  43. self._watch_point_id = 0
  44. def start(self):
  45. """Start server."""
  46. grpc_port = self.grpc_port if self.grpc_port else "50051"
  47. host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
  48. hostname = "{}:{}".format(host, grpc_port)
  49. # initialize a grpc server
  50. grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
  51. grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager)
  52. grpc_server_manager.add_insecure_port(hostname)
  53. grpc_server_manager.start()
  54. my_server_thread = Thread(target=grpc_server_manager.wait_for_termination)
  55. # start grpc server
  56. my_server_thread.start()
  57. self.back_server = my_server_thread
  58. self.grpc_server_manager = grpc_server_manager
  59. # register stop server handler
  60. signal.signal(signal.SIGINT, self._stop_handler)
  61. log.info("Start grpc server %s", hostname)
  62. def _stop_handler(self, signum, frame):
  63. """Register stop server handler."""
  64. self.stop()
  65. log.debug("Deal with stop signal: %s, %s", signum, frame)
  66. def stop(self):
  67. """Stop debugger server."""
  68. self.grpc_server_manager.stop(grace=None)
  69. self.back_server.join()
  70. log.info("Stop debugger server.")
  71. def poll_data(self, pos):
  72. """
  73. Get the pos-th data from DebuggerCache.
  74. Args:
  75. pos (int): The index of data.
  76. Returns:
  77. dict, the data to be updated.
  78. """
  79. if not isinstance(pos, str):
  80. log.error("Pos should be string. Received: %s", pos)
  81. raise DebuggerParamValueError("Pos should be string.")
  82. reply = self.cache_store.get_data(pos)
  83. return reply
  84. def search(self, name, watch_point_id):
  85. """Search for single node in graph."""
  86. log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id)
  87. graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name)
  88. self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes(
  89. graph, watch_point_id)
  90. return graph
  91. def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
  92. """
  93. Get tensor comparisons data for given name, detail, shape and tolerance.
  94. Args:
  95. name (str): The name of tensor for ui.
  96. detail (str): Specify which data to query. Current available value is 'data' which means
  97. concrete tensor data. Histogram or unique count can be supported in the future.
  98. shape (str): Specify concrete dimensions of shape.
  99. tolerance (str): Specify tolerance of difference between current step tensor and previous
  100. step tensor. Default value is 0.
  101. Raises:
  102. DebuggerParamValueError, If node type is not parameter or value of detail is not support.
  103. DebuggerCompareTensorError, If MindSpore is not in waiting state.
  104. Returns:
  105. dict, the retrieved data.
  106. """
  107. if self.cache_store.get_stream_handler(
  108. Streams.METADATA).state != ServerStatus.WAITING.value:
  109. log.error("Failed to compare tensors as the MindSpore is not in waiting state.")
  110. raise DebuggerCompareTensorError(
  111. "Failed to compare tensors as the MindSpore is not in waiting state."
  112. )
  113. self.validate_tensor_param(name, detail)
  114. parsed_shape = self.parse_shape(shape)
  115. node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
  116. tolerance = to_float(tolerance, 'tolerance')
  117. tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
  118. if detail == 'data':
  119. if node_type == NodeTypeEnum.PARAMETER.value:
  120. reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance)
  121. else:
  122. raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type))
  123. else:
  124. raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail))
  125. return reply
  126. def retrieve(self, mode, filter_condition=None):
  127. """
  128. Retrieve data according to mode and params.
  129. Args:
  130. mode (str): The type of info message.
  131. filter_condition (dict): The filter condition.
  132. Returns:
  133. dict, the retrieved data.
  134. """
  135. log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode,
  136. filter_condition)
  137. # validate watchpoint_id
  138. mode_mapping = {
  139. 'all': self._retrieve_all,
  140. 'node': self._retrieve_node,
  141. 'watchpoint': self._retrieve_watchpoint,
  142. 'watchpoint_hit': self._retrieve_watchpoint_hit
  143. }
  144. # validate param <mode>
  145. if mode not in mode_mapping.keys():
  146. log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
  147. "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping)
  148. raise DebuggerParamTypeError("Invalid mode.")
  149. filter_condition = {} if filter_condition is None else filter_condition
  150. self._watch_point_id = filter_condition.get('watch_point_id', 0)
  151. reply = mode_mapping[mode](filter_condition)
  152. return reply
  153. def _retrieve_all(self, filter_condition=None):
  154. """Retrieve metadata, root graph and watchpoint list."""
  155. if filter_condition:
  156. log.error("No filter condition required for retrieve all request.")
  157. raise DebuggerParamTypeError("filter_condition should be empty.")
  158. result = {}
  159. self.cache_store.clean_data()
  160. log.info("Clean data queue cache when retrieve all request.")
  161. self.cache_store.put_command({'reset': True})
  162. for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]:
  163. sub_res = self.cache_store.get_stream_handler(stream).get()
  164. result.update(sub_res)
  165. return result
  166. def _retrieve_node(self, filter_condition):
  167. """
  168. Retrieve node info.
  169. Args:
  170. filter_condition (dict): Filter condition.
  171. - name (str): The name of single node.
  172. - watch_point_id (int): The id of watchpoint.
  173. - single_node (bool): If False, return the sub-layer of single node. If True, return
  174. the node list from root node to single node.
  175. Returns:
  176. dict, the node info.
  177. """
  178. log.info("Retrieve node %s.", filter_condition)
  179. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  180. # validate parameters
  181. node_name = filter_condition.get('name')
  182. if not node_name:
  183. node_type = NodeTypeEnum.NAME_SCOPE.value
  184. else:
  185. node_type = graph_stream.get_node_type(node_name)
  186. filter_condition['node_type'] = node_type
  187. filter_condition['single_node'] = bool(filter_condition.get('single_node'))
  188. # get graph for scope node
  189. if is_scope_type(node_type):
  190. reply = self._get_nodes_info(filter_condition)
  191. # get tensor history for leaf node
  192. else:
  193. reply = self._get_tensor_history(node_name)
  194. if filter_condition.get('single_node'):
  195. graph = self._get_nodes_info(filter_condition)
  196. reply.update(graph)
  197. return reply
  198. def _get_nodes_info(self, filter_condition):
  199. """
  200. Get nodes info.
  201. Args:
  202. filter_condition (dict): The filter condition.
  203. - name (str): The node name.
  204. - single_node (bool): If False, return the sub-layer of single node. If True, return
  205. the node list from root node to single node.
  206. - watch_point_id (int): The id of watchpoint.
  207. Returns:
  208. dict, reply with graph.
  209. """
  210. # get graph
  211. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  212. reply = graph_stream.get(filter_condition)
  213. graph = reply.get('graph')
  214. # add watched label
  215. self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes(
  216. graph, self._watch_point_id)
  217. return reply
  218. def retrieve_tensor_history(self, node_name):
  219. """
  220. Retrieve tensor history for leaf node.
  221. Args:
  222. node_name (str): The name of leaf node.
  223. Returns:
  224. dict, the tensor history and metadata.
  225. """
  226. log.info("Retrieve tensor history for node: %s.", node_name)
  227. self._validate_leaf_name(node_name)
  228. res = self._get_tensor_history(node_name)
  229. return res
  230. def _validate_leaf_name(self, node_name):
  231. """Validate if the node is a leaf node."""
  232. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  233. node_type = graph_stream.get_node_type(node_name)
  234. if is_scope_type(node_type):
  235. log.error("Scope type node has no tensor history.")
  236. raise DebuggerParamValueError("Invalid leaf node name.")
  237. def _get_tensor_history(self, node_name):
  238. """
  239. Get tensor history for single node.
  240. Args:
  241. node_name (str): The name of leaf node.
  242. Returns:
  243. dict, the tensor history and metadata.
  244. """
  245. # get basic tensor history
  246. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  247. tensor_history = graph_stream.get_tensor_history(node_name)
  248. # set the view event
  249. self.cache_store.put_command(
  250. {'reset': True,
  251. 'node_name': node_name,
  252. 'tensor_history': tensor_history.get('tensor_history')})
  253. # add tensor value for tensor history
  254. self._add_tensor_value_for_tensor_history(tensor_history)
  255. # add hit label for tensor history
  256. watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  257. watchpoint_hit_stream.update_tensor_history(tensor_history)
  258. # add metadata
  259. metadata = self.cache_store.get_stream_handler(Streams.METADATA).get()
  260. tensor_history.update(metadata)
  261. return tensor_history
  262. def _add_tensor_value_for_tensor_history(self, tensor_history):
  263. """
  264. Add tensor value for_tensor_history and send ViewCMD if tensor value missed.
  265. Args:
  266. tensor_history (list[dict]): A list of tensor info, including name and type.
  267. Returns:
  268. dict, the tensor info.
  269. """
  270. tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
  271. missed_tensors = tensor_stream.update_tensor_history(tensor_history)
  272. if missed_tensors:
  273. view_cmd = create_view_event_from_tensor_history(missed_tensors)
  274. self.cache_store.put_command(view_cmd)
  275. log.debug("Send view cmd.")
  276. def retrieve_tensor_value(self, name, detail, shape):
  277. """Retrieve the tensor value."""
  278. log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape)
  279. self.validate_tensor_param(name, detail)
  280. parsed_shape = self.parse_shape(shape)
  281. node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
  282. reply = self.cache_store.get_stream_handler(Streams.TENSOR).get(
  283. {'name': tensor_name,
  284. 'node_type': node_type,
  285. 'shape': parsed_shape}
  286. )
  287. reply['tensor_value']['name'] = name
  288. return reply
  289. def _get_tensor_name_and_type_by_ui_name(self, name):
  290. """
  291. Get inner tensor name and type by UI name.
  292. Args:
  293. name (str): Node name shown in UI.
  294. Returns:
  295. str, full name of tensor.
  296. str, node type of tensor.
  297. """
  298. node_name, slot = name.rsplit(':', 1)
  299. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  300. node_type = graph_stream.get_node_type(node_name)
  301. full_name = graph_stream.get_full_name(node_name)
  302. tensor_name = full_name + ':' + slot
  303. return node_type, tensor_name
  304. @staticmethod
  305. def validate_tensor_param(name, detail):
  306. """Validate params for retrieve tensor request."""
  307. # validate name
  308. if not isinstance(name, str) or ':' not in name:
  309. log.error("Invalid tensor name. Received: %s", name)
  310. raise DebuggerParamValueError("Invalid tensor name.")
  311. # validate data
  312. if detail != 'data':
  313. log.error("Invalid detail value. Received: %s", detail)
  314. raise DebuggerParamValueError("Invalid detail value.")
  315. @staticmethod
  316. def parse_shape(shape):
  317. """Parse shape."""
  318. if shape is None:
  319. return shape
  320. if not (isinstance(shape, str) and shape.startswith('[') and shape.endswith(']')):
  321. log.error("Invalid shape. Received: %s", shape)
  322. raise DebuggerParamValueError("Invalid shape.")
  323. shape = shape.strip('[]')
  324. if shape.count(':') > 2:
  325. log.error("Invalid shape. At most two dimensions are specified.")
  326. raise DebuggerParamValueError("Invalid shape.")
  327. parsed_shape = tuple(
  328. str_to_slice_or_int(dim) for dim in shape.split(',')) if shape else tuple()
  329. log.info("Parsed shape: %s from %s", parsed_shape, shape)
  330. return parsed_shape
  331. def _retrieve_watchpoint(self, filter_condition):
  332. """
  333. Retrieve watchpoint.
  334. Args:
  335. filter_condition (dict): Filter condition.
  336. - watch_point_id (int): The id of watchoint. If not given, return all watchpoints.
  337. - name (str): The name of single node.
  338. - single_node (bool): If False, return the sub-layer of single node. If True, return
  339. the node list from root node to single node.
  340. Returns:
  341. dict, watch point list or relative graph.
  342. """
  343. watchpoint_id = filter_condition.get('watch_point_id')
  344. if watchpoint_id is None:
  345. reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
  346. log.debug("Get condition of watchpoints.")
  347. else:
  348. reply = self._retrieve_node(filter_condition)
  349. log.debug("Get graph of %d-th watchpoint.", watchpoint_id)
  350. return reply
  351. def _retrieve_watchpoint_hit(self, filter_condition):
  352. """
  353. Retrieve watchpoint hit.
  354. Args:
  355. filter_condition (dict): Filter condition.
  356. - name (str): The name of single node.
  357. - single_node (bool): If False, return the sub-layer of single node. If True, return
  358. the node list from root node to single node.
  359. Returns:
  360. dict, watch point list or relative graph.
  361. """
  362. node_name = filter_condition.get('name')
  363. # get watchpoint hit list
  364. if node_name is None:
  365. reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get()
  366. return reply
  367. self._validate_leaf_name(node_name)
  368. # get tensor history
  369. reply = self._get_tensor_history(node_name)
  370. log.debug("Get tensor history for watchpoint hit node.")
  371. # get single graph
  372. if filter_condition.get('single_node'):
  373. graph = self._get_nodes_info(filter_condition)
  374. reply.update(graph)
  375. log.debug("Get tensor history for watchpoint hit node.")
  376. return reply
  377. def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None):
  378. """
  379. Create watchpoint.
  380. Args:
  381. watch_condition (dict): The watch condition.
  382. - condition (str): Accept `INF` or `NAN`.
  383. - param (list[float]): Not defined yet.
  384. watch_nodes (list[str]): The list of node names.
  385. watch_point_id (int): The id of watchpoint.
  386. Returns:
  387. dict, the id of new watchpoint.
  388. """
  389. log.info("Received create watchpoint request. WatchCondition: %s", watch_condition)
  390. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  391. if metadata_stream.state != ServerStatus.WAITING.value:
  392. log.error("Failed to create watchpoint as the MindSpore is not in waiting state.")
  393. raise DebuggerCreateWatchPointError(
  394. "Failed to create watchpoint as the MindSpore is not in waiting state."
  395. )
  396. if metadata_stream.backend == 'GPU' and watch_condition.get('condition') == 'OVERFLOW':
  397. log.error("GPU doesn't support OVERFLOW watch condition.")
  398. raise DebuggerParamValueError("GPU doesn't support OVERFLOW watch condition.")
  399. watch_nodes = self._get_node_basic_infos(watch_nodes)
  400. watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint(
  401. watch_condition, watch_nodes, watch_point_id)
  402. log.info("Create watchpoint %d", watch_point_id)
  403. return {'id': watch_point_id}
  404. def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None):
  405. """
  406. Update watchpoint.
  407. Args:
  408. watch_point_id (int): The id of watchpoint.
  409. watch_nodes (list[str]): The list of node names.
  410. mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
  411. 1 for add nodes to watch nodes.
  412. name (str): The search name. Default: None.
  413. Returns:
  414. dict, empty response.
  415. """
  416. if self.cache_store.get_stream_handler(
  417. Streams.METADATA).state != ServerStatus.WAITING.value:
  418. log.error("Failed to update watchpoint as the MindSpore is not in waiting state.")
  419. raise DebuggerUpdateWatchPointError(
  420. "Failed to update watchpoint as the MindSpore is not in waiting state."
  421. )
  422. # validate
  423. if not watch_nodes or not watch_point_id:
  424. log.error("Invalid parameter for update watchpoint.")
  425. raise DebuggerParamValueError("Invalid parameter for update watchpoint.")
  426. # update watch node
  427. if name is not None:
  428. watch_nodes = self._get_watch_nodes_by_search(watch_nodes)
  429. elif mode == 1:
  430. watch_nodes = self._get_node_basic_infos(watch_nodes)
  431. self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint(
  432. watch_point_id, watch_nodes, mode)
  433. log.info("Update watchpoint with id: %d", watch_point_id)
  434. return {}
  435. def _get_watch_nodes_by_search(self, watch_nodes):
  436. """Get watched leaf nodes by search name."""
  437. watched_leaf_nodes = []
  438. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  439. for search_name in watch_nodes:
  440. search_nodes = graph_stream.get_searched_node_list()
  441. search_node_names = [
  442. NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
  443. for node in search_nodes
  444. if node.name.startswith(search_name)]
  445. watched_leaf_nodes.extend(search_node_names)
  446. log.debug("Update nodes: %s", watched_leaf_nodes)
  447. return watched_leaf_nodes
  448. def delete_watchpoint(self, watch_point_id):
  449. """
  450. Delete watchpoint.
  451. Args:
  452. watch_point_id (int): The id of watchpoint.
  453. Returns:
  454. dict, empty response.
  455. """
  456. if self.cache_store.get_stream_handler(
  457. Streams.METADATA).state != ServerStatus.WAITING.value:
  458. log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.")
  459. raise DebuggerDeleteWatchPointError(
  460. "Failed to delete watchpoint as the MindSpore is not in waiting state."
  461. )
  462. self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint(
  463. watch_point_id)
  464. log.info("Delete watchpoint with id: %d", watch_point_id)
  465. return {}
  466. def _get_node_basic_infos(self, node_names):
  467. """Get node info according to node names."""
  468. if not node_names:
  469. return []
  470. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  471. node_infos = []
  472. for node_name in node_names:
  473. node_type = graph_stream.get_node_type(node_name)
  474. # optimizer later
  475. if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value:
  476. sub_nodes = graph_stream.get_nodes(node_name)
  477. sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
  478. for node in sub_nodes]
  479. node_infos.extend(sub_infos)
  480. continue
  481. full_name = graph_stream.get_full_name(node_name)
  482. node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type))
  483. return node_infos
  484. def control(self, params=None):
  485. """
  486. Control the training process.
  487. Args:
  488. params (dict): The control params.
  489. - mode (str): Acceptable control command, including `continue`,
  490. `pause` and `terminate`.
  491. - level (str): The control granularity, `node` level or `step` level.
  492. Default: `step`.
  493. - steps (int): Specify the steps that training should run.
  494. Used when `level` is `step`.
  495. - name (str): Specify the name of the node. Used when `level` is `node`.
  496. Returns:
  497. dict, the response.
  498. """
  499. log.info("Receive control request: %s.", params)
  500. mode = params.get('mode')
  501. metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
  502. if mode == 'continue':
  503. reply = self._continue(metadata_stream, params)
  504. elif mode in ['pause', 'terminate']:
  505. mode_mapping = {
  506. 'pause': self._pause,
  507. 'terminate': self._terminate
  508. }
  509. reply = mode_mapping.get(mode)(metadata_stream)
  510. else:
  511. log.error("Invalid control mode %s", mode)
  512. raise DebuggerParamValueError("Invalid control mode.")
  513. return reply
  514. def _continue(self, metadata_stream, params):
  515. """
  516. Send RunCMD to MindSpore.
  517. Args:
  518. metadata_stream (MetadataHandler): The metadata_handler
  519. params (dict): The control params.
  520. """
  521. if metadata_stream.state != ServerStatus.WAITING.value:
  522. log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state)
  523. raise DebuggerContinueError(
  524. "MindSpore is not ready to run or is running currently."
  525. )
  526. metadata_stream.state = ServerStatus.RUNNING.value
  527. current_state = ServerStatus.RUNNING.value
  528. try:
  529. event = self._construct_run_event(params)
  530. self._send_watchpoints()
  531. self.cache_store.put_command(event)
  532. except MindInsightException as err:
  533. log.error("Failed to send run event.")
  534. log.exception(err)
  535. current_state = ServerStatus.WAITING.value
  536. metadata_stream.state = current_state
  537. raise DebuggerContinueError("Failed to send run command.")
  538. else:
  539. log.debug("Send the RunCMD to command queue.")
  540. return {'metadata': {'state': current_state}}
  541. def _validate_node_type(self, node_name):
  542. """Check the node type in node control."""
  543. if not node_name:
  544. return
  545. node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name)
  546. unsupported_types = [item.value for item in list(NodeTypeEnum)]
  547. if node_type in unsupported_types:
  548. log.error("Invalid node type. %s", node_name)
  549. raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for "
  550. "continue to command.")
  551. def _construct_run_event(self, params):
  552. """
  553. Construct run cmd from input control params.
  554. Args:
  555. params (dict): The control params.
  556. - level (str): The control granularity, `node` level or `step` level.
  557. Default: `step`.
  558. - steps (int): Specify the steps that training should run.
  559. Used when `level` is `step`.
  560. - full_name (str): Specify the name of the node. Used when `level` is `node`.
  561. Returns:
  562. EventReply, control event with run command.
  563. """
  564. level = params.get('level', 'step')
  565. event = get_ack_reply()
  566. if level == 'step':
  567. steps = params.get('steps')
  568. if not steps:
  569. steps = 1
  570. run_cmd = RunCMD(run_level='step', run_steps=steps)
  571. elif level == 'node':
  572. self._validate_node_type(params.get('name'))
  573. name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(
  574. params['name'])
  575. if not name:
  576. name = ''
  577. run_cmd = RunCMD(run_level='node', node_name=name)
  578. else:
  579. log.error("Invalid Value. `level` should be `step` or `node`. Got %s", level)
  580. raise DebuggerParamValueError("level` should be `step` or `node`")
  581. event.run_cmd.CopyFrom(run_cmd)
  582. log.debug("Construct run event. %s", event)
  583. return event
  584. def _send_watchpoints(self):
  585. """Set watchpoints."""
  586. watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
  587. watchpoints = watchpoint_stream.get(filter_condition=True).get('watch_points')
  588. if watchpoints:
  589. for watchpoint in watchpoints:
  590. event = get_ack_reply()
  591. event.set_cmd.CopyFrom(watchpoint)
  592. self.cache_store.put_command(event)
  593. watchpoint_stream.sync_set_cmd()
  594. log.debug("Send SetCMD to MindSpore. %s", event)
  595. def _pause(self, metadata_stream):
  596. """
  597. Pause the training.
  598. Args:
  599. metadata_stream (MetadataHandler): The metadata stream handler.
  600. """
  601. if metadata_stream.state != ServerStatus.RUNNING.value:
  602. log.error("The MindSpore is not running.")
  603. raise DebuggerPauseError("The MindSpore is not running.")
  604. metadata_stream.state = 'waiting'
  605. event = get_ack_reply()
  606. event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
  607. self.cache_store.put_command(event)
  608. log.debug("Send the Pause command")
  609. return {'metadata': {'state': 'waiting'}}
  610. def _terminate(self, metadata_stream):
  611. """
  612. Terminate the training.
  613. Args:
  614. metadata_stream (MetadataHandler): The metadata stream handler.
  615. """
  616. metadata_stream.state = 'pending'
  617. event = get_ack_reply()
  618. event.exit = True
  619. self.cache_store.put_command(event)
  620. log.debug("Send the ExitCMD.")
  621. return {'metadata': {'state': 'pending'}}
  622. def retrieve_node_by_bfs(self, node_name, ascend=False):
  623. """Get the graph and tensor history of the next node name according to node_name."""
  624. log.info("Retrieve node <%s> by bfs, `ascend` is :%s",
  625. node_name, ascend)
  626. reply = {}
  627. graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
  628. next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
  629. # no next node
  630. if next_node_name is None:
  631. return reply
  632. # add graph and tensor history for next node
  633. filter_condition = {
  634. 'name': next_node_name,
  635. 'single_node': True
  636. }
  637. search_graph = self._get_nodes_info(filter_condition)
  638. tensor_history = self._get_tensor_history(next_node_name)
  639. reply = {'name': next_node_name}
  640. reply.update(search_graph)
  641. reply.update(tensor_history)
  642. return reply