From 8bb3851e491cafbf4f0ab2f4fe3b27ce9b83cb1b Mon Sep 17 00:00:00 2001 From: jiangshuqiang Date: Mon, 12 Apr 2021 16:09:35 +0800 Subject: [PATCH] add offline debugger feature --- .../backend/conditionmgr/conditionmgr_api.py | 61 -- mindinsight/backend/config/gunicorn_conf.py | 2 + mindinsight/backend/data_manager/__init__.py | 9 +- mindinsight/backend/debugger/debugger_api.py | 232 ++++--- .../datavisual/data_transform/data_manager.py | 21 +- .../data_transform/graph/msgraph.py | 6 +- .../data_transform/summary_watcher.py | 54 +- .../processors/train_task_manager.py | 5 +- .../debugger/common/exceptions/error_code.py | 18 +- .../debugger/common/exceptions/exceptions.py | 57 +- mindinsight/debugger/common/utils.py | 31 +- .../debugger/conditionmgr/recommender.py | 41 +- mindinsight/debugger/debugger_cache.py | 20 +- .../debugger/debugger_folder_analyzer.py | 41 ++ .../debugger/debugger_services/__init__.py | 15 + .../debugger_grpc_server.py | 44 +- .../debugger_offline_server.py | 613 ++++++++++++++++++ .../debugger_online_server.py | 58 ++ .../debugger_services/debugger_server_base.py | 58 ++ .../debugger_server_factory.py | 92 +++ ...debugger_server.py => debugger_session.py} | 188 +++--- mindinsight/debugger/proto/debug_grpc.proto | 3 + mindinsight/debugger/proto/debug_grpc_pb2.py | 55 +- .../debugger/proto/debug_grpc_pb2_grpc.py | 37 +- mindinsight/debugger/proto/ms_graph.proto | 3 + mindinsight/debugger/proto/ms_graph_pb2.py | 83 +-- mindinsight/debugger/session_manager.py | 172 +++++ .../debugger/stream_cache/data_loader.py | 210 ++++++ mindinsight/debugger/stream_cache/tensor.py | 29 +- .../debugger/stream_cache/watchpoint.py | 52 +- .../debugger/stream_handler/__init__.py | 13 +- .../debugger/stream_handler/device_handler.py | 198 ++++++ .../debugger/stream_handler/graph_handler.py | 96 +-- .../stream_handler/metadata_handler.py | 50 +- .../debugger/stream_handler/tensor_handler.py | 93 ++- .../stream_handler/watchpoint_handler.py | 106 ++- .../stream_operator/tensor_detail_info.py | 47 +- .../training_control_operator.py | 50 +- .../stream_operator/watchpoint_operator.py | 38 +- mindinsight/ui/src/app.vue | 2 +- .../ui/src/components/debugger-tensor.vue | 12 +- mindinsight/ui/src/locales/en-us.json | 22 +- mindinsight/ui/src/locales/zh-cn.json | 22 +- mindinsight/ui/src/mixins/debugger-mixin.vue | 599 ++++++++++++----- mindinsight/ui/src/router.js | 4 + mindinsight/ui/src/services/fetcher.js | 9 +- .../ui/src/services/request-service.js | 93 +-- .../ui/src/views/debugger/debugger.vue | 153 ++++- .../src/views/train-manage/summary-manage.vue | 170 ++++- .../__init__.py => utils/folder_analyzer.py} | 17 +- requirements.txt | 4 +- tests/st/func/debugger/conftest.py | 13 +- .../debugger/debugger_services/__init__.py | 20 + .../debugger_services/mock_dbg_services.py | 141 ++++ .../test_debugger_services.py | 77 +++ .../Ascend/async/.metadata/data_dump.json | 15 + .../Ascend/async/.metadata/hccl.json | 23 + .../Ascend/sync/.metadata/data_dump.json | 15 + .../Ascend/sync/.metadata/hccl.json | 23 + .../GPU/sync/.metadata/data_dump.json | 15 + .../load_device_info_ascend.json | 21 + .../restful_results/multi_next_node.json | 2 +- .../restful_results/multi_retrieve_all.json | 2 +- .../restful_results/retrieve_all.json | 2 +- .../retrieve_next_node_on_gpu.json | 2 +- .../retrieve_tensor_value.json | 30 +- .../restful_results/version_mismatch.json | 2 +- tests/st/func/debugger/test_data_loader.py | 149 +++++ tests/st/func/debugger/test_restful_api.py | 30 +- tests/st/func/debugger/utils.py | 67 +- .../debugger_server/retrieve_all.json | 2 +- .../watchpoint/watchpoint_handler_get_0.json | 81 ++- .../stream_handler/test_graph_handler.py | 14 - .../stream_handler/test_tensor_handler.py | 4 +- .../stream_handler/test_watchpoint_handler.py | 7 +- .../test_training_control_operator.py | 12 +- .../ut/debugger/test_debugger_grpc_server.py | 4 +- tests/ut/debugger/test_debugger_server.py | 13 +- tests/utils/tools.py | 7 + 79 files changed, 3951 insertions(+), 950 deletions(-) delete mode 100644 mindinsight/backend/conditionmgr/conditionmgr_api.py create mode 100644 mindinsight/debugger/debugger_folder_analyzer.py create mode 100644 mindinsight/debugger/debugger_services/__init__.py rename mindinsight/debugger/{ => debugger_services}/debugger_grpc_server.py (95%) create mode 100644 mindinsight/debugger/debugger_services/debugger_offline_server.py create mode 100644 mindinsight/debugger/debugger_services/debugger_online_server.py create mode 100644 mindinsight/debugger/debugger_services/debugger_server_base.py create mode 100644 mindinsight/debugger/debugger_services/debugger_server_factory.py rename mindinsight/debugger/{debugger_server.py => debugger_session.py} (86%) create mode 100644 mindinsight/debugger/session_manager.py create mode 100644 mindinsight/debugger/stream_cache/data_loader.py create mode 100644 mindinsight/debugger/stream_handler/device_handler.py rename mindinsight/{backend/conditionmgr/__init__.py => utils/folder_analyzer.py} (67%) create mode 100644 tests/st/func/debugger/debugger_services/__init__.py create mode 100644 tests/st/func/debugger/debugger_services/mock_dbg_services.py create mode 100644 tests/st/func/debugger/debugger_services/test_debugger_services.py create mode 100644 tests/st/func/debugger/dump_files/Ascend/async/.metadata/data_dump.json create mode 100644 tests/st/func/debugger/dump_files/Ascend/async/.metadata/hccl.json create mode 100644 tests/st/func/debugger/dump_files/Ascend/sync/.metadata/data_dump.json create mode 100644 tests/st/func/debugger/dump_files/Ascend/sync/.metadata/hccl.json create mode 100644 tests/st/func/debugger/dump_files/GPU/sync/.metadata/data_dump.json create mode 100644 tests/st/func/debugger/expect_results/offline_debugger/load_device_info_ascend.json create mode 100644 tests/st/func/debugger/test_data_loader.py diff --git a/mindinsight/backend/conditionmgr/conditionmgr_api.py b/mindinsight/backend/conditionmgr/conditionmgr_api.py deleted file mode 100644 index 936cd78d..00000000 --- a/mindinsight/backend/conditionmgr/conditionmgr_api.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Conditionmgr restful api.""" -import json - -from flask import Blueprint, request - -from mindinsight.conf import settings -from mindinsight.utils.exceptions import ParamValueError -from mindinsight.utils.exceptions import ParamMissError -from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER, _wrap_reply - -BLUEPRINT = Blueprint("conditionmgr", __name__, - url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) - - -@BLUEPRINT.route("/conditionmgr/train-jobs//condition-collections", methods=["GET"]) -def get_condition_collections(train_id): - """get condition collections""" - reply = _wrap_reply(BACKEND_SERVER.get_condition_collections, train_id) - return reply - - -@BLUEPRINT.route("/conditionmgr/train-jobs//set-recommended-watch-points", methods=["POST"]) -def set_recommended_watch_points(train_id): - """set recommended watch points.""" - body = request.stream.read() - try: - body = json.loads(body if body else "{}") - except json.JSONDecodeError: - raise ParamValueError("Json data parse failed.") - - request_body = body.get('requestBody') - if request_body is None: - raise ParamMissError('requestBody') - - set_recommended = request_body.get('set_recommended') - reply = _wrap_reply(BACKEND_SERVER.set_recommended_watch_points, set_recommended, train_id) - return reply - - -def init_module(app): - """ - Init module entry. - - Args: - app (Flask): The application obj. - """ - app.register_blueprint(BLUEPRINT) diff --git a/mindinsight/backend/config/gunicorn_conf.py b/mindinsight/backend/config/gunicorn_conf.py index c4ae6ce3..c26b65de 100644 --- a/mindinsight/backend/config/gunicorn_conf.py +++ b/mindinsight/backend/config/gunicorn_conf.py @@ -26,6 +26,7 @@ import psutil import gunicorn from mindinsight.utils.computing_resource_mgr import terminate +from mindinsight.debugger.session_manager import SessionManager gunicorn.SERVER_SOFTWARE = 'unknown' @@ -110,4 +111,5 @@ def worker_int(worker): global LISTEN_PROCESS if LISTEN_PROCESS is not None: LISTEN_PROCESS.terminate() + SessionManager.get_instance().exit() worker.log.info("Worker int processed.") diff --git a/mindinsight/backend/data_manager/__init__.py b/mindinsight/backend/data_manager/__init__.py index 9e8f797d..a91574e2 100644 --- a/mindinsight/backend/data_manager/__init__.py +++ b/mindinsight/backend/data_manager/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,11 @@ from mindinsight.conf import settings from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater +from mindinsight.debugger.debugger_folder_analyzer import DebuggerFolderAnalyzer + +ANALYZERS = { + "debugger_folder_analyzer": DebuggerFolderAnalyzer() +} def init_module(app): @@ -31,6 +36,8 @@ def init_module(app): """ # Just to suppress pylint warning about unused arg. logger.debug("App: %s", type(app)) + for analyzer in ANALYZERS.values(): + DATA_MANAGER.register_folder_analyzer(analyzer) DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater()) # Let gunicorn load other modules first. time.sleep(1) diff --git a/mindinsight/backend/debugger/debugger_api.py b/mindinsight/backend/debugger/debugger_api.py index 0e8aea63..67da2912 100644 --- a/mindinsight/backend/debugger/debugger_api.py +++ b/mindinsight/backend/debugger/debugger_api.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,22 +19,13 @@ from urllib.parse import unquote from flask import Blueprint, jsonify, request from mindinsight.conf import settings -from mindinsight.debugger.debugger_server import DebuggerServer -from mindinsight.utils.exceptions import ParamValueError +from mindinsight.debugger.session_manager import SessionManager +from mindinsight.utils.exceptions import ParamMissError, ParamValueError BLUEPRINT = Blueprint("debugger", __name__, url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) -def _initialize_debugger_server(): - """Initialize a debugger server instance.""" - enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False - server = None - if enable_debugger: - server = DebuggerServer() - return server - - def _unquote_param(param): """ Decode parameter value. @@ -77,8 +68,8 @@ def _wrap_reply(func, *args, **kwargs): return jsonify(reply) -@BLUEPRINT.route("/debugger/poll-data", methods=["GET"]) -def poll_data(): +@BLUEPRINT.route("/debugger/sessions//poll-data", methods=["GET"]) +def poll_data(session_id): """ Wait for data to be updated on UI. @@ -88,17 +79,17 @@ def poll_data(): str, the updated data. Examples: - >>> Get http://xxxx/v1/mindinsight/debugger/poll-data?pos=xx + >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/poll-data?pos=xx """ pos = request.args.get('pos') - reply = _wrap_reply(BACKEND_SERVER.poll_data, pos) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).poll_data, pos) return reply -@BLUEPRINT.route("/debugger/search", methods=["GET"]) -def search(): +@BLUEPRINT.route("/debugger/sessions//search", methods=["GET"]) +def search(session_id): """ Search nodes in specified watchpoint. @@ -106,42 +97,25 @@ def search(): str, the required data. Examples: - >>> Get http://xxxx/v1/mindinsight/debugger/search?name=mock_name&watch_point_id=1 + >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search?name=mock_name&watch_point_id=1 """ name = request.args.get('name') graph_name = request.args.get('graph_name') watch_point_id = int(request.args.get('watch_point_id', 0)) node_category = request.args.get('node_category') - reply = _wrap_reply(BACKEND_SERVER.search, {'name': name, - 'graph_name': graph_name, - 'watch_point_id': watch_point_id, - 'node_category': node_category}) + rank_id = int(request.args.get('rank_id', 0)) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search, + {'name': name, + 'graph_name': graph_name, + 'watch_point_id': watch_point_id, + 'node_category': node_category, + 'rand_id': rank_id}) return reply -@BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"]) -def retrieve_node_by_bfs(): - """ - Search node by bfs. - - Returns: - str, the required data. - - Examples: - >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true - """ - name = request.args.get('name') - graph_name = request.args.get('graph_name') - ascend = request.args.get('ascend', 'false') - ascend = ascend == 'true' - reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, graph_name, ascend) - - return reply - - -@BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"]) -def tensor_comparisons(): +@BLUEPRINT.route("/debugger/sessions//tensor-comparisons", methods=["GET"]) +def tensor_comparisons(session_id): """ Get tensor comparisons. @@ -149,19 +123,21 @@ def tensor_comparisons(): str, the required data. Examples: - >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons + >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-comparisons """ name = request.args.get('name') detail = request.args.get('detail', 'data') shape = _unquote_param(request.args.get('shape')) tolerance = request.args.get('tolerance', '0') - reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) + rank_id = int(request.args.get('rank_id', 0)) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).tensor_comparisons, name, shape, detail, + tolerance, rank_id) return reply -@BLUEPRINT.route("/debugger/retrieve", methods=["POST"]) -def retrieve(): +@BLUEPRINT.route("/debugger/sessions//retrieve", methods=["POST"]) +def retrieve(session_id): """ Retrieve data according to mode and params. @@ -169,17 +145,17 @@ def retrieve(): str, the required data. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/retrieve + >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/retrieve """ body = _read_post_request(request) mode = body.get('mode') params = body.get('params') - reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve, mode, params) return reply -@BLUEPRINT.route("/debugger/tensor-history", methods=["POST"]) -def retrieve_tensor_history(): +@BLUEPRINT.route("/debugger/sessions//tensor-history", methods=["POST"]) +def retrieve_tensor_history(session_id): """ Retrieve data according to mode and params. @@ -187,17 +163,19 @@ def retrieve_tensor_history(): str, the required data. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/tensor-history + >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-history """ body = _read_post_request(request) name = body.get('name') graph_name = body.get('graph_name') - reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name, graph_name) + rank_id = body.get('rank_id') + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_history, name, graph_name, + rank_id) return reply -@BLUEPRINT.route("/debugger/tensors", methods=["GET"]) -def retrieve_tensor_value(): +@BLUEPRINT.route("/debugger/sessions//tensors", methods=["GET"]) +def retrieve_tensor_value(session_id): """ Retrieve tensor value according to name and shape. @@ -205,20 +183,22 @@ def retrieve_tensor_value(): str, the required data. Examples: - >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=tensor_name&detail=data&shape=[1,1,:,:] + >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensors?name=tensor_name&detail=data&shape=[1,1,:,:] """ name = request.args.get('name') detail = request.args.get('detail') shape = _unquote_param(request.args.get('shape')) graph_name = request.args.get('graph_name') prev = bool(request.args.get('prev') == 'true') + rank_id = int(request.args.get('rank_id', 0)) - reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape, graph_name, prev) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_value, name, detail, + shape, graph_name, prev, rank_id) return reply -@BLUEPRINT.route("/debugger/create-watchpoint", methods=["POST"]) -def create_watchpoint(): +@BLUEPRINT.route("/debugger/sessions//create-watchpoint", methods=["POST"]) +def create_watchpoint(session_id): """ Create watchpoint. @@ -229,16 +209,16 @@ def create_watchpoint(): MindInsightException: If method fails to be called. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/create-watchpoint + >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/create-watchpoint """ params = _read_post_request(request) params['watch_condition'] = params.pop('condition', None) - reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, params) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).create_watchpoint, params) return reply -@BLUEPRINT.route("/debugger/update-watchpoint", methods=["POST"]) -def update_watchpoint(): +@BLUEPRINT.route("/debugger/sessions//update-watchpoint", methods=["POST"]) +def update_watchpoint(session_id): """ Update watchpoint. @@ -249,17 +229,17 @@ def update_watchpoint(): MindInsightException: If method fails to be called. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/update-watchpoint + >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/update-watchpoint """ params = _read_post_request(request) - reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, params) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).update_watchpoint, params) return reply -@BLUEPRINT.route("/debugger/delete-watchpoint", methods=["POST"]) -def delete_watchpoint(): +@BLUEPRINT.route("/debugger/sessions//delete-watchpoint", methods=["POST"]) +def delete_watchpoint(session_id): """ - delete watchpoint. + Delete watchpoint. Returns: str, reply message. @@ -268,19 +248,19 @@ def delete_watchpoint(): MindInsightException: If method fails to be called. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/delete-watchpoint + >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/delete-watchpoint """ body = _read_post_request(request) watch_point_id = body.get('watch_point_id') - reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).delete_watchpoint, watch_point_id) return reply -@BLUEPRINT.route("/debugger/control", methods=["POST"]) -def control(): +@BLUEPRINT.route("/debugger/sessions//control", methods=["POST"]) +def control(session_id): """ Control request. @@ -291,16 +271,16 @@ def control(): MindInsightException: If method fails to be called. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/control + >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/control """ params = _read_post_request(request) - reply = _wrap_reply(BACKEND_SERVER.control, params) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).control, params) return reply -@BLUEPRINT.route("/debugger/recheck", methods=["POST"]) -def recheck(): +@BLUEPRINT.route("/debugger/sessions//recheck", methods=["POST"]) +def recheck(session_id): """ Recheck request. @@ -311,15 +291,15 @@ def recheck(): MindInsightException: If method fails to be called. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/recheck + >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/recheck """ - reply = _wrap_reply(BACKEND_SERVER.recheck) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).recheck) return reply -@BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"]) -def retrieve_tensor_graph(): +@BLUEPRINT.route("/debugger/sessions//tensor-graphs", methods=["GET"]) +def retrieve_tensor_graph(session_id): """ Retrieve tensor value according to name and shape. @@ -327,16 +307,18 @@ def retrieve_tensor_graph(): str, the required data. Examples: - >>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name + >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-graphs?tensor_name=xxx&graph_name=xxx """ tensor_name = request.args.get('tensor_name') graph_name = request.args.get('graph_name') - reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_graph, tensor_name, graph_name) + rank_id = int(request.args.get('rank_id', 0)) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_graph, tensor_name, + graph_name, rank_id) return reply -@BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"]) -def retrieve_tensor_hits(): +@BLUEPRINT.route("/debugger/sessions//tensor-hits", methods=["GET"]) +def retrieve_tensor_hits(session_id): """ Retrieve tensor value according to name and shape. @@ -344,16 +326,18 @@ def retrieve_tensor_hits(): str, the required data. Examples: - >>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name + >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-hits?tensor_name=xxx&graph_name=xxx """ tensor_name = request.args.get('tensor_name') graph_name = request.args.get('graph_name') - reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_hits, tensor_name, graph_name) + rank_id = int(request.args.get('rank_id', 0)) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_hits, tensor_name, + graph_name, rank_id) return reply -@BLUEPRINT.route("/debugger/search-watchpoint-hits", methods=["POST"]) -def search_watchpoint_hits(): +@BLUEPRINT.route("/debugger/sessions//search-watchpoint-hits", methods=["POST"]) +def search_watchpoint_hits(session_id): """ Search watchpoint hits by group condition. @@ -361,15 +345,75 @@ def search_watchpoint_hits(): str, the required data. Examples: - >>> POST http://xxxx/v1/mindinsight/debugger/search-watchpoint-hits + >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search-watchpoint-hits """ body = _read_post_request(request) group_condition = body.get('group_condition') - reply = _wrap_reply(BACKEND_SERVER.search_watchpoint_hits, group_condition) + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search_watchpoint_hits, group_condition) + return reply + + +@BLUEPRINT.route("/debugger/sessions//condition-collections", methods=["GET"]) +def get_condition_collections(session_id): + """Get condition collections.""" + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).get_condition_collections) + return reply + + +@BLUEPRINT.route("/debugger/sessions//set-recommended-watch-points", methods=["POST"]) +def set_recommended_watch_points(session_id): + """Set recommended watch points.""" + body = _read_post_request(request) + request_body = body.get('requestBody') + if request_body is None: + raise ParamMissError('requestBody') + + set_recommended = request_body.get('set_recommended') + reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).set_recommended_watch_points, + set_recommended) return reply -BACKEND_SERVER = _initialize_debugger_server() +@BLUEPRINT.route("/debugger/sessions", methods=["POST"]) +def creat_session(): + """ + Get session id if session exist, else create a session. + + Returns: + str, session id. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/get-session + """ + body = _read_post_request(request) + summary_dir = body.get('dump_dir') + session_type = body.get('session_type') + reply = _wrap_reply(SessionManager.get_instance().creat_session, session_type, summary_dir) + return reply + + +@BLUEPRINT.route("/debugger/sessions", methods=["GET"]) +def get_sessions(): + """ + Check the cuurent active sessions. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/check-sessions + """ + reply = _wrap_reply(SessionManager.get_instance().get_sessions) + return reply + + +@BLUEPRINT.route("/debugger/sessions//delete", methods=["POST"]) +def delete_session(session_id): + """ + Delete session by session id. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/xxx/delete-session + """ + reply = _wrap_reply(SessionManager.get_instance().delete_session, session_id) + return reply def init_module(app): @@ -380,5 +424,3 @@ def init_module(app): app (Flask): The application obj. """ app.register_blueprint(BLUEPRINT) - if BACKEND_SERVER: - BACKEND_SERVER.start() diff --git a/mindinsight/datavisual/data_transform/data_manager.py b/mindinsight/datavisual/data_transform/data_manager.py index 633997f3..ab990591 100644 --- a/mindinsight/datavisual/data_transform/data_manager.py +++ b/mindinsight/datavisual/data_transform/data_manager.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -112,6 +112,11 @@ class _BasicTrainJob: """Get the lineage files count in the summary dir.""" return self._entry['lineage_files'] + @property + def dump_dir(self): + """Get the dump file path in the summary dir.""" + return self._entry.get('dump_dir', None) + class CachedTrainJob: """ @@ -369,6 +374,10 @@ class _BaseCacheManager: class _BriefCacheManager(_BaseCacheManager): """A cache manager that holds all disk train jobs on disk.""" + def __init__(self, summary_base_dir): + super(_BriefCacheManager, self).__init__(summary_base_dir) + self._summary_watcher = SummaryWatcher() + def cache_train_job(self, train_id): """ Cache given train job. @@ -386,7 +395,7 @@ class _BriefCacheManager(_BaseCacheManager): def update_cache(self, executor): """Update cache.""" logger.info('Start to update BriefCacheManager.') - summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir) + summaries_info = self._summary_watcher.list_summary_directories(self._summary_base_dir) basic_train_jobs = [] for info in summaries_info: @@ -425,6 +434,10 @@ class _BriefCacheManager(_BaseCacheManager): return new_cache_items + def register_folder_analyzer(self, analyzer): + """Register folder analyzer.""" + self._summary_watcher.register_folder_analyzer(analyzer) + @property def cache_items(self): """Get cache items.""" @@ -1028,6 +1041,10 @@ class DataManager: """Register brief cache item updater for brief cache manager.""" self._brief_cache.register_cache_item_updater(updater) + def register_folder_analyzer(self, analyzer): + """Register folder analyzer.""" + self._brief_cache.register_folder_analyzer(analyzer) + def get_brief_cache(self): """Get brief cache.""" return self._brief_cache diff --git a/mindinsight/datavisual/data_transform/graph/msgraph.py b/mindinsight/datavisual/data_transform/graph/msgraph.py index 814d8bc6..db1fd8d5 100644 --- a/mindinsight/datavisual/data_transform/graph/msgraph.py +++ b/mindinsight/datavisual/data_transform/graph/msgraph.py @@ -254,22 +254,24 @@ class MSGraph(Graph): return searched_list - def search_leaf_nodes_by_pattern(self, pattern): + def search_leaf_nodes_by_pattern(self, pattern, scope_pattern=False): """ Search leaf node by a given pattern. Args: pattern (Union[str, None]): The pattern of the node to search, if None, return all node names. + scope_pattern (bool): If true, return the children nodes of the scope. Default: False. Returns: list[Node], a list of nodes. """ + is_match = lambda x, y: x.lower().startswith(y) if scope_pattern else y in x.lower() if pattern is not None: pattern = pattern.lower() searched_nodes = [ node for name, node in self._leaf_nodes.items() - if pattern in name.lower() + if is_match(name, pattern) ] else: searched_nodes = [node for node in self._leaf_nodes.values()] diff --git a/mindinsight/datavisual/data_transform/summary_watcher.py b/mindinsight/datavisual/data_transform/summary_watcher.py index bf8f8a63..53454df4 100644 --- a/mindinsight/datavisual/data_transform/summary_watcher.py +++ b/mindinsight/datavisual/data_transform/summary_watcher.py @@ -29,6 +29,7 @@ from mindinsight.utils.exceptions import FileSystemPermissionError LINEAGE_SUMMARY_SUFFIX = '_lineage' EXPLAIN_SUMMARY_SUFFIX = '_explain' +DUMP_FILE_PREFIX = 'dump_' class SummaryWatcher: @@ -45,6 +46,13 @@ class SummaryWatcher: # to avoid long-time blocking MAX_SCAN_COUNT = 20000 + def __init__(self): + self._analyzers = [] + + def register_folder_analyzer(self, analyzer): + """Register folder analyzer.""" + self._analyzers.append(analyzer) + def list_summary_directories(self, summary_base_dir, overall=True, list_explain=False): """ List summary directories within base directory. @@ -104,7 +112,7 @@ class SummaryWatcher: elif entry.is_dir(): self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry, list_explain) entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) - self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry.name, counter, list_explain) + self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry, counter, list_explain) directories = [] for key, value in summary_dict.items(): @@ -119,7 +127,7 @@ class SummaryWatcher: return directories - def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry_name, counter, list_explain): + def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry, counter, list_explain): """ Scan subdir entries. @@ -134,7 +142,7 @@ class SummaryWatcher: try: subdir_entries = os.scandir(entry_path) except PermissionError: - logger.warning('Path of %s under summary base directory is not accessible.', entry_name) + logger.warning('Path of %s under summary base directory is not accessible.', entry.name) return # sort in ascending order according to modification time. @@ -149,11 +157,14 @@ class SummaryWatcher: logger.info('Stop further scanning due to overall is False and ' 'number of scanned files exceeds upper limit.') break - subdir_relative_path = os.path.join('.', entry_name) + subdir_relative_path = os.path.join('.', entry.name) if subdir_entry.is_symlink(): pass self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry, list_explain) + relative_path = './' + self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) + def _is_valid_summary_directory(self, summary_base_dir, relative_path): """ Check if the given summary directory is valid. @@ -198,13 +209,11 @@ class SummaryWatcher: list_explain (bool): Indicates whether to list only the mindexplain folder. """ try: - stat = entry.stat() + ctime, mtime = self._get_stat_time(entry) except FileNotFoundError: logger.warning('File %s not found', entry.name) return - ctime = datetime.datetime.fromtimestamp(stat.st_ctime).astimezone() - mtime = datetime.datetime.fromtimestamp(stat.st_mtime).astimezone() if entry.is_file(): summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) @@ -238,7 +247,10 @@ class SummaryWatcher: summary_dict[relative_path]['explain_files'] += 1 else: summary_dict[relative_path]['summary_files'] += 1 + self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) elif entry.is_dir(): + self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) + if list_explain: return @@ -261,6 +273,28 @@ class SummaryWatcher: else: summary_dict[relative_path] = _new_entry(ctime, mtime, profiler) + def _check_by_analyzers(self, entry, summary_base_dir, relative_path, summary_dict): + """Check by all analyzers.""" + try: + ctime, mtime = self._get_stat_time(entry) + except FileNotFoundError: + logger.warning('File %s not found', entry.name) + return + + for analyzer in self._analyzers: + register_info = analyzer.analyze(entry, summary_base_dir, relative_path) + if register_info: + if relative_path not in summary_dict: + summary_dict[relative_path] = _new_entry(ctime, mtime) + summary_dict[relative_path].update(register_info) + + def _get_stat_time(self, entry): + """Get ctime and mtime.""" + stat = entry.stat() + ctime = datetime.datetime.fromtimestamp(stat.st_ctime).astimezone() + mtime = datetime.datetime.fromtimestamp(stat.st_mtime).astimezone() + return ctime, mtime + def _find_profiler_dir(self, entry, summary_base_dir, relative_path): """Find profiler dir by the given relative path.""" profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name) @@ -342,6 +376,9 @@ class SummaryWatcher: if self._is_valid_profiler_directory(full_path)[0] or \ self._is_valid_cluster_profiler_directory(full_path)[0]: return True + if os.path.exists(os.path.join(summary_directory, os.path.join(entry.name, ".metadata"))): + return True + return False def _is_valid_profiler_directory(self, directory): @@ -515,7 +552,8 @@ def _new_entry(ctime, mtime, profiler=None): 'lineage_files': 0, 'explain_files': 0, 'graph_files': 0, - 'profiler': profiler + 'profiler': profiler, + 'dump_dir': None } diff --git a/mindinsight/datavisual/processors/train_task_manager.py b/mindinsight/datavisual/processors/train_task_manager.py index 2d025ec2..654d00b4 100644 --- a/mindinsight/datavisual/processors/train_task_manager.py +++ b/mindinsight/datavisual/processors/train_task_manager.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -150,7 +150,8 @@ class TrainTaskManager(BaseProcessor): profiler_type=basic_info.profiler_type, summary_files=basic_info.summary_files, graph_files=basic_info.graph_files, - lineage_files=basic_info.lineage_files + lineage_files=basic_info.lineage_files, + dump_dir=basic_info.dump_dir ) if train_job.cache_status != CacheStatus.NOT_IN_CACHE: diff --git a/mindinsight/debugger/common/exceptions/error_code.py b/mindinsight/debugger/common/exceptions/error_code.py index e87460b4..d4355b95 100644 --- a/mindinsight/debugger/common/exceptions/error_code.py +++ b/mindinsight/debugger/common/exceptions/error_code.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ from mindinsight.utils.constant import DebuggerErrors as DebuggerErrorCodes _PARAM_ERROR_MASK = 0b00001 << 7 _DEBUGGER_GRAPH_ERROR = 0b00010 << 7 _DEBUGGER_RUNNING_ERROR = 0b00011 << 7 +_DEBUGGER_SERVER_ERROR = 0b00100 << 7 +_DEBUGGER_SESSION_ERROR = 0b00101 << 7 @unique @@ -44,6 +46,13 @@ class DebuggerErrors(DebuggerErrorCodes): TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR SET_RECOMMEND_WATCHPOINT_ERROR = 9 | _DEBUGGER_RUNNING_ERROR + DEBUGGER_SERVER_RUNNING_ERROR = 0 | _DEBUGGER_SERVER_ERROR + DEVICE_ID_UNREGISTERED = 1 | _DEBUGGER_SERVER_ERROR + MODULE_NOT_FOUND_ERROR = 2 | _DEBUGGER_SERVER_ERROR + + DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR + DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR + @unique class DebuggerErrorMsg(Enum): @@ -63,3 +72,10 @@ class DebuggerErrorMsg(Enum): TENSOR_GRAPH_ERROR = "Get tensor graphs failed." TENSOR_HIT_ERROR = "Get tensor hits failed." SET_RECOMMEND_WATCHPOINT_ERROR = "Set Recommend Watchpoints failed." + + DEBUGGER_SERVER_RUNNING_ERROR = "Debugger server running error. {}" + DEVICE_ID_UNREGISTERED = "Device id unregistered. Device id: {}" + MODULE_NOT_FOUND_ERROR = "{} module not found." + + DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation." + DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found." diff --git a/mindinsight/debugger/common/exceptions/exceptions.py b/mindinsight/debugger/common/exceptions/exceptions.py index 060325fe..7cefd13d 100644 --- a/mindinsight/debugger/common/exceptions/exceptions.py +++ b/mindinsight/debugger/common/exceptions/exceptions.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -190,3 +190,58 @@ class DebuggerConditionUnavailableError(MindInsightException): message=DebuggerErrorMsg.DEBUGGER_CONDITION_UNAVAILABLE_ERROR.value.format(msg), http_code=400 ) + + +class DebuggerServerRunningError(MindInsightException): + """The condition unavailable error in debugger module.""" + + def __init__(self, msg): + super(DebuggerServerRunningError, self).__init__( + error=DebuggerErrors.DEBUGGER_SERVER_RUNNING_ERROR, + message=DebuggerErrorMsg.DEBUGGER_SERVER_RUNNING_ERROR.value.format(msg), + http_code=500 + ) + + +class DeviceIdUnregistered(MindInsightException): + """The condition unavailable error in debugger module.""" + + def __init__(self, msg): + super(DeviceIdUnregistered, self).__init__( + error=DebuggerErrors.DEVICE_ID_UNREGISTERED, + message=DebuggerErrorMsg.DEVICE_ID_UNREGISTERED.value.format(msg), + http_code=400 + ) + + +class DebuggerModuleNotFoundError(MindInsightException): + """The condition unavailable error in debugger module.""" + + def __init__(self, msg): + super(DebuggerModuleNotFoundError, self).__init__( + error=DebuggerErrors.MODULE_NOT_FOUND_ERROR, + message=DebuggerErrorMsg.MODULE_NOT_FOUND_ERROR.value.format(msg), + http_code=500 + ) + + +class DebuggerSessionNumOverBoundError(MindInsightException): + """The condition unavailable error in debugger module.""" + + def __init__(self): + super(DebuggerSessionNumOverBoundError, self).__init__( + error=DebuggerErrors.DEBUGGER_SESSION_OVER_BOUND_ERROR, + message=DebuggerErrorMsg.DEBUGGER_SESSION_OVER_BOUND_ERROR.value, + http_code=400 + ) + + +class DebuggerSessionNotFoundError(MindInsightException): + """The condition unavailable error in debugger module.""" + + def __init__(self, msg): + super(DebuggerSessionNotFoundError, self).__init__( + error=DebuggerErrors.DEBUGGER_SESSION_NOT_FOUND_ERROR, + message=DebuggerErrorMsg.DEBUGGER_SESSION_NOT_FOUND_ERROR.value.format(msg), + http_code=400 + ) diff --git a/mindinsight/debugger/common/utils.py b/mindinsight/debugger/common/utils.py index 24c313b5..cdb19ac8 100644 --- a/mindinsight/debugger/common/utils.py +++ b/mindinsight/debugger/common/utils.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,9 +38,12 @@ NUMPY_TYPE_MAP = { 'DT_FLOAT32': np.float32, 'DT_FLOAT64': np.float64, - 'DT_STRING': np.str + 'DT_STRING': np.str, + 'DT_TYPE': np.str } +MS_VERSION = '1.0.x' + @enum.unique class ReplyStates(enum.Enum): @@ -71,6 +74,7 @@ class Streams(enum.Enum): TENSOR = 'tensor' WATCHPOINT = 'watchpoint' WATCHPOINT_HIT = 'watchpoint_hit' + DEVICE = 'device' class RunLevel(enum.Enum): @@ -152,3 +156,26 @@ def is_scope_type(node_type): def is_cst_type(node_type): """Judge whether the type is const type.""" return node_type == NodeTypeEnum.CONST.value + + +def version_match(ms_version, mi_version): + """Judge if the version of Mindinsight and Mindspore is matched.""" + if not ms_version: + ms_version = MS_VERSION + mi_major, mi_minor = mi_version.split('.')[:2] + ms_major, ms_minor = ms_version.split('.')[:2] + return mi_major == ms_major and mi_minor == ms_minor + + +@enum.unique +class DebuggerServerMode(enum.Enum): + """Debugger Server Mode.""" + ONLINE = 'online' + OFFLINE = 'offline' + + +class DumpSettings(enum.Enum): + """Dump settings.""" + E2E_DUMP_SETTINGS = 'e2e_dump_settings' + COMMON_DUMP_SETTINGS = 'common_dump_settings' + ASYNC_DUMP_SETTINGS = 'async_dump_settings' diff --git a/mindinsight/debugger/conditionmgr/recommender.py b/mindinsight/debugger/conditionmgr/recommender.py index 7f82d0a4..c8bec7fa 100644 --- a/mindinsight/debugger/conditionmgr/recommender.py +++ b/mindinsight/debugger/conditionmgr/recommender.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -64,13 +64,13 @@ class _ConditionParameterValue: return self.parameter.name -def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_context): +def recommend_watchpoints(condition_mgr: ConditionMgr, multi_card_graph_stream, condition_context): """ Recommend watchpoints. Args: condition_mgr (ConditionMgr): Condition manager instance. - graph_stream (GraphHandler): Graph handler instance. + multi_card_graph_stream (GraphHandler): Multi card graph handler instance. condition_context (ConditionContext): Context for condition. Returns: @@ -78,7 +78,7 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c """ watch_points = [] - if not graph_stream.graph: + if not multi_card_graph_stream.has_graph: logger.warning("Given graph is None.") return watch_points @@ -86,7 +86,7 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c return watch_points # add weight watch points - merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, graph_stream) + merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, multi_card_graph_stream) _recommend_weight_initialization(merged_info, condition_mgr, watch_points, condition_context) _recommend_weight_change_too_large(merged_info, condition_mgr, watch_points, condition_context) @@ -97,25 +97,27 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context) # add gradient watch points - merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, graph_stream) + merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, multi_card_graph_stream) _recommend_gradient_vanishing(merged_info, condition_mgr, watch_points, condition_context) # add tensor watch points - merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, graph_stream) + merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, multi_card_graph_stream) _recommend_operator_overflow(merged_info, condition_mgr, watch_points, condition_context) _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context) _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context) # add activation watch points - merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.TANH.value) + merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, + ActivationFuncEnum.TANH.value) _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, ActivationFuncEnum.TANH.value) - merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.SIGMOID.value) + merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, + ActivationFuncEnum.SIGMOID.value) _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, ActivationFuncEnum.SIGMOID.value) - merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, + merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, [ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value]) _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, ActivationFuncEnum.RELU.value) @@ -318,12 +320,21 @@ def _recommend_activation_range(basic_info_nodes, condition_mgr, watch_points, c watch_points.append(activation_range_watchpoint) -def get_basic_node_info(node_category, graph_stream, activation_func=None): +def get_basic_node_info(node_category, multi_card_graph_stream, activation_func=None): """Get node merged info.""" - basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func) - merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph) - merged_info = _add_graph_name(merged_info, graph_stream) - return merged_info + nodes_for_devices = {} + has_node = False + for rank_id, graph_stream in multi_card_graph_stream.graph_handlers.items(): + basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func) + merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph) + merged_info = _add_graph_name(merged_info, graph_stream) + nodes_for_devices[rank_id] = merged_info + has_node = has_node or merged_info + + if has_node: + return nodes_for_devices + + return {} def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None): diff --git a/mindinsight/debugger/debugger_cache.py b/mindinsight/debugger/debugger_cache.py index 05e285f9..830cf87e 100644 --- a/mindinsight/debugger/debugger_cache.py +++ b/mindinsight/debugger/debugger_cache.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,17 +17,19 @@ import sys from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import Streams -from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \ - TensorHandler, WatchpointHandler, WatchpointHitHandler +from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, MultiCardGraphHandler, \ + MultiCardTensorHandler, WatchpointHandler, MultiCardWatchpointHitHandler +from mindinsight.debugger.stream_handler.device_handler import DeviceHandler STREAM_HANDLER_MAP = { Streams.COMMAND.value: EventHandler, Streams.DATA.value: EventHandler, Streams.METADATA.value: MetadataHandler, - Streams.GRAPH.value: GraphHandler, - Streams.TENSOR.value: TensorHandler, + Streams.GRAPH.value: MultiCardGraphHandler, + Streams.TENSOR.value: MultiCardTensorHandler, Streams.WATCHPOINT.value: WatchpointHandler, - Streams.WATCHPOINT_HIT.value: WatchpointHitHandler + Streams.WATCHPOINT_HIT.value: MultiCardWatchpointHitHandler, + Streams.DEVICE.value: DeviceHandler } @@ -40,10 +42,8 @@ class DebuggerCache: def initialize(self): """Initialize the stream handlers.""" self._stream_handler = {} - for stream in Streams: - mode = stream.value - stream_handler = STREAM_HANDLER_MAP.get(mode) - self._stream_handler[mode] = stream_handler() + for mode, stream_class in STREAM_HANDLER_MAP.items(): + self._stream_handler[mode] = stream_class() def clean(self): """Clean cache for all stream.""" diff --git a/mindinsight/debugger/debugger_folder_analyzer.py b/mindinsight/debugger/debugger_folder_analyzer.py new file mode 100644 index 00000000..9909732d --- /dev/null +++ b/mindinsight/debugger/debugger_folder_analyzer.py @@ -0,0 +1,41 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Debugger train job register.""" + +import os + +from mindinsight.utils.folder_analyzer import FolderAnalyzer +from mindinsight.datavisual.common.log import logger + +class DebuggerFolderAnalyzer(FolderAnalyzer): + """Debugger train job register.""" + def analyze(self, entry, summary_base_dir, relative_path): + """Check dir by debugger register.""" + update_info = {} + if entry.is_dir(): + sub_relative_path = os.path.join(relative_path, entry.name) + entry_path = os.path.join(summary_base_dir, sub_relative_path) + try: + subdir_entries = os.scandir(entry_path) + except PermissionError: + logger.warning('Path of %s under summary base directory is not accessible.', entry.name) + return update_info + subdir_entries = [subdir_entry for subdir_entry in subdir_entries if not subdir_entry.is_symlink()] + subdir_entries = sorted(subdir_entries, key=lambda x: x.stat().st_mtime) + for subdir_entry in subdir_entries: + if subdir_entry.is_dir() and subdir_entry.name.startswith(".metadata"): + update_info = {'dump_dir': sub_relative_path} + return update_info + return update_info diff --git a/mindinsight/debugger/debugger_services/__init__.py b/mindinsight/debugger/debugger_services/__init__.py new file mode 100644 index 00000000..fecf22f8 --- /dev/null +++ b/mindinsight/debugger/debugger_services/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Debugger server module.""" diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_services/debugger_grpc_server.py similarity index 95% rename from mindinsight/debugger/debugger_grpc_server.py rename to mindinsight/debugger/debugger_services/debugger_grpc_server.py index 9fd54329..625d1c85 100644 --- a/mindinsight/debugger/debugger_grpc_server.py +++ b/mindinsight/debugger/debugger_services/debugger_grpc_server.py @@ -19,7 +19,7 @@ from functools import wraps import mindinsight from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ - Streams, RunLevel + Streams, RunLevel, version_match from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum, ParamNameEnum from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto @@ -117,9 +117,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): # clean cache data at the beginning of new step or node has been changed. if is_new_step or is_new_node: self._cache_store.clean_data() - self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) + self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).clean_tensors( + request.cur_step) if is_new_step: - self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() + self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean() # receive graph at the beginning of the training if self._status == ServerStatus.RECEIVE_GRAPH: self._send_graph_flag(metadata_stream) @@ -141,7 +142,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): self._status = ServerStatus.WAITING metadata_stream.state = ServerStatus.WAITING.value metadata = metadata_stream.get() - res = self._cache_store.get_stream_handler(Streams.GRAPH).get() + res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get() res.update(metadata) self._cache_store.put_data(res) log.debug("Put graph into data queue.") @@ -157,7 +158,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): # put new metadata into cache metadata_stream.put(metadata_proto) # update current node name and graph name - graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) + graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0) full_name = metadata_proto.cur_node graph_name = graph_stream.get_graph_id_by_full_name( full_name) if full_name else metadata_stream.graph_name @@ -182,7 +183,8 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): def _send_watchpoint_hit_flag(self): """Send Watchpoint hit flag.""" - watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) + watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id( + 0) if not self._received_hit: return watchpoint_hits = self._received_hit @@ -344,7 +346,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): run_cmd.node_name = '' # clean watchpoint hit cache if run_cmd.run_level == RunLevel.RECHECK.value: - self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() + self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean() log.debug("Receive RunCMD. Clean watchpoint hit cache.") # update metadata state from sending to running metadata_stream.state = ServerStatus.RUNNING.value @@ -365,8 +367,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): log.info("The training from %s has finished.", client_ip) else: ms_version = request.ms_version - if not ms_version: - ms_version = '1.0.x' if version_match(ms_version, mindinsight.__version__) is False: log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s", ms_version, mindinsight.__version__) @@ -403,8 +403,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): graph = GraphProto.FromString(serial_graph) log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node)) graph_dict = {graph.name: graph} - self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) - self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) + self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict) + self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals( + graph.const_vals) self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name self._record_parameter_names() self._status = ServerStatus.RECEIVE_GRAPH @@ -429,10 +430,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name, len(sub_graph.node)) serial_graph = b"" - self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals( + self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals( sub_graph.const_vals) - self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) + self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict) self._record_parameter_names() self._status = ServerStatus.RECEIVE_GRAPH log.debug("Send the reply for graph.") @@ -440,9 +441,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): def _record_parameter_names(self): """Record parameter full names in tensor handler.""" - parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).search_in_graph( - pattern={'node_category': TargetTypeEnum.PARAMETER.value}) - tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) + parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)\ + .search_in_graph(pattern={'node_category': TargetTypeEnum.PARAMETER.value}) + tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0) for node in parameter_nodes: tensor_name = [node.full_name + ':0'] tensor_stream.record_parameter_names(tensor_name) @@ -452,7 +453,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): """Send tensors into DebuggerCache.""" log.info("Received tensor.") tensor_contents = [] - tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) + tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0) metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) step = metadata_stream.step for tensor in request_iterator: @@ -482,7 +483,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): # save the watchpoint_hits data watchpoint_hits = [] watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) - graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) + graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0) for watchpoint_hit_proto in request_iterator: node_full_name = watchpoint_hit_proto.tensor.node_name graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) @@ -517,10 +518,3 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): self._received_hit = watchpoint_hits reply = get_ack_reply() return reply - - -def version_match(mi_version, ms_version): - """Judge if the version of Mindinsight and Mindspore is matched""" - mi_major, mi_minor = mi_version.split('.')[:2] - ms_major, ms_minor = ms_version.split('.')[:2] - return mi_major == ms_major and mi_minor == ms_minor diff --git a/mindinsight/debugger/debugger_services/debugger_offline_server.py b/mindinsight/debugger/debugger_services/debugger_offline_server.py new file mode 100644 index 00000000..fb97f273 --- /dev/null +++ b/mindinsight/debugger/debugger_services/debugger_offline_server.py @@ -0,0 +1,613 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Debugger Offline server.""" +import copy +from collections import defaultdict +from importlib import import_module +from threading import Event +from multiprocessing import Process, Manager + +import mindinsight +from mindinsight.datavisual.data_transform.graph import NodeTypeEnum +from mindinsight.debugger.common.exceptions.exceptions import DebuggerModuleNotFoundError +from mindinsight.debugger.common.log import LOGGER as log +from mindinsight.debugger.common.utils import Streams, ServerStatus, version_match, DebuggerServerMode, get_ack_reply, \ + RunLevel +from mindinsight.debugger.conditionmgr.condition import ParamNameEnum +from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase, debugger_server_wrap +from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply +from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto +from mindinsight.debugger.stream_cache.data_loader import DataLoader +from mindinsight.utils.exceptions import MindInsightException + + +class DebuggerOfflineServer(DebuggerServerBase): + """Debugger Offline Server.""" + _MAX_TRY_EXCEPT_COUNT = 500 + + def __init__(self, cache_store, context): + super(DebuggerOfflineServer, self).__init__(cache_store, context) + self._offline_server_manager = DebuggerOfflineManager(cache_store, context.dbg_dir) + self._running = Event() + self._running.clear() + + def run(self): + """Start the debugger offline server.""" + log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) + self._offline_server_manager.initialize() + self._running.set() + log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) + try_count = 0 + while self._running.is_set() and try_count < self._MAX_TRY_EXCEPT_COUNT: + try: + self._offline_server_manager.wait_for_termination() + if not self._offline_server_manager.is_runnable(): + break + except MindInsightException as err: + log.exception(err) + log.warning("Error happens during listening on user commands. Restart listening again.") + finally: + try_count += 1 + # protect server from too much failure commands. + if try_count == self._MAX_TRY_EXCEPT_COUNT: + self._cache_store.clean() + metadata = self._cache_store.get_stream_handler(Streams.METADATA).get() + self._cache_store.put_data(metadata) + log.warning("Exception exceed %d times, stop server.", try_count) + + def stop(self): + """Stop offline debugger server.""" + log.debug("Start to wait for thread started.") + self._running.wait() + log.info("Start to stop offline debugger server.") + self._running.clear() + self._offline_server_manager.stop() + self.join() + + +class DebuggerOfflineManager: + """Debugger offline manager which is used to handle user commands.""" + + def __init__(self, cache_store, dbg_dir): + cache_store.initialize() + self._cache_store = cache_store + self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) + + self._dbg_dir = dbg_dir + self._dbg_services_module = self._get_dbg_service_module() + self._dbg_service = None + + self._command_listener = CommandListener(cache_store) + self._data_loader = DataLoader(dbg_dir) + self._is_running_flag = False + self._old_run_cmd = {} + + def stop(self): + """Stop server.""" + self._is_running_flag = False + self._command_listener.stop() + self._cache_store.clean() + event = get_ack_reply() + event.exit = True + self._cache_store.put_command(event) + log.info("Stop debugger offline manager.") + + def is_runnable(self): + """Check if the offline manager is runnable.""" + state = self._metadata_stream.state + flag = self._is_running_flag and state not in [ServerStatus.MISMATCH.value, ServerStatus.PENDING.value] + if not flag: + log.debug("The offline manager is not runnable, is_running_flag: %s, metadata state: %s", + self._is_running_flag, state) + return flag + + @staticmethod + def _get_dbg_service_module(): + """Get dbg service module from MindSpore.""" + try: + dbg_services_module = import_module('mindspore.offline_debug.dbg_services') + except ModuleNotFoundError as err: + log.error("Failed to find module dbg_services. %s", err) + raise DebuggerModuleNotFoundError("dbg_services") + return dbg_services_module + + @debugger_server_wrap + def initialize(self): + """Start to load offline debugger data.""" + self._data_loader.initialize() + is_sync = self._data_loader.get_sync_flag() + net_name = self._data_loader.get_net_name() + net_dir = self._data_loader.get_net_dir() + self._dbg_service = self._dbg_services_module.DbgServices(net_dir) + self._dbg_service.initialize(net_name=net_name, is_sync_mode=is_sync) + self._cache_store.clean() + self._command_listener.start() + self._is_running_flag = True + self._check_version() + if self._metadata_stream.state == ServerStatus.MISMATCH.value: + log.info("The MindSpore and MindInsight version are mismatched. Failed to initialize offline server.") + return + self._load_metadata() + self._load_graphs() + log.info("Success initialize offline server for %s", self._dbg_dir) + + def _check_version(self): + """Check version.""" + ms_version = self._dbg_services_module.get_version() + mi_version = mindinsight.__version__ + self._metadata_stream.debugger_version = {'ms': ms_version, 'mi': mi_version} + if version_match(ms_version, mi_version) is False: + log.info("Version is mismatched, dbg_services is: %s, mindinsight is: %s", + ms_version, mi_version) + self._metadata_stream.state = ServerStatus.MISMATCH.value + metadata = self._metadata_stream.get(['state', 'debugger_version']) + self._cache_store.put_data(metadata) + + def _load_metadata(self): + """Load metadata.""" + self._metadata_stream.debugger_type = DebuggerServerMode.OFFLINE.value + device_info = self._data_loader.load_device_info() + # The backend referred to the running environment on which the offline debugger + # data was generated. + # Currently supported options: `GPU`, `Ascend` + backend = device_info.get('device_target', 'Ascend') + self._metadata_stream.backend = backend + device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) + device_stream.put(device_info.get('server_list')) + rank_id = 0 + rank_0_info = device_stream.get(rank_id)['devices'][0] + self._metadata_stream.client_ip = rank_0_info.get('server_id') + # get step number per device. dict(device_id, step_num), may be increased with time goes by + step_num_per_device = self._data_loader.load_step_number() + device_stream.add_step_num_info(step_num_per_device) + self._metadata_stream.max_step_num = max(step_num_per_device.values()) + + def _load_graphs(self): + """Load graphs.""" + # the format of graphs is a list of {'device_id': int, 'graph_protos': [GraphProto]}} + graphs = self._data_loader.load_graphs() + device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) + graph_per_rank = {} + for graph in graphs: + device_id = int(graph.get('device_id')) + rank_id = device_stream.get_rank_id_by_device_id(device_id) + graph_per_rank[rank_id] = {} + tensor_stream_per_rank = self._cache_store.get_stream_handler(Streams.TENSOR).\ + get_tensor_handler_by_rank_id(rank_id, create_if_not_exit=True) + for graph_proto in graph.get('graph_protos'): + graph_per_rank[rank_id][graph_proto.name] = graph_proto + tensor_stream_per_rank.put_const_vals(graph_proto.const_vals) + # the graph_per_rank is format like: Dict[, Dict[, ]] + self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_per_rank) + device_stream.add_graph_name_info(graph_per_rank) + self._metadata_stream.state = ServerStatus.RECEIVE_GRAPH.value + + @debugger_server_wrap + def wait_for_termination(self): + """Begin to listen on command event.""" + log.info("Begin to listen for user commands.") + self._send_graph() + while self.is_runnable(): + if not self._command_listener.has_new_command() and self._old_run_cmd: + self._deal_with_old_run_cmd() + continue + cmd = self._command_listener.get_next_command() + self.deal_with_cmd(cmd) + + def _send_graph(self): + """Put graph and metadata info into data queue.""" + if not self.is_runnable(): + return + self._metadata_stream.state = ServerStatus.WAITING.value + metadata = self._metadata_stream.get() + res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get() + res.update(metadata) + self._cache_store.put_data(res) + + def _deal_with_old_run_cmd(self): + """Deal with old run command.""" + left_step_count = self._old_run_cmd.get('left_step_count') + if left_step_count: + self._execute_one_step() + # if old_run_cmd is not cleared due to hit. + if self._old_run_cmd: + self._old_run_cmd['left_step_count'] = left_step_count - 1 if left_step_count > 0 else -1 + if not self._old_run_cmd.get('left_step_count'): + self._old_run_cmd.clear() + + def deal_with_cmd(self, cmd): + """Deal with command.""" + if cmd is None: + return + if isinstance(cmd, dict): + self._deal_with_view_cmd(cmd) + elif isinstance(cmd, EventReply): + self._on_event(cmd) + + def _on_event(self, event): + """ + Deal with different command event. + + Args: + event (EventReply): Command Event. + """ + if event.HasField('run_cmd'): + self._deal_with_run_cmd(event) + elif event.HasField('exit'): + self._cache_store.clean() + self._update_state(ServerStatus.PENDING) + log.debug("Clean cache for exit cmd.") + else: + self._deal_with_set_cmd(event) + log.debug("Deal with set cmd.") + + def _deal_with_view_cmd(self, event): + """ + Deal with view cmd. + + Args: + event (dict): View command params. + + - view_cmd (EventReply): EventReply with view command. + - node_name (str): The center node name for view command. + - tensor_name (str): The center tensor name for view command. + - graph_name (str): The graph name of center node. + - rank_id (int): The device id of the tensor. + """ + view_cmd = event.pop('view_cmd', None).view_cmd + node_info = event + log.debug("Receive view cmd for node: %s.", event) + if not (view_cmd and node_info): + log.info("Invalid view command. Ignore it.") + return + # read tensor value by dbg_service + rank_id = node_info.get('rank_id', 0) + device_id = self._cache_store.get_stream_handler(Streams.DEVICE).get_device_id_by_rank_id(rank_id) + cur_step = self._metadata_stream.step + tensor_protos = view_cmd.tensors + root_graph_id = self.get_root_graph_id() + tensor_infos = [ + self._dbg_services_module.TensorInfo( + node_name=tensor_proto.node_name, + slot=int(tensor_proto.slot), + iteration=cur_step - 1 if tensor_proto.iter == 'prev' else cur_step, + device_id=device_id, + is_parameter=tensor_proto.truncate, + root_graph_id=root_graph_id + ) for tensor_proto in tensor_protos] + res = self._dbg_service.read_tensors(tensor_infos) + # put tensor into cache + for tensor_proto, tensor_data in zip(tensor_protos, res): + log.debug("Tensor name: %s:%s, tensor type: %s, tensor size: %s", tensor_proto.node_name, tensor_proto.slot, + tensor_data.dtype, tensor_data.data_size) + tensor_proto.tensor_content = tensor_data.data_ptr + tensor_proto.ClearField('dims') + tensor_proto.dims.extend(tensor_data.shape) + tensor_proto.data_type = tensor_data.dtype + self._put_tensor_value_into_cache(cur_step, node_info, rank_id, tensor_protos) + log.info("Put tensor value into cache.") + + def get_root_graph_id(self): + """Get root graph id.""" + is_sync = self._data_loader.get_sync_flag() + graph_id = 0 if is_sync else 1 + return graph_id + + def _put_tensor_value_into_cache(self, cur_step, node_info, rank_id, tensor_protos): + """Put tensor value into tensor cache.""" + + tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR). \ + get_tensor_handler_by_rank_id(rank_id) + update_data_flag = False + for tensor_proto in tensor_protos: + if not tensor_proto.tensor_content: + log.warning("Tensor %s:%s is empty.", + tensor_proto.node_name, tensor_proto.slot) + try: + has_update = tensor_stream.put({ + 'step': cur_step, + 'tensor_proto': tensor_proto, + 'tensor_contents': [tensor_proto.tensor_content] + }) + except ValueError as err: + log.warning("Failed to put %s:%s into cache. Ignore it. %s", + tensor_proto.node_name, tensor_proto.slot, str(err)) + continue + if has_update: + update_data_flag = True + if update_data_flag: + # send message to frontend + metadata = self._metadata_stream.get(['step', 'state']) + ret = {'receive_tensor': node_info.copy()} + ret.update(metadata) + self._cache_store.put_data(ret) + + def _deal_with_run_cmd(self, event): + """Deal with run cmd.""" + run_cmd = event.run_cmd + parsed_run_cmd = self._get_parsed_run_cmd(run_cmd) + if parsed_run_cmd.run_steps > 0: + self._execute_one_step() + elif run_cmd.run_level == RunLevel.RECHECK.value: + log.info("Deal with recheck command.") + self._check_watchpoint(self._metadata_stream.step) + + def _execute_one_step(self): + """Execute on step.""" + new_step = self._metadata_stream.step + 1 + if new_step > self._metadata_stream.max_step_num: + self._old_run_cmd.clear() + log.info("The server is already at the last step. %s", self._metadata_stream.max_step_num) + return + log.info("Go to next step: %s.", new_step) + self._check_watchpoint(new_step) + self._metadata_stream.step = new_step + self._cache_store.get_stream_handler(Streams.TENSOR).set_step(new_step) + self._cache_store.put_data(self._metadata_stream.get('step')) + + def _get_parsed_run_cmd(self, run_cmd): + """Get parsed run command.""" + if run_cmd.run_level == RunLevel.STEP.value: + # receive pause cmd + if not run_cmd.run_steps: + log.debug("Pause training and wait for next command.") + self._old_run_cmd.clear() + # update metadata state from sending to waiting + self._update_state(ServerStatus.WAITING) + return run_cmd + # receive step cmd + left_steps = run_cmd.run_steps - 1 + run_cmd.run_steps = 1 + if left_steps: + self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1 + elif run_cmd.node_name: + self._old_run_cmd['node_name'] = run_cmd.node_name + run_cmd.node_name = '' + return run_cmd + + def _check_watchpoint(self, step): + """Save watchpoint hits into cache.""" + self._update_state(ServerStatus.RUNNING) + # Clean watchpoint_hits in cache + multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) + multi_card_hit_streams.clean() + hits = Manager().list() + check_watchpoints_process = Process(target=self._check_watchpoint_work, args=(hits, step,)) + check_watchpoints_process.start() + check_watchpoints_process.join() + log.info("finish check watchpoint of %s", step) + if hits: + log.info("Received WatchpointHits. Left run cmd %s change to empty.", self._old_run_cmd) + self._old_run_cmd.clear() + self._update_state(ServerStatus.WAITING) + self._save_watchpoint_hits(hits) + + def _save_watchpoint_hits(self, hits): + """Save watchpoint hits.""" + multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) + multi_card_graph_streams = self._cache_store.get_stream_handler(Streams.GRAPH) + device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) + watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) + + watchpoint_hits = defaultdict(list) + for hit in hits: + log.info("Received hit\n: " + "name:%s, slot:%s, condition:%s, " + "watchpoint_id:%s" + "error_code:%s, device_id:%s", + hit['name'], hit['slot'], hit['condition'], + hit['watchpoint_id'], hit['error_code'], hit['device_id']) + rank_id = device_stream.get_rank_id_by_device_id(hit['device_id']) + watchpoint_hit = {} + self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit) + if not watchpoint_hit: + continue + self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit) + watchpoint_hit['error_code'] = hit['error_code'] + watchpoint_hits[rank_id].append(watchpoint_hit) + # save hit info into cache + multi_card_hit_streams.put(watchpoint_hits) + self._cache_store.put_data({'receive_watchpoint_hits': True}) + log.debug("Send the watchpoint hits to DataQueue.") + + @staticmethod + def _add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit): + """Add hit node info.""" + graph_stream = multi_card_graph_streams.get_graph_handler_by_rank_id(rank_id) + node_full_name = hit['name'] + graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) + if not graph_name: + log.warning("Cannot find node %s in graph. Skip it.", node_full_name) + return + ui_node_name = graph_stream.get_node_name_by_full_name(node_full_name, graph_name) + log.debug("Receive watch point hit: %s:%s", node_full_name, hit['slot']) + if not ui_node_name: + log.info("Not support to show %s on graph.", node_full_name) + return + watchpoint_hit.update({ + 'tensor_proto': TensorProto(node_name=node_full_name, slot=str(hit['slot'])), + 'node_name': ui_node_name, + 'graph_name': graph_name + }) + + @staticmethod + def _add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit): + """Add watchpoint hit info.""" + watchpoint = copy.deepcopy(watchpoint_stream.get_watchpoint_by_id(hit['watchpoint_id'])) + hit_params = {} + # get hit actual value + for param in hit['parameters']: + if param['name'] not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value, + ParamNameEnum.RANGE_END_INCLUSIVE.value) \ + and hit['error_code'] == 0: + hit_params[param['name']] = param['actual_value'] + # update actual value into watchpoint + watchpoint_condition_params = watchpoint.condition['params'] + for i, param in enumerate(watchpoint_condition_params): + name = param['name'] + if name in hit_params.keys(): + watchpoint_condition_params[i]['actual_value'] = hit_params[name] + else: + watchpoint_condition_params[i]['actual_value'] = None + + watchpoint_hit['watchpoint'] = watchpoint + + def _deal_with_set_cmd(self, event): + """ + Deal with set cmd. + + Args: + event (EventReply): User command event including set_cmd. + """ + set_cmd = event.set_cmd + set_cmd_id = set_cmd.id + delete = set_cmd.delete + if not delete: + log.info("Add watchpoint by using dbg_server.") + watch_condition = set_cmd.watch_condition + param_list = [] + for param in watch_condition.params: + param_list.append( + self._dbg_services_module.Parameter(param.name, param.disabled, param.value)) + watch_nodes = set_cmd.watch_nodes + check_nodes = self._get_check_nodes(watch_nodes) + log.debug("Watchpoint %s, condition: %s, watch nodes: %s", + set_cmd_id, watch_condition.condition, check_nodes) + self._dbg_service.add_watchpoint(set_cmd_id, watch_condition.condition, check_nodes, param_list) + else: + log.info("Remove watchpoint by using dbg_server.") + self._dbg_service.remove_watchpoint(set_cmd_id) + + def _get_check_nodes(self, watch_nodes): + """Get check nodes format""" + check_nodes = {} + device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) + root_graph_id = self.get_root_graph_id() + for watch_node in watch_nodes: + node_name = watch_node.node_name + rank_id = watch_node.rank_id + device_id = device_stream.get_device_id_by_rank_id(rank_id) + if node_name not in check_nodes: + is_parameter = bool(watch_node.node_type == NodeTypeEnum.PARAMETER.value) + check_nodes[node_name] = { + "device_id": [device_id], + "is_parameter": is_parameter, + "root_graph_id": [root_graph_id] + } + else: + check_nodes[node_name]["device_id"].append(device_id) + return check_nodes + + def _update_state(self, server_status): + """ + Update state in metadata stream. + + Args: + server_status (ServerStatus): The enum value in ServerStatus. + """ + if self._metadata_stream.state != server_status.value: + self._metadata_stream.state = server_status.value + self._cache_store.put_data(self._metadata_stream.get()) + + def _check_watchpoint_work(self, hits, step): + """The check WatchPoint function work in another process.""" + log.info("Start checking WatchPointHit process.") + res = self._dbg_service.check_watchpoints(step) + for watchpoint_hit in res: + hit_dict = convert_watchpointhit(watchpoint_hit) + hits.append(hit_dict) + log.info("Checking WatchPointHit process is finished.") + + +class CommandListener: + """Event listener.""" + + def __init__(self, cache_store): + self._cache_store = cache_store + self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) + # the next position of command queue to be queried + self._pos = '0' + self._is_waiting = Event() + + def start(self): + """Start event listener.""" + self._pos = '0' + self._is_waiting.set() + + def stop(self): + """Stop event listener.""" + # stop waiting for new user commands but can still get old commands. + self._is_waiting.clear() + + def has_new_command(self): + """Check if there is new command in command queue.""" + return self._cache_store.has_command(self._pos) + + def get_next_command(self): + """Get next command.""" + event = None + while event is None and self.has_new_command(): + self._pos, event = self._cache_store.get_command(self._pos) + log.debug("Deal with old %s-th command:\n%s.", self._pos, event) + if event is None: + event = self._wait_for_next_command() + return event + + def _wait_for_next_command(self): + """ + Wait for next command. + + Returns: + EventReply, the command event. + """ + if not self._is_waiting.is_set(): + self._metadata_stream.state = ServerStatus.PENDING.value + return None + log.info("Start to wait for command.") + if self._metadata_stream.state != ServerStatus.WAITING.value: + self._metadata_stream.state = ServerStatus.WAITING.value + self._cache_store.put_data(self._metadata_stream.get()) + log.debug("Wait for %s-th command", self._pos) + event = None + while event is None and self._is_waiting.is_set(): + self._pos, event = self._cache_store.get_command(self._pos) + return event + + +def convert_watchpointhit(watchpointhit): + """Convert watchpointhit object to dict.""" + parameters = watchpointhit.parameters + param_list = [] + for param in parameters: + param_dict = convert_param(param) + param_list.append(param_dict) + watchpointhit_dict = {'condition': watchpointhit.condition, + 'device_id': watchpointhit.device_id, + 'error_code': watchpointhit.error_code, + 'name': watchpointhit.name, + 'parameters': param_list, + 'slot': watchpointhit.slot, + 'watchpoint_id': watchpointhit.watchpoint_id} + return watchpointhit_dict + + +def convert_param(param): + """Convert parameter object to dict""" + param_dict = {'actual_value': param.actual_value, + 'disabled': param.disabled, + 'hit': param.hit, + 'name': param.name, + 'value': param.value} + return param_dict diff --git a/mindinsight/debugger/debugger_services/debugger_online_server.py b/mindinsight/debugger/debugger_services/debugger_online_server.py new file mode 100644 index 00000000..6d1cfde6 --- /dev/null +++ b/mindinsight/debugger/debugger_services/debugger_online_server.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Debugger Online server.""" +from concurrent import futures + +import grpc +from mindinsight.debugger.common.log import LOGGER as log +from mindinsight.conf import settings +from mindinsight.debugger.debugger_services.debugger_grpc_server import DebuggerGrpcServer +from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase +from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base + + +def get_debugger_hostname(): + """Get hostname for online debugger server.""" + grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051 + host = settings.HOST if hasattr(settings, 'HOST') else '[::]' + hostname = "{}:{}".format(host, grpc_port) + return hostname + + +class DebuggerOnlineServer(DebuggerServerBase): + """Debugger Online Server.""" + + def __init__(self, cache_store, context): + super(DebuggerOnlineServer, self).__init__(cache_store, context) + self._grpc_server_manager = self.get_grpc_server_manager() + + def run(self): + self._grpc_server_manager.start() + log.info("Start grpc server %s", self._context.hostname) + self._grpc_server_manager.wait_for_termination() + + def get_grpc_server_manager(self): + """Get grpc server instance according to hostname.""" + if self._context.hostname is None: + self._context.hostname = get_debugger_hostname() + grpc_server = DebuggerGrpcServer(self._cache_store, None) + grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + grpc_server_base.add_EventListenerServicer_to_server(grpc_server, grpc_server_manager) + grpc_server_manager.add_insecure_port(self._context.hostname) + return grpc_server_manager + + def stop(self): + self._grpc_server_manager.stop(grace=None) + self.join() diff --git a/mindinsight/debugger/debugger_services/debugger_server_base.py b/mindinsight/debugger/debugger_services/debugger_server_base.py new file mode 100644 index 00000000..cfe17c7b --- /dev/null +++ b/mindinsight/debugger/debugger_services/debugger_server_base.py @@ -0,0 +1,58 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""DebuggerServerBase.""" +import threading +from abc import abstractmethod +from functools import wraps + +from mindinsight.debugger.common.exceptions.exceptions import DebuggerServerRunningError +from mindinsight.debugger.common.log import LOGGER as log + + +def debugger_server_wrap(func): + """Wrapper for catch exception.""" + @wraps(func) + def record_log(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as err: + log.exception(err) + raise DebuggerServerRunningError(str(err)) + return record_log + + +class DebuggerServerBase(threading.Thread): + """ + Debugger Server Base. + + Args: + cache_store (DebuggerCacheStore): Cache store for debugger server. + context (DebuggerServerContext): Context for initialize debugger server. + """ + + def __init__(self, cache_store, context): + super(DebuggerServerBase, self).__init__() + self._cache_store = cache_store + self._context = context + + @abstractmethod + @debugger_server_wrap + def run(self): + """Function that should be called when thread started.""" + + @abstractmethod + @debugger_server_wrap + def stop(self): + """Stop debugger server.""" diff --git a/mindinsight/debugger/debugger_services/debugger_server_factory.py b/mindinsight/debugger/debugger_services/debugger_server_factory.py new file mode 100644 index 00000000..ce805523 --- /dev/null +++ b/mindinsight/debugger/debugger_services/debugger_server_factory.py @@ -0,0 +1,92 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Debugger server factory.""" +import threading + +from mindinsight.debugger.common.utils import DebuggerServerMode +from mindinsight.debugger.debugger_services.debugger_offline_server import DebuggerOfflineServer +from mindinsight.debugger.debugger_services.debugger_online_server import DebuggerOnlineServer + + +class DebuggerServerFactory: + """Create debugger server according to debugger mode.""" + + _lock = threading.Lock() + _instance = None + + def __init__(self): + self._server_map = { + DebuggerServerMode.ONLINE.value: DebuggerOnlineServer, + DebuggerServerMode.OFFLINE.value: DebuggerOfflineServer + } + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls, *args, **kwargs) + return cls._instance + + def get_debugger_server(self, cache_store, context): + """ + Get debugger server according to debugger_context and cache_store. + + Args: + cache_store (DebuggerCacheStore): Cache store for debugger server. + context (DebuggerServerContext): Context for initialize debugger server. + + Returns: + DebuggerServerBase, Debugger server object. + """ + dbg_server = None + dbg_server_class = self._server_map.get(context.dbg_mode) + if dbg_server_class: + dbg_server = dbg_server_class(cache_store, context) + return dbg_server + + +class DebuggerServerContext: + """ + Debugger server context. + + Args: + dbg_mode (str): The debugger mode. Optional: `online` or `offline`. + train_job (str): The relative directory of debugger dump data for one training. + Used only when dbg_mode is `offline.` + dbg_dir (str): The base directory of debugger dump data for one training. + Used only when dbg_mode is `offline.` + hostname (str): The hostname used for online debugger server. + Used only when dbg_mode is `online.` + """ + def __init__(self, dbg_mode, train_job=None, dbg_dir=None, hostname=None): + self._dbg_mode = dbg_mode + self._train_job = train_job + self._dbg_dir = dbg_dir + self.hostname = hostname + + @property + def dbg_mode(self): + """Property of debugger mode.""" + return self._dbg_mode + + @property + def dbg_dir(self): + """Property of debugger mode.""" + return self._dbg_dir + + @property + def train_job(self): + """The property of train job.""" + return self._train_job diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_session.py similarity index 86% rename from mindinsight/debugger/debugger_server.py rename to mindinsight/debugger/debugger_session.py index c42ba3a7..6429aa8e 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_session.py @@ -13,17 +13,8 @@ # limitations under the License. # ============================================================================ """Implement the debugger server.""" -import signal -from concurrent import futures from functools import wraps -from threading import Thread -import grpc - -from mindinsight.debugger.conditionmgr.condition import ConditionContext -from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr -from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints -from mindinsight.conf import settings from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.utils.tools import to_float from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ @@ -32,9 +23,11 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import ServerStatus, \ create_view_event_from_tensor_basic_info, Streams +from mindinsight.debugger.conditionmgr.condition import ConditionContext +from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr +from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints from mindinsight.debugger.debugger_cache import DebuggerCache -from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer -from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base +from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerFactory from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator @@ -57,25 +50,29 @@ def try_except(func): return send_latest_metadata -class DebuggerServer: +class DebuggerSession: """The server manager of debugger.""" - def __init__(self): + def __init__(self, context): self.condition_mgr = ConditionMgr() self.cache_store = DebuggerCache() - self.grpc_server = DebuggerGrpcServer(self.cache_store, self.condition_mgr) - self.grpc_server_manager = None - self.back_server = None + self.context = context + self.back_server = DebuggerServerFactory().get_debugger_server(self.cache_store, context) + + @property + def train_job(self): + """The property of train job.""" + return self.context.train_job - def get_condition_collections(self, train_id): + def get_condition_collections(self, train_id=""): """Get default condition_collections""" metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) return self.condition_mgr.get_all_collections(condition_context) - def set_recommended_watch_points(self, set_recommended, train_id): - """set recommended watch points.""" + def set_recommended_watch_points(self, set_recommended, train_id=""): + """Set recommended watch points.""" if not isinstance(set_recommended, bool): log.error("Bool param should be given for set_recommended") raise DebuggerParamValueError("Bool param should be given.") @@ -97,38 +94,28 @@ class DebuggerServer: def _add_recommended_watchpoints(self, condition_context): """Add predefined watchpoints.""" log.debug("Add predefined watchpoints.") - graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) - watchpoints = recommend_watchpoints(self.condition_mgr, graph_stream, condition_context) + multi_card_graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + watchpoints = recommend_watchpoints(self.condition_mgr, multi_card_graph_stream, condition_context) watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT) + device_stream = self.cache_store.get_stream_handler(Streams.DEVICE) watch_points_ids = [] for watchpoint in watchpoints: watch_points_id = watch_point_stream_handler.create_watchpoint( watch_condition=watchpoint.get_watch_condition_dict(), watch_nodes=watchpoint.watch_nodes, name=watchpoint.name, - condition_mgr=self.condition_mgr + condition_mgr=self.condition_mgr, + device_amount=device_stream.device_amount ) watch_points_ids.append(watch_points_id) return watch_points_ids def start(self): """Start server.""" - grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051 - host = settings.HOST if hasattr(settings, 'HOST') else '[::]' - hostname = "{}:{}".format(host, grpc_port) - # initialize a grpc server - grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager) - grpc_server_manager.add_insecure_port(hostname) - grpc_server_manager.start() - my_server_thread = Thread(target=grpc_server_manager.wait_for_termination) - # start grpc server - my_server_thread.start() - self.back_server = my_server_thread - self.grpc_server_manager = grpc_server_manager + self.back_server.start() # register stop server handler - signal.signal(signal.SIGINT, self._stop_handler) - log.info("Start grpc server %s", hostname) + #signal.signal(signal.SIGINT, self._stop_handler) + log.info("Start debugger backend server.") def _stop_handler(self, signum, frame): """Register stop server handler.""" @@ -139,8 +126,7 @@ class DebuggerServer: """Stop debugger server.""" log.info("Send terminate info to client.") self.control({'mode': 'terminate'}) - self.grpc_server_manager.stop(grace=None) - self.back_server.join() + self.back_server.stop() log.info("Stop debugger server.") def poll_data(self, pos): @@ -172,6 +158,7 @@ class DebuggerServer: - graph_name (str): The graph name. - watch_point_id (int): The id of watchpoint. Default: 0. - node_category (str): The node_category. Default: None + - rank_id (int): The id of rank. Default: 0. Returns: dict, the searched nodes. @@ -179,19 +166,20 @@ class DebuggerServer: log.info("receive search request with filter_condition: %s", filter_condition) # validate watchpoint id watch_point_id = filter_condition.pop('watch_point_id', 0) + rank_id = filter_condition.pop('rank_id', 0) watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) watchpoint_stream.validate_watchpoint_id(watch_point_id) # validate and update graph name - graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) filter_condition['graph_name'] = graph_name # get searched graph graph = graph_stream.search_nodes(filter_condition) # add watched label to graph - watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name) + watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name, rank_id) return graph - def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): + def tensor_comparisons(self, name, shape, detail='data', tolerance='0', rank_id=0): """ Get tensor comparisons data for given name, detail, shape and tolerance. @@ -202,6 +190,7 @@ class DebuggerServer: shape (str): Specify concrete dimensions of shape. tolerance (str): Specify tolerance of difference between current step tensor and previous step tensor. Default value is 0. + rank_id (int): The id of rank. Default: 0. Raises: DebuggerParamValueError, If node type is not parameter or value of detail is not support. @@ -220,9 +209,10 @@ class DebuggerServer: parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) tolerance = to_float(tolerance, 'tolerance') - tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) + tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id) + cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step if node_type == NodeTypeEnum.PARAMETER.value: - reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) + reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance, cur_step) else: raise DebuggerParamValueError( "The node type must be parameter, but got {}.".format(node_type)) @@ -270,10 +260,18 @@ class DebuggerServer: self.cache_store.clean_data() log.info("Clean data queue cache when retrieve all request.") result = {} - for stream in [Streams.METADATA, Streams.GRAPH]: + for stream in [Streams.METADATA, Streams.GRAPH, Streams.DEVICE]: sub_res = self.cache_store.get_stream_handler(stream).get() result.update(sub_res) + devices = result['devices'] + if not devices: + graph = result['graph'] + metadata = result['metadata'] + device = {'rank_id': 0, 'server_ip': metadata.get('ip', 'localhost'), + 'device_id': metadata.get('device_name', ''), + 'graph_names': graph.get('graph_names', [])} + devices.append(device) sub_res = self._hide_parameters_for_ui() result.update(sub_res) @@ -298,7 +296,8 @@ class DebuggerServer: log.debug("Retrieve node %s.", filter_condition) # validate node name node_name = filter_condition.get('name') - graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + rank_id = filter_condition.get('rank_id', 0) + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) if node_name: # validate node name @@ -325,24 +324,27 @@ class DebuggerServer: dict, reply with graph. """ # validate watch_point_id + rank_id = filter_condition.get('rank_id', 0) watch_point_id = filter_condition.get('watch_point_id', 0) watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) watchpoint_stream.validate_watchpoint_id(watch_point_id) # get graph - graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) reply = graph_stream.get(filter_condition) graph = reply.get('graph') # add watched label to graph - watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name')) + watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name'), + rank_id) return reply - def retrieve_tensor_history(self, node_name, graph_name=None): + def retrieve_tensor_history(self, node_name, graph_name=None, rank_id=0): """ Retrieve tensor history for leaf node. Args: node_name (str): The name of leaf node. graph_name (str): The graph name. Default: None. + rank_id (int): The id of rank. Default: 0. Returns: dict, the tensor history and metadata. @@ -352,34 +354,34 @@ class DebuggerServer: if metadata_stream.state == ServerStatus.PENDING.value: log.info("The backend is in pending status.") return metadata_stream.get(['state', 'step']) - res = self._get_tensor_history(node_name, graph_name) + res = self._get_tensor_history(node_name, graph_name, rank_id) return res - def _get_tensor_history(self, node_name, graph_name=None): + def _get_tensor_history(self, node_name, graph_name=None, rank_id=0): """ Get tensor history for single node. Args: node_name (str): The name of leaf node. graph_name (str): The graph name. Default: None. + rank_id (int): The id of rank. Default: 0. Returns: dict, the tensor history and metadata. """ # get basic tensor history - graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) tensor_history = graph_stream.get_tensor_history(node_name, graph_name) # add tensor value for tensor history - self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name) + self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name, rank_id) # add hit label for tensor history - watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) - watchpoint_hit_stream.update_tensor_history(tensor_history) + self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).update_tensor_history(tensor_history, rank_id) # add metadata metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(['step']) tensor_history.update(metadata) return tensor_history - def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name): + def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name, rank_id): """ Add tensor value for_tensor_history and send ViewCMD if tensor value missed. @@ -387,48 +389,53 @@ class DebuggerServer: tensor_history (list[dict]): A list of tensor info, including name and type. node_name (str): The UI node name. graph_name (str): The graph name. Default: None. + rank_id (int): The id of rank. Default: 0. Returns: dict, the tensor info. """ - tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) - missed_tensors = tensor_stream.update_tensor_history(tensor_history) + tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id) + cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step + missed_tensors = tensor_stream.update_tensor_history(tensor_history, cur_step) if missed_tensors: view_cmd = create_view_event_from_tensor_basic_info(missed_tensors) - self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name}) + self.cache_store.put_command( + {'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name, 'rank_id': rank_id}) log.debug("Send view cmd.") - def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False): + def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False, rank_id=0): """Retrieve the tensor value.""" log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) self.validate_tensor_param(name, detail) # Limit to query max two dimensions for tensor in table view. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) - node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name) + node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name, rank_id) reply = self.cache_store.get_stream_handler(Streams.TENSOR).get( {'name': tensor_name, 'node_type': node_type, 'shape': parsed_shape, - 'prev': prev} + 'prev': prev}, + rank_id ) reply['tensor_value']['name'] = name return reply - def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None): + def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None, rank_id=0): """ Get inner tensor name and type by UI name. Args: name (str): Node name shown in UI. graph_name (Union[str, None]): The graph name, default is: None. + rank_id (int): The id of rank. Default: 0. Returns: str, full name of tensor. str, node type of tensor. """ node_name, slot = name.rsplit(':', 1) - graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(node_name) node_type = graph_stream.get_node_type(node_name, graph_name) full_name = graph_stream.get_full_name(node_name, graph_name) @@ -483,6 +490,7 @@ class DebuggerServer: - offset (int): The offset of current page. - node_name (str): The retrieved node name. - graph_name (str): The retrieved graph name. + - rank_id (int): The rank id. Returns: dict, watch point list or relative graph. @@ -496,7 +504,13 @@ class DebuggerServer: log.info("The backend is in pending status.") return metadata_stream.get() - reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).group_by(group_condition) + rank_id = group_condition.pop('rank_id', 0) + reply = {} + multi_watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) + if multi_watchpoint_hit_stream.check_rank_id(rank_id): + watchpoint_hit_stream = multi_watchpoint_hit_stream.get_hit_handler_by_rank_id(rank_id) + reply = watchpoint_hit_stream.group_by(group_condition) + reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() return reply @@ -591,40 +605,6 @@ class DebuggerServer: training_controller.validate_mode(mode) return training_controller.control(mode, params) - def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False): - """ - Get the graph of the next node according to node_name. - - Args: - node_name (str): The name of current chosen leaf node. - graph_name (str): The graph name. - ascend (bool): If True, traverse the input nodes; - If False, traverse the output nodes. Default is True. - - Returns: - dict, the next node information. - """ - log.info("Retrieve node <%s> by bfs, `ascend` is :%s", - node_name, ascend) - reply = {} - graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) - graph_name = graph_stream.validate_graph_name(graph_name) - next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) - # no next node - if next_node_name is None: - return reply - # add graph and tensor history for next node - filter_condition = { - 'name': next_node_name, - 'graph_name': graph_name, - 'single_node': True - } - search_graph = self._get_nodes_info(filter_condition) - reply = {'name': next_node_name} - reply.update(search_graph) - - return reply - @try_except def recheck(self): """ @@ -635,13 +615,14 @@ class DebuggerServer: """ return TrainingControlOperator(self.cache_store).recheck() - def retrieve_tensor_graph(self, tensor_name, graph_name): + def retrieve_tensor_graph(self, tensor_name, graph_name, rank_id=0): """ Retrieve tensor graph. Args: tensor_name (str): The tensor name from UI. graph_name (str): The graph name. + rank_id (int): The id of rank. Default: 0. Returns: dict, tensor graph object. @@ -650,16 +631,17 @@ class DebuggerServer: log.error("Failed to get tensor graph the MindSpore is not in waiting state.") raise DebuggerTensorGraphError log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name) - tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name) + tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name, rank_id) return tensor_graph_ops - def retrieve_tensor_hits(self, tensor_name, graph_name): + def retrieve_tensor_hits(self, tensor_name, graph_name, rank_id=0): """ Retrieve tensor hit information. Args: tensor_name (str): The tensor name from UI. graph_name (str): The graph name. + rank_id (int): The id of rank. Default: 0. Returns: dict, tensor hit info. @@ -668,7 +650,7 @@ class DebuggerServer: log.error("Failed to get tensor hits as the MindSpore is not in waiting state.") raise DebuggerTensorHitError log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name) - watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name) + watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name, rank_id) return {'watch_points': watch_points} def _hide_parameters_for_ui(self): diff --git a/mindinsight/debugger/proto/debug_grpc.proto b/mindinsight/debugger/proto/debug_grpc.proto index a364391f..71efd984 100644 --- a/mindinsight/debugger/proto/debug_grpc.proto +++ b/mindinsight/debugger/proto/debug_grpc.proto @@ -122,6 +122,9 @@ message WatchCondition { message WatchNode { string node_name = 1; string node_type = 2; + string graph_name = 3; + int32 rank_id = 4; + int32 device_id = 5; } message WatchpointHit { diff --git a/mindinsight/debugger/proto/debug_grpc_pb2.py b/mindinsight/debugger/proto/debug_grpc_pb2.py index c15ae513..f336142e 100644 --- a/mindinsight/debugger/proto/debug_grpc_pb2.py +++ b/mindinsight/debugger/proto/debug_grpc_pb2.py @@ -2,8 +2,6 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: mindinsight/debugger/proto/debug_grpc.proto -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection @@ -21,7 +19,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='debugger', syntax='proto3', serialized_options=None, - serialized_pb=_b('\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\x81\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x95\x02\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x0c\n\x08overflow\x10\x02\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3') + serialized_pb=b'\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\x81\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x95\x02\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x0c\n\x08overflow\x10\x02\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"i\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\x12\x12\n\ngraph_name\x18\x03 \x01(\t\x12\x0f\n\x07rank_id\x18\x04 \x01(\x05\x12\x11\n\tdevice_id\x18\x05 \x01(\x05\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3' , dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) @@ -130,7 +128,7 @@ _METADATA = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='device_name', full_name='debugger.Metadata.device_name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -144,14 +142,14 @@ _METADATA = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='backend', full_name='debugger.Metadata.backend', index=2, number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='cur_node', full_name='debugger.Metadata.cur_node', index=3, number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -172,7 +170,7 @@ _METADATA = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='ms_version', full_name='debugger.Metadata.ms_version', index=6, number=7, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -203,7 +201,7 @@ _CHUNK = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='buffer', full_name='debugger.Chunk.buffer', index=0, number=1, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=_b(""), + has_default_value=False, default_value=b"", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -311,7 +309,7 @@ _RUNCMD = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='run_level', full_name='debugger.RunCMD.run_level', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -325,7 +323,7 @@ _RUNCMD = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='node_name', full_name='debugger.RunCMD.node_name', index=2, number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -442,7 +440,7 @@ _WATCHCONDITION_PARAMETER = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.WatchCondition.Parameter.name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -546,14 +544,35 @@ _WATCHNODE = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='node_name', full_name='debugger.WatchNode.node_name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='node_type', full_name='debugger.WatchNode.node_type', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='graph_name', full_name='debugger.WatchNode.graph_name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='rank_id', full_name='debugger.WatchNode.rank_id', index=3, + number=4, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='device_id', full_name='debugger.WatchNode.device_id', index=4, + number=5, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -570,7 +589,7 @@ _WATCHNODE = _descriptor.Descriptor( oneofs=[ ], serialized_start=1335, - serialized_end=1384, + serialized_end=1440, ) @@ -621,8 +640,8 @@ _WATCHPOINTHIT = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1387, - serialized_end=1524, + serialized_start=1443, + serialized_end=1580, ) _EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS @@ -750,8 +769,8 @@ _EVENTLISTENER = _descriptor.ServiceDescriptor( file=DESCRIPTOR, index=0, serialized_options=None, - serialized_start=1527, - serialized_end=1912, + serialized_start=1583, + serialized_end=1968, methods=[ _descriptor.MethodDescriptor( name='WaitCMD', diff --git a/mindinsight/debugger/proto/debug_grpc_pb2_grpc.py b/mindinsight/debugger/proto/debug_grpc_pb2_grpc.py index 8c5940c3..8cfa23c5 100644 --- a/mindinsight/debugger/proto/debug_grpc_pb2_grpc.py +++ b/mindinsight/debugger/proto/debug_grpc_pb2_grpc.py @@ -1,5 +1,4 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" import grpc from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2 @@ -7,7 +6,7 @@ from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_ class EventListenerStub(object): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" def __init__(self, channel): """Constructor. @@ -48,40 +47,40 @@ class EventListenerStub(object): class EventListenerServicer(object): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" def WaitCMD(self, request, context): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') def SendMetadata(self, request, context): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') def SendGraph(self, request_iterator, context): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') def SendTensors(self, request_iterator, context): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') def SendWatchpointHits(self, request_iterator, context): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') def SendMultiGraphs(self, request_iterator, context): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') @@ -127,7 +126,7 @@ def add_EventListenerServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. class EventListener(object): - """Missing associated documentation comment in .proto file.""" + """Missing associated documentation comment in .proto file""" @staticmethod def WaitCMD(request, @@ -135,7 +134,6 @@ class EventListener(object): options=(), channel_credentials=None, call_credentials=None, - insecure=False, compression=None, wait_for_ready=None, timeout=None, @@ -144,7 +142,7 @@ class EventListener(object): mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendMetadata(request, @@ -152,7 +150,6 @@ class EventListener(object): options=(), channel_credentials=None, call_credentials=None, - insecure=False, compression=None, wait_for_ready=None, timeout=None, @@ -161,7 +158,7 @@ class EventListener(object): mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendGraph(request_iterator, @@ -169,7 +166,6 @@ class EventListener(object): options=(), channel_credentials=None, call_credentials=None, - insecure=False, compression=None, wait_for_ready=None, timeout=None, @@ -178,7 +174,7 @@ class EventListener(object): mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendTensors(request_iterator, @@ -186,7 +182,6 @@ class EventListener(object): options=(), channel_credentials=None, call_credentials=None, - insecure=False, compression=None, wait_for_ready=None, timeout=None, @@ -195,7 +190,7 @@ class EventListener(object): mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendWatchpointHits(request_iterator, @@ -203,7 +198,6 @@ class EventListener(object): options=(), channel_credentials=None, call_credentials=None, - insecure=False, compression=None, wait_for_ready=None, timeout=None, @@ -212,7 +206,7 @@ class EventListener(object): mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def SendMultiGraphs(request_iterator, @@ -220,7 +214,6 @@ class EventListener(object): options=(), channel_credentials=None, call_credentials=None, - insecure=False, compression=None, wait_for_ready=None, timeout=None, @@ -229,4 +222,4 @@ class EventListener(object): mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/mindinsight/debugger/proto/ms_graph.proto b/mindinsight/debugger/proto/ms_graph.proto index 4c30c30f..9e9f70a0 100644 --- a/mindinsight/debugger/proto/ms_graph.proto +++ b/mindinsight/debugger/proto/ms_graph.proto @@ -229,6 +229,9 @@ message NodeProto { // full name with scope optional string full_name = 8; + + // The corresponding source code for this node. + optional string source_address = 9; } // Models diff --git a/mindinsight/debugger/proto/ms_graph_pb2.py b/mindinsight/debugger/proto/ms_graph_pb2.py index d2a791d1..9b693691 100644 --- a/mindinsight/debugger/proto/ms_graph_pb2.py +++ b/mindinsight/debugger/proto/ms_graph_pb2.py @@ -2,8 +2,6 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: mindinsight/debugger/proto/ms_graph.proto -import sys -_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message @@ -21,7 +19,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='debugger', syntax='proto2', serialized_options=None, - serialized_pb=_b('\n)mindinsight/debugger/proto/ms_graph.proto\x12\x08\x64\x65\x62ugger\"\xab\x04\n\nValueProto\x12!\n\x05\x64type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12\x10\n\x08\x62ool_val\x18\x02 \x01(\x08\x12\x0f\n\x07int_val\x18\x03 \x01(\x03\x12\x10\n\x08uint_val\x18\x04 \x01(\x04\x12\x11\n\tfloat_val\x18\x05 \x01(\x02\x12\x12\n\ndouble_val\x18\x06 \x01(\x01\x12\x0f\n\x07str_val\x18\x07 \x01(\t\x12)\n\ntensor_val\x18\x08 \x01(\x0b\x32\x15.debugger.TensorProto\x12#\n\x05graph\x18\t \x01(\x0b\x32\x14.debugger.GraphProto\x12\x11\n\tbool_vals\x18\n \x03(\x08\x12\x10\n\x08int_vals\x18\x0b \x03(\x03\x12\x11\n\tuint_vals\x18\x0c \x03(\x04\x12\x12\n\nfloat_vals\x18\r \x03(\x02\x12\x13\n\x0b\x64ouble_vals\x18\x0e \x03(\x01\x12\x10\n\x08str_vals\x18\x0f \x03(\t\x12*\n\x0btensor_vals\x18\x10 \x03(\x0b\x32\x15.debugger.TensorProto\x12$\n\x06graphs\x18\x11 \x03(\x0b\x32\x14.debugger.GraphProto\x12$\n\x06values\x18\x12 \x03(\x0b\x32\x14.debugger.ValueProto\x12+\n\x08\x64ict_val\x18\x13 \x03(\x0b\x32\x19.debugger.NamedValueProto\x12%\n\x08type_val\x18\x14 \x01(\x0b\x32\x13.debugger.TypeProto\"C\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.debugger.ValueProto\"C\n\x0fNamedValueProto\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.debugger.ValueProto\"n\n\x10TensorShapeProto\x12\x31\n\x03\x64im\x18\x01 \x03(\x0b\x32$.debugger.TensorShapeProto.Dimension\x1a\'\n\tDimension\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\"\xb6\x02\n\tTypeProto\x12%\n\tdata_type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12\x31\n\x0btensor_type\x18\x02 \x01(\x0b\x32\x1a.debugger.TypeProto.TensorH\x00\x12\x35\n\rsequence_type\x18\x03 \x01(\x0b\x32\x1c.debugger.TypeProto.SequenceH\x00\x1aZ\n\x06Tensor\x12%\n\telem_type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12)\n\x05shape\x18\x02 \x01(\x0b\x32\x1a.debugger.TensorShapeProto\x1a\x33\n\x08Sequence\x12\'\n\nelem_types\x18\x01 \x03(\x0b\x32\x13.debugger.TypeProtoB\x07\n\x05value\"l\n\x0eParameterProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12!\n\x04type\x18\x02 \x01(\x0b\x32\x13.debugger.TypeProto\x12)\n\x0b\x64\x65\x66\x61ult_val\x18\x03 \x01(\x0b\x32\x14.debugger.ValueProto\">\n\x0bOutputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12!\n\x04type\x18\x02 \x01(\x0b\x32\x13.debugger.TypeProto\"t\n\nInputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12+\n\x04type\x18\x02 \x01(\x0e\x32\x1d.debugger.InputProto.EdgeType\"+\n\x08\x45\x64geType\x12\r\n\tDATA_EDGE\x10\x00\x12\x10\n\x0c\x43ONTROL_EDGE\x10\x01\"\xda\x01\n\tNodeProto\x12#\n\x05input\x18\x01 \x03(\x0b\x32\x14.debugger.InputProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0f\n\x07op_type\x18\x03 \x01(\t\x12\r\n\x05scope\x18\x04 \x01(\t\x12+\n\tattribute\x18\x05 \x03(\x0b\x32\x18.debugger.AttributeProto\x12(\n\x0boutput_type\x18\x06 \x01(\x0b\x32\x13.debugger.TypeProto\x12\x10\n\x08output_i\x18\x07 \x01(\x04\x12\x11\n\tfull_name\x18\x08 \x01(\t\"\xa4\x01\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\x03\x12\x0e\n\x06\x64omain\x18\x02 \x01(\t\x12\x15\n\rmodel_version\x18\x03 \x01(\x03\x12#\n\x05graph\x18\x04 \x01(\x0b\x32\x14.debugger.GraphProto\x12\x36\n\x12metadata_operators\x18\x05 \x01(\x0b\x32\x1a.debugger.OperatorSetProto\"?\n\rOperatorProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x02 \x01(\x0c\x12\x10\n\x08obj_info\x18\x03 \x01(\x0c\"O\n\x10OperatorSetProto\x12*\n\toperators\x18\x01 \x03(\x0b\x32\x17.debugger.OperatorProto\x12\x0f\n\x07summary\x18\x02 \x01(\t\"\xc2\x01\n\nGraphProto\x12!\n\x04node\x18\x01 \x03(\x0b\x32\x13.debugger.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12,\n\nparameters\x18\x03 \x03(\x0b\x32\x18.debugger.ParameterProto\x12&\n\x07outputs\x18\x04 \x03(\x0b\x32\x15.debugger.OutputProto\x12-\n\nconst_vals\x18\x05 \x03(\x0b\x32\x19.debugger.NamedValueProto\"\xad\x01\n\x0bTensorProto\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x0c\n\x04slot\x18\x02 \x01(\t\x12\x16\n\x0etensor_content\x18\x03 \x01(\x0c\x12\x0c\n\x04\x64ims\x18\x04 \x03(\x03\x12%\n\tdata_type\x18\x05 \x01(\x0e\x32\x12.debugger.DataType\x12\x10\n\x08\x66inished\x18\x06 \x01(\x08\x12\x0c\n\x04iter\x18\x07 \x01(\t\x12\x10\n\x08truncate\x18\x08 \x01(\x08*/\n\x07Version\x12\x14\n\x10UNKNOWWN_VERSION\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01*\x96\x05\n\x08\x44\x61taType\x12\x10\n\x0c\x44T_UNDEFINED\x10\x00\x12\x0b\n\x07\x44T_BOOL\x10\x01\x12\x0b\n\x07\x44T_INT8\x10\x02\x12\x0c\n\x08\x44T_INT16\x10\x03\x12\x0c\n\x08\x44T_INT32\x10\x04\x12\x0c\n\x08\x44T_INT64\x10\x05\x12\x0c\n\x08\x44T_UINT8\x10\x06\x12\r\n\tDT_UINT16\x10\x07\x12\r\n\tDT_UINT32\x10\x08\x12\r\n\tDT_UINT64\x10\t\x12\x0e\n\nDT_FLOAT16\x10\n\x12\x0e\n\nDT_FLOAT32\x10\x0b\x12\x0e\n\nDT_FLOAT64\x10\x0c\x12\r\n\tDT_STRING\x10\r\x12\r\n\tDT_TENSOR\x10\x0e\x12\x0c\n\x08\x44T_GRAPH\x10\x0f\x12\x0c\n\x08\x44T_BOOLS\x10\x10\x12\x0c\n\x08\x44T_INTS8\x10\x11\x12\r\n\tDT_INTS16\x10\x12\x12\r\n\tDT_INTS32\x10\x13\x12\r\n\tDT_INTS64\x10\x14\x12\r\n\tDT_UINTS8\x10\x15\x12\x0e\n\nDT_UINTS16\x10\x16\x12\x0e\n\nDT_UINTS32\x10\x17\x12\x0e\n\nDT_UINTS64\x10\x18\x12\x0f\n\x0b\x44T_FLOATS16\x10\x19\x12\x0f\n\x0b\x44T_FLOATS32\x10\x1a\x12\x0f\n\x0b\x44T_FLOATS64\x10\x1b\x12\x0e\n\nDT_STRINGS\x10\x1c\x12\x0e\n\nDT_TENSORS\x10\x1d\x12\r\n\tDT_GRAPHS\x10\x1e\x12\x0c\n\x08\x44T_TUPLE\x10\x1f\x12\x0b\n\x07\x44T_LIST\x10 \x12\x0b\n\x07\x44T_DICT\x10!\x12\x0b\n\x07\x44T_NONE\x10\"\x12\x0f\n\x0b\x44T_SYM_INST\x10#\x12\x0f\n\x0b\x44T_BASE_INT\x10$\x12\x10\n\x0c\x44T_BASE_UINT\x10%\x12\x11\n\rDT_BASE_FLOAT\x10&\x12\x0b\n\x07\x44T_TYPE\x10\'\x12\x0f\n\x0b\x44T_ANYTHING\x10(\x12\r\n\tDT_REFKEY\x10)\x12\n\n\x06\x44T_REF\x10*') + serialized_pb=b'\n)mindinsight/debugger/proto/ms_graph.proto\x12\x08\x64\x65\x62ugger\"\xab\x04\n\nValueProto\x12!\n\x05\x64type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12\x10\n\x08\x62ool_val\x18\x02 \x01(\x08\x12\x0f\n\x07int_val\x18\x03 \x01(\x03\x12\x10\n\x08uint_val\x18\x04 \x01(\x04\x12\x11\n\tfloat_val\x18\x05 \x01(\x02\x12\x12\n\ndouble_val\x18\x06 \x01(\x01\x12\x0f\n\x07str_val\x18\x07 \x01(\t\x12)\n\ntensor_val\x18\x08 \x01(\x0b\x32\x15.debugger.TensorProto\x12#\n\x05graph\x18\t \x01(\x0b\x32\x14.debugger.GraphProto\x12\x11\n\tbool_vals\x18\n \x03(\x08\x12\x10\n\x08int_vals\x18\x0b \x03(\x03\x12\x11\n\tuint_vals\x18\x0c \x03(\x04\x12\x12\n\nfloat_vals\x18\r \x03(\x02\x12\x13\n\x0b\x64ouble_vals\x18\x0e \x03(\x01\x12\x10\n\x08str_vals\x18\x0f \x03(\t\x12*\n\x0btensor_vals\x18\x10 \x03(\x0b\x32\x15.debugger.TensorProto\x12$\n\x06graphs\x18\x11 \x03(\x0b\x32\x14.debugger.GraphProto\x12$\n\x06values\x18\x12 \x03(\x0b\x32\x14.debugger.ValueProto\x12+\n\x08\x64ict_val\x18\x13 \x03(\x0b\x32\x19.debugger.NamedValueProto\x12%\n\x08type_val\x18\x14 \x01(\x0b\x32\x13.debugger.TypeProto\"C\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.debugger.ValueProto\"C\n\x0fNamedValueProto\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.debugger.ValueProto\"n\n\x10TensorShapeProto\x12\x31\n\x03\x64im\x18\x01 \x03(\x0b\x32$.debugger.TensorShapeProto.Dimension\x1a\'\n\tDimension\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\"\xb6\x02\n\tTypeProto\x12%\n\tdata_type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12\x31\n\x0btensor_type\x18\x02 \x01(\x0b\x32\x1a.debugger.TypeProto.TensorH\x00\x12\x35\n\rsequence_type\x18\x03 \x01(\x0b\x32\x1c.debugger.TypeProto.SequenceH\x00\x1aZ\n\x06Tensor\x12%\n\telem_type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12)\n\x05shape\x18\x02 \x01(\x0b\x32\x1a.debugger.TensorShapeProto\x1a\x33\n\x08Sequence\x12\'\n\nelem_types\x18\x01 \x03(\x0b\x32\x13.debugger.TypeProtoB\x07\n\x05value\"l\n\x0eParameterProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12!\n\x04type\x18\x02 \x01(\x0b\x32\x13.debugger.TypeProto\x12)\n\x0b\x64\x65\x66\x61ult_val\x18\x03 \x01(\x0b\x32\x14.debugger.ValueProto\">\n\x0bOutputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12!\n\x04type\x18\x02 \x01(\x0b\x32\x13.debugger.TypeProto\"t\n\nInputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12+\n\x04type\x18\x02 \x01(\x0e\x32\x1d.debugger.InputProto.EdgeType\"+\n\x08\x45\x64geType\x12\r\n\tDATA_EDGE\x10\x00\x12\x10\n\x0c\x43ONTROL_EDGE\x10\x01\"\xf2\x01\n\tNodeProto\x12#\n\x05input\x18\x01 \x03(\x0b\x32\x14.debugger.InputProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0f\n\x07op_type\x18\x03 \x01(\t\x12\r\n\x05scope\x18\x04 \x01(\t\x12+\n\tattribute\x18\x05 \x03(\x0b\x32\x18.debugger.AttributeProto\x12(\n\x0boutput_type\x18\x06 \x01(\x0b\x32\x13.debugger.TypeProto\x12\x10\n\x08output_i\x18\x07 \x01(\x04\x12\x11\n\tfull_name\x18\x08 \x01(\t\x12\x16\n\x0esource_address\x18\t \x01(\t\"\xa4\x01\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\x03\x12\x0e\n\x06\x64omain\x18\x02 \x01(\t\x12\x15\n\rmodel_version\x18\x03 \x01(\x03\x12#\n\x05graph\x18\x04 \x01(\x0b\x32\x14.debugger.GraphProto\x12\x36\n\x12metadata_operators\x18\x05 \x01(\x0b\x32\x1a.debugger.OperatorSetProto\"?\n\rOperatorProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x02 \x01(\x0c\x12\x10\n\x08obj_info\x18\x03 \x01(\x0c\"O\n\x10OperatorSetProto\x12*\n\toperators\x18\x01 \x03(\x0b\x32\x17.debugger.OperatorProto\x12\x0f\n\x07summary\x18\x02 \x01(\t\"\xc2\x01\n\nGraphProto\x12!\n\x04node\x18\x01 \x03(\x0b\x32\x13.debugger.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12,\n\nparameters\x18\x03 \x03(\x0b\x32\x18.debugger.ParameterProto\x12&\n\x07outputs\x18\x04 \x03(\x0b\x32\x15.debugger.OutputProto\x12-\n\nconst_vals\x18\x05 \x03(\x0b\x32\x19.debugger.NamedValueProto\"\xad\x01\n\x0bTensorProto\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x0c\n\x04slot\x18\x02 \x01(\t\x12\x16\n\x0etensor_content\x18\x03 \x01(\x0c\x12\x0c\n\x04\x64ims\x18\x04 \x03(\x03\x12%\n\tdata_type\x18\x05 \x01(\x0e\x32\x12.debugger.DataType\x12\x10\n\x08\x66inished\x18\x06 \x01(\x08\x12\x0c\n\x04iter\x18\x07 \x01(\t\x12\x10\n\x08truncate\x18\x08 \x01(\x08*/\n\x07Version\x12\x14\n\x10UNKNOWWN_VERSION\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01*\x96\x05\n\x08\x44\x61taType\x12\x10\n\x0c\x44T_UNDEFINED\x10\x00\x12\x0b\n\x07\x44T_BOOL\x10\x01\x12\x0b\n\x07\x44T_INT8\x10\x02\x12\x0c\n\x08\x44T_INT16\x10\x03\x12\x0c\n\x08\x44T_INT32\x10\x04\x12\x0c\n\x08\x44T_INT64\x10\x05\x12\x0c\n\x08\x44T_UINT8\x10\x06\x12\r\n\tDT_UINT16\x10\x07\x12\r\n\tDT_UINT32\x10\x08\x12\r\n\tDT_UINT64\x10\t\x12\x0e\n\nDT_FLOAT16\x10\n\x12\x0e\n\nDT_FLOAT32\x10\x0b\x12\x0e\n\nDT_FLOAT64\x10\x0c\x12\r\n\tDT_STRING\x10\r\x12\r\n\tDT_TENSOR\x10\x0e\x12\x0c\n\x08\x44T_GRAPH\x10\x0f\x12\x0c\n\x08\x44T_BOOLS\x10\x10\x12\x0c\n\x08\x44T_INTS8\x10\x11\x12\r\n\tDT_INTS16\x10\x12\x12\r\n\tDT_INTS32\x10\x13\x12\r\n\tDT_INTS64\x10\x14\x12\r\n\tDT_UINTS8\x10\x15\x12\x0e\n\nDT_UINTS16\x10\x16\x12\x0e\n\nDT_UINTS32\x10\x17\x12\x0e\n\nDT_UINTS64\x10\x18\x12\x0f\n\x0b\x44T_FLOATS16\x10\x19\x12\x0f\n\x0b\x44T_FLOATS32\x10\x1a\x12\x0f\n\x0b\x44T_FLOATS64\x10\x1b\x12\x0e\n\nDT_STRINGS\x10\x1c\x12\x0e\n\nDT_TENSORS\x10\x1d\x12\r\n\tDT_GRAPHS\x10\x1e\x12\x0c\n\x08\x44T_TUPLE\x10\x1f\x12\x0b\n\x07\x44T_LIST\x10 \x12\x0b\n\x07\x44T_DICT\x10!\x12\x0b\n\x07\x44T_NONE\x10\"\x12\x0f\n\x0b\x44T_SYM_INST\x10#\x12\x0f\n\x0b\x44T_BASE_INT\x10$\x12\x10\n\x0c\x44T_BASE_UINT\x10%\x12\x11\n\rDT_BASE_FLOAT\x10&\x12\x0b\n\x07\x44T_TYPE\x10\'\x12\x0f\n\x0b\x44T_ANYTHING\x10(\x12\r\n\tDT_REFKEY\x10)\x12\n\n\x06\x44T_REF\x10*' ) _VERSION = _descriptor.EnumDescriptor( @@ -41,8 +39,8 @@ _VERSION = _descriptor.EnumDescriptor( ], containing_type=None, serialized_options=None, - serialized_start=2375, - serialized_end=2422, + serialized_start=2399, + serialized_end=2446, ) _sym_db.RegisterEnumDescriptor(_VERSION) @@ -228,8 +226,8 @@ _DATATYPE = _descriptor.EnumDescriptor( ], containing_type=None, serialized_options=None, - serialized_start=2425, - serialized_end=3087, + serialized_start=2449, + serialized_end=3111, ) _sym_db.RegisterEnumDescriptor(_DATATYPE) @@ -356,7 +354,7 @@ _VALUEPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='str_val', full_name='debugger.ValueProto.str_val', index=6, number=7, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -478,7 +476,7 @@ _ATTRIBUTEPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.AttributeProto.name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -516,7 +514,7 @@ _NAMEDVALUEPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='key', full_name='debugger.NamedValueProto.key', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -561,7 +559,7 @@ _TENSORSHAPEPROTO_DIMENSION = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.TensorShapeProto.Dimension.name', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -737,7 +735,7 @@ _PARAMETERPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.ParameterProto.name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -782,7 +780,7 @@ _OUTPUTPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.OutputProto.name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -820,7 +818,7 @@ _INPUTPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.InputProto.name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -866,21 +864,21 @@ _NODEPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.NodeProto.name', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='op_type', full_name='debugger.NodeProto.op_type', index=2, number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='scope', full_name='debugger.NodeProto.scope', index=3, number=4, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -908,7 +906,14 @@ _NODEPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='full_name', full_name='debugger.NodeProto.full_name', index=7, number=8, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='source_address', full_name='debugger.NodeProto.source_address', index=8, + number=9, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -925,7 +930,7 @@ _NODEPROTO = _descriptor.Descriptor( oneofs=[ ], serialized_start=1469, - serialized_end=1687, + serialized_end=1711, ) @@ -946,7 +951,7 @@ _MODELPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='domain', full_name='debugger.ModelProto.domain', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -983,8 +988,8 @@ _MODELPROTO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1690, - serialized_end=1854, + serialized_start=1714, + serialized_end=1878, ) @@ -998,21 +1003,21 @@ _OPERATORPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.OperatorProto.name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='config', full_name='debugger.OperatorProto.config', index=1, number=2, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=_b(""), + has_default_value=False, default_value=b"", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='obj_info', full_name='debugger.OperatorProto.obj_info', index=2, number=3, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=_b(""), + has_default_value=False, default_value=b"", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -1028,8 +1033,8 @@ _OPERATORPROTO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1856, - serialized_end=1919, + serialized_start=1880, + serialized_end=1943, ) @@ -1050,7 +1055,7 @@ _OPERATORSETPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='summary', full_name='debugger.OperatorSetProto.summary', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -1066,8 +1071,8 @@ _OPERATORSETPROTO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=1921, - serialized_end=2000, + serialized_start=1945, + serialized_end=2024, ) @@ -1088,7 +1093,7 @@ _GRAPHPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='name', full_name='debugger.GraphProto.name', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -1125,8 +1130,8 @@ _GRAPHPROTO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=2003, - serialized_end=2197, + serialized_start=2027, + serialized_end=2221, ) @@ -1140,21 +1145,21 @@ _TENSORPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='node_name', full_name='debugger.TensorProto.node_name', index=0, number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='slot', full_name='debugger.TensorProto.slot', index=1, number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), _descriptor.FieldDescriptor( name='tensor_content', full_name='debugger.TensorProto.tensor_content', index=2, number=3, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=_b(""), + has_default_value=False, default_value=b"", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -1182,7 +1187,7 @@ _TENSORPROTO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='iter', full_name='debugger.TensorProto.iter', index=6, number=7, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=_b("").decode('utf-8'), + has_default_value=False, default_value=b"".decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), @@ -1205,8 +1210,8 @@ _TENSORPROTO = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=2200, - serialized_end=2373, + serialized_start=2224, + serialized_end=2397, ) _VALUEPROTO.fields_by_name['dtype'].enum_type = _DATATYPE diff --git a/mindinsight/debugger/session_manager.py b/mindinsight/debugger/session_manager.py new file mode 100644 index 00000000..68c2dbcd --- /dev/null +++ b/mindinsight/debugger/session_manager.py @@ -0,0 +1,172 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Implement the session manager.""" +import os +import threading +from urllib.parse import unquote + +import _thread + +from mindinsight.conf import settings +from mindinsight.debugger.common.log import LOGGER as logger +from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \ + DebuggerSessionNotFoundError +from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext +from mindinsight.debugger.debugger_session import DebuggerSession + + +class SessionManager: + """The server manager of debugger.""" + + ONLINE_TYPE = "ONLINE" + MAX_SESSION_NUM = 2 + ONLINE_SESSION_ID = "0" + _instance = None + _cls_lock = threading.Lock() + + def __init__(self): + self.train_jobs = {} + self.sessions = {} + self.session_id = 1 + self.online_session = None + self._lock = threading.Lock() + self._exiting = False + enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False + if enable_debugger: + self.creat_session(self.ONLINE_TYPE) + + @classmethod + def get_instance(cls): + """Get the singleton instance.""" + with cls._cls_lock: + if cls._instance is None: + cls._instance = SessionManager() + return cls._instance + + def exit(self): + """ + Called when the gunicorn worker process is exiting. + """ + with self._lock: + logger.info("Start to exit sessions.") + self._exiting = True + for session in self.sessions: + session.stop() + self.online_session.stop() + logger.info("Exited.") + + def get_session(self, session_id): + """ + Get session by session id or get all session info. + + Args: + session_id (Union[None, str]: The id of session. + + Returns: + DebuggerSession, debugger session object. + """ + with self._lock: + if session_id == self.ONLINE_SESSION_ID and self.online_session is not None: + return self.online_session + + if session_id in self.sessions: + return self.sessions.get(session_id) + + raise DebuggerSessionNotFoundError("{}".format(session_id)) + + def creat_session(self, session_type, train_job=None): + """ + Create session by the train job info. + + Args: + session_type (str): The session_type. + train_job (str): The train job info. + + Returns: + str, session id. + """ + with self._lock: + if self._exiting: + logger.info( + "System is exiting, will terminate the thread.") + _thread.exit() + + if session_type == self.ONLINE_TYPE: + if self.online_session is None: + context = DebuggerServerContext(dbg_mode='online') + self.online_session = DebuggerSession(context) + self.online_session.start() + return self.ONLINE_SESSION_ID + + if train_job in self.train_jobs: + return self.train_jobs.get(train_job) + + self._check_session_num() + summary_base_dir = settings.SUMMARY_BASE_DIR + unquote_path = unquote(train_job, errors='strict') + whole_path = os.path.join(summary_base_dir, unquote_path) + normalized_path = validate_and_normalize_path(whole_path) + context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) + session = DebuggerSession(context) + session.start() + session_id = str(self.session_id) + self.sessions[session_id] = session + self.train_jobs[train_job] = session_id + self.session_id += 1 + return session_id + + def delete_session(self, session_id): + """Delete session by session id.""" + with self._lock: + if session_id == self.ONLINE_SESSION_ID: + self.online_session.stop() + self.online_session = None + return + + if session_id not in self.sessions: + raise DebuggerSessionNotFoundError("session id {}".format(session_id)) + + session = self.sessions.get(session_id) + session.stop() + self.sessions.pop(session_id) + self.train_jobs.pop(session.train_job) + return + + def get_sessions(self): + """get all sessions""" + return {"train_jobs": self.train_jobs} + + def _check_session_num(self): + """Check the amount of sessions.""" + if len(self.sessions) >= self.MAX_SESSION_NUM: + raise DebuggerSessionNumOverBoundError() + + +def validate_and_normalize_path(path): + """Validate and normalize_path""" + if not path: + raise ValueError("The path is invalid!") + + path_str = str(path) + + if not path_str.startswith("/"): + raise ValueError("The path is invalid!") + + try: + normalized_path = os.path.realpath(path) + except ValueError: + raise ValueError("The path is invalid!") + + return normalized_path diff --git a/mindinsight/debugger/stream_cache/data_loader.py b/mindinsight/debugger/stream_cache/data_loader.py new file mode 100644 index 00000000..c46c1662 --- /dev/null +++ b/mindinsight/debugger/stream_cache/data_loader.py @@ -0,0 +1,210 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""This file is used to define the DataLoader.""" +import os +import json +from mindinsight.debugger.proto.ms_graph_pb2 import ModelProto +from mindinsight.debugger.common.log import LOGGER as log +from mindinsight.utils.exceptions import ParamValueError +from mindinsight.debugger.common.utils import DumpSettings + + +class DataLoader: + """The DataLoader object provides interface to load graphs and device information from base_dir.""" + def __init__(self, base_dir): + self._debugger_base_dir = base_dir + self._graph_protos = [] + self._device_info = {} + self._step_num = {} + # flag for whether the data is from sync dump or async dump, True for sync dump, False for async dump. + self._is_sync = None + self._net_dir = "" + self._net_name = "" + self.initialize() + + def initialize(self): + """Initialize the data_mode and net_dir of DataLoader.""" + dump_config_file = os.path.join(self._debugger_base_dir, os.path.join(".metadata", "data_dump.json")) + with open(dump_config_file, 'r') as load_f: + dump_config = json.load(load_f) + common_settings = dump_config.get(DumpSettings.COMMON_DUMP_SETTINGS.value) + if not common_settings: + raise ParamValueError('common_dump_settings not found in dump_config file.') + self._net_name = common_settings['net_name'] + if dump_config.get(DumpSettings.E2E_DUMP_SETTINGS.value) and \ + dump_config[DumpSettings.E2E_DUMP_SETTINGS.value]['enable']: + self._is_sync = True + self._net_dir = os.path.join(self._debugger_base_dir, self._net_name) + elif dump_config.get(DumpSettings.ASYNC_DUMP_SETTINGS.value) and \ + dump_config[DumpSettings.ASYNC_DUMP_SETTINGS.value]['enable']: + self._is_sync = False + self._net_dir = self._debugger_base_dir + else: + raise ParamValueError('The data must be generated from sync dump or async dump.') + + def load_graphs(self): + """Load graphs from the debugger_base_dir.""" + files = os.listdir(self._net_dir) + for file in files: + if not self.is_device_dir(file): + continue + device_id, device_dir = self.get_device_id_and_dir(file) + graphs_dir = os.path.join(device_dir, 'graphs') + if not os.path.exists(graphs_dir) or not os.path.isdir(graphs_dir): + log.debug("Directory '%s' not exist.", graphs_dir) + self._graph_protos.append({'device_id': device_id, 'graph_protos': []}) + continue + graph_protos = get_graph_protos_from_dir(graphs_dir) + self._graph_protos.append({'device_id': device_id, 'graph_protos': graph_protos}) + + return self._graph_protos + + def load_device_info(self): + """Load device_info from file""" + hccl_json_file = os.path.join(self._debugger_base_dir, '.metadata/hccl.json') + if not os.path.isfile(hccl_json_file): + device = [] + device_ids = self.get_all_device_id() + device_ids.sort() + for i, device_id in enumerate(device_ids): + rank_id = i + device.append({'device_id': str(device_id), 'rank_id': str(rank_id)}) + device_target = 'Ascend' + self._device_info = {'device_target': device_target, + 'server_list': [{'server_id': 'localhost', 'device': device}]} + else: + with open(hccl_json_file, 'r') as load_f: + load_dict = json.load(load_f) + self._device_info = {'device_target': 'Ascend', 'server_list': load_dict['server_list']} + return self._device_info + + def load_step_number(self): + """Load step number in the directory""" + files = os.listdir(self._net_dir) + for file in files: + if not self.is_device_dir(file): + continue + device_id, device_dir = self.get_device_id_and_dir(file) + max_step = 0 + files_in_device = os.listdir(device_dir) + if self._is_sync: + for file_in_device in files_in_device: + abs_file_in_device = os.path.join(device_dir, file_in_device) + if os.path.isdir(abs_file_in_device) and file_in_device.startswith("iteration_"): + step_id_str = file_in_device.split('_')[-1] + max_step = update_max_step(step_id_str, max_step) + self._step_num[str(device_id)] = max_step + else: + net_graph_dir = [] + for file_in_device in files_in_device: + abs_file_in_device = os.path.join(device_dir, file_in_device) + if os.path.isdir(abs_file_in_device) and file_in_device.startswith(self._net_name): + net_graph_dir.append(abs_file_in_device) + if len(net_graph_dir) > 1: + log.warning("There are more than one graph directory in device_dir: %s. " + "OfflineDebugger use data in %s.", device_dir, net_graph_dir[0]) + net_graph_dir_to_use = net_graph_dir[0] + graph_id = net_graph_dir_to_use.split('_')[-1] + graph_id_dir = os.path.join(net_graph_dir_to_use, graph_id) + step_ids = os.listdir(graph_id_dir) + for step_id_str in step_ids: + max_step = update_max_step(step_id_str, max_step) + self._step_num[str(device_id)] = max_step + + return self._step_num + + def is_device_dir(self, file_name): + """Judge if the file_name is a sub directory named 'device_x'.""" + if not file_name.startswith("device_"): + return False + id_str = file_name.split("_")[-1] + if not id_str.isdigit(): + return False + device_dir = os.path.join(self._net_dir, file_name) + if not os.path.isdir(device_dir): + return False + return True + + def get_device_id_and_dir(self, file_name): + """Get device_id and absolute directory of file_name.""" + id_str = file_name.split("_")[-1] + device_id = int(id_str) + device_dir = os.path.join(self._net_dir, file_name) + return device_id, device_dir + + def get_all_device_id(self): + """Get all device_id int the debugger_base_dir""" + device_ids = [] + files = os.listdir(self._net_dir) + for file in files: + if not self.is_device_dir(file): + continue + id_str = file.split("_")[-1] + device_id = int(id_str) + device_ids.append(device_id) + return device_ids + + def get_net_dir(self): + """Get graph_name directory of the data.""" + return self._net_dir + + def get_sync_flag(self): + """Get the sync flag of the data.""" + return self._is_sync + + def get_net_name(self): + """Get net_name of the data.""" + return self._net_name + + +def load_graph_from_file(graph_file_name): + """Load graph from file.""" + with open(graph_file_name, 'rb') as file_handler: + model_bytes = file_handler.read() + model = ModelProto.FromString(model_bytes) + graph = model.graph + + return graph + + +def get_graph_protos_from_dir(graphs_dir): + """ + Get graph from graph directory. + + Args: + graph_dir (str): The absolute directory of graph files. + + Returns: + list, list of 'GraphProto' object. + """ + files_in_graph_dir = os.listdir(graphs_dir) + graph_protos = [] + pre_file_name = "ms_output_trace_code_graph_" + for file_in_device in files_in_graph_dir: + if file_in_device.startswith(pre_file_name) and file_in_device.endswith(".pb"): + abs_graph_file = os.path.join(graphs_dir, file_in_device) + graph_proto = load_graph_from_file(abs_graph_file) + graph_protos.append(graph_proto) + return graph_protos + + +def update_max_step(step_id_str, max_step): + """Update max_step by compare step_id_str and max_step.""" + res = max_step + if step_id_str.isdigit(): + step_id = int(step_id_str) + if step_id > max_step: + res = step_id + return res diff --git a/mindinsight/debugger/stream_cache/tensor.py b/mindinsight/debugger/stream_cache/tensor.py index e6f4b2bf..34295ac7 100644 --- a/mindinsight/debugger/stream_cache/tensor.py +++ b/mindinsight/debugger/stream_cache/tensor.py @@ -14,7 +14,6 @@ # ============================================================================ """The definition of tensor stream.""" from abc import abstractmethod, ABC - import numpy as np from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError @@ -149,7 +148,10 @@ class OpTensor(BaseTensor): @property def shape(self): """The property of tensor shape.""" - return list(self._tensor_proto.dims) + dims = list(self._tensor_proto.dims) + if dims == [0]: + dims = [] + return dims @property def value(self): @@ -254,12 +256,13 @@ class OpTensor(BaseTensor): class ConstTensor(BaseTensor): """Tensor data structure for Const Node.""" _STRING_TYPE = 'DT_STRING' + _DT_TYPE = 'DT_TYPE' def __init__(self, const_proto): # the type of const_proto is NamedValueProto super(ConstTensor, self).__init__() self._const_proto = const_proto - self._value = self.generate_value_from_proto(const_proto) + self._value = self.generate_value_from_proto(const_proto.value) def set_step(self, step): """Set step value.""" @@ -295,16 +298,25 @@ class ConstTensor(BaseTensor): Returns: Union[None, str, np.ndarray], the value of the tensor. """ - fields = tensor_proto.value.ListFields() + fields = tensor_proto.ListFields() if len(fields) != 2: log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto) tensor_value = None for field_obj, field_value in fields: if field_obj.name != 'dtype': - tensor_value = field_value + if tensor_proto.dtype == DataType.DT_TUPLE: + tensor_values = [] + for field_value_element in field_value: + value_element = self.generate_value_from_proto(field_value_element) + tensor_values.append(value_element) + tensor_value = tensor_values + elif tensor_proto.dtype == DataType.DT_TYPE: + tensor_value = DataType.Name(field_value.data_type) + else: + tensor_value = field_value break - if tensor_value is not None and self.dtype != self._STRING_TYPE: - tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(self.dtype)) + if tensor_value is not None and tensor_proto.dtype != self._STRING_TYPE: + tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(tensor_proto.dtype)) return tensor_value def get_tensor_value_by_shape(self, shape=None): @@ -328,7 +340,8 @@ class ConstTensor(BaseTensor): Returns: dict, overall statistics. """ - if self.empty or self.dtype == self._STRING_TYPE: + if self.empty or self.dtype == self._STRING_TYPE or self.dtype == self._DT_TYPE: + log.debug("The tensor dtype is: %s, skip getting statistics.", self.dtype) return {} stats = TensorUtils.get_statistics_from_tensor(self.value) statistics = TensorUtils.get_overall_statistic_dict(stats) diff --git a/mindinsight/debugger/stream_cache/watchpoint.py b/mindinsight/debugger/stream_cache/watchpoint.py index 41131f48..d6af2c1b 100644 --- a/mindinsight/debugger/stream_cache/watchpoint.py +++ b/mindinsight/debugger/stream_cache/watchpoint.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -184,7 +184,7 @@ class Watchpoint: def __init__(self, watchpoint_id, watch_condition, name=None): self._id = watchpoint_id self._condition = watch_condition - self._watch_node = WatchNodeTree() + self._watch_node = {0: WatchNodeTree()} self.name = name @property @@ -214,32 +214,36 @@ class Watchpoint: else: self._watch_node = other_watchpoint.nodes - def add_nodes(self, nodes): + def add_nodes(self, nodes, rank_id): """Add node into watchpoint.""" if not nodes: log.warning("Add empty nodes.") return - + if rank_id not in self._watch_node: + self._watch_node[rank_id] = WatchNodeTree() if not isinstance(nodes, list): nodes = [nodes] for node in nodes: - self._watch_node.add_node(node.name, node.type, node.full_name) + watch_node = self._watch_node.get(rank_id) + watch_node.add_node(node.name, node.type, node.full_name) - def remove_nodes(self, nodes): + def remove_nodes(self, nodes, rank_id): """Remove nodes from watchpoint.""" if not nodes: return + self.validate_rank_id(rank_id) if not isinstance(nodes, list): nodes = [nodes] for node in nodes: - self._watch_node.remove_node(node.name) + self._watch_node.get(rank_id).remove_node(node.name) - def get_node_status(self, node_name, node_type, full_name): + def get_node_status(self, node_name, node_type, full_name, rank_id): """Judge if the node is in watch nodes.""" if is_cst_type(node_type): return WatchNodeTree.INVALID scope_names = node_name.split('/') - cur_node = self._watch_node + self.validate_rank_id(rank_id) + cur_node = self._watch_node.get(rank_id) status = 1 for scope_name in scope_names: cur_node = cur_node.get(scope_name) @@ -250,7 +254,7 @@ class Watchpoint: status = WatchNodeTree.TOTAL_WATCH break if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name: - self._watch_node.add_node(node_name, node_type, full_name) + self._watch_node.get(rank_id).add_node(node_name, node_type, full_name) return status @@ -278,11 +282,14 @@ class Watchpoint: Returns: list[NodeBasicInfo], the list of watch node basic infos. """ - watch_nodes = [] - self._get_watch_node(self._watch_node, watch_nodes) - return watch_nodes - - def get_pending_cmd(self, watch_nodes): + watch_nodes_for_devices = {} + for rank_id, watch_node_tree in self._watch_node.items(): + watch_nodes = [] + self._get_watch_node(watch_node_tree, watch_nodes) + watch_nodes_for_devices[rank_id] = watch_nodes + return watch_nodes_for_devices + + def get_pending_cmd(self, watch_nodes_for_devices): """Return the watchpoint in proto format.""" # construct SetCMD condition_id = self._condition.get('id') @@ -309,10 +316,12 @@ class Watchpoint: param_proto.name = param_name param_proto.disabled = True - for watch_node in watch_nodes: - event_node = set_cmd.watch_nodes.add() - event_node.node_name = watch_node.full_name - event_node.node_type = watch_node.type + for rank_id, watch_nodes in watch_nodes_for_devices.items(): + for watch_node in watch_nodes: + event_node = set_cmd.watch_nodes.add() + event_node.node_name = watch_node.full_name + event_node.node_type = watch_node.type + event_node.rank_id = rank_id return set_cmd def get_watch_condition_info(self): @@ -325,6 +334,11 @@ class Watchpoint: watchpoint_info['name'] = self.name return watchpoint_info + def validate_rank_id(self, rank_id): + if rank_id not in self._watch_node: + log.warning("Rank_id not exist") + return + class WatchpointHit: """The watchpoint hit structure.""" diff --git a/mindinsight/debugger/stream_handler/__init__.py b/mindinsight/debugger/stream_handler/__init__.py index 73bc29ab..284e6a6f 100644 --- a/mindinsight/debugger/stream_handler/__init__.py +++ b/mindinsight/debugger/stream_handler/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,9 +15,10 @@ """Import the streams handlers.""" from .event_handler import EventHandler from .metadata_handler import MetadataHandler -from .graph_handler import GraphHandler -from .tensor_handler import TensorHandler -from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler +from .graph_handler import GraphHandler, MultiCardGraphHandler +from .tensor_handler import TensorHandler, MultiCardTensorHandler +from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler, MultiCardWatchpointHitHandler -__all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', - 'WatchpointHandler', 'WatchpointHitHandler'] +__all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', 'WatchpointHitHandler', + 'MultiCardGraphHandler', 'MultiCardTensorHandler', + 'WatchpointHandler', 'MultiCardWatchpointHitHandler'] diff --git a/mindinsight/debugger/stream_handler/device_handler.py b/mindinsight/debugger/stream_handler/device_handler.py new file mode 100644 index 00000000..c8ac9177 --- /dev/null +++ b/mindinsight/debugger/stream_handler/device_handler.py @@ -0,0 +1,198 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the Device stream handler.""" +from collections import defaultdict + +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, DeviceIdUnregistered, \ + DebuggerParamTypeError +from mindinsight.debugger.common.log import LOGGER as log +from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase + + +class DeviceHandler(StreamHandlerBase): + """Metadata Handler.""" + + def __init__(self): + # contains all device infos, the format is like Dict[int(, )] + self._rank_info = defaultdict(DeviceInfo) + self._device_rank_map = {} + + @property + def rank_ids(self): + """The rank ids.""" + return list(self._rank_info) + + @property + def device_amount(self): + """The rank ids.""" + return len(self._rank_info) + + def put(self, value): + """ + Put value into device info cache. + + Args: + value (list): The list of server info. Each item is format like: + { + "server_id": str, + "device": list[] + }, + The format of is like: + { + "device_id": str, + "device_ip": str, + "rank_id": str + }. + """ + if not isinstance(value, list): + log.error("Invalid input type. list object is expected.") + raise DebuggerParamTypeError("List object is expected.") + try: + self._extract_rank_info(value) + except TypeError as err: + log.exception(err) + log.error("Invalid Device info.") + raise DebuggerParamValueError("Invalid device info.") + log.debug("Put Device into cache") + + def _extract_rank_info(self, value): + """Extract rank info and save.""" + for server_info in value: + server_ip = server_info.get('server_id') + for device_info in server_info.get('device', []): + rank_id = int(device_info.get('rank_id')) + if rank_id in self._rank_info: + log.error("Repeated rank info for rank_id: %d", rank_id) + raise DebuggerParamValueError("Repeated rank info.") + device_info_obj = self._rank_info[rank_id] + device_info_obj.rank_id = rank_id + device_info_obj.server_ip = server_ip + device_info_obj.device_id = int(device_info.get('device_id')) + device_info_obj.device_ip = device_info.get('device_ip') + self._device_rank_map[device_info_obj.device_id] = rank_id + + def add_step_num_info(self, step_info): + """ + Add step number information for each device. + + Args: + step_info (dict): Step info per device. The key is the device id, the value + is the relative step number. + """ + if not step_info: + log.warning("No step number information.") + return + if len(step_info) == 1 and not self._rank_info: + device_id = int(list(step_info)[0]) + log.info("Default registered device %d as rank 0.", device_id) + self._rank_info[0].device_id = device_id + if len(step_info) > 1 and not self._rank_info: + log.error("Missing device info for multi-card training.") + raise DeviceIdUnregistered("all") + + for device_id, step_num in step_info.items(): + device_id = int(device_id) + rank_id = self.get_rank_id_by_device_id(device_id) + self._rank_info[rank_id].step_num = step_num + + def add_graph_name_info(self, graphs): + """ + Add graph name per device. + + Args: + graphs (dict): Graph infos of all rank id. Each item is format like + """ + for rank_id, graph_info in graphs.items(): + graph_names = list(graph_info) + self._rank_info[rank_id].graph_names = graph_names + + def get(self, filter_condition=None): + """ + Get device information according to filter_condition. + + Args: + filter_condition (list): The rank id. + + Returns: + dict, the device info. + """ + if filter_condition is None: + filter_condition = self.rank_ids + if not isinstance(filter_condition, list): + filter_condition = [filter_condition] + device_infos = [] + for rank_id in filter_condition: + device_info = self._rank_info.get(rank_id) + if device_info is None: + log.error("Invalid rank id.") + raise DeviceIdUnregistered(rank_id) + device_infos.append(device_info.to_dict()) + return {'devices': device_infos} + + def get_rank_id_by_device_id(self, device_id): + """ + Get rank id by device id. + + Args: + device_id (int): The device id. + + Returns: + int, the rank id. + """ + rank_id = self._device_rank_map.get(device_id) + if rank_id is None: + log.error("Failed to find rank_id for device_id %s", device_id) + raise DeviceIdUnregistered(device_id) + return rank_id + + def get_device_id_by_rank_id(self, rank_id): + """ + Get device id by rank id. + + Args: + rank_id (int): The rank id. + + Returns: + int, the device id. + """ + device_info = self._rank_info.get(rank_id) + if device_info: + return device_info.device_id + log.error("Failed to find device id according to rank_id %s", rank_id) + raise DeviceIdUnregistered(rank_id) + + +class DeviceInfo: + """Device info object.""" + + def __init__(self): + self.rank_id = 0 + self.device_id = 0 + self.server_ip = '' + self.graph_names = [] + self.device_ip = '' + self.step_num = 0 + + def to_dict(self): + """Convert device info to dict.""" + res = { + 'rank_id': self.rank_id, + 'server_ip': self.server_ip, + 'device_id': self.device_id, + 'device_ip': self.device_ip, + 'graph_names': self.graph_names, + 'total_step_num': self.step_num + } + return res diff --git a/mindinsight/debugger/stream_handler/graph_handler.py b/mindinsight/debugger/stream_handler/graph_handler.py index ce28f871..9ea9a141 100644 --- a/mindinsight/debugger/stream_handler/graph_handler.py +++ b/mindinsight/debugger/stream_handler/graph_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,55 @@ from mindinsight.debugger.stream_cache.debugger_multigraph import DebuggerMultiG from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase +class MultiCardGraphHandler: + """Multi-card Graph Handler.""" + + def __init__(self): + self._graph_handlers = {0: GraphHandler()} + + @property + def graph_handlers(self): + """The property of whole_graph.""" + return self._graph_handlers + + def get_graph_handler_by_rank_id(self, rank_id=0): + """Get handler by rank id""" + if rank_id in self._graph_handlers: + return self._graph_handlers.get(rank_id) + log.error("There is no rank id %d.", rank_id) + raise ValueError + + def put(self, value): + """put graphs into graph_handlers""" + for rank_id, graph in value.items(): + if rank_id not in self._graph_handlers: + self._graph_handlers[rank_id] = GraphHandler() + self._graph_handlers[rank_id].put(graph) + + def get(self, filter_condition=None, rank_id=0): + """Get the graph of specific node for specific device.""" + if rank_id in self._graph_handlers: + return self._graph_handlers.get(rank_id).get(filter_condition) + log.error("There is no rank id %d.", rank_id) + raise ValueError + + def has_graph(self): + """check if has graph""" + res = False + for graph_handler in self._graph_handlers: + res = res or graph_handler.graph + return res + + def register_graph_handler(self, rank_id, graph_handler): + """Register graph handler.""" + self._graph_handlers[rank_id] = graph_handler + + def clean(self): + """Clean cache.""" + self.__init__() + + + class GraphHandler(StreamHandlerBase): """Metadata Handler.""" @@ -68,7 +117,7 @@ class GraphHandler(StreamHandlerBase): Put value into graph cache. Called by grpc server. Args: - value (GraphProto): The Graph proto message. + value (dict): The Graph proto message. Each item is format like (, GraphProto). """ log.info("Put graph into cache.") sorted_value_list = self._sort_graph(value) @@ -430,8 +479,8 @@ class GraphHandler(StreamHandlerBase): graph_name, node_name = self._parse_node_name(scope_name, graph_name) graph = self._get_graph(graph_name) # to make sure fully match the scope name - node_name = node_name + '/' if not node_name.endswith('/') else node_name - nodes = graph.search_leaf_nodes_by_pattern(node_name) + node_name = node_name + '/' if node_name and not node_name.endswith('/') else node_name + nodes = graph.search_leaf_nodes_by_pattern(node_name, True) res = [self.construct_node_basic_info(full_name=node.full_name, graph_name=graph_name, node_name=node.name, @@ -448,45 +497,6 @@ class GraphHandler(StreamHandlerBase): log.debug("Get empty full name.") return node_name - def get_node_by_bfs_order(self, node_name=None, ascend=True): - """ - Traverse the graph in order of breath-first search by given node. - - Args: - node_name (str): The name of current chosen leaf node. - ascend (bool): If True, traverse the input nodes; - If False, traverse the output nodes. Default is True. - Returns: - Union[None, dict], the next node object in dict type or None. - """ - bfs_order = self.bfs_order - length = len(bfs_order) - - if not bfs_order: - log.error('Cannot get the BFS order of the graph!') - msg = 'Cannot get the BFS order of the graph!' - raise DebuggerParamValueError(msg) - - if node_name is None: - if ascend is False: - next_node = None - else: - next_node = bfs_order[0] - else: - try: - index = bfs_order.index(node_name) - log.debug("The index of the node in BFS list is: %d", index) - except ValueError as err: - log.error('Cannot find the node: %s. Please check ' - 'the node name: %s', node_name, err) - msg = f'Cannot find the node: {node_name}. ' \ - f'Please check the node name {err}.' - raise DebuggerParamValueError(msg) - - next_node = self._get_next_node_in_bfs(index, length, ascend) - - return next_node - def _get_next_node_in_bfs(self, index, length, ascend): """ Get the next node in bfs order. diff --git a/mindinsight/debugger/stream_handler/metadata_handler.py b/mindinsight/debugger/stream_handler/metadata_handler.py index 1f161aeb..80536279 100644 --- a/mindinsight/debugger/stream_handler/metadata_handler.py +++ b/mindinsight/debugger/stream_handler/metadata_handler.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,9 @@ # limitations under the License. # ============================================================================ """Define the metadata stream handler.""" + from mindinsight.debugger.common.log import LOGGER as log -from mindinsight.debugger.common.utils import ServerStatus +from mindinsight.debugger.common.utils import ServerStatus, DebuggerServerMode from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase @@ -24,28 +25,36 @@ class MetadataHandler(StreamHandlerBase): def __init__(self): self._state = ServerStatus.PENDING self._device_name = "" - self._step = 0 + self.step = 0 self._client_ip = "" self._cur_node_name = "" self._cur_full_name = "" - self._backend = "" + self.backend = "" self._enable_recheck = False self._cur_graph_name = "" # If recommendation_confirmed is true, it only means the user has answered yes or no to the question, # it does not necessarily mean that the user will use the recommended watch points. self._recommendation_confirmed = False self._debugger_version = {} + # maximum step number among all devices + self._max_step_num = 0 + self._debugger_type = DebuggerServerMode.ONLINE.value + + @property + def debugger_type(self): + """The property of debugger_type.""" + return self._debugger_type + + @debugger_type.setter + def debugger_type(self, debugger_type): + """The property of debugger_type.""" + self._debugger_type = debugger_type @property def device_name(self): """The property of device name.""" return self._device_name - @property - def step(self): - """The property of current step.""" - return self._step - @property def node_name(self): """The property of current node name.""" @@ -71,11 +80,6 @@ class MetadataHandler(StreamHandlerBase): """The property of current node name.""" return self._cur_full_name - @property - def backend(self): - """The property of current backend.""" - return self._backend - @property def state(self): """The property of state.""" @@ -152,6 +156,16 @@ class MetadataHandler(StreamHandlerBase): """ self._debugger_version = value + @property + def max_step_num(self): + """The property of max_step_num.""" + return self._max_step_num + + @max_step_num.setter + def max_step_num(self, max_step_num): + """Set the property of max_step_num.""" + self._max_step_num = max_step_num + def put(self, value): """ Put value into metadata cache. Called by grpc server. @@ -160,10 +174,10 @@ class MetadataHandler(StreamHandlerBase): value (MetadataProto): The Metadata proto message. """ self._device_name = value.device_name.split(':')[0] - self._step = value.cur_step + self.step = value.cur_step self._cur_full_name = value.cur_node - self._backend = value.backend if value.backend else "Ascend" - log.debug("Put metadata into cache at the %d-th step.", self._step) + self.backend = value.backend if value.backend else "Ascend" + log.debug("Put metadata into cache at the %d-th step.", self.step) def get(self, filter_condition=None): """ @@ -190,6 +204,8 @@ class MetadataHandler(StreamHandlerBase): 'recommendation_confirmed': self._recommendation_confirmed, 'debugger_version': self.debugger_version } + if self.debugger_type == 'offline': + metadata['total_step_num'] = self.max_step_num else: if not isinstance(filter_condition, list): filter_condition = [filter_condition] diff --git a/mindinsight/debugger/stream_handler/tensor_handler.py b/mindinsight/debugger/stream_handler/tensor_handler.py index b4ba2f8e..81d82037 100644 --- a/mindinsight/debugger/stream_handler/tensor_handler.py +++ b/mindinsight/debugger/stream_handler/tensor_handler.py @@ -28,6 +28,46 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) +class MultiCardTensorHandler: + """Multi-card Tensor Handler.""" + def __init__(self): + self.tensor_handlers = {0: TensorHandler()} + + def set_step(self, step_id): + """Set step id.""" + for tensor_handler in self.tensor_handlers.values(): + tensor_handler.cur_step = step_id + + def get_tensor_handler_by_rank_id(self, rank_id=0, create_if_not_exit=False): + """get handler by rank id""" + if rank_id in self.tensor_handlers: + return self.tensor_handlers.get(rank_id) + if create_if_not_exit: + tensor_handler = TensorHandler() + self.tensor_handlers[rank_id] = tensor_handler + return tensor_handler + log.error("There is no rank id %d in MultiCardTensorHandler.", rank_id) + raise ValueError + + def put(self, value): + """put graphs into graph_handlers""" + for rank_id, tensor in value: + if rank_id not in self.tensor_handlers: + self.tensor_handlers[rank_id] = TensorHandler() + self.tensor_handlers[rank_id].put(tensor) + + def get(self, filter_condition=None, rank_id=0): + """Get the graph of specific node for specific device.""" + if rank_id in self.tensor_handlers: + return self.tensor_handlers.get(rank_id).get(filter_condition) + log.error("There is no rank id %d.", rank_id) + raise ValueError + + def clean(self): + """Clean cache.""" + self.__init__() + + class TensorHandler(StreamHandlerBase): """Metadata Handler.""" @@ -46,6 +86,11 @@ class TensorHandler(StreamHandlerBase): """The property of current step.""" return self._cur_step + @cur_step.setter + def cur_step(self, step_id): + """The property of current step.""" + self._cur_step = step_id + @property def prev_step(self): """The property of previous step.""" @@ -172,7 +217,7 @@ class TensorHandler(StreamHandlerBase): log.error("No tensor named %s at the step %s", name, step) raise DebuggerParamValueError("No tensor named {}".format(name)) tensor_info = tensor.get_full_info(shape) - self._update_has_prev_step_field(tensor_info, name, node_type) + self._update_has_prev_step_field(tensor_info, name, node_type, self.cur_step) return {'tensor_value': tensor_info} def _get_tensor(self, tensor_name, node_type=None, step=None): @@ -198,20 +243,21 @@ class TensorHandler(StreamHandlerBase): return tensor - def _get_basic_info(self, tensor_name, node_type=None): + def _get_basic_info(self, tensor_name, node_type, step): """Get the latest basic tensor info by tensor name.""" - tensor = self._get_tensor(tensor_name, node_type) + tensor = self._get_tensor(tensor_name, node_type, step) if tensor: return tensor.get_basic_info() return None - def update_tensor_history(self, tensor_history): + def update_tensor_history(self, tensor_history, step=None): """ Add tensor basic info in tensor_history. Args: tensor_history (dict): Tensor history, including a list of tensor name and type. + step (int): The step of tensor info. Default: None. Returns: list[dict], the list of tensor basic info cache. @@ -220,9 +266,9 @@ class TensorHandler(StreamHandlerBase): for tensor_info in tensor_history.get('tensor_history'): tensor_name = tensor_info.get('full_name') node_type = tensor_info.get('node_type') - basic_info = self._get_basic_info(tensor_name, node_type) + basic_info = self._get_basic_info(tensor_name, node_type, step) # add `has_prev_step` field to tensor basic info. - missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type) + missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type, step) if basic_info: tensor_info.update(basic_info) if missing_tensors_info: @@ -230,14 +276,14 @@ class TensorHandler(StreamHandlerBase): return missed_tensors - def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): + def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type, step=None): """Update has_prev_step field in tensor info.""" - missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type) - if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0: + missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type, step) + if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and step > 0: tensor_info['has_prev_step'] = True return missing_tensors_info - def _get_missing_tensor_info(self, tensor_name, node_type): + def _get_missing_tensor_info(self, tensor_name, node_type, step): """ Get missing tensor infos. @@ -248,7 +294,6 @@ class TensorHandler(StreamHandlerBase): Returns: list, list of missing tensor basic information. """ - step = self.cur_step missing_tensors_info = [] # check the current step value is missing if self._is_tensor_value_missing(tensor_name, step): @@ -278,13 +323,13 @@ class TensorHandler(StreamHandlerBase): tensor = self._get_tensor(tensor_name, step=step) return bool(not tensor or tensor.empty) - def get_valid_tensor_by_name(self, tensor_name, prev=False): + def get_valid_tensor_by_name(self, tensor_name, step, prev=False): """Get tensor value by name in numpy type.""" - step = self.prev_step if prev else self.cur_step - if step < 0: - log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name) + target_step = step - 1 if prev else step + if target_step < 0: + log.warning("Step %d has no previous value for tensor: %s", target_step, tensor_name) return None - tensor = self._get_tensor(tensor_name, step=step) + tensor = self._get_tensor(tensor_name, step=target_step) if tensor and tensor.empty: log.warning("%s has empty value.", tensor_name) return None @@ -316,9 +361,9 @@ class TensorHandler(StreamHandlerBase): self._tensors.pop(param) log.debug("Clean param %s in cache.", param) - def get_tensors_diff(self, tensor_name, shape, tolerance=0): + def get_tensors_diff(self, tensor_name, shape, tolerance=0, step=None): """ - Get tensor comparisons data for given name, detail, shape and tolerance. + Get tensor comparisons data for given name, detail, shape and tolerance. Args: tensor_name (str): The name of tensor for cache. @@ -329,6 +374,7 @@ class TensorHandler(StreamHandlerBase): calculate the min value and max value of the result of the current step tensor subtract the previous step tensor. If the absolute value of result is less than or equal to boundary value, the result will set to be zero. + step (int): The step of the tensor. Default: None. Raises: DebuggerParamValueError, If get current step node and previous step node failed or @@ -337,8 +383,8 @@ class TensorHandler(StreamHandlerBase): Returns: dict, the retrieved data. """ - curr_tensor = self.get_valid_tensor_by_name(tensor_name) - prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True) + curr_tensor = self.get_valid_tensor_by_name(tensor_name, step=step) + prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True, step=step) if not (curr_tensor and prev_tensor): log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " @@ -386,22 +432,23 @@ class TensorHandler(StreamHandlerBase): stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats) return stats_info - def get_tensor_info_for_tensor_graph(self, tensor_name, node_type): + def get_tensor_info_for_tensor_graph(self, tensor_name, node_type, step): """ Get Tensor info for tensor graphs. Args: tensor_name (str): Tensor name, format like `node_name:slot`. node_type (str): Node type. + step (int): The step of tensor info. Returns: dict, tensor infos, including overall statistics, tensor shape and has_prev_step info. list, list of missing tensor basic information. """ res = {} - tensor = self._get_tensor(tensor_name, node_type) + tensor = self._get_tensor(tensor_name, node_type, step) if tensor and not tensor.empty: res['statistics'] = tensor.get_tensor_statistics() res['shape'] = tensor.shape - missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type) + missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type, step) return res, missing_tensors diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py index 2c18f474..ed76891f 100644 --- a/mindinsight/debugger/stream_handler/watchpoint_handler.py +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -105,12 +105,12 @@ class WatchpointHandler(StreamHandlerBase): return {'watch_points': reply} - def get_pending_commands(self, graph_stream): + def get_pending_commands(self, multi_card_graph_stream): """ Get all watchpoint in SetCMD proto format. Args: - graph_stream (GraphHandler): Graph handler. + multi_card_graph_stream (MultiCardGraphHandler): Multi card graph handler. Returns: list[SetCMD], updated watchpoint to be sent to MindSpore. @@ -118,9 +118,13 @@ class WatchpointHandler(StreamHandlerBase): newly_set_cmds = [] for _, watchpoint in self._updated_watchpoints.items(): # construct set command with leaf nodes - watch_nodes = watchpoint.get_watch_nodes() - leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes) - newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes)) + watch_nodes_for_devices = watchpoint.get_watch_nodes() + leaf_watch_nodes_for_devices = {} + for rank_id, watch_nodes in watch_nodes_for_devices.items(): + graph_stream = multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id) + leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes) + leaf_watch_nodes_for_devices[rank_id] = leaf_watch_nodes + newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes_for_devices)) newly_set_cmds.extend(self._deleted_watchpoints) self.sync_set_cmd(newly_set_cmds) @@ -161,7 +165,7 @@ class WatchpointHandler(StreamHandlerBase): """ return self._outdated - def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None): + def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None, rank_id=0): """ set watch nodes for graph. @@ -170,23 +174,24 @@ class WatchpointHandler(StreamHandlerBase): graph_stream (GraphHandler): The graph handler. watch_point_id (int): The id of watchpoint. graph_name (str): The graph name. + rank_id (int): The rank id. """ if not (watch_point_id and graph): return log.debug("add watch flags") watchpoint = self._watchpoints.get(watch_point_id) - self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name) + self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name, rank_id) - def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None): + def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None, rank_id=0): """Set watch status to graph.""" if graph.get('children'): self._set_watch_status_recursively( - graph.get('children'), graph_stream, watchpoint, graph_name) + graph.get('children'), graph_stream, watchpoint, graph_name, rank_id=0) if graph.get('nodes'): - _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name) + _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name, rank_id) - def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name): + def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name, rank_id=0): """ Set watch state for nodes. @@ -204,11 +209,11 @@ class WatchpointHandler(StreamHandlerBase): node_name = node.get('name') # search result could have `nodes` in nodes object if node.get('nodes'): - flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name) + flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name, rank_id) else: full_name = graph_stream.get_full_name(node_name, graph_name) new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name]) - flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name) + flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name, rank_id) node['watched'] = flag if flag == WatchNodeTree.NOT_WATCH: continue @@ -224,7 +229,8 @@ class WatchpointHandler(StreamHandlerBase): state = WatchNodeTree.TOTAL_WATCH return state - def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None): + def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None, + device_amount=8): """ Create watchpoint. Args: @@ -241,9 +247,10 @@ class WatchpointHandler(StreamHandlerBase): } - id (str): Id of condition. - param (list[dict]): The list of param for this condition. - watch_nodes (list[NodeBasicInfo]): The list of node basic info. + watch_nodes (dict[list[NodeBasicInfo]]): The list of node basic info. watch_point_id (int): The id of watchpoint. name (str): The name of watchpoint. + device_amount (int): The amount of devices. Returns: int, the new id of watchpoint. @@ -253,7 +260,9 @@ class WatchpointHandler(StreamHandlerBase): new_id = self._latest_id + 1 watchpoint = Watchpoint(new_id, watch_condition, name) if watch_nodes: - watchpoint.add_nodes(watch_nodes) + for rank_id, watch_nodes_for_device in watch_nodes.items(): + validate_rank_id(rank_id, device_amount) + watchpoint.add_nodes(watch_nodes_for_device, rank_id) elif watch_point_id: self.validate_watchpoint_id(watch_point_id) watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) @@ -261,7 +270,7 @@ class WatchpointHandler(StreamHandlerBase): self._outdated = True return new_id - def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): + def update_watchpoint(self, watch_point_id, watch_nodes, watched=False, rank_id=0): """ Update watchpoint. @@ -270,13 +279,14 @@ class WatchpointHandler(StreamHandlerBase): watch_nodes (list[NodeBasicInfo]): The list of node basic info. watched (bool): The update operator on nodes. If False, remove nodes from watch nodes. If True, add nodes to watch nodes. Default: False. + rank_id (int): The rank id. """ self.validate_watchpoint_id(watch_point_id) watchpoint = self._watchpoints.get(watch_point_id) if watched: - watchpoint.add_nodes(watch_nodes) + watchpoint.add_nodes(watch_nodes, rank_id) else: - watchpoint.remove_nodes(watch_nodes) + watchpoint.remove_nodes(watch_nodes, rank_id) self._updated_watchpoints[watch_point_id] = watchpoint self._outdated = True log.debug("Update watchpoint %d in cache.", watch_point_id) @@ -328,6 +338,58 @@ class WatchpointHandler(StreamHandlerBase): raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) +class MultiCardWatchpointHitHandler: + """Multi-card Watchpoint-hit Handler.""" + + def __init__(self): + self.watchpoint_hit_handlers = {0: WatchpointHitHandler()} + + def get_hit_handler_by_rank_id(self, rank_id=0): + """Get handler by rank id.""" + if rank_id in self.watchpoint_hit_handlers: + return self.watchpoint_hit_handlers.get(rank_id) + log.error("There is no rank id %d.", rank_id) + raise ValueError + + def put(self, value): + """Put watchpoint hit into cache.""" + for rank_id, tensor_hit_values in value.items(): + if rank_id not in self.watchpoint_hit_handlers: + self.watchpoint_hit_handlers[rank_id] = WatchpointHitHandler() + cur_hit_handler = self.watchpoint_hit_handlers[rank_id] + for tensor_hit_value in tensor_hit_values: + cur_hit_handler.put(tensor_hit_value) + + def get(self, filter_condition=None, rank_id=0): + """Get the graph of specific node for specific device.""" + if rank_id in self.watchpoint_hit_handlers: + return self.watchpoint_hit_handlers.get(rank_id).get(filter_condition) + log.error("There is no rank id %d.", rank_id) + raise ValueError + + def update_tensor_history(self, tensor_history, rank_id): + """ + Add hit flag to tensor history. + + Args: + tensor_history (dict): The tensor history. + rank_id (int): The rank id. + """ + if rank_id in self.watchpoint_hit_handlers: + self.watchpoint_hit_handlers[rank_id].update_tensor_history(tensor_history) + else: + for tensor_info in tensor_history.get('tensor_history'): + tensor_info['is_hit'] = False + + def check_rank_id(self, rank_id): + """check if has the rank id.""" + return rank_id in self.watchpoint_hit_handlers + + def clean(self): + """Clean cache.""" + self.__init__() + + class WatchpointHitHandler(StreamHandlerBase): """Watchpoint hit handler.""" @@ -743,3 +805,9 @@ def _get_error_list(error_code): error_list.append(error_str) return error_list + + +def validate_rank_id(rank_id, device_amount): + """validate rank id""" + if rank_id >= device_amount: + log.debug("The rank id %d over device amount.", rank_id) diff --git a/mindinsight/debugger/stream_operator/tensor_detail_info.py b/mindinsight/debugger/stream_operator/tensor_detail_info.py index 72a15cd4..16991bb9 100644 --- a/mindinsight/debugger/stream_operator/tensor_detail_info.py +++ b/mindinsight/debugger/stream_operator/tensor_detail_info.py @@ -23,17 +23,19 @@ class TensorDetailInfo: def __init__(self, cache): self._put_command = cache.put_command - self._tensor_stream = cache.get_stream_handler(Streams.TENSOR) - self._graph_stream = cache.get_stream_handler(Streams.GRAPH) - self._hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT) + self._metadata_stream = cache.get_stream_handler(Streams.METADATA) + self._multi_card_tensor_stream = cache.get_stream_handler(Streams.TENSOR) + self._multi_card_graph_stream = cache.get_stream_handler(Streams.GRAPH) + self._multi_card_hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT) - def validate_tensor_name(self, tensor_name, graph_name): + def validate_tensor_name(self, tensor_name, graph_name, rank_id): """ Get the graph id of the tensor. Args: tensor_name (str): The tensor name on UI. graph_name (str): The graph name. + rank_id (int): The rank id. """ # validate tensor name format if not isinstance(tensor_name, str) or ':' not in tensor_name: @@ -41,15 +43,17 @@ class TensorDetailInfo: raise DebuggerParamValueError("Invalid tensor name.") node_name, _ = tensor_name.rsplit(':', 1) # check if the node name is in graph - self._graph_stream.validate_node_name(node_name=node_name, graph_name=graph_name) + self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).validate_node_name(node_name=node_name, + graph_name=graph_name) - def get_tensor_graph(self, tensor_name, graph_name): + def get_tensor_graph(self, tensor_name, graph_name, rank_id=0): """ Get the graph related to specific tensor. Args: tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. graph_name (str): The graph name. + rank_id (int): The rank id. Returns: dict, tensor graph, format is {'nodes': [Node object]}. @@ -68,8 +72,9 @@ class TensorDetailInfo: 'slot_mapping': list[pair], }. """ - self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) - graph = self._graph_stream.get_tensor_graph(tensor_name, graph_name) + self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name, rank_id=rank_id) + graph = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_tensor_graph(tensor_name, + graph_name) # add watchpoint hits info and statistics info for each tensor in tensor graph. # record missing tensor basic info nodes = graph.get('graph', {}).get('nodes', []) @@ -77,13 +82,13 @@ class TensorDetailInfo: for node in nodes: node['graph_name'] = graph_name for slot_info in node.get('slots', []): - self._add_watchpoint_hit_info(slot_info, node, graph_name) - self._add_tensor_info(slot_info, node, missing_tensors) + self._add_watchpoint_hit_info(slot_info, node, graph_name, rank_id) + self._add_tensor_info(slot_info, node, missing_tensors, rank_id) # query missing tensor values from client self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) return graph - def _add_watchpoint_hit_info(self, slot_info, node, graph_name): + def _add_watchpoint_hit_info(self, slot_info, node, graph_name, rank_id): """ Add watchpoint hit info for the tensor. @@ -93,9 +98,12 @@ class TensorDetailInfo: graph_name (str): Graph name. """ tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) - slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name)) + if self._multi_card_hit_stream.check_rank_id(rank_id=rank_id): + slot_info.update( + self._multi_card_hit_stream.get_hit_handler_by_rank_id(rank_id).get_tensor_hit_infos(tensor_name, + graph_name)) - def _add_tensor_info(self, slot_info, node, missing_tensors): + def _add_tensor_info(self, slot_info, node, missing_tensors, rank_id): """ Add the tensor info and query for missed tensors. @@ -106,7 +114,8 @@ class TensorDetailInfo: """ tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')]) node_type = node.get('type') - tensor_info, cur_missing_tensors = self._tensor_stream.get_tensor_info_for_tensor_graph(tensor_name, node_type) + tensor_info, cur_missing_tensors = self._multi_card_tensor_stream.get_tensor_handler_by_rank_id( + rank_id).get_tensor_info_for_tensor_graph(tensor_name, node_type, self._metadata_stream.step) slot_info.update(tensor_info) if cur_missing_tensors: log.debug("Get missing tensor basic infos for %s", tensor_name) @@ -128,20 +137,24 @@ class TensorDetailInfo: self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name}) log.debug("Send view cmd for tensor-graphs.") - def get_tensor_watch_points(self, tensor_name, graph_name): + def get_tensor_watch_points(self, tensor_name, graph_name, rank_id=0): """ Get all watchpoints that the tensor hit. Args: tensor_name (str): Tensor name from UI. graph_name (str): The graph name. + rank_id (int): The rank id. Returns: list, watchpoint hit infos. """ # validate tensor_name - self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) + self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name, rank_id=rank_id) # get watchpoint info that the tensor hit - tensor_hit_info = self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name) + if not self._multi_card_hit_stream.check_rank_id(rank_id=rank_id): + return [] + tensor_hit_info = self._multi_card_hit_stream.get_hit_handler_by_rank_id(rank_id).get_tensor_hit_infos( + tensor_name, graph_name) watch_points = tensor_hit_info.get('watch_points', []) return watch_points diff --git a/mindinsight/debugger/stream_operator/training_control_operator.py b/mindinsight/debugger/stream_operator/training_control_operator.py index f19c9c3d..5fa49924 100644 --- a/mindinsight/debugger/stream_operator/training_control_operator.py +++ b/mindinsight/debugger/stream_operator/training_control_operator.py @@ -18,7 +18,8 @@ import enum from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \ DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError from mindinsight.debugger.common.log import LOGGER as log -from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type +from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type, \ + DebuggerServerMode from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD from mindinsight.utils.exceptions import MindInsightException @@ -29,6 +30,7 @@ class ControlTypeEnum(enum.Enum): CONTINUE = 'continue' # continue to run training PAUSE = 'pause' # suspend training TERMINATE = 'terminate' # terminate training + RESET = 'reset' # reset the step_id in offline debugger class TrainingControlOperator: @@ -39,7 +41,7 @@ class TrainingControlOperator: def __init__(self, cache_store): self._cache_store = cache_store self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) - self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) + self._multi_card_graph_stream = cache_store.get_stream_handler(Streams.GRAPH) self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) @staticmethod @@ -71,6 +73,9 @@ class TrainingControlOperator: """ if mode == ControlTypeEnum.CONTINUE.value: reply = self.continue_training(params) + elif mode == ControlTypeEnum.RESET.value: + step_id = params['steps'] + reply = self.reset_training_step(step_id) else: mode_mapping = { ControlTypeEnum.PAUSE.value: self.pause_training, @@ -150,13 +155,15 @@ class TrainingControlOperator: if level == RunLevel.NODE.value: node_name = params.get('name') graph_name = params.get('graph_name') - self._validate_continue_node_name(node_name, graph_name) + rank_id = params.get('rank_id', 0) + self._validate_continue_node_name(node_name, graph_name, rank_id) - def _validate_continue_node_name(self, node_name, graph_name): + def _validate_continue_node_name(self, node_name, graph_name, rank_id): """Validate if the node is a leaf node.""" if not node_name: return - node_type = self._graph_stream.get_node_type(node_name, graph_name) + node_type = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_node_type(node_name, + graph_name) if is_scope_type(node_type): log.error("Scope type node has no tensor history.") raise DebuggerParamValueError("Invalid leaf node name.") @@ -188,7 +195,9 @@ class TrainingControlOperator: name = params.get('name', '') graph_name = params.get('graph_name') if name: - name = self._cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name) + rank_id = params.get('rank_id', 0) + name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_full_name(name, + graph_name) run_cmd = RunCMD(run_level='node', node_name=name) else: run_cmd = RunCMD(run_level='recheck') @@ -199,7 +208,7 @@ class TrainingControlOperator: def _send_watchpoints(self): """Send watchpoints to client.""" - set_commands = self._watchpoint_stream.get_pending_commands(self._graph_stream) + set_commands = self._watchpoint_stream.get_pending_commands(self._multi_card_graph_stream) if not set_commands: return for set_cmd in set_commands: @@ -274,3 +283,30 @@ class TrainingControlOperator: else: log.debug("Send the recheck to command queue.") return metadata_stream.get(['state', 'enable_recheck']) + + def reset_training_step(self, step_id): + """ + Reset the training step. + + Args: + step_id (int): The target step_id. + + Returns: + dict, metadata info. + """ + metadata_stream = self._metadata_stream + if metadata_stream.debugger_type == DebuggerServerMode.ONLINE.value: + log.error("'step_id' can not be changed manually in online debugger.") + return metadata_stream.get(['state', 'enable_recheck', 'step']) + if step_id > metadata_stream.max_step_num: + log.error("Invalid step_id, step_id should be less than %d.", metadata_stream.max_step_num) + raise DebuggerParamValueError("Invalid step_id.") + metadata_stream.state = ServerStatus.SENDING.value + metadata_stream.step = step_id + self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id) + self._cache_store.clean_data() + self._cache_store.clean_command() + metadata_stream.enable_recheck = False + metadata_stream.state = ServerStatus.WAITING.value + log.debug("Send the Change_training_step CMD.") + return metadata_stream.get(['state', 'enable_recheck', 'step']) diff --git a/mindinsight/debugger/stream_operator/watchpoint_operator.py b/mindinsight/debugger/stream_operator/watchpoint_operator.py index 63157c3c..cb011535 100644 --- a/mindinsight/debugger/stream_operator/watchpoint_operator.py +++ b/mindinsight/debugger/stream_operator/watchpoint_operator.py @@ -31,8 +31,9 @@ class WatchpointOperator: def __init__(self, cache_store, condition_mgr): self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) - self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) + self._multi_card_graph_stream = cache_store.get_stream_handler(Streams.GRAPH) self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) + self._device_stream = cache_store.get_stream_handler(Streams.DEVICE) self._condition_mgr = condition_mgr def create_watchpoint(self, params): @@ -70,11 +71,6 @@ class WatchpointOperator: "Failed to create watchpoint as the MindSpore is not in waiting state.") self._validate_watch_condition(watch_condition) - watch_nodes = self._get_watch_node_with_basic_info( - node_names=params.get('watch_nodes'), - search_pattern=params.get('search_pattern'), - graph_name=params.get('graph_name')) - validate_watch_condition(self._condition_mgr, watch_condition) condition_id = watch_condition.get('id') condition = self._condition_mgr.get_condition(condition_id) @@ -84,10 +80,11 @@ class WatchpointOperator: raise DebuggerConditionUnavailableError( "Failed to create watchpoint as the condition is not available.") - watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy() + watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._multi_card_graph_stream) watchpoint_stream = self._watchpoint_stream - watch_point_id = watchpoint_stream.create_watchpoint( - self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) + watch_point_id = watchpoint_stream.create_watchpoint(self._condition_mgr, watch_condition, watch_nodes, + params.get('watch_point_id'), + self._device_stream.device_amount) log.info("Create watchpoint %d", watch_point_id) metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() @@ -115,6 +112,7 @@ class WatchpointOperator: 1 for add nodes to watch nodes. - search_pattern (dict): The search pattern. - graph_name (str): The relative graph_name of the watched node. + - rank_id (int): The rank id. Returns: dict, the metadata info. @@ -137,13 +135,14 @@ class WatchpointOperator: watch_nodes = self._get_watch_node_with_basic_info( node_names=params.get('watch_nodes'), search_pattern=params.get('search_pattern'), - graph_name=params.get('graph_name')) - watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode')) + graph_name=params.get('graph_name'), + rank_id=params.get('rank_id', 0)) + watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode'), params.get('rank_id', 0)) metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() log.info("Update watchpoint with id: %d", watch_point_id) return metadata_stream.get(['state', 'enable_recheck']) - def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None): + def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None, rank_id=0): """ Get watch node with basic info. @@ -151,20 +150,21 @@ class WatchpointOperator: node_names (list[str]): A list of node names. search_pattern (dict): Get watch node with search pattern. Default: None graph_name (str): The relative graph_name of the watched node. Default: None. + rank_id (int): The rank id. Returns: list[NodeBasicInfo], a list of node basic infos. """ if not node_names: return [] - graph_name = self._graph_stream.validate_graph_name(graph_name) + graph_name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).validate_graph_name(graph_name) if search_pattern is not None: - watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name) + watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name, rank_id) else: - watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name) + watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name, rank_id=rank_id) return watch_nodes - def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name): + def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name, rank_id): """ Get watched leaf nodes by search name. @@ -180,7 +180,7 @@ class WatchpointOperator: list[NodeBasicInfo], a list of node basic infos. """ search_pattern['graph_name'] = graph_name - search_nodes = self._graph_stream.search_nodes(search_pattern) + search_nodes = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).search_nodes(search_pattern) watch_node_names = set() for name in node_names: names = self._get_watch_names_by_search(search_nodes, name) @@ -260,7 +260,7 @@ class WatchpointOperator: log.info("Delete watchpoint with id: %s", watch_point_id) return metadata_stream.get(['state', 'enable_recheck']) - def _get_node_basic_infos(self, node_names, graph_name=None): + def _get_node_basic_infos(self, node_names, graph_name=None, rank_id=0): """ Get watch node info according to node names. @@ -273,7 +273,7 @@ class WatchpointOperator: """ if not node_names: return [] - graph_stream = self._graph_stream + graph_stream = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id) node_infos = [] for node_name in node_names: node_info = graph_stream.get_node_basic_info(node_name, graph_name) diff --git a/mindinsight/ui/src/app.vue b/mindinsight/ui/src/app.vue index 2949a6fd..72cd11c2 100644 --- a/mindinsight/ui/src/app.vue +++ b/mindinsight/ui/src/app.vue @@ -26,7 +26,7 @@ limitations under the License.
- +
diff --git a/mindinsight/ui/src/components/debugger-tensor.vue b/mindinsight/ui/src/components/debugger-tensor.vue index 777f80c8..ce00df2b 100644 --- a/mindinsight/ui/src/components/debugger-tensor.vue +++ b/mindinsight/ui/src/components/debugger-tensor.vue @@ -362,8 +362,9 @@ export default { const params = { tensor_name: this.curRowObj.name, graph_name: this.curRowObj.graph_name, + rank_id: this.curRowObj.rank_id, }; - RequestService.getTensorGraphData(params).then( + RequestService.getTensorGraphData(params, this.curRowObj.sessionId).then( (res) => { if (res && res.data && res.data.graph && res.data.graph.nodes && res.data.graph.nodes.length) { this.graphShow = true; @@ -419,8 +420,9 @@ export default { const params = { tensor_name: this.curRowObj.name, graph_name: this.curRowObj.graph_name, + rank_id: this.curRowObj.rank_id, }; - RequestService.tensorHitsData(params).then( + RequestService.tensorHitsData(params, this.curRowObj.sessionId).then( (res) => { if (res && res.data && res.data.watch_points && res.data.watch_points.length) { this.leftDataShow = true; @@ -995,11 +997,12 @@ export default { shape: encodeURIComponent(shape), tolerance: this.tolerance / 100, graph_name: row.graph_name, + rank_id: row.rank_id, }; if (loadingFlag) { this.loadingInstance = this.$loading(this.loadingOption); } - RequestService.tensorComparisons(params).then( + RequestService.tensorComparisons(params, row.sessionId).then( (res) => { if (res && res.data && res.data.tensor_value) { if (row.shape === '[]') { @@ -1088,11 +1091,12 @@ export default { shape: encodeURIComponent(shape), graph_name: row.graph_name, prev: this.gridType === 'preStep' ? true : false, + rank_id: row.rank_id, }; if (loadingFlag) { this.loadingInstance = this.$loading(this.loadingOption); } - RequestService.tensors(params).then( + RequestService.tensors(params, row.sessionId).then( (res) => { if (row.shape === '[]') { this.showFilterInput = false; diff --git a/mindinsight/ui/src/locales/en-us.json b/mindinsight/ui/src/locales/en-us.json index f570858f..61918b35 100644 --- a/mindinsight/ui/src/locales/en-us.json +++ b/mindinsight/ui/src/locales/en-us.json @@ -24,7 +24,9 @@ "dataLoading": "Loading data...", "notice": "Information", "caseMode": "Not case sensitive", - "all": "All" + "all": "All", + "details": "Details", + "delete": "Delete" }, "symbols": { "leftbracket": "(", @@ -52,12 +54,14 @@ "operation": "Operation", "viewDashboard": "Training Dashboard", "viewProfiler": "Profiling", + "viewOfflineDebugger": "Offline Debugger", "modelTraceback": "Model Lineage", "dataTraceback": "Dataset Lineage", "comparePlate": "Comparison Dashboard", "disableProfilerTip": "Failed to view profiling because no profiler log is available.", "disableDashboardTip": "Failed to view training dashboard because no summary log or pb files are available.", "disableParameterTip": "Failed to view parameter details because no lineage log is available.", + "disableOfflineDebugger": "Failed to view offline debugger because no debugger log is available.", "openNewTab": "Open Link in New Tab", "paramDetails": "Parameter Details", "trainingParamDetails": "Training Parameter Details", @@ -80,7 +84,12 @@ "tensorUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#tensor-visualization", "graphUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#computational-graph-visualization", "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#dataset-graph-visualization", - "imageUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#image-visualization" + "imageUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#image-visualization", + "sessionLimit": "The number of sessions of the offline debugger exceeds the number of online sessions", + "sessionLimitNum": "At most 2 exist at the same time", + "sessionLists": "List of currently existing sessions", + "deleteSessionConfirm": "This operation will delete the current session, do you want to continue?", + "deleteSessionSuccess": "Delete session successfully!" }, "modelTraceback": { "summaryPath": "Summary Path", @@ -561,7 +570,7 @@ "terminate": "TERMINATE", "selectCondition": "Select a condition", "inputStep": "Enter a step value", - "inputTip": "A positive integer less than 2147483648", + "inputTip": "A positive integer less than or equal to {total_step_num}", "curHitNode": "Watch Point Hit List", "backstageStatus": "The backend running status is ", "view": "View", @@ -830,7 +839,9 @@ "allPositive": "he parameter value must be greater than 0.", "watchOverflow": "The asynchronous full overflow watching function must be enabled before the training starts." }, - "paramValueTip": "Preset Value: {value}" + "paramValueTip": "Preset Value: {value}", + "logicCard": "Logic card", + "inpStepTip": "Step:0~{total_step_num}" }, "explain": { "explain": "Model Explanation", @@ -952,6 +963,7 @@ "5054B183": "Backend training is in progress or has ended. Please try again later", "5054B184": "The operation is too fast, the backend service has been suspended.", "5054B189": "Do not set the value repeatedly.", - "5054B083": "Failed to create the watchpoint. Do not use invalid rules." + "5054B083": "Failed to create the watchpoint. Do not use invalid rules.", + "5054B202": "The debugger offline server module was not found" } } \ No newline at end of file diff --git a/mindinsight/ui/src/locales/zh-cn.json b/mindinsight/ui/src/locales/zh-cn.json index 3949c452..3c03f592 100644 --- a/mindinsight/ui/src/locales/zh-cn.json +++ b/mindinsight/ui/src/locales/zh-cn.json @@ -24,7 +24,9 @@ "dataLoading": "数据加载中", "notice": "提示", "caseMode": "不区分大小写", - "all": "全部" + "all": "全部", + "details": "详情", + "delete": "删除" }, "symbols": { "leftbracket": "(", @@ -52,12 +54,14 @@ "operation": "操作", "viewDashboard": "训练看板", "viewProfiler": "性能分析", + "viewOfflineDebugger": "离线调试器", "modelTraceback": "模型溯源", "dataTraceback": "数据溯源", "comparePlate": "对比看板", "disableProfilerTip": "无profiler日志,无法查看性能分析", "disableDashboardTip": "无summary日志或pb文件,无法查看训练看板", "disableParameterTip": "无lineage日志,无法查看参数详情", + "disableOfflineDebugger": "无Debugger日志,无法查看离线调试器", "openNewTab": "打开新页签", "paramDetails": "参数详情", "trainingParamDetails": "训练参数详情", @@ -80,7 +84,12 @@ "tensorUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id8", "graphUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id5", "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id6", - "imageUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id7" + "imageUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id7", + "sessionLimit": "离线调试器的session个数超过上线", + "sessionLimitNum": "最多同时存在2个", + "sessionLists": "目前存在的session列表", + "deleteSessionConfirm": "此操作将删除当前session, 是否继续?", + "deleteSessionSuccess": "删除session成功!" }, "modelTraceback": { "summaryPath": "训练日志路径", @@ -560,7 +569,7 @@ "terminate": "结束", "selectCondition": "请选择条件", "inputStep": "请输入轮次值", - "inputTip": "小于2147483648的正整数", + "inputTip": "小于等于{total_step_num}的正整数", "curHitNode": "命中的监测点", "backstageStatus": "后台运行状态是", "view": "查看", @@ -825,7 +834,9 @@ "allPositive": "此参数值必须大于0", "watchOverflow": "训练开始前需开启异步全量溢出监测功能" }, - "paramValueTip": "设置值为:{value}" + "paramValueTip": "设置值为:{value}", + "logicCard": "逻辑卡", + "inpStepTip": "可输入当前轮次:0~{total_step_num}" }, "explain": { "explain": "模型解释", @@ -947,6 +958,7 @@ "5054B183": "后台训练运行中,请稍后重试", "5054B184": "操作过快,后台服务已暂停。", "5054B189": "请勿重复设置。", - "5054B083": "监测点创建失败,请勿使用已失效规则。" + "5054B083": "监测点创建失败,请勿使用已失效规则。", + "5054B202": "未找到调试器离线服务器模块" } } \ No newline at end of file diff --git a/mindinsight/ui/src/mixins/debugger-mixin.vue b/mindinsight/ui/src/mixins/debugger-mixin.vue index 6392e099..3a208a98 100644 --- a/mindinsight/ui/src/mixins/debugger-mixin.vue +++ b/mindinsight/ui/src/mixins/debugger-mixin.vue @@ -25,6 +25,65 @@ export default { }; }, methods: { + editStep() { + this.isShowInp = true; + this.newStep = this.metadata.step; + }, + newStepChange(val) { + if (val === '') { + return; + } + val = val.replace(/[^0-9]+/g, ''); + if (Number(val) <= this.metadata.total_step_num) { + this.newStep = Number(val); + } else { + this.newStep = this.metadata.total_step_num; + } + }, + saveStepValue() { + this.isShowInp = false; + if (this.newStep === '' || this.newStep === this.metadata.step) { + return; + } + this.metadata.step = this.newStep; + const params = { + mode: 'reset', + level: 'step', + steps: this.metadata.step, + graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, + }; + if (this.graphFiles.value === this.$t('debugger.all')) { + delete params.graph_name; + } + RequestService.control(params, this.sessionId).then( + (res) => { + this.queryTensorHistory(); + }, + (err) => { + this.showErrorMsg(err); + }, + ); + }, + getSession() { + const params = { + dump_dir: null, + session_type: 'ONLINE', + }; + RequestService.getSession(params).then((res) => { + if (res) { + this.sessionId = res.data; + this.retrieveAll(); + } + }); + }, + deleteSession() { + RequestService.deleteSession(this.sessionId).then((res) => { + this.$router.push({ + path: '/summary-manage', + }); + }); + }, handleCurrentChange(page) { this.pagination.currentPage = page; this.searchWatchpointHits(false); @@ -37,10 +96,13 @@ export default { * Initialize the condition */ initCondition() { - if (this.metadata.state === this.state.running || this.metadata.state === this.state.sending) { + if ( + this.metadata.state === this.state.running || + this.metadata.state === this.state.sending + ) { return; } - RequestService.queryConditions(this.trainId).then((res) => { + RequestService.queryConditions(this.sessionId).then((res) => { if (res && res.data) { this.conditionCollections = res.data; this.addWatchPoint(); @@ -85,7 +147,7 @@ export default { if (this.step === '') { return; } - const maxStep = 2147483648; + const maxStep = this.metadata.total_step_num; this.step = this.step .toString() .replace(/[^\.\d]/g, '') @@ -95,7 +157,7 @@ export default { this.step = 1; } if (this.step >= maxStep) { - this.step = maxStep - 1; + this.step = maxStep; } }, /** @@ -123,7 +185,7 @@ export default { } params.params.graph_name = this.graphFiles.value; } - RequestService.retrieve(params).then( + RequestService.retrieve(params, this.sessionId).then( (res) => { if (res.data) { if (res.data.metadata) { @@ -136,7 +198,11 @@ export default { this.origialTree = graph.nodes.map((val) => { return { label: val.name.split('/').pop(), - leaf: val.type === 'name_scope' || val.type === 'aggregation_scope' ? false : true, + leaf: + val.type === 'name_scope' || + val.type === 'aggregation_scope' + ? false + : true, ...val, showCheckbox: val.watched !== -1, }; @@ -156,10 +222,22 @@ export default { this.allGraphData = {}; d3.select('#graph svg').remove(); this.selectedNode.name = ''; - this.packageDataToObject('', true, JSON.parse(JSON.stringify(graph.nodes))); - this.querySingleNode(JSON.parse(JSON.stringify(graph)), name, true); + this.packageDataToObject( + '', + true, + JSON.parse(JSON.stringify(graph.nodes)), + ); + this.querySingleNode( + JSON.parse(JSON.stringify(graph)), + name, + true, + ); } else { - this.querySingleNode(JSON.parse(JSON.stringify(graph)), name, true); + this.querySingleNode( + JSON.parse(JSON.stringify(graph)), + name, + true, + ); } if (graph.children) { this.dealTreeData(graph.children, name); @@ -188,11 +266,12 @@ export default { level: 'node', name: '', graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.graph_name; } - RequestService.control(params).then( + RequestService.control(params, this.sessionId).then( (res) => {}, (err) => { this.showErrorMsg(err); @@ -207,7 +286,8 @@ export default { const data = this.$refs.tree.getCurrentNode(); let name = this.$refs.tree.getCurrentKey(); if ( - (data && (data.type === 'name_scope' || data.type === 'aggregation_scope')) || + (data && + (data.type === 'name_scope' || data.type === 'aggregation_scope')) || this.curLeafNodeName === null ) { name = this.curLeafNodeName; @@ -232,7 +312,10 @@ export default { this.dealTreeData(graph.children, name); this.defaultCheckedArr = this.$refs.tree.getCheckedKeys(); } - this.querySingleNode(JSON.parse(JSON.stringify(graph)), res.data.name); + this.querySingleNode( + JSON.parse(JSON.stringify(graph)), + res.data.name, + ); } else if (ascend) { this.$message.success(this.$t('debugger.nextNodeTip')); } else { @@ -248,13 +331,21 @@ export default { * Terminate current training */ terminate() { - this.$confirm(this.$t('debugger.ternimateConfirm'), this.$t('public.notice'), { - confirmButtonText: this.$t('public.sure'), - cancelButtonText: this.$t('public.cancel'), - type: 'warning', - }).then( + this.$confirm( + this.$t('debugger.ternimateConfirm'), + this.$t('public.notice'), + { + confirmButtonText: this.$t('public.sure'), + cancelButtonText: this.$t('public.cancel'), + type: 'warning', + }, + ).then( () => { - this.control(2); + if (this.trainId) { + this.deleteSession(); + } else { + this.control(2); + } }, (err) => { this.showErrorMsg(err); @@ -331,7 +422,12 @@ export default { if (this.graphFiles.value === this.$t('debugger.all')) { if (data.name.includes('/')) { const graphName = data.name.split('/')[0]; - this.queryAllTreeData(data.name.replace(`${graphName}/`, ''), true, graphName, true); + this.queryAllTreeData( + data.name.replace(`${graphName}/`, ''), + true, + graphName, + true, + ); } else { this.queryAllTreeData(data.name, true, data.name, true); } @@ -348,13 +444,14 @@ export default { retrieveTensorHistory(data, graphName) { const params = { name: data.name, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { params.name = `${graphName}/${data.name}`; } else { params.graph_name = graphName; } - RequestService.retrieveTensorHistory(params).then( + RequestService.retrieveTensorHistory(params, this.sessionId).then( (res) => { if (res.data && res.data.metadata) { this.dealMetadata(res.data.metadata); @@ -437,7 +534,8 @@ export default { (nodeName !== this.currentNodeName && nodeName !== '') || this.metadata.step !== metadata.step || (this.metadata.state === this.state.waiting && - (temState === this.state.sending || temState === this.state.running)) + (temState === this.state.sending || + temState === this.state.running)) ) { if (nodeName) { if (this.metadata.state !== this.state.running) { @@ -447,8 +545,14 @@ export default { } this.metadata.step = metadata.step; - let graphName = this.graphFiles.value === this.$t('debugger.all') ? '' : this.graphFiles.value; - if (this.graphFiles.value === this.$t('debugger.all') && this.selectedNode.name) { + let graphName = + this.graphFiles.value === this.$t('debugger.all') + ? '' + : this.graphFiles.value; + if ( + this.graphFiles.value === this.$t('debugger.all') && + this.selectedNode.name + ) { graphName = this.selectedNode.name.split('/')[0]; } if (metadata.graph_name) { @@ -486,10 +590,16 @@ export default { const path = this.selectedNode.name.split('^'); const type = this.allGraphData[path[0].replace('_unfold', '')].type; const ignoreType = ['name_scope', 'aggregation_scope']; - if (!this.selectedNode.name.includes('more...') && !ignoreType.includes(type)) { + if ( + !this.selectedNode.name.includes('more...') && + !ignoreType.includes(type) + ) { const name = path[0].replace('_unfold', ''); if (this.graphFiles.value === this.$t('debugger.all')) { - this.retrieveTensorHistory({name: name.replace(`${name.split('/')[0]}/`, '')}, name.split('/')[0]); + this.retrieveTensorHistory( + {name: name.replace(`${name.split('/')[0]}/`, '')}, + name.split('/')[0], + ); } else { this.retrieveTensorHistory( { @@ -507,11 +617,12 @@ export default { const params = { pos: this.metadata.pos, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.graph_name; } - RequestService.pollData(params).then( + RequestService.pollData(params, this.sessionId).then( (res) => { if (res.data) { if (res.data.metadata) { @@ -555,7 +666,10 @@ export default { ) { const debTensor = this.$refs['deb-tensor']; if (debTensor) { - debTensor.updateGraphData(res.data.receive_tensor.graph_name, res.data.receive_tensor.tensor_name); + debTensor.updateGraphData( + res.data.receive_tensor.graph_name, + res.data.receive_tensor.tensor_name, + ); } } this.pollData(); @@ -594,7 +708,7 @@ export default { } else if (type === 3) { params.mode = 'pause'; } - RequestService.control(params).then( + RequestService.control(params, this.sessionId).then( (res) => { if (res.data && res.data.metadata) { setTimeout(() => { @@ -604,7 +718,9 @@ export default { } else if (this.metadata.state === this.state.running) { msg = this.$t('debugger.stateMsg.running'); } else { - msg = `${this.$t('debugger.backstageStatus')}${this.metadata.state}`; + msg = `${this.$t('debugger.backstageStatus')}${ + this.metadata.state + }`; } this.$message(msg); }, 500); @@ -632,7 +748,7 @@ export default { if (!this.enableRecheck) { return; } - RequestService.recheckWatchPoints().then( + RequestService.recheckWatchPoints(this.sessionId).then( (res) => { if (res && res.data && res.data.metadata) { if (res.data.metadata.enable_recheck !== undefined) { @@ -691,14 +807,16 @@ export default { return; } if ((item && item.id) || !item) { - const msg = item ? this.$t('debugger.deleteWatchpointConfirm') : this.$t('debugger.clearWatchpointConfirm'); + const msg = item + ? this.$t('debugger.deleteWatchpointConfirm') + : this.$t('debugger.clearWatchpointConfirm'); this.$confirm(msg, this.$t('public.notice'), { confirmButtonText: this.$t('public.sure'), cancelButtonText: this.$t('public.cancel'), type: 'warning', }).then(() => { const params = {watch_point_id: item ? item.id : null}; - RequestService.deleteWatchpoint(params).then( + RequestService.deleteWatchpoint(params, this.sessionId).then( (res) => { if (!item) { this.curWatchPointId = null; @@ -707,7 +825,12 @@ export default { this.loadOriginalTree(); this.queryWatchPoints(); this.$message.success(this.$t('debugger.successDeleteWP')); - if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { + if ( + res && + res.data && + res.data.metadata && + res.data.metadata.enable_recheck !== undefined + ) { this.enableRecheck = res.data.metadata.enable_recheck; } this.curWatchPointId = null; @@ -793,6 +916,7 @@ export default { }, watch_nodes: [], graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.graph_name; @@ -802,7 +926,10 @@ export default { params.condition.params = [ { name: item.param.name, - value: item.param.type === 'BOOL' ? Boolean(item.param.value) : Number(item.param.value), + value: + item.param.type === 'BOOL' + ? Boolean(item.param.value) + : Number(item.param.value), }, ]; } @@ -814,12 +941,17 @@ export default { }); }); } - RequestService.createWatchpoint(params).then( + RequestService.createWatchpoint(params, this.sessionId).then( (res) => { this.createWatchPointArr = []; this.createWPDialogVisible = false; this.$message.success(this.$t('debugger.successCreateWP')); - if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { + if ( + res && + res.data && + res.data.metadata && + res.data.metadata.enable_recheck !== undefined + ) { this.enableRecheck = res.data.metadata.enable_recheck; } @@ -902,9 +1034,11 @@ export default { })[0]; if (param.required_params && param.required_params.length) { - item.compositeParams.selections = item.compositeParams.options.filter((i) => { - return param.required_params.includes(i.name); - }); + item.compositeParams.selections = item.compositeParams.options.filter( + (i) => { + return param.required_params.includes(i.name); + }, + ); item.compositeParams.selections.forEach((i) => { i.value = i.type === 'BOOL' ? true : ''; }); @@ -968,14 +1102,20 @@ export default { watch_nodes: watchNodes, mode: node.indeterminate || check ? 1 : 0, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.graph_name; } - RequestService.updateWatchpoint(params).then( + RequestService.updateWatchpoint(params, this.sessionId).then( (res) => { this.defaultCheckedArr = checkedKeys; - if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { + if ( + res && + res.data && + res.data.metadata && + res.data.metadata.enable_recheck !== undefined + ) { this.enableRecheck = res.data.metadata.enable_recheck; } }, @@ -990,7 +1130,9 @@ export default { const parent = node.parent; if ( parent && - !parent.childNodes.filter((val) => val.data.watched !== -1).find((val) => val.checked === false) + !parent.childNodes + .filter((val) => val.data.watched !== -1) + .find((val) => val.checked === false) ) { parent.checked = true; parent.indeterminate = false; @@ -1034,6 +1176,7 @@ export default { watch_nodes: watchNodes, mode: check ? 1 : 0, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, search_pattern: {name: this.searchedWord}, }; if (this.graphFiles.value === this.$t('debugger.all')) { @@ -1042,10 +1185,15 @@ export default { if (this.nodeTypes.value !== 'all') { params.search_pattern.node_category = this.nodeTypes.value; } - RequestService.updateWatchpoint(params).then( + RequestService.updateWatchpoint(params, this.sessionId).then( (res) => { this.searchCheckedArr = checkedKeys; - if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { + if ( + res && + res.data && + res.data.metadata && + res.data.metadata.enable_recheck !== undefined + ) { this.enableRecheck = res.data.metadata.enable_recheck; } }, @@ -1097,7 +1245,9 @@ export default { this.treeFlag = true; this.$nextTick(() => { setTimeout(() => { - const dom = document.querySelector('.el-tree-node.is-current.is-focusable'); + const dom = document.querySelector( + '.el-tree-node.is-current.is-focusable', + ); if (dom) { dom.scrollIntoView(); } @@ -1116,6 +1266,7 @@ export default { name: this.searchWord, watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.graph_name; @@ -1124,7 +1275,7 @@ export default { params.node_category = this.nodeTypes.value; } const loadingInstance = this.$loading(this.loadingOption); - RequestService.search(params).then( + RequestService.search(params, this.sessionId).then( (res) => { loadingInstance.close(); if (res.data && res.data.nodes) { @@ -1181,7 +1332,10 @@ export default { children.forEach((val) => { const node = this.$refs.searchTree.getNode(val.parentName); val.label = val.name.split('/').pop(); - val.leaf = val.type === 'name_scope' || val.type === 'aggregation_scope' ? false : true; + val.leaf = + val.type === 'name_scope' || val.type === 'aggregation_scope' + ? false + : true; val.showCheckbox = val.watched !== -1; this.$refs.searchTree.append(val, node); node.expanded = true; @@ -1221,88 +1375,120 @@ export default { val.label = val.name.split('/').pop(); }); }, + retrieveAll() { + this.loadingInstance = this.$loading(this.loadingOption); + const params = { + mode: 'all', + }; + RequestService.retrieve(params, this.sessionId).then( + (res) => { + this.initFail = false; + this.dialogVisible = false; + if (res.data) { + if (res.data.graph && res.data.graph.nodes) { + this.origialTree = res.data.graph.nodes.map((val) => { + return { + label: val.name.split('/').pop(), + leaf: + val.type === 'name_scope' || + val.type === 'aggregation_scope' + ? false + : true, + ...val, + }; + }); + this.resolve(this.origialTree); + this.dealGraphData( + JSON.parse(JSON.stringify(res.data.graph.nodes)), + ); + } + if (res.data.devices && res.data.devices.length) { + this.devices = res.data.devices; + this.logicCard.value = this.devices[0].rank_id; + this.graphFiles.options = JSON.parse( + JSON.stringify(this.devices[0].graph_names), + ); + if (this.graphFiles.options.length > 1) { + this.graphFiles.options.unshift(this.$t('debugger.all')); + } + this.graphFiles.value = this.graphFiles.options[0]; + this.logicCard.options = this.devices.map((val) => val.rank_id); + } + if (res.data.watch_points) { + this.watchPointArr = res.data.watch_points.map((val) => { + return { + id: val.id, + condition: val.watch_condition.id, + params: val.watch_condition.params || [], + selected: false, + }; + }); + } + if (res.data.metadata) { + if (res.data.metadata.debugger_version) { + this.debuggerVersion = res.data.metadata.debugger_version; + } + this.metadata = res.data.metadata; + if ( + res && + res.data && + res.data.metadata && + res.data.metadata.enable_recheck !== undefined + ) { + this.enableRecheck = res.data.metadata.enable_recheck; + } + if (this.metadata.backend) { + this.version = this.metadata.backend; + } + if ( + !res.data.metadata.recommendation_confirmed && + this.sessionId && + this.metadata.state === this.state.waiting + ) { + this.recommendWatchPointDialog = true; + } + + this.nodeName = this.metadata.node_name; + this.currentNodeName = this.nodeName; + if ( + this.metadata.state === this.state.pending || + this.metadata.state === this.state.mismatch + ) { + this.loadingInstance.close(); + } + if (this.pollInit) { + this.pollData(); + this.pollInit = false; + } + if (this.devices && this.devices.length) { + this.metadata.ip = this.devices[0].server_ip; + this.metadata.device_name = this.devices[0].device_id; + } + } + } + }, + (err) => { + this.initFail = true; + this.dialogVisible = true; + this.loadingInstance.close(); + }, + ); + }, /** * Draw the tree * @param {Object} node tree root node * @param {Function} resolve callback function ,return next node data */ loadNode(node, resolve) { - this.loadingInstance = this.$loading(this.loadingOption); if (node.level === 0) { node.childNodes = []; if (!this.node && !this.resolve) { this.node = node; this.resolve = resolve; } - const params = { - mode: 'all', - }; - RequestService.retrieve(params).then( - (res) => { - this.initFail = false; - this.dialogVisible = false; - if (res.data) { - if (res.data.graph && res.data.graph.nodes) { - this.graphFiles.options = res.data.graph.graph_names || []; - if (this.graphFiles.options.length > 1) { - this.graphFiles.options.unshift(this.$t('debugger.all')); - } - this.graphFiles.value = this.graphFiles.options[0]; - this.origialTree = res.data.graph.nodes.map((val) => { - return { - label: val.name.split('/').pop(), - leaf: val.type === 'name_scope' || val.type === 'aggregation_scope' ? false : true, - ...val, - }; - }); - resolve(this.origialTree); - this.dealGraphData(JSON.parse(JSON.stringify(res.data.graph.nodes))); - } - if (res.data.watch_points) { - this.watchPointArr = res.data.watch_points.map((val) => { - return { - id: val.id, - condition: val.watch_condition.id, - params: val.watch_condition.params || [], - selected: false, - }; - }); - } - if (res.data.metadata) { - if (res.data.metadata.debugger_version) { - this.debuggerVersion = res.data.metadata.debugger_version; - } - this.metadata = res.data.metadata; - if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { - this.enableRecheck = res.data.metadata.enable_recheck; - } - if (this.metadata.backend) { - this.version = this.metadata.backend; - } - this.trainId = encodeURIComponent(res.data.metadata.ip); - if (!res.data.metadata.recommendation_confirmed && this.trainId) { - this.recommendWatchPointDialog = true; - } - - this.nodeName = this.metadata.node_name; - this.currentNodeName = this.nodeName; - if (this.metadata.state === this.state.pending || this.metadata.state === this.state.mismatch) { - this.loadingInstance.close(); - } - if (this.pollInit) { - this.pollData(); - this.pollInit = false; - } - } - } - }, - (err) => { - this.initFail = true; - this.dialogVisible = true; - this.loadingInstance.close(); - }, - ); + resolve([]); } else if (node.level >= 1) { + this.loadingInstance = this.$loading(this.loadingOption); this.isIntoView = false; const curHalfCheckedKeys = this.$refs.tree.getHalfCheckedKeys(); const params = { @@ -1312,12 +1498,13 @@ export default { watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, name: node.data.name, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.params.graph_name; } - RequestService.retrieve(params).then( + RequestService.retrieve(params, this.sessionId).then( (res) => { if (res.data && res.data.metadata) { this.dealMetadata(res.data.metadata); @@ -1327,7 +1514,11 @@ export default { this.curNodeData = graph.nodes.map((val) => { return { label: val.name.split('/').pop(), - leaf: val.type === 'name_scope' || val.type === 'aggregation_scope' ? false : true, + leaf: + val.type === 'name_scope' || + val.type === 'aggregation_scope' + ? false + : true, ...val, showCheckbox: val.watched !== -1, }; @@ -1363,12 +1554,21 @@ export default { val.checked = false; } }); - [...new Set(curHalfCheckedKeys.concat(this.$refs.tree.getHalfCheckedKeys()))].forEach((val) => { + [ + ...new Set( + curHalfCheckedKeys.concat( + this.$refs.tree.getHalfCheckedKeys(), + ), + ), + ].forEach((val) => { this.$refs.tree.getNode(val).indeterminate = true; }); this.selectedNode.name = node.data.name; if (!this.allGraphData[node.data.name].isUnfold) { - this.dealGraphData(JSON.parse(JSON.stringify(graph.nodes)), node.data.name); + this.dealGraphData( + JSON.parse(JSON.stringify(graph.nodes)), + node.data.name, + ); } else { this.selectNode(true); } @@ -1412,12 +1612,13 @@ export default { watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, name: node.data.name, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.params.graph_name; } - RequestService.retrieve(params).then((res) => { + RequestService.retrieve(params, this.sessionId).then((res) => { if (res.data && res.data.metadata) { this.dealMetadata(res.data.metadata); } @@ -1425,7 +1626,10 @@ export default { this.curNodeData = res.data.graph.nodes.map((val) => { return { label: val.name.split('/').pop(), - leaf: val.type === 'name_scope' || val.type === 'aggregation_scope' ? false : true, + leaf: + val.type === 'name_scope' || val.type === 'aggregation_scope' + ? false + : true, ...val, showCheckbox: val.watched !== -1, }; @@ -1453,7 +1657,13 @@ export default { node.indeterminate = true; } }); - [...new Set(curHalfCheckedKeys.concat(this.$refs.searchTree.getHalfCheckedKeys()))].forEach((val) => { + [ + ...new Set( + curHalfCheckedKeys.concat( + this.$refs.searchTree.getHalfCheckedKeys(), + ), + ), + ].forEach((val) => { this.$refs.searchTree.getNode(val).indeterminate = true; }); } @@ -1463,20 +1673,19 @@ export default { initRecommendWatchPoints(value) { this.recommendWatchPointDialog = false; const params = { - trainId: this.trainId, - body: { - requestBody: { - set_recommended: value, - }, + requestBody: { + set_recommended: value, }, }; - RequestService.setRecommendWatchPoints(params).then((res) => { - if (res && res.data) { - if (value) { - this.queryWatchPoints(false); - } - } - }); + RequestService.setRecommendWatchPoints(params, this.sessionId).then( + (res) => { + if (res && res.data) { + if (value) { + this.queryWatchPoints(false); + } + } + }, + ); }, /** * Show data of current selected watchpoint @@ -1512,11 +1721,12 @@ export default { const params = { mode: 'watchpoint', graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.graph_name; } - RequestService.retrieve(params).then( + RequestService.retrieve(params, this.sessionId).then( (res) => { if (res.data.watch_points) { this.watchPointArr = res.data.watch_points.map((val) => { @@ -1531,7 +1741,9 @@ export default { if (focusLast) { this.selectWatchPoint(this.watchPointArr.length - 1); this.$nextTick(() => { - const newWatchPointDom = document.querySelector('#watch-point-list>li:last-child'); + const newWatchPointDom = document.querySelector( + '#watch-point-list>li:last-child', + ); if (newWatchPointDom) { newWatchPointDom.scrollIntoView(); } @@ -1566,7 +1778,11 @@ export default { }); // watched 0:unchecked 1:indeterminate 2:checked -1:no checkbox node.childNodes.forEach((val) => { - if (node.checked && !node.childNodes.find((val) => val.data.watched !== 2) && val.data.watched !== -1) { + if ( + node.checked && + !node.childNodes.find((val) => val.data.watched !== 2) && + val.data.watched !== -1 + ) { val.checked = true; } if (val.data.watched === this.checkboxStatus.checked) { @@ -1575,7 +1791,10 @@ export default { if (val.data.watched === this.checkboxStatus.indeterminate) { val.indeterminate = true; } - if (val.data.type !== 'name_scope' && val.data.type !== 'aggregation_scope') { + if ( + val.data.type !== 'name_scope' && + val.data.type !== 'aggregation_scope' + ) { val.isLeaf = true; } }); @@ -1586,14 +1805,18 @@ export default { this.$nextTick(() => { if ( node.indeterminate && - !node.childNodes.filter((val) => val.data.watched !== -1).find((val) => val.checked === false) + !node.childNodes + .filter((val) => val.data.watched !== -1) + .find((val) => val.checked === false) ) { node.indeterminate = false; node.checked = true; this.dealParentNode(node); } setTimeout(() => { - const dom = document.querySelector('.el-tree-node.is-current.is-focusable'); + const dom = document.querySelector( + '.el-tree-node.is-current.is-focusable', + ); if (dom) { dom.scrollIntoView(); } @@ -1606,12 +1829,16 @@ export default { searchWatchpointHits(type) { if (this.radio1 === 'hit') { const params = {}; - const condition = {}; + const condition = { + rank_id: this.logicCard.value, + }; if (type) { if (this.selectedNode.name) { if (this.graphFiles.value === this.$t('debugger.all')) { const arr = this.selectedNode.name.split('/'); - condition.node_name = arr[1] ? this.selectedNode.name.replace(`${arr[0]}/`, '') : arr[0]; + condition.node_name = arr[1] + ? this.selectedNode.name.replace(`${arr[0]}/`, '') + : arr[0]; condition.graph_name = arr[0]; } else { condition.node_name = this.selectedNode.name; @@ -1625,13 +1852,13 @@ export default { } condition.limit = this.pagination.pageSize; params.group_condition = condition; - RequestService.searchWatchpointHits(params).then( + RequestService.searchWatchpointHits(params, this.sessionId).then( (res) => { if (res.data.metadata) { this.dealMetadata(res.data.metadata); } + this.hitsOutdated = res.data.outdated; if (res.data && res.data.watch_point_hits) { - this.hitsOutdated = res.data.outdated; this.watchPointHits = []; this.pagination.total = res.data.total; this.pagination.currentPage = res.data.offset + 1; @@ -1659,7 +1886,9 @@ export default { } else { this.$nextTick(() => { setTimeout(() => { - const dom = document.querySelector('.el-tree-node.is-current.is-focusable'); + const dom = document.querySelector( + '.el-tree-node.is-current.is-focusable', + ); if (dom) { dom.scrollIntoView(); } @@ -1676,18 +1905,25 @@ export default { selected: false, id: hit.node_name, graph_name: hit.graph_name, + rank_id: this.logicCard.value, }; if (hit.tensors && hit.tensors.length) { hit.tensors.forEach((i) => { const tensorName = `slot: ${i.slot}, `; if (i.watch_points && i.watch_points.length) { i.watch_points.forEach((j, key) => { - let item = `${tensorName}${this.$t('debugger.watchPoint')} ${j.id}, `; + let item = `${tensorName}${this.$t('debugger.watchPoint')} ${ + j.id + }, `; let params = []; if (j.watch_condition) { item += ` ${this.transCondition(j.watch_condition.id)}`; - this.formateWatchpointParams(j.watch_condition.params || []); - params = JSON.parse(JSON.stringify(j.watch_condition.params)); + this.formateWatchpointParams( + j.watch_condition.params || [], + ); + params = JSON.parse( + JSON.stringify(j.watch_condition.params), + ); } obj.lists.push({ name: item, @@ -1699,7 +1935,8 @@ export default { .map((i) => { return this.$t('debugger.checkTips')[i]; }) - .join('') + this.$t('debugger.checkTips').cannotCheck + .join('') + + this.$t('debugger.checkTips').cannotCheck : '', }); }); @@ -1719,7 +1956,10 @@ export default { if (this.selectedNode.name) { let selectedNodeName = this.selectedNode.name; if (this.graphFiles.value === this.$t('debugger.all')) { - selectedNodeName = selectedNodeName.replace(`${selectedNodeName.split('/')[0]}/`, ''); + selectedNodeName = selectedNodeName.replace( + `${selectedNodeName.split('/')[0]}/`, + '', + ); } this.expandKeys = []; let focused = false; @@ -1761,6 +2001,7 @@ export default { single_node: true, watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, graph_name: currentHit.graph_name, + rank_id: this.logicCard.value, }, }; if (this.graphFiles.value === this.$t('debugger.all')) { @@ -1775,12 +2016,15 @@ export default { } }); this.watchPointHits = JSON.parse(JSON.stringify(this.watchPointHits)); - RequestService.retrieve(params).then( + RequestService.retrieve(params, this.sessionId).then( (res) => { if (res.data.metadata) { this.dealMetadata(res.data.metadata); } - this.retrieveTensorHistory({name: this.nodeName}, currentHit.graph_name); + this.retrieveTensorHistory( + {name: this.nodeName}, + currentHit.graph_name, + ); if (res.data && res.data.graph) { const graph = res.data.graph; @@ -1791,7 +2035,11 @@ export default { this.graphFiles.value = currentHit.graph_name; this.resetAllData(graph, params.params.name); } else { - this.querySingleNode(JSON.parse(JSON.stringify(graph)), params.params.name, true); + this.querySingleNode( + JSON.parse(JSON.stringify(graph)), + params.params.name, + true, + ); } if (graph.children) { this.dealTreeData(graph.children, name); @@ -1825,7 +2073,11 @@ export default { watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, }, }; - if (this.graphFiles.value === this.$t('debugger.all') && graphName && name) { + if ( + this.graphFiles.value === this.$t('debugger.all') && + graphName && + name + ) { if (name !== graphName) { name = `${graphName}/${name}`; params.params.name = name; @@ -1833,7 +2085,7 @@ export default { } else { params.params.graph_name = graphName; } - RequestService.retrieve(params).then( + RequestService.retrieve(params, this.sessionId).then( (res) => { if (res.data && res.data.metadata) { this.dealMetadata(res.data.metadata); @@ -1844,7 +2096,11 @@ export default { this.resetAllData(graph, name); this.isCurrentGraph = true; } else { - this.querySingleNode(JSON.parse(JSON.stringify(graph)), name, true); + this.querySingleNode( + JSON.parse(JSON.stringify(graph)), + name, + true, + ); } if (graph.children) { this.dealTreeData(graph.children, name); @@ -1866,7 +2122,8 @@ export default { if (children.nodes) { if ( (children.nodes.length > this.nodesCountLimit && - this.$refs.tree.getNode(children.scope_name).data.type === 'name_scope') || + this.$refs.tree.getNode(children.scope_name).data.type === + 'name_scope') || this.allGraphData[children.scope_name].maxChainNum > this.maxChainNum ) { return; @@ -1881,7 +2138,11 @@ export default { data.forEach((val) => { const node = this.$refs.tree.getNode(children.scope_name); if (node.childNodes) { - if (node.childNodes.map((value) => value.data.name).indexOf(val.name) === -1) { + if ( + node.childNodes + .map((value) => value.data.name) + .indexOf(val.name) === -1 + ) { this.$refs.tree.append(val, node); } } else { @@ -1897,7 +2158,10 @@ export default { if (val.data.watched === this.checkboxStatus.indeterminate) { val.indeterminate = true; } - if (val.data.type !== 'name_scope' && val.data.type !== 'aggregation_scope') { + if ( + val.data.type !== 'name_scope' && + val.data.type !== 'aggregation_scope' + ) { val.isLeaf = true; } }); @@ -1907,7 +2171,9 @@ export default { this.$refs.tree.setCurrentKey(name); this.$nextTick(() => { setTimeout(() => { - const dom = document.querySelector('.el-tree-node.is-current.is-focusable'); + const dom = document.querySelector( + '.el-tree-node.is-current.is-focusable', + ); if (dom) { dom.scrollIntoView(); } @@ -1923,7 +2189,10 @@ export default { this.origialTree = graph.nodes.map((val) => { return { label: val.name.split('/').pop(), - leaf: val.type === 'name_scope' || val.type === 'aggregation_scope' ? false : true, + leaf: + val.type === 'name_scope' || val.type === 'aggregation_scope' + ? false + : true, ...val, showCheckbox: val.watched !== -1, }; @@ -1941,7 +2210,11 @@ export default { this.firstFloorNodes = []; this.allGraphData = {}; d3.select('#graph svg').remove(); - this.packageDataToObject('', true, JSON.parse(JSON.stringify(graph.nodes))); + this.packageDataToObject( + '', + true, + JSON.parse(JSON.stringify(graph.nodes)), + ); if (name) { this.querySingleNode(JSON.parse(JSON.stringify(graph)), name, true); } else { diff --git a/mindinsight/ui/src/router.js b/mindinsight/ui/src/router.js index f1a09794..5ad692a2 100644 --- a/mindinsight/ui/src/router.js +++ b/mindinsight/ui/src/router.js @@ -157,6 +157,10 @@ export default new Router({ path: '/debugger', component: () => import('./views/debugger/debugger.vue'), }, + { + path: '/offline-debugger', + component: () => import('./views/debugger/debugger.vue'), + }, { path: '/explain', component: () => import('./views/explain/summary-list.vue'), diff --git a/mindinsight/ui/src/services/fetcher.js b/mindinsight/ui/src/services/fetcher.js index b2622af0..1a5f55ba 100644 --- a/mindinsight/ui/src/services/fetcher.js +++ b/mindinsight/ui/src/services/fetcher.js @@ -62,7 +62,14 @@ axios.interceptors.response.use( const errorData = i18n.messages[i18n.locale].error; const path = router.currentRoute.path; - if (path === '/debugger') { + if (path === '/debugger' || path === '/offline-debugger') { + if ( + error.response && + error.response.data && + error.response.data.error_code === '5054B281' + ) { + router.push('/'); + } return Promise.reject(error); } // error returned by backend diff --git a/mindinsight/ui/src/services/request-service.js b/mindinsight/ui/src/services/request-service.js index 4c710213..bd8bae79 100644 --- a/mindinsight/ui/src/services/request-service.js +++ b/mindinsight/ui/src/services/request-service.js @@ -309,55 +309,74 @@ export default { }); }, // debugger - pollData(params) { + getSession(params) { + return axios({ + method: 'post', + url: 'v1/mindinsight/debugger/sessions', + data: params, + }); + }, + deleteSession(sessionId) { + return axios({ + method: 'post', + url: `v1/mindinsight/debugger/sessions/${sessionId}/delete`, + }); + }, + checkSessions() { return axios({ method: 'get', - url: 'v1/mindinsight/debugger/poll-data', + url: `v1/mindinsight/debugger/sessions`, + }); + }, + pollData(params, sessionId) { + return axios({ + method: 'get', + url: `v1/mindinsight/debugger/sessions/${sessionId}/poll-data`, params: params, headers: { ignoreError: true, }, }); }, - retrieve(params) { + retrieve(params, sessionId) { return axios({ method: 'post', - url: 'v1/mindinsight/debugger/retrieve', + url: `v1/mindinsight/debugger/sessions/${sessionId}/retrieve`, data: params, }); }, - createWatchpoint(params) { + createWatchpoint(params, sessionId) { return axios({ method: 'post', - url: 'v1/mindinsight/debugger/create-watchpoint', + url: `v1/mindinsight/debugger/sessions/${sessionId}/create-watchpoint`, data: params, }); }, - updateWatchpoint(params) { + updateWatchpoint(params, sessionId) { return axios({ method: 'post', - url: 'v1/mindinsight/debugger/update-watchpoint', + url: `v1/mindinsight/debugger/sessions/${sessionId}/update-watchpoint`, data: params, }); }, - deleteWatchpoint(params) { + deleteWatchpoint(params, sessionId) { return axios({ method: 'post', - url: 'v1/mindinsight/debugger/delete-watchpoint', + url: `v1/mindinsight/debugger/sessions/${sessionId}/delete-watchpoint`, data: params, }); }, - control(params) { + control(params, sessionId) { return axios({ method: 'post', - url: 'v1/mindinsight/debugger/control', + url: `v1/mindinsight/debugger/sessions/${sessionId}/control`, data: params, }); }, - search(params) { + search(params, sessionId) { return axios({ method: 'get', - url: 'v1/mindinsight/debugger/search', + url: `v1/mindinsight/debugger/sessions/${sessionId}/search`, params: params, }); }, @@ -368,43 +387,43 @@ export default { params: params, }); }, - tensorComparisons(params) { + tensorComparisons(params, sessionId) { return axios({ method: 'get', - url: 'v1/mindinsight/debugger/tensor-comparisons', + url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-comparisons`, params: params, }); }, - tensors(params) { + tensors(params, sessionId) { return axios({ method: 'get', - url: 'v1/mindinsight/debugger/tensors', + url: `v1/mindinsight/debugger/sessions/${sessionId}/tensors`, params: params, }); }, - retrieveTensorHistory(params) { + retrieveTensorHistory(params, sessionId) { return axios({ method: 'post', - url: 'v1/mindinsight/debugger/tensor-history', + url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-history`, data: params, }); }, - queryConditions(trainId) { + queryConditions(sessionId) { return axios({ method: 'get', - url: `v1/mindinsight/conditionmgr/train-jobs/${trainId}/condition-collections`, + url: `v1/mindinsight/debugger/sessions/${sessionId}/condition-collections`, }); }, - recheckWatchPoints() { + recheckWatchPoints(sessionId) { return axios({ method: 'post', - url: `v1/mindinsight/debugger/recheck`, + url: `v1/mindinsight/debugger/sessions/${sessionId}/recheck`, }); }, - searchWatchpointHits(params) { + searchWatchpointHits(params, sessionId) { return axios({ method: 'post', - url: `v1/mindinsight/debugger/search-watchpoint-hits`, + url: `v1/mindinsight/debugger/sessions/${sessionId}/search-watchpoint-hits`, data: params, }); }, @@ -447,33 +466,25 @@ export default { data: params, }); }, - tensorHitsData(params) { + tensorHitsData(params, sessionId) { return axios({ method: 'get', - url: 'v1/mindinsight/debugger/tensor-hits', + url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-hits`, params: params, }); }, - getTensorGraphData(params) { + getTensorGraphData(params, sessionId) { return axios({ method: 'get', - url: 'v1/mindinsight/debugger/tensor-graphs', + url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-graphs`, params: params, }); }, - getCpuUtilization(params) { - return axios({ - method: 'post', - url: 'v1/mindinsight/profile/minddata-cpu-utilization-summary', - params: params.params, - data: params.body, - }); - }, - setRecommendWatchPoints(params) { + setRecommendWatchPoints(params, sessionId) { return axios({ method: 'post', - url: `v1/mindinsight/conditionmgr/train-jobs/${params.trainId}/set-recommended-watch-points`, - data: params.body, + url: `v1/mindinsight/debugger/sessions/${sessionId}/set-recommended-watch-points`, + data: params, }); }, // memory-datail apis diff --git a/mindinsight/ui/src/views/debugger/debugger.vue b/mindinsight/ui/src/views/debugger/debugger.vue index b24c01b5..516b4cdb 100644 --- a/mindinsight/ui/src/views/debugger/debugger.vue +++ b/mindinsight/ui/src/views/debugger/debugger.vue @@ -46,6 +46,17 @@ limitations under the License.
+
+
{{ $t('debugger.logicCard') }}
+ + + + +
{{ $t('debugger.graphFile') }}
+
+
{{ $t('debugger.logicCard') }}
+ + + + +
+ + + + + +
@@ -505,7 +546,7 @@ limitations under the License. :close-on-click-modal="false" :modal-append-to-body="false" class="creat-watch-point-dialog" - width="890px"> + width="930px">
val.rank_id === this.logicCard.value).graph_names), + ); + if (this.graphFiles.options.length > 1) { + this.graphFiles.options.unshift(this.$t('debugger.all')); + } + this.graphFiles.value = this.graphFiles.options[0]; + const device = this.devices.find((val) => val.rank_id === this.logicCard.value); + this.metadata.ip = device.server_ip; + this.metadata.device_name = device.device_id; + this.queryGraphByFile(); + }, queryGraphByFile() { this.searchWord = ''; this.nodeTypes.value = 'all'; @@ -931,12 +1001,13 @@ export default { params: { watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.params.graph_name; } - RequestService.retrieve(params).then( + RequestService.retrieve(params, this.sessionId).then( (res) => { if (res.data && res.data.metadata) { this.dealMetadata(res.data.metadata); @@ -975,6 +1046,7 @@ export default { d3.select('#graph svg').remove(); this.selectedNode.name = ''; this.dealGraphData(JSON.parse(JSON.stringify(graph.nodes))); + this.tableData = []; } }, (err) => { @@ -1015,11 +1087,12 @@ export default { watch_nodes: watchNodes, mode: type ? 1 : 0, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.graph_name; } - RequestService.updateWatchpoint(params).then( + RequestService.updateWatchpoint(params, this.sessionId).then( (res) => { this.defaultCheckedArr = this.$refs.tree.getCheckedKeys(); if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { @@ -1049,12 +1122,16 @@ export default { queryGraphByWatchpoint(id) { const params = { mode: 'watchpoint', - params: {watch_point_id: id, graph_name: this.graphFiles.value}, + params: { + watch_point_id: id, + graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, + }, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.params.graph_name; } - RequestService.retrieve(params).then( + RequestService.retrieve(params, this.sessionId).then( (res) => { if (res.data && res.data.graph) { const graph = res.data.graph; @@ -1306,11 +1383,12 @@ export default { level: 'node', name: this.selectedNode.name.replace('_unfold', ''), graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.graph_name; } - RequestService.control(params).then( + RequestService.control(params, this.sessionId).then( (res) => { if (res && res.data) { } @@ -1387,12 +1465,13 @@ export default { node_type: type, single_node: false, graph_name: this.graphFiles.value, + rank_id: this.logicCard.value, }; if (this.graphFiles.value === this.$t('debugger.all')) { delete params.params.graph_name; } } - RequestService.retrieve(params) + RequestService.retrieve(params, this.sessionId) .then( (response) => { if (response && response.data && response.data.graph) { @@ -1560,7 +1639,12 @@ export default { graphName = key.split('/')[0]; key = key.replace(`${graphName}/`, ''); } - const obj = {name: key, IOType: 'output', graph_name: graphName}; + const obj = { + name: key, + IOType: 'output', + graph_name: graphName, + rank_id: this.logicCard.value, + }; IOInfo.push(obj); this.selectedNode.outputNum++; }); @@ -1572,7 +1656,12 @@ export default { graphName = key.split('/')[0]; key = key.replace(`${graphName}/`, ''); } - const obj = {name: key, IOType: 'input', graph_name: graphName}; + const obj = { + name: key, + IOType: 'input', + graph_name: graphName, + rank_id: this.logicCard.value, + }; IOInfo.push(obj); this.selectedNode.inputNum++; }); @@ -1606,11 +1695,7 @@ export default { `translate(${this.graph.transform.x},` + `${this.graph.transform.y}) scale(${this.graph.transform.k})`, ); - const transitionTime = Math.min( - Math.abs(screenChange.x) * 2, - Math.abs(screenChange.y) * 2, - needDelay ? 800 : 0, - ); + const transitionTime = Math.min(Math.abs(screenChange.x) * 2, Math.abs(screenChange.y) * 2, needDelay ? 800 : 0); this.graph.dom.style.transition = `${transitionTime / 1000}s`; this.graph.dom.style['transition-timing-function'] = 'linear'; @@ -1829,8 +1914,8 @@ export default { height: calc(100% - 145px); } .deb-wrap .left-wrap .left .content .node-type { - height: 50px; - padding: 15px 15px 0 15px; + height: 40px; + padding: 10px 15px 0 15px; } .deb-wrap .left-wrap .left .content .node-type .label { display: inline-block; @@ -1855,7 +1940,7 @@ export default { font-size: 12px; } .deb-wrap .left-wrap .left .content .tree-wrap { - height: calc(70% - 155px); + height: calc(70% - 172px); overflow-y: auto; padding: 0 15px 15px; position: relative; @@ -1973,12 +2058,13 @@ export default { color: red; } .deb-wrap .left-wrap .left .content .hit-list-wrap { - height: 100%; + height: calc(100% - 40px); padding: 10px; } .deb-wrap .left-wrap .left .content .hit-list-wrap .watchpoint-table { max-height: calc(100% - 45px); overflow: auto; + margin-top: 10px; } .deb-wrap .left-wrap .left .content .hit-list-wrap .el-table::before { height: 0; @@ -2096,7 +2182,7 @@ export default { /* Opera */ } .deb-wrap .right .header { - padding: 15px; + line-height: 51px; border-bottom: 1px solid #ebeef5; position: relative; background: #fff; @@ -2113,6 +2199,25 @@ export default { .deb-wrap .right .header .item + .item { margin-left: 15px; } +.deb-wrap .right .header .el-icon-edit { + margin-left: 5px; +} +.deb-wrap .right .header i { + font-size: 18px; + margin: 0 2px; + color: #00a5a7; + cursor: pointer; +} +.deb-wrap .right .header .el-icon-close { + color: #f56c6c; +} +.deb-wrap .right .header .el-input { + width: 45px; +} +.deb-wrap .right .header .el-input input { + padding: 0; + text-align: center; +} .deb-wrap .right .header .tooltip { margin-left: 5px; cursor: pointer; @@ -2343,13 +2448,13 @@ export default { display: none; } .deb-wrap .creat-watch-point-dialog .conditions-container .collection { - width: 200px; + width: 210px; } .deb-wrap .creat-watch-point-dialog .conditions-container .condition, .deb-wrap .creat-watch-point-dialog .conditions-container .param, .deb-wrap .creat-watch-point-dialog .conditions-container .param-value { margin-left: 10px; - width: 200px; + width: 210px; } .deb-wrap .creat-watch-point-dialog .conditions-container .percent-sign { display: inline-block; diff --git a/mindinsight/ui/src/views/train-manage/summary-manage.vue b/mindinsight/ui/src/views/train-manage/summary-manage.vue index f66c2899..c8111339 100644 --- a/mindinsight/ui/src/views/train-manage/summary-manage.vue +++ b/mindinsight/ui/src/views/train-manage/summary-manage.vue @@ -96,6 +96,16 @@ limitations under the License. :title="$t('summaryManage.disableProfilerTip')"> {{$t('summaryManage.viewProfiler')}} + + {{$t('summaryManage.viewOfflineDebugger')}} + + {{$t('summaryManage.viewOfflineDebugger')}} + @@ -157,6 +167,45 @@ limitations under the License.
  • {{$t('summaryManage.openNewTab')}}
  • + + + {{ debuggerDialog.title }} + + + + +
    {{ $t('summaryManage.sessionLists') }}
    + + + + + + + + + + +
    @@ -223,7 +272,12 @@ export default { type: 0, }, tableDom: null, - operateWidth: localStorage.getItem('milang') === 'en-us' ? 400 : 290, + operateWidth: localStorage.getItem('milang') === 'en-us' ? 550 : 400, + debuggerDialog: { + title: this.$t('summaryManage.sessionLimit'), + showDialogModel: false, + trainJobs: [], + }, }; }, computed: {}, @@ -286,6 +340,7 @@ export default { i.update_time = i.update_time ? i.update_time : '--'; i.viewProfiler = i.profiler_dir && i.profiler_dir.length; i.viewDashboard = i.summary_files || i.graph_files || i.lineage_files; + i.viewOfflineDebugger = i.dump_dir; i.paramDetails = i.lineage_files; }); this.currentFolder = res.data.name ? res.data.name : '--'; @@ -363,7 +418,83 @@ export default { }, }); }, - + /** + * go to Offline Debugger + * @param {Object} row select row + */ + goToOfflineDebugger(row) { + this.contextMenu.show = false; + const debuggerDir = row.dump_dir; + const params = { + session_type: 'OFFLINE', + dump_dir: debuggerDir, + }; + this.getSessionId(params).then((value) => { + if (value !== undefined) { + this.$router.push({ + path: '/offline-debugger', + query: { + dir: debuggerDir, + sessionId: value, + }, + }); + } + }); + }, + getSessionId(params) { + return RequestService.getSession(params).then( + (res) => { + if (res && res.data) { + const sessionId = res.data; + return sessionId; + } + }, + (error) => { + if (error && error.response && error.response.data && error.response.data.error_code === '5054B280') { + this.checkSessions(); + } + }, + ); + }, + deleteSession(sessionId) { + this.$confirm(this.$t('summaryManage.deleteSessionConfirm'), this.$t('public.notice'), { + confirmButtonText: this.$t('public.sure'), + cancelButtonText: this.$t('public.cancel'), + type: 'warning', + }).then(() => { + RequestService.deleteSession(sessionId).then((res) => { + this.$message({ + type: 'success', + message: this.$t('summaryManage.deleteSessionSuccess'), + }); + this.checkSessions(); + }); + }); + }, + checkSessions() { + RequestService.checkSessions().then((res) => { + if (res && res.data && res.data.train_jobs) { + const trainJobs = res.data.train_jobs; + this.debuggerDialog.trainJobs = Object.keys(trainJobs).map((val) => { + return { + relative_path: decodeURIComponent(val), + session_id: trainJobs[val], + }; + }); + this.debuggerDialog.showDialogModel = true; + } + }); + }, + viewSession(row) { + const dir = row.relative_path; + this.$router.push({ + path: '/offline-debugger', + query: { + dir, + sessionId: row.session_id, + }, + }); + }, rightClick(row, event, type) { const maxWidth = 175; this.contextMenu.data = row; @@ -380,7 +511,28 @@ export default { if (!row) { return; } - if (this.contextMenu.type) { + if (this.contextMenu.type === 2) { + // open offline debugger + this.contextMenu.show = false; + const debuggerDir = row.dump_dir; + const params = { + session_type: 'OFFLINE', + dump_dir: debuggerDir, + }; + this.getSessionId(params).then((value) => { + if (value !== undefined) { + const routeUrl = this.$router.resolve({ + path: '/offline-debugger', + query: { + dir: debuggerDir, + sessionId: value, + }, + }); + window.open(routeUrl.href, '_blank'); + } + }); + } else if (this.contextMenu.type === 1) { + // open profiling this.contextMenu.show = false; const profilerDir = encodeURIComponent(row.profiler_dir); const trainId = encodeURIComponent(row.train_id); @@ -400,7 +552,7 @@ export default { }, }); window.open(routeUrl.href, '_blank'); - } else { + } else { // open training dashboard this.contextMenu.show = false; const trainId = encodeURIComponent(row.train_id); @@ -693,6 +845,16 @@ export default { #cl-summary-manage .details-data-list .el-dialog__body .details-data-title { margin-bottom: 20px; } +#cl-summary-manage .details-data-list .sessionMsg { + color: #333; + font-weight: bold; + font-size: 16px; + margin-right: 5px; +} +#cl-summary-manage .details-data-list .session-title { + margin-bottom: 10px; + color: #333; +} #cl-summary-manage .is-disabled.custom-btn { background-color: #f5f5f6; border: 1px solid #dfe1e6 !important; diff --git a/mindinsight/backend/conditionmgr/__init__.py b/mindinsight/utils/folder_analyzer.py similarity index 67% rename from mindinsight/backend/conditionmgr/__init__.py rename to mindinsight/utils/folder_analyzer.py index 5924bd78..495ec063 100644 --- a/mindinsight/backend/conditionmgr/__init__.py +++ b/mindinsight/utils/folder_analyzer.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,15 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Module init file.""" -from mindinsight.backend.conditionmgr.conditionmgr_api import init_module as init_query_module +"""Train job register.""" -def init_module(app): - """ - Init module entry. - - Args: - app (Flask): A Flask instance. - """ - init_query_module(app) +class FolderAnalyzer: + """Train job register. The subclass should implement the analyze method and return update info.""" + def analyze(self, entry, summary_base_dir, relative_path): + """Analyze file.""" diff --git a/requirements.txt b/requirements.txt index 396be80f..723b9b27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,6 @@ six>=1.12.0 Werkzeug>=1.0.0 pandas>=1.0.4 yapf>=0.30.0 -grpcio>=1.27.3 \ No newline at end of file +treelib>=1.6.1 +grpcio>=1.27.3 +XlsxWriter>=1.2.9 \ No newline at end of file diff --git a/tests/st/func/debugger/conftest.py b/tests/st/func/debugger/conftest.py index 7246dea7..b8bde9a4 100644 --- a/tests/st/func/debugger/conftest.py +++ b/tests/st/func/debugger/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,13 +25,15 @@ from mindinsight.conf import settings from mindinsight.datavisual.utils import tools from mindinsight.debugger.proto import ms_graph_pb2 from mindinsight.debugger.stream_handler.graph_handler import GraphHandler +from mindinsight.debugger.session_manager import SessionManager GRAPH_PROTO_FILE = os.path.join( os.path.dirname(__file__), '../../../utils/resource/graph_pb/lenet.pb' ) -DEBUGGER_BASE_URL = '/v1/mindinsight/debugger' -DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expect_results') +DEBUGGER_BASE_URL = '/v1/mindinsight/debugger/sessions/0/' +DEBUGGER_TEST_BASE_DIR = os.path.dirname(__file__) +DEBUGGER_EXPECTED_RESULTS = os.path.join(DEBUGGER_TEST_BASE_DIR, 'expect_results') def init_graph_handler(): @@ -51,14 +53,13 @@ def init_graph_handler(): @pytest.fixture(scope='session') def app_client(): """This fixture is flask server.""" - packages = ["mindinsight.backend.debugger", "mindinsight.backend.conditionmgr"] + packages = ["mindinsight.backend.debugger"] settings.ENABLE_DEBUGGER = True mock_obj = Mock(return_value=packages) tools.find_app_package = mock_obj from mindinsight.backend.application import APP - from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER APP.response_class = Response client = APP.test_client() original_val = settings.ENABLE_RECOMMENDED_WATCHPOINTS @@ -67,4 +68,4 @@ def app_client(): yield client finally: settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_val - BACKEND_SERVER.stop() + SessionManager.get_instance().online_session.stop() diff --git a/tests/st/func/debugger/debugger_services/__init__.py b/tests/st/func/debugger/debugger_services/__init__.py new file mode 100644 index 00000000..3dc1bea3 --- /dev/null +++ b/tests/st/func/debugger/debugger_services/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Function: + Test debugger services. +Usage: + pytest tests/st/func/debugger +""" diff --git a/tests/st/func/debugger/debugger_services/mock_dbg_services.py b/tests/st/func/debugger/debugger_services/mock_dbg_services.py new file mode 100644 index 00000000..cbe043b0 --- /dev/null +++ b/tests/st/func/debugger/debugger_services/mock_dbg_services.py @@ -0,0 +1,141 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +The module DbgServices provides offline debugger APIs. +""" +from unittest.mock import MagicMock + +import numpy as np + +import mindinsight +from mindinsight.debugger.common.log import LOGGER as log + + +def get_version(): + """Get version.""" + return mindinsight.__version__ + + +class DbgServices: + """ + DbgServices. + + Args: + dump_file_path (str): dir where the dump files are saved. + """ + def __init__(self, dump_file_path, verbose=True): + self._verbose = verbose + self.dump_file_path = dump_file_path + self.dbg_instance = MagicMock() + self._watchpoints = {} + self.print_mes("in Python __init__, file path is {}".format(dump_file_path)) + self.version = get_version() + self.initialized = False + self.is_sync = True + + def print_mes(self, mes): + """Print message.""" + if self._verbose: + log.info(mes) + + def initialize(self, net_name, is_sync_mode): + """Initialize.""" + self.print_mes(" Python Initialize dump_file_path: {}, is_sync: {}".format(net_name, is_sync_mode)) + self.initialized = True + + def add_watchpoint(self, watchpoint_id, watch_condition, check_node_list, parameter_list): + """Add watchpoint.""" + self.print_mes("Add watchpoint with watchpoint id: {}".format(watchpoint_id)) + self._watchpoints[watchpoint_id] = {'watch_condition': watch_condition, + 'check_nodes': check_node_list, + 'parameter_list': parameter_list} + return 0 + + def remove_watchpoint(self, watchpoint_id): + """Remove watchpoints.""" + self.print_mes("Remove watchpoint with watchpoint id: {}".format(watchpoint_id)) + return self._watchpoints.pop(watchpoint_id) + + def check_watchpoints(self, iteration): + """Check watchpoints.""" + self.print_mes("Check watchpoint at iteration: {}".format(iteration)) + watch_hits = [] + for watchpoint_id, watchpoint in self._watchpoints.items(): + # add param hit info + for param in watchpoint.get('parameter_list'): + param.hit = True + param.value = 0.0 + for watch_node_name, node_info in watchpoint.get('check_nodes'): + for device_id in node_info.get('device_id'): + hit = WatchpointHit(watch_node_name, + 0, + watchpoint.get('watch_condition'), + watchpoint_id, + watchpoint.get('parameter_list'), + 0, + device_id) + watch_hits.append(hit) + + return watch_hits + + def read_tensors(self, info): + """Read tensor values.""" + value = np.asarray(list(range(12)), dtype=np.int32).tobytes() + info_list_inst = [] + for _ in range(info): + tensor_data = TensorData(value, len(value), 4, [2, 2, 3]) + info_list_inst.append(tensor_data) + return info_list_inst + + +class TensorInfo: + """Tensor Information.""" + def __init__(self, node_name, slot, iteration, device_id, is_parameter): + self.node_name = node_name + self.slot = slot + self.iteration = iteration + self.device_id = device_id + self.is_parameter = is_parameter + + +class TensorData: + """Tensor data structure.""" + def __init__(self, data_ptr, data_size, dtype, shape): + self.data_ptr = data_ptr + self.data_size = data_size + self.dtype = dtype + self.shape = shape + + +class Parameter: + """Parameter structure.""" + def __init__(self, name, disabled, value, hit=False, actual_value=0.0): + self.name = name + self.disabled = disabled + self.value = value + self.hit = hit + self.actual_value = actual_value + + +class WatchpointHit: + """Watchpoint hit structure.""" + def __init__(self, name, slot, condition, watchpoint_id, parameters, error_code, device_id): + self.name = name + self.slot = slot + self.condition = condition + self.watchpoint_id = watchpoint_id + self.parameters = parameters + self.error_code = error_code + self.device_id = device_id diff --git a/tests/st/func/debugger/debugger_services/test_debugger_services.py b/tests/st/func/debugger/debugger_services/test_debugger_services.py new file mode 100644 index 00000000..7baf48c6 --- /dev/null +++ b/tests/st/func/debugger/debugger_services/test_debugger_services.py @@ -0,0 +1,77 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Function: + Test query debugger services. +Usage: + pytest tests/st/func/debugger +""" + +import os +import shutil +from unittest import mock + +import pytest + +from mindinsight.debugger.debugger_cache import DebuggerCache +from mindinsight.debugger.debugger_services.debugger_server_factory import \ + DebuggerServerFactory, DebuggerServerContext +from tests.st.func.debugger.utils import build_dump_file_structure +from tests.st.func.debugger.debugger_services import mock_dbg_services + + +class TestDebuggerServerFactory: + """Test debugger on Ascend backend.""" + + @classmethod + def setup_class(cls): + """Setup class.""" + cls.debugger_tmp_dir, cls.dump_files_dir = build_dump_file_structure() + cls._dbg_dir = os.path.join(cls.dump_files_dir, 'Ascend/sync') + cls._dbg_server_factory = DebuggerServerFactory() + + @classmethod + def teardown_class(cls): + """Run after test this class.""" + if os.path.exists(cls.debugger_tmp_dir): + shutil.rmtree(cls.debugger_tmp_dir) + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_get_dbg_online_server(self): + """Get debugger online server""" + context = DebuggerServerContext(dbg_mode='online') + server_obj = self._dbg_server_factory.get_debugger_server(DebuggerCache(), context) + server_obj.start() + server_obj.stop() + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + @mock.patch('mindinsight.debugger.debugger_services.debugger_offline_server.import_module') + def test_get_dbg_offline_server(self, mock_import): + """Get debugger offline server""" + mock_import.return_value = mock_dbg_services + context = DebuggerServerContext(dbg_mode='offline', dbg_dir=self._dbg_dir) + server_obj = self._dbg_server_factory.get_debugger_server(DebuggerCache(), context) + server_obj.start() + server_obj.stop() diff --git a/tests/st/func/debugger/dump_files/Ascend/async/.metadata/data_dump.json b/tests/st/func/debugger/dump_files/Ascend/async/.metadata/data_dump.json new file mode 100644 index 00000000..63628b2a --- /dev/null +++ b/tests/st/func/debugger/dump_files/Ascend/async/.metadata/data_dump.json @@ -0,0 +1,15 @@ +{ + "common_dump_settings": { + "dump_mode": 0, + "path": "/absolute_path", + "net_name": "Lenet", + "iteration": 0, + "input_output": 0, + "kernels": ["Default/Conv-op12"], + "support_device": [0,1,2,3,4,5,6,7] + }, + "async_dump_settings": { + "enable": true, + "op_debug_mode": 0 + } +} diff --git a/tests/st/func/debugger/dump_files/Ascend/async/.metadata/hccl.json b/tests/st/func/debugger/dump_files/Ascend/async/.metadata/hccl.json new file mode 100644 index 00000000..e620387f --- /dev/null +++ b/tests/st/func/debugger/dump_files/Ascend/async/.metadata/hccl.json @@ -0,0 +1,23 @@ +{ + "version": "1.0", + "server_count": "1", + "server_list": [ + { + "server_id": "0.0.0.0", + "device": [ + { + "device_id": "0", + "device_ip": "0.0.0.1", + "rank_id": "0" + }, + { + "device_id": "1", + "device_ip": "0.0.0.2", + "rank_id": "1" + } + ], + "host_nic_ip": "reserve" + } + ], + "status": "completed" +} diff --git a/tests/st/func/debugger/dump_files/Ascend/sync/.metadata/data_dump.json b/tests/st/func/debugger/dump_files/Ascend/sync/.metadata/data_dump.json new file mode 100644 index 00000000..14a043a7 --- /dev/null +++ b/tests/st/func/debugger/dump_files/Ascend/sync/.metadata/data_dump.json @@ -0,0 +1,15 @@ +{ + "common_dump_settings": { + "dump_mode": 0, + "path": "/absolute_path", + "net_name": "Lenet", + "iteration": 0, + "input_output": 0, + "kernels": ["Default/Conv-op12"], + "support_device": [0,1,2,3,4,5,6,7] + }, + "e2e_dump_settings": { + "enable": true, + "trans_flag": false + } +} diff --git a/tests/st/func/debugger/dump_files/Ascend/sync/.metadata/hccl.json b/tests/st/func/debugger/dump_files/Ascend/sync/.metadata/hccl.json new file mode 100644 index 00000000..e620387f --- /dev/null +++ b/tests/st/func/debugger/dump_files/Ascend/sync/.metadata/hccl.json @@ -0,0 +1,23 @@ +{ + "version": "1.0", + "server_count": "1", + "server_list": [ + { + "server_id": "0.0.0.0", + "device": [ + { + "device_id": "0", + "device_ip": "0.0.0.1", + "rank_id": "0" + }, + { + "device_id": "1", + "device_ip": "0.0.0.2", + "rank_id": "1" + } + ], + "host_nic_ip": "reserve" + } + ], + "status": "completed" +} diff --git a/tests/st/func/debugger/dump_files/GPU/sync/.metadata/data_dump.json b/tests/st/func/debugger/dump_files/GPU/sync/.metadata/data_dump.json new file mode 100644 index 00000000..14a043a7 --- /dev/null +++ b/tests/st/func/debugger/dump_files/GPU/sync/.metadata/data_dump.json @@ -0,0 +1,15 @@ +{ + "common_dump_settings": { + "dump_mode": 0, + "path": "/absolute_path", + "net_name": "Lenet", + "iteration": 0, + "input_output": 0, + "kernels": ["Default/Conv-op12"], + "support_device": [0,1,2,3,4,5,6,7] + }, + "e2e_dump_settings": { + "enable": true, + "trans_flag": false + } +} diff --git a/tests/st/func/debugger/expect_results/offline_debugger/load_device_info_ascend.json b/tests/st/func/debugger/expect_results/offline_debugger/load_device_info_ascend.json new file mode 100644 index 00000000..48a0dcbe --- /dev/null +++ b/tests/st/func/debugger/expect_results/offline_debugger/load_device_info_ascend.json @@ -0,0 +1,21 @@ +{ + "device_target": "Ascend", + "server_list": [ + { + "server_id": "0.0.0.0", + "device": [ + { + "device_id": "0", + "device_ip": "0.0.0.1", + "rank_id": "0" + }, + { + "device_id": "1", + "device_ip": "0.0.0.2", + "rank_id": "1" + } + ], + "host_nic_ip": "reserve" + } + ] +} \ No newline at end of file diff --git a/tests/st/func/debugger/expect_results/restful_results/multi_next_node.json b/tests/st/func/debugger/expect_results/restful_results/multi_next_node.json index 0a52dbe6..0bbfb3ed 100644 --- a/tests/st/func/debugger/expect_results/restful_results/multi_next_node.json +++ b/tests/st/func/debugger/expect_results/restful_results/multi_next_node.json @@ -1 +1 @@ -{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} \ No newline at end of file +{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0", "graph_1"]}], "watch_points": []} \ No newline at end of file diff --git a/tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json b/tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json index 4357418d..bb54d897 100644 --- a/tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json +++ b/tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json @@ -1 +1 @@ -{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} \ No newline at end of file +{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0", "graph_1"]}], "watch_points": []} \ No newline at end of file diff --git a/tests/st/func/debugger/expect_results/restful_results/retrieve_all.json b/tests/st/func/debugger/expect_results/restful_results/retrieve_all.json index ab7240bc..fee09f2a 100644 --- a/tests/st/func/debugger/expect_results/restful_results/retrieve_all.json +++ b/tests/st/func/debugger/expect_results/restful_results/retrieve_all.json @@ -1 +1 @@ -{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "graph_0", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0"], "nodes": [{"name": "Default", "type": "name_scope", "attr": {}, "input": {"Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op21": {"shape": [[10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op24": {"shape": [[10, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op29": {"shape": [[84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op32": {"shape": [[84, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op37": {"shape": [[120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op40": {"shape": [[120, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropFilter-op48": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Conv2DBackpropFilter-op55": {"shape": [[6, 1, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output": {"Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropInput-op52": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Conv2DBackpropFilter-op55": {"shape": [[32, 1, 32, 32]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/gradMaxPoolWithArgmax/MaxPoolGradWithArgmax-op53": {"shape": [[32, 6, 4, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/gradMaxPoolWithArgmax/MaxPoolGradWithArgmax-op46": {"shape": [[32, 16, 4, 3]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op40": {"shape": [[32, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGrad-op36": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op32": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGrad-op28": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op24": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropFilter-op48": {"shape": [[32, 6, 14, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/gradSoftmaxCrossEntropyWithLogits/Mul-op20": {"shape": [[32, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92": {"shape": [[32, 1, 10, 10, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op94": {"shape": [[32, 1, 28, 28, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 7, "independent_layout": false}, {"name": "Gradients", "type": "name_scope", "attr": {}, "input": {"Default/tuple_getitem[10]_0/tuple_getitem-op210": {"shape": [[32, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205": {"shape": [[32, 16, 10, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op206": {"shape": [[32, 16, 4, 3]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op202": {"shape": [[32, 1, 10, 10, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op197": {"shape": [[32, 6, 14, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/Cast-op188": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op195": {"shape": [[32, 6, 28, 28]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op196": {"shape": [[32, 6, 4, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op192": {"shape": [[32, 1, 28, 28, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190": {"shape": [[32, 1, 32, 32]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/flatten-Flatten/Reshape-op9": {"shape": [[32, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output": {"Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op22": {"shape": [[10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op30": {"shape": [[84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op38": {"shape": [[120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op49": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op56": {"shape": [[6, 1, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op25": {"shape": [[10, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op33": {"shape": [[84, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op41": {"shape": [[120, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 1, "independent_layout": false}]}, "watch_points": []} \ No newline at end of file +{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "graph_0", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0"], "nodes": [{"name": "Default", "type": "name_scope", "attr": {}, "input": {"Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op21": {"shape": [[10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op24": {"shape": [[10, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op29": {"shape": [[84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op32": {"shape": [[84, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op37": {"shape": [[120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op40": {"shape": [[120, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropFilter-op48": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Conv2DBackpropFilter-op55": {"shape": [[6, 1, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output": {"Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropInput-op52": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Conv2DBackpropFilter-op55": {"shape": [[32, 1, 32, 32]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/gradMaxPoolWithArgmax/MaxPoolGradWithArgmax-op53": {"shape": [[32, 6, 4, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/gradMaxPoolWithArgmax/MaxPoolGradWithArgmax-op46": {"shape": [[32, 16, 4, 3]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op40": {"shape": [[32, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGrad-op36": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op32": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGrad-op28": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op24": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropFilter-op48": {"shape": [[32, 6, 14, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/gradSoftmaxCrossEntropyWithLogits/Mul-op20": {"shape": [[32, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92": {"shape": [[32, 1, 10, 10, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op94": {"shape": [[32, 1, 28, 28, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 7, "independent_layout": false}, {"name": "Gradients", "type": "name_scope", "attr": {}, "input": {"Default/tuple_getitem[10]_0/tuple_getitem-op210": {"shape": [[32, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205": {"shape": [[32, 16, 10, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op206": {"shape": [[32, 16, 4, 3]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op202": {"shape": [[32, 1, 10, 10, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op197": {"shape": [[32, 6, 14, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/Cast-op188": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op195": {"shape": [[32, 6, 28, 28]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op196": {"shape": [[32, 6, 4, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op192": {"shape": [[32, 1, 28, 28, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190": {"shape": [[32, 1, 32, 32]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/flatten-Flatten/Reshape-op9": {"shape": [[32, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output": {"Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op22": {"shape": [[10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op30": {"shape": [[84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op38": {"shape": [[120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op49": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op56": {"shape": [[6, 1, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op25": {"shape": [[10, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op33": {"shape": [[84, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op41": {"shape": [[120, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 1, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0"]}], "watch_points": []} \ No newline at end of file diff --git a/tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json b/tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json index 0b46576f..bf4506e6 100644 --- a/tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json +++ b/tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json @@ -1 +1 @@ -{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/TransData-op99", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_0", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0"], "nodes": [{"name": "Default", "type": "name_scope", "attr": {}, "input": {"Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op21": {"shape": [[10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op24": {"shape": [[10, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op29": {"shape": [[84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op32": {"shape": [[84, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op37": {"shape": [[120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op40": {"shape": [[120, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropFilter-op48": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Conv2DBackpropFilter-op55": {"shape": [[6, 1, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output": {"Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropInput-op52": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Conv2DBackpropFilter-op55": {"shape": [[32, 1, 32, 32]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/gradMaxPoolWithArgmax/MaxPoolGradWithArgmax-op53": {"shape": [[32, 6, 4, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/gradMaxPoolWithArgmax/MaxPoolGradWithArgmax-op46": {"shape": [[32, 16, 4, 3]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op40": {"shape": [[32, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGrad-op36": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op32": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGrad-op28": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op24": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropFilter-op48": {"shape": [[32, 6, 14, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/gradSoftmaxCrossEntropyWithLogits/Mul-op20": {"shape": [[32, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92": {"shape": [[32, 1, 10, 10, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op94": {"shape": [[32, 1, 28, 28, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 7, "independent_layout": false}, {"name": "Gradients", "type": "name_scope", "attr": {}, "input": {"Default/tuple_getitem[10]_0/tuple_getitem-op210": {"shape": [[32, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205": {"shape": [[32, 16, 10, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op206": {"shape": [[32, 16, 4, 3]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op202": {"shape": [[32, 1, 10, 10, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op197": {"shape": [[32, 6, 14, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/Cast-op188": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op195": {"shape": [[32, 6, 28, 28]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op196": {"shape": [[32, 6, 4, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op192": {"shape": [[32, 1, 28, 28, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190": {"shape": [[32, 1, 32, 32]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/flatten-Flatten/Reshape-op9": {"shape": [[32, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output": {"Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op22": {"shape": [[10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op30": {"shape": [[84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op38": {"shape": [[120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op49": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op56": {"shape": [[6, 1, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op25": {"shape": [[10, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op33": {"shape": [[84, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op41": {"shape": [[120, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 1, "independent_layout": false}]}, "watch_points": []} \ No newline at end of file +{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/TransData-op99", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_0", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0"], "nodes": [{"name": "Default", "type": "name_scope", "attr": {}, "input": {"Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op21": {"shape": [[10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op24": {"shape": [[10, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op29": {"shape": [[84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op32": {"shape": [[84, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradBiasAdd/BiasAddGrad-op37": {"shape": [[120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op40": {"shape": [[120, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropFilter-op48": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Conv2DBackpropFilter-op55": {"shape": [[6, 1, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output": {"Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropInput-op52": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/gradConv2D/Conv2DBackpropFilter-op55": {"shape": [[32, 1, 32, 32]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/gradMaxPoolWithArgmax/MaxPoolGradWithArgmax-op53": {"shape": [[32, 6, 4, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/gradMaxPoolWithArgmax/MaxPoolGradWithArgmax-op46": {"shape": [[32, 16, 4, 3]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op40": {"shape": [[32, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGrad-op36": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op32": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGrad-op28": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/gradMatMul/MatMul[6]_5/MatMul-op24": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/gradConv2D/Conv2DBackpropFilter-op48": {"shape": [[32, 6, 14, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/gradSoftmaxCrossEntropyWithLogits/Mul-op20": {"shape": [[32, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92": {"shape": [[32, 1, 10, 10, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Gradients/Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op94": {"shape": [[32, 1, 28, 28, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 7, "independent_layout": false}, {"name": "Gradients", "type": "name_scope", "attr": {}, "input": {"Default/tuple_getitem[10]_0/tuple_getitem-op210": {"shape": [[32, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op15": {"shape": [[32, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLU-op12": {"shape": [[32, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205": {"shape": [[32, 16, 10, 10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op206": {"shape": [[32, 16, 4, 3]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op202": {"shape": [[32, 1, 10, 10, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op197": {"shape": [[32, 6, 14, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/conv2-Conv2d/Cast-op188": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op195": {"shape": [[32, 6, 28, 28]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op196": {"shape": [[32, 6, 4, 14]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT16]"}, "Default/tuple_getitem[10]_0/tuple_getitem-op192": {"shape": [[32, 1, 28, 28, 2]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_UINT8]"}, "Default/network-WithLossCell/_backbone-LeNet5/conv1-Conv2d/Cast-op190": {"shape": [[32, 1, 32, 32]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/network-WithLossCell/_backbone-LeNet5/flatten-Flatten/Reshape-op9": {"shape": [[32, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output": {"Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op22": {"shape": [[10]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op30": {"shape": [[84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op38": {"shape": [[120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op49": {"shape": [[16, 6, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op56": {"shape": [[6, 1, 5, 5]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op25": {"shape": [[10, 84]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op33": {"shape": [[84, 120]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}, "Default/optimizer-Momentum/ApplyMomentum[8]_1/ApplyMomentum-op41": {"shape": [[120, 400]], "edge_type": "data", "independent_layout": false, "data_type": "DT_TENSOR[DT_FLOAT32]"}}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 1, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0"]}], "watch_points": []} \ No newline at end of file diff --git a/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json b/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json index 2f67ce50..44969fdb 100644 --- a/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json +++ b/tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json @@ -1,29 +1 @@ -{ - "tensor_value": { - "full_name": "Default/TransData-op99:0", - "step": 1, - "dtype": "DT_FLOAT32", - "shape": [ - 2, - 3 - ], - "has_prev_step": false, - "statistics": { - "overall_max": 6.0, - "overall_min": 1.0, - "overall_avg": 3.5, - "overall_count": 6, - "overall_nan_count": 0, - "overall_neg_inf_count": 0, - "overall_pos_inf_count": 0, - "overall_zero_count": 0.0, - "overall_neg_zero_count": 0.0, - "overall_pos_zero_count": 6.0 - }, - "value": [ - 5.0, - 6.0 - ], - "name": "Default/TransData-op99:0" - } -} \ No newline at end of file +{"tensor_value": {"full_name": "Default/TransData-op99:0", "step": 1, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": [5.0, 6.0], "statistics": {"overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, "overall_count": 6, "overall_nan_count": 0, "overall_neg_inf_count": 0, "overall_pos_inf_count": 0, "overall_zero_count": 0.0, "overall_neg_zero_count": 0.0, "overall_pos_zero_count": 6.0}, "name": "Default/TransData-op99:0"}} \ No newline at end of file diff --git a/tests/st/func/debugger/expect_results/restful_results/version_mismatch.json b/tests/st/func/debugger/expect_results/restful_results/version_mismatch.json index be42e1f5..54e099f4 100644 --- a/tests/st/func/debugger/expect_results/restful_results/version_mismatch.json +++ b/tests/st/func/debugger/expect_results/restful_results/version_mismatch.json @@ -1 +1 @@ -{"metadata": {"state": "mismatch", "step": 0, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.0.0"}}, "graph": {}, "watch_points": []} \ No newline at end of file +{"metadata": {"state": "mismatch", "step": 0, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": []}], "watch_points": []} \ No newline at end of file diff --git a/tests/st/func/debugger/test_data_loader.py b/tests/st/func/debugger/test_data_loader.py new file mode 100644 index 00000000..d5d11176 --- /dev/null +++ b/tests/st/func/debugger/test_data_loader.py @@ -0,0 +1,149 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test DataLoader of offline debugger.""" +import os +import shutil + +import pytest + +from mindinsight.debugger.stream_cache.data_loader import DataLoader +from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE +from tests.st.func.debugger.utils import build_dump_file_structure +from tests.utils.tools import compare_result_with_file, compare_result_with_binary_file + + +class TestDataLoader: + """Test DataLoader.""" + + @classmethod + def setup_class(cls): + """Init TestDataLoader for DataLoader unittest.""" + cls.debugger_tmp_dir, cls.dump_files_dir = build_dump_file_structure() + cls.expected_results_dir = os.path.join(os.path.dirname(__file__), + 'expect_results/offline_debugger') + cls.dump_files_dir_ascend = os.path.join(cls.dump_files_dir, + 'Ascend/sync') + cls.data_loader_ascend = DataLoader(cls.dump_files_dir_ascend) + cls.data_loader_ascend.initialize() + cls.dump_files_dir_gpu = os.path.join(cls.dump_files_dir, + 'GPU/sync') + cls.data_loader_gpu = DataLoader(cls.dump_files_dir_gpu) + cls.data_loader_gpu.initialize() + cls.dump_files_dir_ascend_async = os.path.join(cls.dump_files_dir, + 'Ascend/async') + cls.data_loader_ascend_async = DataLoader(cls.dump_files_dir_ascend_async) + cls.data_loader_ascend_async.initialize() + + @classmethod + def teardown_class(cls): + """Run after test this class.""" + if os.path.exists(cls.debugger_tmp_dir): + shutil.rmtree(cls.debugger_tmp_dir) + + @pytest.mark.level + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_load_graphs_ascend(self): + """Test load_graphs function of offline-debugger.""" + res = self.data_loader_ascend.load_graphs() + expected_result0 = GRAPH_PROTO_FILE + res0 = res[0]['graph_protos'][0].SerializeToString() + compare_result_with_binary_file(res0, expected_result0) + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_load_device_info_ascend(self): + """Test load_device_info of ascend chip for offline-debugger.""" + res = self.data_loader_ascend.load_device_info() + expected_result = os.path.join(self.expected_results_dir, 'load_device_info_ascend.json') + compare_result_with_file(res, expected_result) + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_load_step_num_ascend(self): + """Test load_step_num of ascend chip for offline-debugger.""" + res = self.data_loader_ascend.load_step_number() + expected_result = {"0": 4, "1": 4} + assert res == expected_result + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_get_net_name_ascend(self): + """Test get_net_name of ascend chip for offline-debugger.""" + res = self.data_loader_ascend.get_net_name() + assert res == 'Lenet' + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_get_sync_flag(self): + """Test get_sync_flag of ascend chip for offline-debugger.""" + res = self.data_loader_ascend.get_sync_flag() + assert res + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_load_graphs_gpu(self): + """Test load_graphs function of offline-debugger.""" + res = self.data_loader_gpu.load_graphs() + expected_result0 = GRAPH_PROTO_FILE + res0 = res[0]['graph_protos'][0].SerializeToString() + compare_result_with_binary_file(res0, expected_result0) + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_load_step_num_gpu(self): + """Test load_step_num of ascend chip for offline-debugger.""" + res = self.data_loader_gpu.load_step_number() + expected_result = {"0": 3, "1": 3} + assert res == expected_result + + @pytest.mark.level0 + @pytest.mark.env_single + @pytest.mark.platform_x86_cpu + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + def test_load_step_num_ascend_async(self): + """Test load_step_num of ascend chip for offline-debugger.""" + res = self.data_loader_ascend_async.load_step_number() + expected_result = {"0": 3, "1": 3} + assert res == expected_result diff --git a/tests/st/func/debugger/test_restful_api.py b/tests/st/func/debugger/test_restful_api.py index 16b624fd..07c79e34 100644 --- a/tests/st/func/debugger/test_restful_api.py +++ b/tests/st/func/debugger/test_restful_api.py @@ -84,7 +84,7 @@ class TestAscendDebugger: def test_get_conditions(self, app_client): """Test get conditions for ascend.""" - url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections' + url = '/v1/mindinsight/debugger/sessions/0/condition-collections' body_data = {} expect_file = 'get_conditions_for_ascend.json' with self._debugger_client.get_thread_instance(): @@ -191,7 +191,7 @@ class TestAscendDebugger: check_state(app_client) # prepare tensor value url = 'tensor-history' - body_data = {'name': node_name} + body_data = {'name': node_name, 'rank_id': 0} expect_file = 'retrieve_empty_tensor_history.json' send_and_compare_result(app_client, url, body_data, expect_file) # check full tensor history from poll data @@ -229,7 +229,7 @@ class TestAscendDebugger: get_request_result(app_client, url, body_data) check_state(app_client) get_request_result( - app_client=app_client, url='tensor-history', body_data={'name': node_name}) + app_client=app_client, url='tensor-history', body_data={'name': node_name, 'rank_id': 0}) res = get_request_result( app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get') assert res.get('receive_tensor', {}).get('node_name') == node_name @@ -239,30 +239,12 @@ class TestAscendDebugger: 'name': node_name + ':0', 'detail': 'data', 'shape': quote('[:, :]'), - 'tolerance': 1 - } + 'tolerance': 1, + 'rank_id': 0} expect_file = 'compare_tensors.json' send_and_compare_result(app_client, url, body_data, expect_file, method='get') send_terminate_cmd(app_client) - @pytest.mark.level0 - @pytest.mark.env_single - @pytest.mark.platform_x86_cpu - @pytest.mark.platform_arm_ascend_training - @pytest.mark.platform_x86_gpu_training - @pytest.mark.platform_x86_ascend_training - @pytest.mark.parametrize("body_data, expect_file", [ - ({'ascend': True}, 'retrieve_node_by_bfs_ascend.json'), - ({'name': 'Default/args0', 'ascend': False}, 'retrieve_node_by_bfs.json') - ]) - def test_retrieve_bfs_node(self, app_client, body_data, expect_file): - """Test retrieve bfs node.""" - with self._debugger_client.get_thread_instance(): - check_state(app_client) - # prepare tensor values - url = 'retrieve_node_by_bfs' - send_and_compare_result(app_client, url, body_data, expect_file, method='get') - send_terminate_cmd(app_client) @pytest.mark.level0 @pytest.mark.env_single @@ -441,7 +423,7 @@ class TestGPUDebugger: def test_get_conditions(self, app_client): """Test get conditions for gpu.""" - url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections' + url = '/v1/mindinsight/debugger/sessions/0/condition-collections' body_data = {} expect_file = 'get_conditions_for_gpu.json' with self._debugger_client.get_thread_instance(): diff --git a/tests/st/func/debugger/utils.py b/tests/st/func/debugger/utils.py index 275e6a31..9593811e 100644 --- a/tests/st/func/debugger/utils.py +++ b/tests/st/func/debugger/utils.py @@ -16,8 +16,10 @@ import json import os import time - -from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL +import shutil +import tempfile +from mindinsight.debugger.proto import ms_graph_pb2 +from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL, GRAPH_PROTO_FILE from tests.utils.tools import compare_result_with_file, get_url @@ -74,10 +76,57 @@ def send_and_save_result(app_client, url, body_data, file_path, method='post'): def delete_random_items(res): """delete the random items in metadata.""" - if isinstance(res, dict) and res.get('metadata'): - if res['metadata'].get('ip'): - res['metadata'].pop('ip') - if res['metadata'].get('pos'): - res['metadata'].pop('pos') - if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'): - res['metadata']['debugger_version'].pop('mi') + if isinstance(res, dict): + if res.get('metadata'): + if res['metadata'].get('ip'): + res['metadata'].pop('ip') + if res['metadata'].get('pos'): + res['metadata'].pop('pos') + if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'): + res['metadata']['debugger_version'].pop('mi') + res['metadata']['debugger_version'].pop('ms') + if res.get('devices'): + for device in res.get('devices'): + if device.get('server_ip'): + device.pop('server_ip') + + +def build_dump_file_structure(): + """Build the dump file structure.""" + async_file_structure = { + "Ascend/async/device_0/Lenet_graph_1/1": 3, + "Ascend/async/device_1/Lenet_graph_1/1": 3 + } + + sync_file_structure = { + "Ascend/sync/Lenet/device_0": 4, + "Ascend/sync/Lenet/device_1": 4, + "GPU/sync/Lenet/device_0": 3, + "GPU/sync/Lenet/device_1": 3 + } + + debugger_tmp_dir = tempfile.mkdtemp(suffix='debugger_tmp') + dump_files_dir = os.path.join(debugger_tmp_dir, 'dump_files') + shutil.copytree(os.path.join(os.path.dirname(__file__), 'dump_files'), dump_files_dir) + + for sub_dir, steps in async_file_structure.items(): + for step in range(1, steps + 1): + os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), str(step)), exist_ok=True) + + for sub_dir, steps in sync_file_structure.items(): + for step in range(1, steps + 1): + os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), 'iteration_' + str(step)), + exist_ok=True) + graph_dir_path = os.path.join(os.path.join(dump_files_dir, sub_dir), 'graphs') + os.makedirs(graph_dir_path, exist_ok=True) + graph_path = os.path.join(graph_dir_path, 'ms_output_trace_code_graph_0.pb') + with open(GRAPH_PROTO_FILE, 'rb') as file_handler: + content = file_handler.read() + + model = ms_graph_pb2.ModelProto() + model.graph.ParseFromString(content) + model_str = model.SerializeToString() + with open(graph_path, 'wb') as file_handler: + file_handler.write(model_str) + + return debugger_tmp_dir, dump_files_dir diff --git a/tests/ut/debugger/expected_results/debugger_server/retrieve_all.json b/tests/ut/debugger/expected_results/debugger_server/retrieve_all.json index e79218b7..accf6228 100644 --- a/tests/ut/debugger/expected_results/debugger_server/retrieve_all.json +++ b/tests/ut/debugger/expected_results/debugger_server/retrieve_all.json @@ -1 +1 @@ -{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "watch_points": []} \ No newline at end of file +{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "devices": [{"rank_id": 0, "server_ip": "", "device_id": "", "graph_names": []}], "watch_points": []} \ No newline at end of file diff --git a/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json b/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json index bd64a13f..473de836 100644 --- a/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json +++ b/tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json @@ -1,5 +1,80 @@ [ - {"watchCondition": {"condition": "tensor_too_small", "value": 1.0, "params": [{"name": "abs_mean_lt", "disabled": true}, {"name": "max_lt", "value": 1.0}, {"name": "min_lt", "disabled": true}, {"name": "mean_lt", "disabled": true}]}, "id": 1, "watch_nodes_num": 0}, - {"watchCondition": {"condition": "tensor_too_small", "value": 1.0, "params": [{"name": "abs_mean_lt", "disabled": true}, {"name": "max_lt", "disabled": true}, {"name": "min_lt", "value": 1.0}, {"name": "mean_lt", "disabled": true}]}, "id": 2, "watch_nodes_num": 172}, - {"watchCondition": {"condition": "tensor_too_large", "value": 1.0, "params": [{"name": "abs_mean_gt", "disabled": true}, {"name": "max_gt", "value": 1.0}, {"name": "min_gt", "disabled": true}, {"name": "mean_gt", "disabled": true}]}, "id": 3, "watch_nodes_num": 1} + { + "watchCondition": { + "condition": "tensor_too_small", + "value": 1.0, + "params": [ + { + "name": "abs_mean_lt", + "disabled": true + }, + { + "name": "max_lt", + "value": 1.0 + }, + { + "name": "min_lt", + "disabled": true + }, + { + "name": "mean_lt", + "disabled": true + } + ] + }, + "id": 1, + "watch_nodes_num": 0 + }, + { + "watchCondition": { + "condition": "tensor_too_small", + "value": 1.0, + "params": [ + { + "name": "abs_mean_lt", + "disabled": true + }, + { + "name": "max_lt", + "disabled": true + }, + { + "name": "min_lt", + "value": 1.0 + }, + { + "name": "mean_lt", + "disabled": true + } + ] + }, + "id": 2, + "watch_nodes_num": 142 + }, + { + "watchCondition": { + "condition": "tensor_too_large", + "value": 1.0, + "params": [ + { + "name": "abs_mean_gt", + "disabled": true + }, + { + "name": "max_gt", + "value": 1.0 + }, + { + "name": "min_gt", + "disabled": true + }, + { + "name": "mean_gt", + "disabled": true + } + ] + }, + "id": 3, + "watch_nodes_num": 1 + } ] \ No newline at end of file diff --git a/tests/ut/debugger/stream_handler/test_graph_handler.py b/tests/ut/debugger/stream_handler/test_graph_handler.py index b840a1cd..f247b16f 100644 --- a/tests/ut/debugger/stream_handler/test_graph_handler.py +++ b/tests/ut/debugger/stream_handler/test_graph_handler.py @@ -111,20 +111,6 @@ class TestGraphHandler: node_name = self.graph_handler.get_node_name_by_full_name(full_name, 'kernel_graph_0') assert node_name == expect_node_name - @pytest.mark.parametrize("node_name, ascend, expect_next", [ - (None, True, - "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0"), - (None, False, None), - ("Default/tuple_getitem[10]_0/tuple_getitem-op206", True, - "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89"), - ("Default/tuple_getitem[10]_0/tuple_getitem-op206", False, - "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205") - ]) - def test_get_node_by_bfs_order(self, node_name, ascend, expect_next): - """Test get node by BFS order.""" - next_node = self.graph_handler.get_node_by_bfs_order(node_name, ascend) - assert next_node == expect_next - @pytest.mark.parametrize("tensor_name, expect_file", [ ("Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0:0", "get_tensor_graph-0.json"), ("Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89:1", "get_tensor_graph-1.json"), diff --git a/tests/ut/debugger/stream_handler/test_tensor_handler.py b/tests/ut/debugger/stream_handler/test_tensor_handler.py index 50ce2d9f..1e411d33 100644 --- a/tests/ut/debugger/stream_handler/test_tensor_handler.py +++ b/tests/ut/debugger/stream_handler/test_tensor_handler.py @@ -40,7 +40,7 @@ class TestTensorHandler: def test_get_tensor_value_by_name_none(self): """Test get_tensor_value_by_name.""" - res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', True) + res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', step=0, prev=True) assert res is None @mock.patch.object(log, "error") @@ -49,5 +49,5 @@ class TestTensorHandler: """Test get_tensors_diff.""" mock_error.return_value = None with pytest.raises(DebuggerParamValueError) as ex: - self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}) + self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}, step=0) assert f"Get current step and previous step for this tensor name {tensor_name} failed." in str(ex.value) diff --git a/tests/ut/debugger/stream_handler/test_watchpoint_handler.py b/tests/ut/debugger/stream_handler/test_watchpoint_handler.py index 08e37be6..86033e36 100644 --- a/tests/ut/debugger/stream_handler/test_watchpoint_handler.py +++ b/tests/ut/debugger/stream_handler/test_watchpoint_handler.py @@ -30,6 +30,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue DebuggerParamTypeError from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.stream_cache.watchpoint import Watchpoint +from mindinsight.debugger.stream_handler import MultiCardGraphHandler from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHandler, \ WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \ @@ -48,7 +49,9 @@ class TestWatchpointHandler: '../expected_results/watchpoint') cls.graph_results_dir = os.path.join(os.path.dirname(__file__), '../expected_results/graph') + cls.multi_graph_stream = MultiCardGraphHandler() cls.graph_stream = init_graph_handler() + cls.multi_graph_stream.register_graph_handler(0, cls.graph_stream) cls.conditionmgr = None cls.handler = None @@ -69,7 +72,7 @@ class TestWatchpointHandler: ] for watch_condition, watch_nodes, watch_point_id, expect_new_id in watchpoints: watch_nodes = get_node_basic_infos(watch_nodes) - watch_point_id = self.handler.create_watchpoint(self.conditionmgr, watch_condition, watch_nodes, + watch_point_id = self.handler.create_watchpoint(self.conditionmgr, watch_condition, {0: watch_nodes}, watch_point_id) assert watch_point_id == expect_new_id @@ -105,7 +108,7 @@ class TestWatchpointHandler: file_path = os.path.join(self.results_dir, result_file) with open(file_path, 'r') as file_handler: contents = json.load(file_handler) - protos = self.handler.get_pending_commands(self.graph_stream) + protos = self.handler.get_pending_commands(self.multi_graph_stream) for proto in protos: msg_dict = json_format.MessageToDict(proto) msg_dict['watch_nodes_num'] = len(msg_dict.pop('watchNodes', [])) diff --git a/tests/ut/debugger/stream_operator/test_training_control_operator.py b/tests/ut/debugger/stream_operator/test_training_control_operator.py index 1d1d5ef6..27553497 100644 --- a/tests/ut/debugger/stream_operator/test_training_control_operator.py +++ b/tests/ut/debugger/stream_operator/test_training_control_operator.py @@ -48,7 +48,8 @@ class TestTrainingControlOperator: """Test validate leaf name.""" args[0].return_value = 'name_scope' with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): - self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') + self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name', + rank_id=0) @pytest.mark.parametrize('mode, cur_state, state', [ ('continue', 'waiting', 'sending'), @@ -64,3 +65,12 @@ class TestTrainingControlOperator: """Test construct run event.""" res = self._server._construct_run_event({'level': 'node'}) assert res.run_cmd == RunCMD(run_level='node', node_name='') + + @pytest.mark.parametrize('mode, state', [ + ('reset', 'waiting')]) + def test_control_reset_step(self, mode, state): + """Test control request, in 'reset' mode.""" + with mock.patch.object(MetadataHandler, 'max_step_num', 10), \ + mock.patch.object(MetadataHandler, 'debugger_type', 'offline'): + res = self._server.control(mode=mode, params={'steps': 9}) + assert res == {'metadata': {'enable_recheck': False, 'state': state, 'step': 9}} diff --git a/tests/ut/debugger/test_debugger_grpc_server.py b/tests/ut/debugger/test_debugger_grpc_server.py index c503843a..4a707bc4 100644 --- a/tests/ut/debugger/test_debugger_grpc_server.py +++ b/tests/ut/debugger/test_debugger_grpc_server.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ import numpy as np from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus from mindinsight.debugger.debugger_cache import DebuggerCache -from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer +from mindinsight.debugger.debugger_services.debugger_grpc_server import DebuggerGrpcServer from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \ diff --git a/tests/ut/debugger/test_debugger_server.py b/tests/ut/debugger/test_debugger_server.py index 6a08babe..a42e7a8a 100644 --- a/tests/ut/debugger/test_debugger_server.py +++ b/tests/ut/debugger/test_debugger_server.py @@ -30,11 +30,11 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError from mindinsight.debugger.common.utils import Streams from mindinsight.debugger.debugger_cache import DebuggerCache -from mindinsight.debugger.debugger_server import DebuggerServer -from mindinsight.debugger.debugger_server import grpc_server_base -from mindinsight.debugger.stream_operator import watchpoint_operator +from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext +from mindinsight.debugger.debugger_session import DebuggerSession as DebuggerServer from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ TensorHandler +from mindinsight.debugger.stream_operator import watchpoint_operator from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history @@ -48,12 +48,12 @@ class TestDebuggerServer: def setup_method(self): """Prepare debugger server object.""" - self._server = DebuggerServer() + context = DebuggerServerContext(dbg_mode='online') + self._server = DebuggerServer(context) @mock.patch.object(signal, 'signal') @mock.patch.object(Thread, 'join') @mock.patch.object(Thread, 'start') - @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server') @mock.patch.object(grpc, 'server') def test_stop_server(self, *args): """Test stop debugger server.""" @@ -62,7 +62,6 @@ class TestDebuggerServer: self._server.start() self._server._stop_handler(MagicMock(), MagicMock()) assert self._server.back_server is not None - assert self._server.grpc_server_manager == mock_grpc_server_manager @mock.patch.object(DebuggerCache, 'get_data') def test_poll_data(self, *args): @@ -186,7 +185,6 @@ class TestDebuggerServer: self._server.create_watchpoint({'watch_condition': {'id': 'inf'}}) @mock.patch.object(MetadataHandler, 'state', 'waiting') - @mock.patch.object(MetadataHandler, 'backend', 'GPU') @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') @mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock()) @@ -194,6 +192,7 @@ class TestDebuggerServer: def test_create_watchpoint(self, *args): """Test create watchpoint.""" args[0].return_value = 1 + self._server.cache_store.get_stream_handler((Streams.METADATA)).backend = 'GPU' res = self._server.create_watchpoint( {'watch_condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, 'watch_nodes': ['watch_node_name']}) diff --git a/tests/utils/tools.py b/tests/utils/tools.py index e1695c7d..6d270bd0 100644 --- a/tests/utils/tools.py +++ b/tests/utils/tools.py @@ -68,6 +68,13 @@ def compare_result_with_file(result, expected_file_path): assert result == expected_results +def compare_result_with_binary_file(result, expected_file_path): + """Compare result with binary file which contain the expected results.""" + with open(expected_file_path, 'rb') as file: + expected_results = file.read() + assert result == expected_results + + def deal_float_for_dict(res: dict, expected_res: dict, decimal_num): """ Deal float rounded to specified decimals in dict.