You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_offline_server.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Debugger Offline server."""
  16. import copy
  17. from collections import defaultdict
  18. from importlib import import_module
  19. from threading import Event
  20. from multiprocessing import Process, Manager
  21. import mindinsight
  22. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  23. from mindinsight.debugger.common.exceptions.exceptions import DebuggerModuleNotFoundError
  24. from mindinsight.debugger.common.log import LOGGER as log
  25. from mindinsight.debugger.common.utils import Streams, ServerStatus, version_match, DebuggerServerMode, get_ack_reply, \
  26. RunLevel
  27. from mindinsight.debugger.conditionmgr.condition import ParamNameEnum
  28. from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase, debugger_server_wrap
  29. from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
  30. from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto
  31. from mindinsight.debugger.stream_cache.data_loader import DataLoader
  32. from mindinsight.utils.exceptions import MindInsightException
  33. class DebuggerOfflineServer(DebuggerServerBase):
  34. """Debugger Offline Server."""
  35. _MAX_TRY_EXCEPT_COUNT = 500
  36. def __init__(self, cache_store, context):
  37. super(DebuggerOfflineServer, self).__init__(cache_store, context)
  38. self._offline_server_manager = DebuggerOfflineManager(cache_store, context.dbg_dir)
  39. self._running = Event()
  40. self._running.clear()
  41. def run(self):
  42. """Start the debugger offline server."""
  43. log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir)
  44. self._offline_server_manager.initialize()
  45. self._running.set()
  46. log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir)
  47. try_count = 0
  48. while self._running.is_set() and try_count < self._MAX_TRY_EXCEPT_COUNT:
  49. try:
  50. self._offline_server_manager.wait_for_termination()
  51. if not self._offline_server_manager.is_runnable():
  52. break
  53. except MindInsightException as err:
  54. log.exception(err)
  55. log.warning("Error happens during listening on user commands. Restart listening again.")
  56. finally:
  57. try_count += 1
  58. # protect server from too much failure commands.
  59. if try_count == self._MAX_TRY_EXCEPT_COUNT:
  60. self._cache_store.clean()
  61. metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
  62. self._cache_store.put_data(metadata)
  63. log.warning("Exception exceed %d times, stop server.", try_count)
  64. def stop(self):
  65. """Stop offline debugger server."""
  66. log.debug("Start to wait for thread started.")
  67. self._running.wait()
  68. log.info("Start to stop offline debugger server.")
  69. self._running.clear()
  70. self._offline_server_manager.stop()
  71. self.join()
  72. class DebuggerOfflineManager:
  73. """Debugger offline manager which is used to handle user commands."""
  74. def __init__(self, cache_store, dbg_dir):
  75. cache_store.initialize()
  76. self._cache_store = cache_store
  77. self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA)
  78. self._dbg_dir = dbg_dir
  79. self._dbg_services_module = self._get_dbg_service_module()
  80. self._dbg_service = None
  81. self._command_listener = CommandListener(cache_store)
  82. self._data_loader = DataLoader(dbg_dir)
  83. self._is_running_flag = False
  84. self._old_run_cmd = {}
  85. def stop(self):
  86. """Stop server."""
  87. self._is_running_flag = False
  88. self._command_listener.stop()
  89. self._cache_store.clean()
  90. event = get_ack_reply()
  91. event.exit = True
  92. self._cache_store.put_command(event)
  93. log.info("Stop debugger offline manager.")
  94. def is_runnable(self):
  95. """Check if the offline manager is runnable."""
  96. state = self._metadata_stream.state
  97. flag = self._is_running_flag and state not in [ServerStatus.MISMATCH.value, ServerStatus.PENDING.value]
  98. if not flag:
  99. log.debug("The offline manager is not runnable, is_running_flag: %s, metadata state: %s",
  100. self._is_running_flag, state)
  101. return flag
  102. @staticmethod
  103. def _get_dbg_service_module():
  104. """Get dbg service module from MindSpore."""
  105. try:
  106. dbg_services_module = import_module('mindspore.offline_debug.dbg_services')
  107. except (ModuleNotFoundError, ImportError) as err:
  108. log.error("Failed to find module dbg_services. %s", err)
  109. raise DebuggerModuleNotFoundError("dbg_services")
  110. return dbg_services_module
  111. @debugger_server_wrap
  112. def initialize(self):
  113. """Start to load offline debugger data."""
  114. self._data_loader.initialize()
  115. is_sync = self._data_loader.get_sync_flag()
  116. net_name = self._data_loader.get_net_name()
  117. net_dir = self._data_loader.get_net_dir()
  118. self._dbg_service = self._dbg_services_module.DbgServices(net_dir)
  119. self._dbg_service.initialize(net_name=net_name, is_sync_mode=is_sync)
  120. self._cache_store.clean()
  121. self._command_listener.start()
  122. self._is_running_flag = True
  123. self._check_version()
  124. if self._metadata_stream.state == ServerStatus.MISMATCH.value:
  125. log.info("The MindSpore and MindInsight version are mismatched. Failed to initialize offline server.")
  126. return
  127. self._load_metadata()
  128. self._load_graphs()
  129. log.info("Success initialize offline server for %s", self._dbg_dir)
  130. def _check_version(self):
  131. """Check version."""
  132. ms_version = self._dbg_services_module.get_version()
  133. mi_version = mindinsight.__version__
  134. self._metadata_stream.debugger_version = {'ms': ms_version, 'mi': mi_version}
  135. if version_match(ms_version, mi_version) is False:
  136. log.info("Version is mismatched, dbg_services is: %s, mindinsight is: %s",
  137. ms_version, mi_version)
  138. self._metadata_stream.state = ServerStatus.MISMATCH.value
  139. metadata = self._metadata_stream.get(['state', 'debugger_version'])
  140. self._cache_store.put_data(metadata)
  141. def _load_metadata(self):
  142. """Load metadata."""
  143. self._metadata_stream.debugger_type = DebuggerServerMode.OFFLINE.value
  144. device_info = self._data_loader.load_device_info()
  145. # The backend referred to the running environment on which the offline debugger
  146. # data was generated.
  147. # Currently supported options: `GPU`, `Ascend`
  148. backend = device_info.get('device_target', 'Ascend')
  149. self._metadata_stream.backend = backend
  150. device_stream = self._cache_store.get_stream_handler(Streams.DEVICE)
  151. device_stream.put(device_info.get('server_list'))
  152. rank_id = 0
  153. rank_0_info = device_stream.get(rank_id)['devices'][0]
  154. self._metadata_stream.client_ip = rank_0_info.get('server_id')
  155. # get step number per device. dict(device_id, step_num), may be increased with time goes by
  156. step_num_per_device = self._data_loader.load_step_number()
  157. device_stream.add_step_num_info(step_num_per_device)
  158. self._metadata_stream.max_step_num = max(step_num_per_device.values())
  159. def _load_graphs(self):
  160. """Load graphs."""
  161. # the format of graphs is a list of {'device_id': int, 'graph_protos': [GraphProto]}}
  162. graphs = self._data_loader.load_graphs()
  163. device_stream = self._cache_store.get_stream_handler(Streams.DEVICE)
  164. graph_per_rank = {}
  165. for graph in graphs:
  166. device_id = int(graph.get('device_id'))
  167. rank_id = device_stream.get_rank_id_by_device_id(device_id)
  168. graph_per_rank[rank_id] = {}
  169. tensor_stream_per_rank = self._cache_store.get_stream_handler(Streams.TENSOR).\
  170. get_tensor_handler_by_rank_id(rank_id, create_if_not_exit=True)
  171. for graph_proto in graph.get('graph_protos'):
  172. graph_per_rank[rank_id][graph_proto.name] = graph_proto
  173. tensor_stream_per_rank.put_const_vals(graph_proto.const_vals)
  174. # the graph_per_rank is format like: Dict[<rank_id>, Dict[<graph_name>, <GraphProto>]]
  175. self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_per_rank)
  176. device_stream.add_graph_name_info(graph_per_rank)
  177. self._metadata_stream.state = ServerStatus.RECEIVE_GRAPH.value
  178. @debugger_server_wrap
  179. def wait_for_termination(self):
  180. """Begin to listen on command event."""
  181. log.info("Begin to listen for user commands.")
  182. self._send_graph()
  183. while self.is_runnable():
  184. if not self._command_listener.has_new_command() and self._old_run_cmd:
  185. self._deal_with_old_run_cmd()
  186. continue
  187. cmd = self._command_listener.get_next_command()
  188. self.deal_with_cmd(cmd)
  189. def _send_graph(self):
  190. """Put graph and metadata info into data queue."""
  191. if not self.is_runnable():
  192. return
  193. self._metadata_stream.state = ServerStatus.WAITING.value
  194. metadata = self._metadata_stream.get()
  195. res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get()
  196. res.update(metadata)
  197. self._cache_store.put_data(res)
  198. def _deal_with_old_run_cmd(self):
  199. """Deal with old run command."""
  200. left_step_count = self._old_run_cmd.get('left_step_count')
  201. if left_step_count:
  202. self._execute_one_step()
  203. # if old_run_cmd is not cleared due to hit.
  204. if self._old_run_cmd:
  205. self._old_run_cmd['left_step_count'] = left_step_count - 1 if left_step_count > 0 else -1
  206. if not self._old_run_cmd.get('left_step_count'):
  207. self._old_run_cmd.clear()
  208. def deal_with_cmd(self, cmd):
  209. """Deal with command."""
  210. if cmd is None:
  211. return
  212. if isinstance(cmd, dict):
  213. self._deal_with_view_cmd(cmd)
  214. elif isinstance(cmd, EventReply):
  215. self._on_event(cmd)
  216. def _on_event(self, event):
  217. """
  218. Deal with different command event.
  219. Args:
  220. event (EventReply): Command Event.
  221. """
  222. if event.HasField('run_cmd'):
  223. self._deal_with_run_cmd(event)
  224. elif event.HasField('exit'):
  225. self._cache_store.clean()
  226. self._update_state(ServerStatus.PENDING)
  227. log.debug("Clean cache for exit cmd.")
  228. else:
  229. self._deal_with_set_cmd(event)
  230. log.debug("Deal with set cmd.")
  231. def _deal_with_view_cmd(self, event):
  232. """
  233. Deal with view cmd.
  234. Args:
  235. event (dict): View command params.
  236. - view_cmd (EventReply): EventReply with view command.
  237. - node_name (str): The center node name for view command.
  238. - tensor_name (str): The center tensor name for view command.
  239. - graph_name (str): The graph name of center node.
  240. - rank_id (int): The device id of the tensor.
  241. """
  242. view_cmd = event.pop('view_cmd', None).view_cmd
  243. node_info = event
  244. log.debug("Receive view cmd for node: %s.", event)
  245. if not (view_cmd and node_info):
  246. log.info("Invalid view command. Ignore it.")
  247. return
  248. # read tensor value by dbg_service
  249. rank_id = node_info.get('rank_id', 0)
  250. device_id = self._cache_store.get_stream_handler(Streams.DEVICE).get_device_id_by_rank_id(rank_id)
  251. cur_step = self._metadata_stream.step
  252. tensor_protos = view_cmd.tensors
  253. root_graph_id = self.get_root_graph_id()
  254. tensor_infos = [
  255. self._dbg_services_module.TensorInfo(
  256. node_name=tensor_proto.node_name,
  257. slot=int(tensor_proto.slot),
  258. iteration=cur_step - 1 if tensor_proto.iter == 'prev' else cur_step,
  259. device_id=device_id,
  260. is_parameter=tensor_proto.truncate,
  261. root_graph_id=root_graph_id
  262. ) for tensor_proto in tensor_protos]
  263. res = self._dbg_service.read_tensors(tensor_infos)
  264. # put tensor into cache
  265. for tensor_proto, tensor_data in zip(tensor_protos, res):
  266. log.debug("Tensor name: %s:%s, tensor type: %s, tensor size: %s", tensor_proto.node_name, tensor_proto.slot,
  267. tensor_data.dtype, tensor_data.data_size)
  268. tensor_proto.tensor_content = tensor_data.data_ptr
  269. tensor_proto.ClearField('dims')
  270. tensor_proto.dims.extend(tensor_data.shape)
  271. tensor_proto.data_type = tensor_data.dtype
  272. self._put_tensor_value_into_cache(cur_step, node_info, rank_id, tensor_protos)
  273. log.info("Put tensor value into cache.")
  274. def get_root_graph_id(self):
  275. """Get root graph id."""
  276. is_sync = self._data_loader.get_sync_flag()
  277. graph_id = 0 if is_sync else 1
  278. return graph_id
  279. def _put_tensor_value_into_cache(self, cur_step, node_info, rank_id, tensor_protos):
  280. """Put tensor value into tensor cache."""
  281. tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR). \
  282. get_tensor_handler_by_rank_id(rank_id)
  283. update_data_flag = False
  284. for tensor_proto in tensor_protos:
  285. if not tensor_proto.tensor_content:
  286. log.warning("Tensor %s:%s is empty.",
  287. tensor_proto.node_name, tensor_proto.slot)
  288. try:
  289. has_update = tensor_stream.put({
  290. 'step': cur_step,
  291. 'tensor_proto': tensor_proto,
  292. 'tensor_contents': [tensor_proto.tensor_content]
  293. })
  294. except ValueError as err:
  295. log.warning("Failed to put %s:%s into cache. Ignore it. %s",
  296. tensor_proto.node_name, tensor_proto.slot, str(err))
  297. continue
  298. if has_update:
  299. update_data_flag = True
  300. if update_data_flag:
  301. # send message to frontend
  302. metadata = self._metadata_stream.get(['step', 'state'])
  303. ret = {'receive_tensor': node_info.copy()}
  304. ret.update(metadata)
  305. self._cache_store.put_data(ret)
  306. def _deal_with_run_cmd(self, event):
  307. """Deal with run cmd."""
  308. run_cmd = event.run_cmd
  309. parsed_run_cmd = self._get_parsed_run_cmd(run_cmd)
  310. if parsed_run_cmd.run_steps > 0:
  311. self._execute_one_step()
  312. elif run_cmd.run_level == RunLevel.RECHECK.value:
  313. log.info("Deal with recheck command.")
  314. self._check_watchpoint(self._metadata_stream.step)
  315. def _execute_one_step(self):
  316. """Execute on step."""
  317. new_step = self._metadata_stream.step + 1
  318. if new_step > self._metadata_stream.max_step_num:
  319. self._old_run_cmd.clear()
  320. log.info("The server is already at the last step. %s", self._metadata_stream.max_step_num)
  321. return
  322. log.info("Go to next step: %s.", new_step)
  323. self._check_watchpoint(new_step)
  324. self._metadata_stream.step = new_step
  325. self._cache_store.get_stream_handler(Streams.TENSOR).set_step(new_step)
  326. self._cache_store.put_data(self._metadata_stream.get('step'))
  327. def _get_parsed_run_cmd(self, run_cmd):
  328. """Get parsed run command."""
  329. if run_cmd.run_level == RunLevel.STEP.value:
  330. # receive pause cmd
  331. if not run_cmd.run_steps:
  332. log.debug("Pause training and wait for next command.")
  333. self._old_run_cmd.clear()
  334. # update metadata state from sending to waiting
  335. self._update_state(ServerStatus.WAITING)
  336. return run_cmd
  337. # receive step cmd
  338. left_steps = run_cmd.run_steps - 1
  339. run_cmd.run_steps = 1
  340. if left_steps:
  341. self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1
  342. elif run_cmd.node_name:
  343. self._old_run_cmd['node_name'] = run_cmd.node_name
  344. run_cmd.node_name = ''
  345. return run_cmd
  346. def _check_watchpoint(self, step):
  347. """Save watchpoint hits into cache."""
  348. self._update_state(ServerStatus.RUNNING)
  349. # Clean watchpoint_hits in cache
  350. multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  351. multi_card_hit_streams.clean()
  352. hits = Manager().list()
  353. check_watchpoints_process = Process(target=self._check_watchpoint_work, args=(hits, step,))
  354. check_watchpoints_process.start()
  355. check_watchpoints_process.join()
  356. log.info("finish check watchpoint of %s", step)
  357. if hits:
  358. log.info("Received WatchpointHits. Left run cmd %s change to empty.", self._old_run_cmd)
  359. self._old_run_cmd.clear()
  360. self._update_state(ServerStatus.WAITING)
  361. self._save_watchpoint_hits(hits)
  362. def _save_watchpoint_hits(self, hits):
  363. """Save watchpoint hits."""
  364. multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
  365. multi_card_graph_streams = self._cache_store.get_stream_handler(Streams.GRAPH)
  366. device_stream = self._cache_store.get_stream_handler(Streams.DEVICE)
  367. watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
  368. watchpoint_hits = defaultdict(list)
  369. for hit in hits:
  370. log.info("Received hit\n: "
  371. "name:%s, slot:%s, condition:%s, "
  372. "watchpoint_id:%s"
  373. "error_code:%s, device_id:%s",
  374. hit['name'], hit['slot'], hit['condition'],
  375. hit['watchpoint_id'], hit['error_code'], hit['device_id'])
  376. rank_id = device_stream.get_rank_id_by_device_id(hit['device_id'])
  377. watchpoint_hit = {}
  378. self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit)
  379. if not watchpoint_hit:
  380. continue
  381. self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit)
  382. watchpoint_hit['error_code'] = hit['error_code']
  383. watchpoint_hits[rank_id].append(watchpoint_hit)
  384. # save hit info into cache
  385. multi_card_hit_streams.put(watchpoint_hits)
  386. self._cache_store.put_data({'receive_watchpoint_hits': True})
  387. log.debug("Send the watchpoint hits to DataQueue.")
  388. @staticmethod
  389. def _add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit):
  390. """Add hit node info."""
  391. graph_stream = multi_card_graph_streams.get_graph_handler_by_rank_id(rank_id)
  392. node_full_name = hit['name']
  393. graph_name = graph_stream.get_graph_id_by_full_name(node_full_name)
  394. if not graph_name:
  395. log.warning("Cannot find node %s in graph. Skip it.", node_full_name)
  396. return
  397. ui_node_name = graph_stream.get_node_name_by_full_name(node_full_name, graph_name)
  398. log.debug("Receive watch point hit: %s:%s", node_full_name, hit['slot'])
  399. if not ui_node_name:
  400. log.info("Not support to show %s on graph.", node_full_name)
  401. return
  402. watchpoint_hit.update({
  403. 'tensor_proto': TensorProto(node_name=node_full_name, slot=str(hit['slot'])),
  404. 'node_name': ui_node_name,
  405. 'graph_name': graph_name
  406. })
  407. @staticmethod
  408. def _add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit):
  409. """Add watchpoint hit info."""
  410. watchpoint = copy.deepcopy(watchpoint_stream.get_watchpoint_by_id(hit['watchpoint_id']))
  411. hit_params = {}
  412. # get hit actual value
  413. for param in hit['parameters']:
  414. if param['name'] not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value,
  415. ParamNameEnum.RANGE_END_INCLUSIVE.value) \
  416. and hit['error_code'] == 0:
  417. hit_params[param['name']] = param['actual_value']
  418. # update actual value into watchpoint
  419. watchpoint_condition_params = watchpoint.condition['params']
  420. for i, param in enumerate(watchpoint_condition_params):
  421. name = param['name']
  422. if name in hit_params.keys():
  423. watchpoint_condition_params[i]['actual_value'] = hit_params[name]
  424. else:
  425. watchpoint_condition_params[i]['actual_value'] = None
  426. watchpoint_hit['watchpoint'] = watchpoint
  427. def _deal_with_set_cmd(self, event):
  428. """
  429. Deal with set cmd.
  430. Args:
  431. event (EventReply): User command event including set_cmd.
  432. """
  433. set_cmd = event.set_cmd
  434. set_cmd_id = set_cmd.id
  435. delete = set_cmd.delete
  436. if not delete:
  437. log.info("Add watchpoint by using dbg_server.")
  438. watch_condition = set_cmd.watch_condition
  439. param_list = []
  440. for param in watch_condition.params:
  441. param_list.append(
  442. self._dbg_services_module.Parameter(param.name, param.disabled, param.value))
  443. watch_nodes = set_cmd.watch_nodes
  444. check_nodes = self._get_check_nodes(watch_nodes)
  445. log.debug("Watchpoint %s, condition: %s, watch nodes: %s",
  446. set_cmd_id, watch_condition.condition, check_nodes)
  447. self._dbg_service.add_watchpoint(set_cmd_id, watch_condition.condition, check_nodes, param_list)
  448. else:
  449. log.info("Remove watchpoint by using dbg_server.")
  450. self._dbg_service.remove_watchpoint(set_cmd_id)
  451. def _get_check_nodes(self, watch_nodes):
  452. """Get check nodes format"""
  453. check_nodes = {}
  454. device_stream = self._cache_store.get_stream_handler(Streams.DEVICE)
  455. root_graph_id = self.get_root_graph_id()
  456. for watch_node in watch_nodes:
  457. node_name = watch_node.node_name
  458. rank_id = watch_node.rank_id
  459. device_id = device_stream.get_device_id_by_rank_id(rank_id)
  460. if node_name not in check_nodes:
  461. is_parameter = bool(watch_node.node_type == NodeTypeEnum.PARAMETER.value)
  462. check_nodes[node_name] = {
  463. "device_id": [device_id],
  464. "is_parameter": is_parameter,
  465. "root_graph_id": [root_graph_id]
  466. }
  467. else:
  468. check_nodes[node_name]["device_id"].append(device_id)
  469. return check_nodes
  470. def _update_state(self, server_status):
  471. """
  472. Update state in metadata stream.
  473. Args:
  474. server_status (ServerStatus): The enum value in ServerStatus.
  475. """
  476. if self._metadata_stream.state != server_status.value:
  477. self._metadata_stream.state = server_status.value
  478. self._cache_store.put_data(self._metadata_stream.get())
  479. def _check_watchpoint_work(self, hits, step):
  480. """The check WatchPoint function work in another process."""
  481. log.info("Start checking WatchPointHit process.")
  482. res = self._dbg_service.check_watchpoints(step)
  483. for watchpoint_hit in res:
  484. hit_dict = convert_watchpointhit(watchpoint_hit)
  485. hits.append(hit_dict)
  486. log.info("Checking WatchPointHit process is finished.")
  487. class CommandListener:
  488. """Event listener."""
  489. def __init__(self, cache_store):
  490. self._cache_store = cache_store
  491. self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA)
  492. # the next position of command queue to be queried
  493. self._pos = '0'
  494. self._is_waiting = Event()
  495. def start(self):
  496. """Start event listener."""
  497. self._pos = '0'
  498. self._is_waiting.set()
  499. def stop(self):
  500. """Stop event listener."""
  501. # stop waiting for new user commands but can still get old commands.
  502. self._is_waiting.clear()
  503. def has_new_command(self):
  504. """Check if there is new command in command queue."""
  505. return self._cache_store.has_command(self._pos)
  506. def get_next_command(self):
  507. """Get next command."""
  508. event = None
  509. while event is None and self.has_new_command():
  510. self._pos, event = self._cache_store.get_command(self._pos)
  511. log.debug("Deal with old %s-th command:\n%s.", self._pos, event)
  512. if event is None:
  513. event = self._wait_for_next_command()
  514. return event
  515. def _wait_for_next_command(self):
  516. """
  517. Wait for next command.
  518. Returns:
  519. EventReply, the command event.
  520. """
  521. if not self._is_waiting.is_set():
  522. self._metadata_stream.state = ServerStatus.PENDING.value
  523. return None
  524. log.info("Start to wait for command.")
  525. if self._metadata_stream.state != ServerStatus.WAITING.value:
  526. self._metadata_stream.state = ServerStatus.WAITING.value
  527. self._cache_store.put_data(self._metadata_stream.get())
  528. log.debug("Wait for %s-th command", self._pos)
  529. event = None
  530. while event is None and self._is_waiting.is_set():
  531. self._pos, event = self._cache_store.get_command(self._pos)
  532. return event
  533. def convert_watchpointhit(watchpointhit):
  534. """Convert watchpointhit object to dict."""
  535. parameters = watchpointhit.parameters
  536. param_list = []
  537. for param in parameters:
  538. param_dict = convert_param(param)
  539. param_list.append(param_dict)
  540. watchpointhit_dict = {'condition': watchpointhit.condition,
  541. 'device_id': watchpointhit.device_id,
  542. 'error_code': watchpointhit.error_code,
  543. 'name': watchpointhit.name,
  544. 'parameters': param_list,
  545. 'slot': watchpointhit.slot,
  546. 'watchpoint_id': watchpointhit.watchpoint_id}
  547. return watchpointhit_dict
  548. def convert_param(param):
  549. """Convert parameter object to dict"""
  550. param_dict = {'actual_value': param.actual_value,
  551. 'disabled': param.disabled,
  552. 'hit': param.hit,
  553. 'name': param.name,
  554. 'value': param.value}
  555. return param_dict