# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Debugger Offline server.""" import copy from collections import defaultdict from importlib import import_module from threading import Event from multiprocessing import Process, Manager import mindinsight from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.debugger.common.exceptions.exceptions import DebuggerModuleNotFoundError from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import Streams, ServerStatus, version_match, DebuggerServerMode, get_ack_reply, \ RunLevel from mindinsight.debugger.conditionmgr.condition import ParamNameEnum from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase, debugger_server_wrap from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto from mindinsight.debugger.stream_cache.data_loader import DataLoader from mindinsight.utils.exceptions import MindInsightException class DebuggerOfflineServer(DebuggerServerBase): """Debugger Offline Server.""" _MAX_TRY_EXCEPT_COUNT = 500 def __init__(self, cache_store, context): super(DebuggerOfflineServer, self).__init__(cache_store, context) self._offline_server_manager = DebuggerOfflineManager(cache_store, context.dbg_dir) self._running = Event() self._running.clear() def run(self): """Start the debugger offline server.""" log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) self._offline_server_manager.initialize() self._running.set() log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) try_count = 0 while self._running.is_set() and try_count < self._MAX_TRY_EXCEPT_COUNT: try: self._offline_server_manager.wait_for_termination() if not self._offline_server_manager.is_runnable(): break except MindInsightException as err: log.exception(err) log.warning("Error happens during listening on user commands. Restart listening again.") finally: try_count += 1 # protect server from too much failure commands. if try_count == self._MAX_TRY_EXCEPT_COUNT: self._cache_store.clean() metadata = self._cache_store.get_stream_handler(Streams.METADATA).get() self._cache_store.put_data(metadata) log.warning("Exception exceed %d times, stop server.", try_count) def stop(self): """Stop offline debugger server.""" log.debug("Start to wait for thread started.") self._running.wait() log.info("Start to stop offline debugger server.") self._running.clear() self._offline_server_manager.stop() self.join() class DebuggerOfflineManager: """Debugger offline manager which is used to handle user commands.""" def __init__(self, cache_store, dbg_dir): cache_store.initialize() self._cache_store = cache_store self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) self._dbg_dir = dbg_dir self._dbg_services_module = self._get_dbg_service_module() self._dbg_service = None self._command_listener = CommandListener(cache_store) self._data_loader = DataLoader(dbg_dir) self._is_running_flag = False self._old_run_cmd = {} def stop(self): """Stop server.""" self._is_running_flag = False self._command_listener.stop() self._cache_store.clean() event = get_ack_reply() event.exit = True self._cache_store.put_command(event) log.info("Stop debugger offline manager.") def is_runnable(self): """Check if the offline manager is runnable.""" state = self._metadata_stream.state flag = self._is_running_flag and state not in [ServerStatus.MISMATCH.value, ServerStatus.PENDING.value] if not flag: log.debug("The offline manager is not runnable, is_running_flag: %s, metadata state: %s", self._is_running_flag, state) return flag @staticmethod def _get_dbg_service_module(): """Get dbg service module from MindSpore.""" try: dbg_services_module = import_module('mindspore.offline_debug.dbg_services') except ModuleNotFoundError as err: log.error("Failed to find module dbg_services. %s", err) raise DebuggerModuleNotFoundError("dbg_services") return dbg_services_module @debugger_server_wrap def initialize(self): """Start to load offline debugger data.""" self._data_loader.initialize() is_sync = self._data_loader.get_sync_flag() net_name = self._data_loader.get_net_name() net_dir = self._data_loader.get_net_dir() self._dbg_service = self._dbg_services_module.DbgServices(net_dir) self._dbg_service.initialize(net_name=net_name, is_sync_mode=is_sync) self._cache_store.clean() self._command_listener.start() self._is_running_flag = True self._check_version() if self._metadata_stream.state == ServerStatus.MISMATCH.value: log.info("The MindSpore and MindInsight version are mismatched. Failed to initialize offline server.") return self._load_metadata() self._load_graphs() log.info("Success initialize offline server for %s", self._dbg_dir) def _check_version(self): """Check version.""" ms_version = self._dbg_services_module.get_version() mi_version = mindinsight.__version__ self._metadata_stream.debugger_version = {'ms': ms_version, 'mi': mi_version} if version_match(ms_version, mi_version) is False: log.info("Version is mismatched, dbg_services is: %s, mindinsight is: %s", ms_version, mi_version) self._metadata_stream.state = ServerStatus.MISMATCH.value metadata = self._metadata_stream.get(['state', 'debugger_version']) self._cache_store.put_data(metadata) def _load_metadata(self): """Load metadata.""" self._metadata_stream.debugger_type = DebuggerServerMode.OFFLINE.value device_info = self._data_loader.load_device_info() # The backend referred to the running environment on which the offline debugger # data was generated. # Currently supported options: `GPU`, `Ascend` backend = device_info.get('device_target', 'Ascend') self._metadata_stream.backend = backend device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) device_stream.put(device_info.get('server_list')) rank_id = 0 rank_0_info = device_stream.get(rank_id)['devices'][0] self._metadata_stream.client_ip = rank_0_info.get('server_id') # get step number per device. dict(device_id, step_num), may be increased with time goes by step_num_per_device = self._data_loader.load_step_number() device_stream.add_step_num_info(step_num_per_device) self._metadata_stream.max_step_num = max(step_num_per_device.values()) def _load_graphs(self): """Load graphs.""" # the format of graphs is a list of {'device_id': int, 'graph_protos': [GraphProto]}} graphs = self._data_loader.load_graphs() device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) graph_per_rank = {} for graph in graphs: device_id = int(graph.get('device_id')) rank_id = device_stream.get_rank_id_by_device_id(device_id) graph_per_rank[rank_id] = {} tensor_stream_per_rank = self._cache_store.get_stream_handler(Streams.TENSOR).\ get_tensor_handler_by_rank_id(rank_id, create_if_not_exit=True) for graph_proto in graph.get('graph_protos'): graph_per_rank[rank_id][graph_proto.name] = graph_proto tensor_stream_per_rank.put_const_vals(graph_proto.const_vals) # the graph_per_rank is format like: Dict[, Dict[, ]] self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_per_rank) device_stream.add_graph_name_info(graph_per_rank) self._metadata_stream.state = ServerStatus.RECEIVE_GRAPH.value @debugger_server_wrap def wait_for_termination(self): """Begin to listen on command event.""" log.info("Begin to listen for user commands.") self._send_graph() while self.is_runnable(): if not self._command_listener.has_new_command() and self._old_run_cmd: self._deal_with_old_run_cmd() continue cmd = self._command_listener.get_next_command() self.deal_with_cmd(cmd) def _send_graph(self): """Put graph and metadata info into data queue.""" if not self.is_runnable(): return self._metadata_stream.state = ServerStatus.WAITING.value metadata = self._metadata_stream.get() res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get() res.update(metadata) self._cache_store.put_data(res) def _deal_with_old_run_cmd(self): """Deal with old run command.""" left_step_count = self._old_run_cmd.get('left_step_count') if left_step_count: self._execute_one_step() # if old_run_cmd is not cleared due to hit. if self._old_run_cmd: self._old_run_cmd['left_step_count'] = left_step_count - 1 if left_step_count > 0 else -1 if not self._old_run_cmd.get('left_step_count'): self._old_run_cmd.clear() def deal_with_cmd(self, cmd): """Deal with command.""" if cmd is None: return if isinstance(cmd, dict): self._deal_with_view_cmd(cmd) elif isinstance(cmd, EventReply): self._on_event(cmd) def _on_event(self, event): """ Deal with different command event. Args: event (EventReply): Command Event. """ if event.HasField('run_cmd'): self._deal_with_run_cmd(event) elif event.HasField('exit'): self._cache_store.clean() self._update_state(ServerStatus.PENDING) log.debug("Clean cache for exit cmd.") else: self._deal_with_set_cmd(event) log.debug("Deal with set cmd.") def _deal_with_view_cmd(self, event): """ Deal with view cmd. Args: event (dict): View command params. - view_cmd (EventReply): EventReply with view command. - node_name (str): The center node name for view command. - tensor_name (str): The center tensor name for view command. - graph_name (str): The graph name of center node. - rank_id (int): The device id of the tensor. """ view_cmd = event.pop('view_cmd', None).view_cmd node_info = event log.debug("Receive view cmd for node: %s.", event) if not (view_cmd and node_info): log.info("Invalid view command. Ignore it.") return # read tensor value by dbg_service rank_id = node_info.get('rank_id', 0) device_id = self._cache_store.get_stream_handler(Streams.DEVICE).get_device_id_by_rank_id(rank_id) cur_step = self._metadata_stream.step tensor_protos = view_cmd.tensors root_graph_id = self.get_root_graph_id() tensor_infos = [ self._dbg_services_module.TensorInfo( node_name=tensor_proto.node_name, slot=int(tensor_proto.slot), iteration=cur_step - 1 if tensor_proto.iter == 'prev' else cur_step, device_id=device_id, is_parameter=tensor_proto.truncate, root_graph_id=root_graph_id ) for tensor_proto in tensor_protos] res = self._dbg_service.read_tensors(tensor_infos) # put tensor into cache for tensor_proto, tensor_data in zip(tensor_protos, res): log.debug("Tensor name: %s:%s, tensor type: %s, tensor size: %s", tensor_proto.node_name, tensor_proto.slot, tensor_data.dtype, tensor_data.data_size) tensor_proto.tensor_content = tensor_data.data_ptr tensor_proto.ClearField('dims') tensor_proto.dims.extend(tensor_data.shape) tensor_proto.data_type = tensor_data.dtype self._put_tensor_value_into_cache(cur_step, node_info, rank_id, tensor_protos) log.info("Put tensor value into cache.") def get_root_graph_id(self): """Get root graph id.""" is_sync = self._data_loader.get_sync_flag() graph_id = 0 if is_sync else 1 return graph_id def _put_tensor_value_into_cache(self, cur_step, node_info, rank_id, tensor_protos): """Put tensor value into tensor cache.""" tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR). \ get_tensor_handler_by_rank_id(rank_id) update_data_flag = False for tensor_proto in tensor_protos: if not tensor_proto.tensor_content: log.warning("Tensor %s:%s is empty.", tensor_proto.node_name, tensor_proto.slot) try: has_update = tensor_stream.put({ 'step': cur_step, 'tensor_proto': tensor_proto, 'tensor_contents': [tensor_proto.tensor_content] }) except ValueError as err: log.warning("Failed to put %s:%s into cache. Ignore it. %s", tensor_proto.node_name, tensor_proto.slot, str(err)) continue if has_update: update_data_flag = True if update_data_flag: # send message to frontend metadata = self._metadata_stream.get(['step', 'state']) ret = {'receive_tensor': node_info.copy()} ret.update(metadata) self._cache_store.put_data(ret) def _deal_with_run_cmd(self, event): """Deal with run cmd.""" run_cmd = event.run_cmd parsed_run_cmd = self._get_parsed_run_cmd(run_cmd) if parsed_run_cmd.run_steps > 0: self._execute_one_step() elif run_cmd.run_level == RunLevel.RECHECK.value: log.info("Deal with recheck command.") self._check_watchpoint(self._metadata_stream.step) def _execute_one_step(self): """Execute on step.""" new_step = self._metadata_stream.step + 1 if new_step > self._metadata_stream.max_step_num: self._old_run_cmd.clear() log.info("The server is already at the last step. %s", self._metadata_stream.max_step_num) return log.info("Go to next step: %s.", new_step) self._check_watchpoint(new_step) self._metadata_stream.step = new_step self._cache_store.get_stream_handler(Streams.TENSOR).set_step(new_step) self._cache_store.put_data(self._metadata_stream.get('step')) def _get_parsed_run_cmd(self, run_cmd): """Get parsed run command.""" if run_cmd.run_level == RunLevel.STEP.value: # receive pause cmd if not run_cmd.run_steps: log.debug("Pause training and wait for next command.") self._old_run_cmd.clear() # update metadata state from sending to waiting self._update_state(ServerStatus.WAITING) return run_cmd # receive step cmd left_steps = run_cmd.run_steps - 1 run_cmd.run_steps = 1 if left_steps: self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1 elif run_cmd.node_name: self._old_run_cmd['node_name'] = run_cmd.node_name run_cmd.node_name = '' return run_cmd def _check_watchpoint(self, step): """Save watchpoint hits into cache.""" self._update_state(ServerStatus.RUNNING) # Clean watchpoint_hits in cache multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) multi_card_hit_streams.clean() hits = Manager().list() check_watchpoints_process = Process(target=self._check_watchpoint_work, args=(hits, step,)) check_watchpoints_process.start() check_watchpoints_process.join() log.info("finish check watchpoint of %s", step) if hits: log.info("Received WatchpointHits. Left run cmd %s change to empty.", self._old_run_cmd) self._old_run_cmd.clear() self._update_state(ServerStatus.WAITING) self._save_watchpoint_hits(hits) def _save_watchpoint_hits(self, hits): """Save watchpoint hits.""" multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) multi_card_graph_streams = self._cache_store.get_stream_handler(Streams.GRAPH) device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) watchpoint_hits = defaultdict(list) for hit in hits: log.info("Received hit\n: " "name:%s, slot:%s, condition:%s, " "watchpoint_id:%s" "error_code:%s, device_id:%s", hit['name'], hit['slot'], hit['condition'], hit['watchpoint_id'], hit['error_code'], hit['device_id']) rank_id = device_stream.get_rank_id_by_device_id(hit['device_id']) watchpoint_hit = {} self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit) if not watchpoint_hit: continue self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit) watchpoint_hit['error_code'] = hit['error_code'] watchpoint_hits[rank_id].append(watchpoint_hit) # save hit info into cache multi_card_hit_streams.put(watchpoint_hits) self._cache_store.put_data({'receive_watchpoint_hits': True}) log.debug("Send the watchpoint hits to DataQueue.") @staticmethod def _add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit): """Add hit node info.""" graph_stream = multi_card_graph_streams.get_graph_handler_by_rank_id(rank_id) node_full_name = hit['name'] graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) if not graph_name: log.warning("Cannot find node %s in graph. Skip it.", node_full_name) return ui_node_name = graph_stream.get_node_name_by_full_name(node_full_name, graph_name) log.debug("Receive watch point hit: %s:%s", node_full_name, hit['slot']) if not ui_node_name: log.info("Not support to show %s on graph.", node_full_name) return watchpoint_hit.update({ 'tensor_proto': TensorProto(node_name=node_full_name, slot=str(hit['slot'])), 'node_name': ui_node_name, 'graph_name': graph_name }) @staticmethod def _add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit): """Add watchpoint hit info.""" watchpoint = copy.deepcopy(watchpoint_stream.get_watchpoint_by_id(hit['watchpoint_id'])) hit_params = {} # get hit actual value for param in hit['parameters']: if param['name'] not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value, ParamNameEnum.RANGE_END_INCLUSIVE.value) \ and hit['error_code'] == 0: hit_params[param['name']] = param['actual_value'] # update actual value into watchpoint watchpoint_condition_params = watchpoint.condition['params'] for i, param in enumerate(watchpoint_condition_params): name = param['name'] if name in hit_params.keys(): watchpoint_condition_params[i]['actual_value'] = hit_params[name] else: watchpoint_condition_params[i]['actual_value'] = None watchpoint_hit['watchpoint'] = watchpoint def _deal_with_set_cmd(self, event): """ Deal with set cmd. Args: event (EventReply): User command event including set_cmd. """ set_cmd = event.set_cmd set_cmd_id = set_cmd.id delete = set_cmd.delete if not delete: log.info("Add watchpoint by using dbg_server.") watch_condition = set_cmd.watch_condition param_list = [] for param in watch_condition.params: param_list.append( self._dbg_services_module.Parameter(param.name, param.disabled, param.value)) watch_nodes = set_cmd.watch_nodes check_nodes = self._get_check_nodes(watch_nodes) log.debug("Watchpoint %s, condition: %s, watch nodes: %s", set_cmd_id, watch_condition.condition, check_nodes) self._dbg_service.add_watchpoint(set_cmd_id, watch_condition.condition, check_nodes, param_list) else: log.info("Remove watchpoint by using dbg_server.") self._dbg_service.remove_watchpoint(set_cmd_id) def _get_check_nodes(self, watch_nodes): """Get check nodes format""" check_nodes = {} device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) root_graph_id = self.get_root_graph_id() for watch_node in watch_nodes: node_name = watch_node.node_name rank_id = watch_node.rank_id device_id = device_stream.get_device_id_by_rank_id(rank_id) if node_name not in check_nodes: is_parameter = bool(watch_node.node_type == NodeTypeEnum.PARAMETER.value) check_nodes[node_name] = { "device_id": [device_id], "is_parameter": is_parameter, "root_graph_id": [root_graph_id] } else: check_nodes[node_name]["device_id"].append(device_id) return check_nodes def _update_state(self, server_status): """ Update state in metadata stream. Args: server_status (ServerStatus): The enum value in ServerStatus. """ if self._metadata_stream.state != server_status.value: self._metadata_stream.state = server_status.value self._cache_store.put_data(self._metadata_stream.get()) def _check_watchpoint_work(self, hits, step): """The check WatchPoint function work in another process.""" log.info("Start checking WatchPointHit process.") res = self._dbg_service.check_watchpoints(step) for watchpoint_hit in res: hit_dict = convert_watchpointhit(watchpoint_hit) hits.append(hit_dict) log.info("Checking WatchPointHit process is finished.") class CommandListener: """Event listener.""" def __init__(self, cache_store): self._cache_store = cache_store self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) # the next position of command queue to be queried self._pos = '0' self._is_waiting = Event() def start(self): """Start event listener.""" self._pos = '0' self._is_waiting.set() def stop(self): """Stop event listener.""" # stop waiting for new user commands but can still get old commands. self._is_waiting.clear() def has_new_command(self): """Check if there is new command in command queue.""" return self._cache_store.has_command(self._pos) def get_next_command(self): """Get next command.""" event = None while event is None and self.has_new_command(): self._pos, event = self._cache_store.get_command(self._pos) log.debug("Deal with old %s-th command:\n%s.", self._pos, event) if event is None: event = self._wait_for_next_command() return event def _wait_for_next_command(self): """ Wait for next command. Returns: EventReply, the command event. """ if not self._is_waiting.is_set(): self._metadata_stream.state = ServerStatus.PENDING.value return None log.info("Start to wait for command.") if self._metadata_stream.state != ServerStatus.WAITING.value: self._metadata_stream.state = ServerStatus.WAITING.value self._cache_store.put_data(self._metadata_stream.get()) log.debug("Wait for %s-th command", self._pos) event = None while event is None and self._is_waiting.is_set(): self._pos, event = self._cache_store.get_command(self._pos) return event def convert_watchpointhit(watchpointhit): """Convert watchpointhit object to dict.""" parameters = watchpointhit.parameters param_list = [] for param in parameters: param_dict = convert_param(param) param_list.append(param_dict) watchpointhit_dict = {'condition': watchpointhit.condition, 'device_id': watchpointhit.device_id, 'error_code': watchpointhit.error_code, 'name': watchpointhit.name, 'parameters': param_list, 'slot': watchpointhit.slot, 'watchpoint_id': watchpointhit.watchpoint_id} return watchpointhit_dict def convert_param(param): """Convert parameter object to dict""" param_dict = {'actual_value': param.actual_value, 'disabled': param.disabled, 'hit': param.hit, 'name': param.name, 'value': param.value} return param_dict