| @@ -1,61 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Conditionmgr restful api.""" | |||||
| import json | |||||
| from flask import Blueprint, request | |||||
| from mindinsight.conf import settings | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| from mindinsight.utils.exceptions import ParamMissError | |||||
| from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER, _wrap_reply | |||||
| BLUEPRINT = Blueprint("conditionmgr", __name__, | |||||
| url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) | |||||
| @BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/condition-collections", methods=["GET"]) | |||||
| def get_condition_collections(train_id): | |||||
| """get condition collections""" | |||||
| reply = _wrap_reply(BACKEND_SERVER.get_condition_collections, train_id) | |||||
| return reply | |||||
| @BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/set-recommended-watch-points", methods=["POST"]) | |||||
| def set_recommended_watch_points(train_id): | |||||
| """set recommended watch points.""" | |||||
| body = request.stream.read() | |||||
| try: | |||||
| body = json.loads(body if body else "{}") | |||||
| except json.JSONDecodeError: | |||||
| raise ParamValueError("Json data parse failed.") | |||||
| request_body = body.get('requestBody') | |||||
| if request_body is None: | |||||
| raise ParamMissError('requestBody') | |||||
| set_recommended = request_body.get('set_recommended') | |||||
| reply = _wrap_reply(BACKEND_SERVER.set_recommended_watch_points, set_recommended, train_id) | |||||
| return reply | |||||
| def init_module(app): | |||||
| """ | |||||
| Init module entry. | |||||
| Args: | |||||
| app (Flask): The application obj. | |||||
| """ | |||||
| app.register_blueprint(BLUEPRINT) | |||||
| @@ -26,6 +26,7 @@ import psutil | |||||
| import gunicorn | import gunicorn | ||||
| from mindinsight.utils.computing_resource_mgr import terminate | from mindinsight.utils.computing_resource_mgr import terminate | ||||
| from mindinsight.debugger.session_manager import SessionManager | |||||
| gunicorn.SERVER_SOFTWARE = 'unknown' | gunicorn.SERVER_SOFTWARE = 'unknown' | ||||
| @@ -110,4 +111,5 @@ def worker_int(worker): | |||||
| global LISTEN_PROCESS | global LISTEN_PROCESS | ||||
| if LISTEN_PROCESS is not None: | if LISTEN_PROCESS is not None: | ||||
| LISTEN_PROCESS.terminate() | LISTEN_PROCESS.terminate() | ||||
| SessionManager.get_instance().exit() | |||||
| worker.log.info("Worker int processed.") | worker.log.info("Worker int processed.") | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -19,6 +19,11 @@ from mindinsight.conf import settings | |||||
| from mindinsight.datavisual.common.log import logger | from mindinsight.datavisual.common.log import logger | ||||
| from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER | from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER | ||||
| from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater | from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater | ||||
| from mindinsight.debugger.debugger_folder_analyzer import DebuggerFolderAnalyzer | |||||
| ANALYZERS = { | |||||
| "debugger_folder_analyzer": DebuggerFolderAnalyzer() | |||||
| } | |||||
| def init_module(app): | def init_module(app): | ||||
| @@ -31,6 +36,8 @@ def init_module(app): | |||||
| """ | """ | ||||
| # Just to suppress pylint warning about unused arg. | # Just to suppress pylint warning about unused arg. | ||||
| logger.debug("App: %s", type(app)) | logger.debug("App: %s", type(app)) | ||||
| for analyzer in ANALYZERS.values(): | |||||
| DATA_MANAGER.register_folder_analyzer(analyzer) | |||||
| DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater()) | DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater()) | ||||
| # Let gunicorn load other modules first. | # Let gunicorn load other modules first. | ||||
| time.sleep(1) | time.sleep(1) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -19,22 +19,13 @@ from urllib.parse import unquote | |||||
| from flask import Blueprint, jsonify, request | from flask import Blueprint, jsonify, request | ||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| from mindinsight.debugger.session_manager import SessionManager | |||||
| from mindinsight.utils.exceptions import ParamMissError, ParamValueError | |||||
| BLUEPRINT = Blueprint("debugger", __name__, | BLUEPRINT = Blueprint("debugger", __name__, | ||||
| url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) | url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) | ||||
| def _initialize_debugger_server(): | |||||
| """Initialize a debugger server instance.""" | |||||
| enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | |||||
| server = None | |||||
| if enable_debugger: | |||||
| server = DebuggerServer() | |||||
| return server | |||||
| def _unquote_param(param): | def _unquote_param(param): | ||||
| """ | """ | ||||
| Decode parameter value. | Decode parameter value. | ||||
| @@ -77,8 +68,8 @@ def _wrap_reply(func, *args, **kwargs): | |||||
| return jsonify(reply) | return jsonify(reply) | ||||
| @BLUEPRINT.route("/debugger/poll-data", methods=["GET"]) | |||||
| def poll_data(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/poll-data", methods=["GET"]) | |||||
| def poll_data(session_id): | |||||
| """ | """ | ||||
| Wait for data to be updated on UI. | Wait for data to be updated on UI. | ||||
| @@ -88,17 +79,17 @@ def poll_data(): | |||||
| str, the updated data. | str, the updated data. | ||||
| Examples: | Examples: | ||||
| >>> Get http://xxxx/v1/mindinsight/debugger/poll-data?pos=xx | |||||
| >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/poll-data?pos=xx | |||||
| """ | """ | ||||
| pos = request.args.get('pos') | pos = request.args.get('pos') | ||||
| reply = _wrap_reply(BACKEND_SERVER.poll_data, pos) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).poll_data, pos) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/search", methods=["GET"]) | |||||
| def search(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/search", methods=["GET"]) | |||||
| def search(session_id): | |||||
| """ | """ | ||||
| Search nodes in specified watchpoint. | Search nodes in specified watchpoint. | ||||
| @@ -106,42 +97,25 @@ def search(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> Get http://xxxx/v1/mindinsight/debugger/search?name=mock_name&watch_point_id=1 | |||||
| >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search?name=mock_name&watch_point_id=1 | |||||
| """ | """ | ||||
| name = request.args.get('name') | name = request.args.get('name') | ||||
| graph_name = request.args.get('graph_name') | graph_name = request.args.get('graph_name') | ||||
| watch_point_id = int(request.args.get('watch_point_id', 0)) | watch_point_id = int(request.args.get('watch_point_id', 0)) | ||||
| node_category = request.args.get('node_category') | node_category = request.args.get('node_category') | ||||
| reply = _wrap_reply(BACKEND_SERVER.search, {'name': name, | |||||
| 'graph_name': graph_name, | |||||
| 'watch_point_id': watch_point_id, | |||||
| 'node_category': node_category}) | |||||
| rank_id = int(request.args.get('rank_id', 0)) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search, | |||||
| {'name': name, | |||||
| 'graph_name': graph_name, | |||||
| 'watch_point_id': watch_point_id, | |||||
| 'node_category': node_category, | |||||
| 'rand_id': rank_id}) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"]) | |||||
| def retrieve_node_by_bfs(): | |||||
| """ | |||||
| Search node by bfs. | |||||
| Returns: | |||||
| str, the required data. | |||||
| Examples: | |||||
| >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true | |||||
| """ | |||||
| name = request.args.get('name') | |||||
| graph_name = request.args.get('graph_name') | |||||
| ascend = request.args.get('ascend', 'false') | |||||
| ascend = ascend == 'true' | |||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, graph_name, ascend) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"]) | |||||
| def tensor_comparisons(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-comparisons", methods=["GET"]) | |||||
| def tensor_comparisons(session_id): | |||||
| """ | """ | ||||
| Get tensor comparisons. | Get tensor comparisons. | ||||
| @@ -149,19 +123,21 @@ def tensor_comparisons(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons | |||||
| >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-comparisons | |||||
| """ | """ | ||||
| name = request.args.get('name') | name = request.args.get('name') | ||||
| detail = request.args.get('detail', 'data') | detail = request.args.get('detail', 'data') | ||||
| shape = _unquote_param(request.args.get('shape')) | shape = _unquote_param(request.args.get('shape')) | ||||
| tolerance = request.args.get('tolerance', '0') | tolerance = request.args.get('tolerance', '0') | ||||
| reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) | |||||
| rank_id = int(request.args.get('rank_id', 0)) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).tensor_comparisons, name, shape, detail, | |||||
| tolerance, rank_id) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/retrieve", methods=["POST"]) | |||||
| def retrieve(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/retrieve", methods=["POST"]) | |||||
| def retrieve(session_id): | |||||
| """ | """ | ||||
| Retrieve data according to mode and params. | Retrieve data according to mode and params. | ||||
| @@ -169,17 +145,17 @@ def retrieve(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/retrieve | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/retrieve | |||||
| """ | """ | ||||
| body = _read_post_request(request) | body = _read_post_request(request) | ||||
| mode = body.get('mode') | mode = body.get('mode') | ||||
| params = body.get('params') | params = body.get('params') | ||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve, mode, params) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/tensor-history", methods=["POST"]) | |||||
| def retrieve_tensor_history(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-history", methods=["POST"]) | |||||
| def retrieve_tensor_history(session_id): | |||||
| """ | """ | ||||
| Retrieve data according to mode and params. | Retrieve data according to mode and params. | ||||
| @@ -187,17 +163,19 @@ def retrieve_tensor_history(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/tensor-history | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-history | |||||
| """ | """ | ||||
| body = _read_post_request(request) | body = _read_post_request(request) | ||||
| name = body.get('name') | name = body.get('name') | ||||
| graph_name = body.get('graph_name') | graph_name = body.get('graph_name') | ||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name, graph_name) | |||||
| rank_id = body.get('rank_id') | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_history, name, graph_name, | |||||
| rank_id) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/tensors", methods=["GET"]) | |||||
| def retrieve_tensor_value(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensors", methods=["GET"]) | |||||
| def retrieve_tensor_value(session_id): | |||||
| """ | """ | ||||
| Retrieve tensor value according to name and shape. | Retrieve tensor value according to name and shape. | ||||
| @@ -205,20 +183,22 @@ def retrieve_tensor_value(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=tensor_name&detail=data&shape=[1,1,:,:] | |||||
| >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensors?name=tensor_name&detail=data&shape=[1,1,:,:] | |||||
| """ | """ | ||||
| name = request.args.get('name') | name = request.args.get('name') | ||||
| detail = request.args.get('detail') | detail = request.args.get('detail') | ||||
| shape = _unquote_param(request.args.get('shape')) | shape = _unquote_param(request.args.get('shape')) | ||||
| graph_name = request.args.get('graph_name') | graph_name = request.args.get('graph_name') | ||||
| prev = bool(request.args.get('prev') == 'true') | prev = bool(request.args.get('prev') == 'true') | ||||
| rank_id = int(request.args.get('rank_id', 0)) | |||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape, graph_name, prev) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_value, name, detail, | |||||
| shape, graph_name, prev, rank_id) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/create-watchpoint", methods=["POST"]) | |||||
| def create_watchpoint(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/create-watchpoint", methods=["POST"]) | |||||
| def create_watchpoint(session_id): | |||||
| """ | """ | ||||
| Create watchpoint. | Create watchpoint. | ||||
| @@ -229,16 +209,16 @@ def create_watchpoint(): | |||||
| MindInsightException: If method fails to be called. | MindInsightException: If method fails to be called. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/create-watchpoint | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/create-watchpoint | |||||
| """ | """ | ||||
| params = _read_post_request(request) | params = _read_post_request(request) | ||||
| params['watch_condition'] = params.pop('condition', None) | params['watch_condition'] = params.pop('condition', None) | ||||
| reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, params) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).create_watchpoint, params) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/update-watchpoint", methods=["POST"]) | |||||
| def update_watchpoint(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/update-watchpoint", methods=["POST"]) | |||||
| def update_watchpoint(session_id): | |||||
| """ | """ | ||||
| Update watchpoint. | Update watchpoint. | ||||
| @@ -249,17 +229,17 @@ def update_watchpoint(): | |||||
| MindInsightException: If method fails to be called. | MindInsightException: If method fails to be called. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/update-watchpoint | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/update-watchpoint | |||||
| """ | """ | ||||
| params = _read_post_request(request) | params = _read_post_request(request) | ||||
| reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, params) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).update_watchpoint, params) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/delete-watchpoint", methods=["POST"]) | |||||
| def delete_watchpoint(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/delete-watchpoint", methods=["POST"]) | |||||
| def delete_watchpoint(session_id): | |||||
| """ | """ | ||||
| delete watchpoint. | |||||
| Delete watchpoint. | |||||
| Returns: | Returns: | ||||
| str, reply message. | str, reply message. | ||||
| @@ -268,19 +248,19 @@ def delete_watchpoint(): | |||||
| MindInsightException: If method fails to be called. | MindInsightException: If method fails to be called. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/delete-watchpoint | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/delete-watchpoint | |||||
| """ | """ | ||||
| body = _read_post_request(request) | body = _read_post_request(request) | ||||
| watch_point_id = body.get('watch_point_id') | watch_point_id = body.get('watch_point_id') | ||||
| reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).delete_watchpoint, watch_point_id) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/control", methods=["POST"]) | |||||
| def control(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/control", methods=["POST"]) | |||||
| def control(session_id): | |||||
| """ | """ | ||||
| Control request. | Control request. | ||||
| @@ -291,16 +271,16 @@ def control(): | |||||
| MindInsightException: If method fails to be called. | MindInsightException: If method fails to be called. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/control | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/control | |||||
| """ | """ | ||||
| params = _read_post_request(request) | params = _read_post_request(request) | ||||
| reply = _wrap_reply(BACKEND_SERVER.control, params) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).control, params) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/recheck", methods=["POST"]) | |||||
| def recheck(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/recheck", methods=["POST"]) | |||||
| def recheck(session_id): | |||||
| """ | """ | ||||
| Recheck request. | Recheck request. | ||||
| @@ -311,15 +291,15 @@ def recheck(): | |||||
| MindInsightException: If method fails to be called. | MindInsightException: If method fails to be called. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/recheck | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/recheck | |||||
| """ | """ | ||||
| reply = _wrap_reply(BACKEND_SERVER.recheck) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).recheck) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"]) | |||||
| def retrieve_tensor_graph(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-graphs", methods=["GET"]) | |||||
| def retrieve_tensor_graph(session_id): | |||||
| """ | """ | ||||
| Retrieve tensor value according to name and shape. | Retrieve tensor value according to name and shape. | ||||
| @@ -327,16 +307,18 @@ def retrieve_tensor_graph(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name | |||||
| >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-graphs?tensor_name=xxx&graph_name=xxx | |||||
| """ | """ | ||||
| tensor_name = request.args.get('tensor_name') | tensor_name = request.args.get('tensor_name') | ||||
| graph_name = request.args.get('graph_name') | graph_name = request.args.get('graph_name') | ||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_graph, tensor_name, graph_name) | |||||
| rank_id = int(request.args.get('rank_id', 0)) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_graph, tensor_name, | |||||
| graph_name, rank_id) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"]) | |||||
| def retrieve_tensor_hits(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-hits", methods=["GET"]) | |||||
| def retrieve_tensor_hits(session_id): | |||||
| """ | """ | ||||
| Retrieve tensor value according to name and shape. | Retrieve tensor value according to name and shape. | ||||
| @@ -344,16 +326,18 @@ def retrieve_tensor_hits(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name | |||||
| >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-hits?tensor_name=xxx&graph_name=xxx | |||||
| """ | """ | ||||
| tensor_name = request.args.get('tensor_name') | tensor_name = request.args.get('tensor_name') | ||||
| graph_name = request.args.get('graph_name') | graph_name = request.args.get('graph_name') | ||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_hits, tensor_name, graph_name) | |||||
| rank_id = int(request.args.get('rank_id', 0)) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_hits, tensor_name, | |||||
| graph_name, rank_id) | |||||
| return reply | return reply | ||||
| @BLUEPRINT.route("/debugger/search-watchpoint-hits", methods=["POST"]) | |||||
| def search_watchpoint_hits(): | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/search-watchpoint-hits", methods=["POST"]) | |||||
| def search_watchpoint_hits(session_id): | |||||
| """ | """ | ||||
| Search watchpoint hits by group condition. | Search watchpoint hits by group condition. | ||||
| @@ -361,15 +345,75 @@ def search_watchpoint_hits(): | |||||
| str, the required data. | str, the required data. | ||||
| Examples: | Examples: | ||||
| >>> POST http://xxxx/v1/mindinsight/debugger/search-watchpoint-hits | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search-watchpoint-hits | |||||
| """ | """ | ||||
| body = _read_post_request(request) | body = _read_post_request(request) | ||||
| group_condition = body.get('group_condition') | group_condition = body.get('group_condition') | ||||
| reply = _wrap_reply(BACKEND_SERVER.search_watchpoint_hits, group_condition) | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search_watchpoint_hits, group_condition) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/condition-collections", methods=["GET"]) | |||||
| def get_condition_collections(session_id): | |||||
| """Get condition collections.""" | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).get_condition_collections) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/set-recommended-watch-points", methods=["POST"]) | |||||
| def set_recommended_watch_points(session_id): | |||||
| """Set recommended watch points.""" | |||||
| body = _read_post_request(request) | |||||
| request_body = body.get('requestBody') | |||||
| if request_body is None: | |||||
| raise ParamMissError('requestBody') | |||||
| set_recommended = request_body.get('set_recommended') | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).set_recommended_watch_points, | |||||
| set_recommended) | |||||
| return reply | return reply | ||||
| BACKEND_SERVER = _initialize_debugger_server() | |||||
| @BLUEPRINT.route("/debugger/sessions", methods=["POST"]) | |||||
| def creat_session(): | |||||
| """ | |||||
| Get session id if session exist, else create a session. | |||||
| Returns: | |||||
| str, session id. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/get-session | |||||
| """ | |||||
| body = _read_post_request(request) | |||||
| summary_dir = body.get('dump_dir') | |||||
| session_type = body.get('session_type') | |||||
| reply = _wrap_reply(SessionManager.get_instance().creat_session, session_type, summary_dir) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/sessions", methods=["GET"]) | |||||
| def get_sessions(): | |||||
| """ | |||||
| Check the cuurent active sessions. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/check-sessions | |||||
| """ | |||||
| reply = _wrap_reply(SessionManager.get_instance().get_sessions) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/delete", methods=["POST"]) | |||||
| def delete_session(session_id): | |||||
| """ | |||||
| Delete session by session id. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/xxx/delete-session | |||||
| """ | |||||
| reply = _wrap_reply(SessionManager.get_instance().delete_session, session_id) | |||||
| return reply | |||||
| def init_module(app): | def init_module(app): | ||||
| @@ -380,5 +424,3 @@ def init_module(app): | |||||
| app (Flask): The application obj. | app (Flask): The application obj. | ||||
| """ | """ | ||||
| app.register_blueprint(BLUEPRINT) | app.register_blueprint(BLUEPRINT) | ||||
| if BACKEND_SERVER: | |||||
| BACKEND_SERVER.start() | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -112,6 +112,11 @@ class _BasicTrainJob: | |||||
| """Get the lineage files count in the summary dir.""" | """Get the lineage files count in the summary dir.""" | ||||
| return self._entry['lineage_files'] | return self._entry['lineage_files'] | ||||
| @property | |||||
| def dump_dir(self): | |||||
| """Get the dump file path in the summary dir.""" | |||||
| return self._entry.get('dump_dir', None) | |||||
| class CachedTrainJob: | class CachedTrainJob: | ||||
| """ | """ | ||||
| @@ -369,6 +374,10 @@ class _BaseCacheManager: | |||||
| class _BriefCacheManager(_BaseCacheManager): | class _BriefCacheManager(_BaseCacheManager): | ||||
| """A cache manager that holds all disk train jobs on disk.""" | """A cache manager that holds all disk train jobs on disk.""" | ||||
| def __init__(self, summary_base_dir): | |||||
| super(_BriefCacheManager, self).__init__(summary_base_dir) | |||||
| self._summary_watcher = SummaryWatcher() | |||||
| def cache_train_job(self, train_id): | def cache_train_job(self, train_id): | ||||
| """ | """ | ||||
| Cache given train job. | Cache given train job. | ||||
| @@ -386,7 +395,7 @@ class _BriefCacheManager(_BaseCacheManager): | |||||
| def update_cache(self, executor): | def update_cache(self, executor): | ||||
| """Update cache.""" | """Update cache.""" | ||||
| logger.info('Start to update BriefCacheManager.') | logger.info('Start to update BriefCacheManager.') | ||||
| summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir) | |||||
| summaries_info = self._summary_watcher.list_summary_directories(self._summary_base_dir) | |||||
| basic_train_jobs = [] | basic_train_jobs = [] | ||||
| for info in summaries_info: | for info in summaries_info: | ||||
| @@ -425,6 +434,10 @@ class _BriefCacheManager(_BaseCacheManager): | |||||
| return new_cache_items | return new_cache_items | ||||
| def register_folder_analyzer(self, analyzer): | |||||
| """Register folder analyzer.""" | |||||
| self._summary_watcher.register_folder_analyzer(analyzer) | |||||
| @property | @property | ||||
| def cache_items(self): | def cache_items(self): | ||||
| """Get cache items.""" | """Get cache items.""" | ||||
| @@ -1028,6 +1041,10 @@ class DataManager: | |||||
| """Register brief cache item updater for brief cache manager.""" | """Register brief cache item updater for brief cache manager.""" | ||||
| self._brief_cache.register_cache_item_updater(updater) | self._brief_cache.register_cache_item_updater(updater) | ||||
| def register_folder_analyzer(self, analyzer): | |||||
| """Register folder analyzer.""" | |||||
| self._brief_cache.register_folder_analyzer(analyzer) | |||||
| def get_brief_cache(self): | def get_brief_cache(self): | ||||
| """Get brief cache.""" | """Get brief cache.""" | ||||
| return self._brief_cache | return self._brief_cache | ||||
| @@ -254,22 +254,24 @@ class MSGraph(Graph): | |||||
| return searched_list | return searched_list | ||||
| def search_leaf_nodes_by_pattern(self, pattern): | |||||
| def search_leaf_nodes_by_pattern(self, pattern, scope_pattern=False): | |||||
| """ | """ | ||||
| Search leaf node by a given pattern. | Search leaf node by a given pattern. | ||||
| Args: | Args: | ||||
| pattern (Union[str, None]): The pattern of the node to search, | pattern (Union[str, None]): The pattern of the node to search, | ||||
| if None, return all node names. | if None, return all node names. | ||||
| scope_pattern (bool): If true, return the children nodes of the scope. Default: False. | |||||
| Returns: | Returns: | ||||
| list[Node], a list of nodes. | list[Node], a list of nodes. | ||||
| """ | """ | ||||
| is_match = lambda x, y: x.lower().startswith(y) if scope_pattern else y in x.lower() | |||||
| if pattern is not None: | if pattern is not None: | ||||
| pattern = pattern.lower() | pattern = pattern.lower() | ||||
| searched_nodes = [ | searched_nodes = [ | ||||
| node for name, node in self._leaf_nodes.items() | node for name, node in self._leaf_nodes.items() | ||||
| if pattern in name.lower() | |||||
| if is_match(name, pattern) | |||||
| ] | ] | ||||
| else: | else: | ||||
| searched_nodes = [node for node in self._leaf_nodes.values()] | searched_nodes = [node for node in self._leaf_nodes.values()] | ||||
| @@ -29,6 +29,7 @@ from mindinsight.utils.exceptions import FileSystemPermissionError | |||||
| LINEAGE_SUMMARY_SUFFIX = '_lineage' | LINEAGE_SUMMARY_SUFFIX = '_lineage' | ||||
| EXPLAIN_SUMMARY_SUFFIX = '_explain' | EXPLAIN_SUMMARY_SUFFIX = '_explain' | ||||
| DUMP_FILE_PREFIX = 'dump_' | |||||
| class SummaryWatcher: | class SummaryWatcher: | ||||
| @@ -45,6 +46,13 @@ class SummaryWatcher: | |||||
| # to avoid long-time blocking | # to avoid long-time blocking | ||||
| MAX_SCAN_COUNT = 20000 | MAX_SCAN_COUNT = 20000 | ||||
| def __init__(self): | |||||
| self._analyzers = [] | |||||
| def register_folder_analyzer(self, analyzer): | |||||
| """Register folder analyzer.""" | |||||
| self._analyzers.append(analyzer) | |||||
| def list_summary_directories(self, summary_base_dir, overall=True, list_explain=False): | def list_summary_directories(self, summary_base_dir, overall=True, list_explain=False): | ||||
| """ | """ | ||||
| List summary directories within base directory. | List summary directories within base directory. | ||||
| @@ -104,7 +112,7 @@ class SummaryWatcher: | |||||
| elif entry.is_dir(): | elif entry.is_dir(): | ||||
| self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry, list_explain) | self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry, list_explain) | ||||
| entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) | entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) | ||||
| self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry.name, counter, list_explain) | |||||
| self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry, counter, list_explain) | |||||
| directories = [] | directories = [] | ||||
| for key, value in summary_dict.items(): | for key, value in summary_dict.items(): | ||||
| @@ -119,7 +127,7 @@ class SummaryWatcher: | |||||
| return directories | return directories | ||||
| def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry_name, counter, list_explain): | |||||
| def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry, counter, list_explain): | |||||
| """ | """ | ||||
| Scan subdir entries. | Scan subdir entries. | ||||
| @@ -134,7 +142,7 @@ class SummaryWatcher: | |||||
| try: | try: | ||||
| subdir_entries = os.scandir(entry_path) | subdir_entries = os.scandir(entry_path) | ||||
| except PermissionError: | except PermissionError: | ||||
| logger.warning('Path of %s under summary base directory is not accessible.', entry_name) | |||||
| logger.warning('Path of %s under summary base directory is not accessible.', entry.name) | |||||
| return | return | ||||
| # sort in ascending order according to modification time. | # sort in ascending order according to modification time. | ||||
| @@ -149,11 +157,14 @@ class SummaryWatcher: | |||||
| logger.info('Stop further scanning due to overall is False and ' | logger.info('Stop further scanning due to overall is False and ' | ||||
| 'number of scanned files exceeds upper limit.') | 'number of scanned files exceeds upper limit.') | ||||
| break | break | ||||
| subdir_relative_path = os.path.join('.', entry_name) | |||||
| subdir_relative_path = os.path.join('.', entry.name) | |||||
| if subdir_entry.is_symlink(): | if subdir_entry.is_symlink(): | ||||
| pass | pass | ||||
| self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry, list_explain) | self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry, list_explain) | ||||
| relative_path = './' | |||||
| self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) | |||||
| def _is_valid_summary_directory(self, summary_base_dir, relative_path): | def _is_valid_summary_directory(self, summary_base_dir, relative_path): | ||||
| """ | """ | ||||
| Check if the given summary directory is valid. | Check if the given summary directory is valid. | ||||
| @@ -198,13 +209,11 @@ class SummaryWatcher: | |||||
| list_explain (bool): Indicates whether to list only the mindexplain folder. | list_explain (bool): Indicates whether to list only the mindexplain folder. | ||||
| """ | """ | ||||
| try: | try: | ||||
| stat = entry.stat() | |||||
| ctime, mtime = self._get_stat_time(entry) | |||||
| except FileNotFoundError: | except FileNotFoundError: | ||||
| logger.warning('File %s not found', entry.name) | logger.warning('File %s not found', entry.name) | ||||
| return | return | ||||
| ctime = datetime.datetime.fromtimestamp(stat.st_ctime).astimezone() | |||||
| mtime = datetime.datetime.fromtimestamp(stat.st_mtime).astimezone() | |||||
| if entry.is_file(): | if entry.is_file(): | ||||
| summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) | summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) | ||||
| pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) | pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) | ||||
| @@ -238,7 +247,10 @@ class SummaryWatcher: | |||||
| summary_dict[relative_path]['explain_files'] += 1 | summary_dict[relative_path]['explain_files'] += 1 | ||||
| else: | else: | ||||
| summary_dict[relative_path]['summary_files'] += 1 | summary_dict[relative_path]['summary_files'] += 1 | ||||
| self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) | |||||
| elif entry.is_dir(): | elif entry.is_dir(): | ||||
| self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) | |||||
| if list_explain: | if list_explain: | ||||
| return | return | ||||
| @@ -261,6 +273,28 @@ class SummaryWatcher: | |||||
| else: | else: | ||||
| summary_dict[relative_path] = _new_entry(ctime, mtime, profiler) | summary_dict[relative_path] = _new_entry(ctime, mtime, profiler) | ||||
| def _check_by_analyzers(self, entry, summary_base_dir, relative_path, summary_dict): | |||||
| """Check by all analyzers.""" | |||||
| try: | |||||
| ctime, mtime = self._get_stat_time(entry) | |||||
| except FileNotFoundError: | |||||
| logger.warning('File %s not found', entry.name) | |||||
| return | |||||
| for analyzer in self._analyzers: | |||||
| register_info = analyzer.analyze(entry, summary_base_dir, relative_path) | |||||
| if register_info: | |||||
| if relative_path not in summary_dict: | |||||
| summary_dict[relative_path] = _new_entry(ctime, mtime) | |||||
| summary_dict[relative_path].update(register_info) | |||||
| def _get_stat_time(self, entry): | |||||
| """Get ctime and mtime.""" | |||||
| stat = entry.stat() | |||||
| ctime = datetime.datetime.fromtimestamp(stat.st_ctime).astimezone() | |||||
| mtime = datetime.datetime.fromtimestamp(stat.st_mtime).astimezone() | |||||
| return ctime, mtime | |||||
| def _find_profiler_dir(self, entry, summary_base_dir, relative_path): | def _find_profiler_dir(self, entry, summary_base_dir, relative_path): | ||||
| """Find profiler dir by the given relative path.""" | """Find profiler dir by the given relative path.""" | ||||
| profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name) | profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name) | ||||
| @@ -342,6 +376,9 @@ class SummaryWatcher: | |||||
| if self._is_valid_profiler_directory(full_path)[0] or \ | if self._is_valid_profiler_directory(full_path)[0] or \ | ||||
| self._is_valid_cluster_profiler_directory(full_path)[0]: | self._is_valid_cluster_profiler_directory(full_path)[0]: | ||||
| return True | return True | ||||
| if os.path.exists(os.path.join(summary_directory, os.path.join(entry.name, ".metadata"))): | |||||
| return True | |||||
| return False | return False | ||||
| def _is_valid_profiler_directory(self, directory): | def _is_valid_profiler_directory(self, directory): | ||||
| @@ -515,7 +552,8 @@ def _new_entry(ctime, mtime, profiler=None): | |||||
| 'lineage_files': 0, | 'lineage_files': 0, | ||||
| 'explain_files': 0, | 'explain_files': 0, | ||||
| 'graph_files': 0, | 'graph_files': 0, | ||||
| 'profiler': profiler | |||||
| 'profiler': profiler, | |||||
| 'dump_dir': None | |||||
| } | } | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -150,7 +150,8 @@ class TrainTaskManager(BaseProcessor): | |||||
| profiler_type=basic_info.profiler_type, | profiler_type=basic_info.profiler_type, | ||||
| summary_files=basic_info.summary_files, | summary_files=basic_info.summary_files, | ||||
| graph_files=basic_info.graph_files, | graph_files=basic_info.graph_files, | ||||
| lineage_files=basic_info.lineage_files | |||||
| lineage_files=basic_info.lineage_files, | |||||
| dump_dir=basic_info.dump_dir | |||||
| ) | ) | ||||
| if train_job.cache_status != CacheStatus.NOT_IN_CACHE: | if train_job.cache_status != CacheStatus.NOT_IN_CACHE: | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -20,6 +20,8 @@ from mindinsight.utils.constant import DebuggerErrors as DebuggerErrorCodes | |||||
| _PARAM_ERROR_MASK = 0b00001 << 7 | _PARAM_ERROR_MASK = 0b00001 << 7 | ||||
| _DEBUGGER_GRAPH_ERROR = 0b00010 << 7 | _DEBUGGER_GRAPH_ERROR = 0b00010 << 7 | ||||
| _DEBUGGER_RUNNING_ERROR = 0b00011 << 7 | _DEBUGGER_RUNNING_ERROR = 0b00011 << 7 | ||||
| _DEBUGGER_SERVER_ERROR = 0b00100 << 7 | |||||
| _DEBUGGER_SESSION_ERROR = 0b00101 << 7 | |||||
| @unique | @unique | ||||
| @@ -44,6 +46,13 @@ class DebuggerErrors(DebuggerErrorCodes): | |||||
| TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR | TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR | ||||
| SET_RECOMMEND_WATCHPOINT_ERROR = 9 | _DEBUGGER_RUNNING_ERROR | SET_RECOMMEND_WATCHPOINT_ERROR = 9 | _DEBUGGER_RUNNING_ERROR | ||||
| DEBUGGER_SERVER_RUNNING_ERROR = 0 | _DEBUGGER_SERVER_ERROR | |||||
| DEVICE_ID_UNREGISTERED = 1 | _DEBUGGER_SERVER_ERROR | |||||
| MODULE_NOT_FOUND_ERROR = 2 | _DEBUGGER_SERVER_ERROR | |||||
| DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR | |||||
| DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR | |||||
| @unique | @unique | ||||
| class DebuggerErrorMsg(Enum): | class DebuggerErrorMsg(Enum): | ||||
| @@ -63,3 +72,10 @@ class DebuggerErrorMsg(Enum): | |||||
| TENSOR_GRAPH_ERROR = "Get tensor graphs failed." | TENSOR_GRAPH_ERROR = "Get tensor graphs failed." | ||||
| TENSOR_HIT_ERROR = "Get tensor hits failed." | TENSOR_HIT_ERROR = "Get tensor hits failed." | ||||
| SET_RECOMMEND_WATCHPOINT_ERROR = "Set Recommend Watchpoints failed." | SET_RECOMMEND_WATCHPOINT_ERROR = "Set Recommend Watchpoints failed." | ||||
| DEBUGGER_SERVER_RUNNING_ERROR = "Debugger server running error. {}" | |||||
| DEVICE_ID_UNREGISTERED = "Device id unregistered. Device id: {}" | |||||
| MODULE_NOT_FOUND_ERROR = "{} module not found." | |||||
| DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation." | |||||
| DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found." | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -190,3 +190,58 @@ class DebuggerConditionUnavailableError(MindInsightException): | |||||
| message=DebuggerErrorMsg.DEBUGGER_CONDITION_UNAVAILABLE_ERROR.value.format(msg), | message=DebuggerErrorMsg.DEBUGGER_CONDITION_UNAVAILABLE_ERROR.value.format(msg), | ||||
| http_code=400 | http_code=400 | ||||
| ) | ) | ||||
| class DebuggerServerRunningError(MindInsightException): | |||||
| """The condition unavailable error in debugger module.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerServerRunningError, self).__init__( | |||||
| error=DebuggerErrors.DEBUGGER_SERVER_RUNNING_ERROR, | |||||
| message=DebuggerErrorMsg.DEBUGGER_SERVER_RUNNING_ERROR.value.format(msg), | |||||
| http_code=500 | |||||
| ) | |||||
| class DeviceIdUnregistered(MindInsightException): | |||||
| """The condition unavailable error in debugger module.""" | |||||
| def __init__(self, msg): | |||||
| super(DeviceIdUnregistered, self).__init__( | |||||
| error=DebuggerErrors.DEVICE_ID_UNREGISTERED, | |||||
| message=DebuggerErrorMsg.DEVICE_ID_UNREGISTERED.value.format(msg), | |||||
| http_code=400 | |||||
| ) | |||||
| class DebuggerModuleNotFoundError(MindInsightException): | |||||
| """The condition unavailable error in debugger module.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerModuleNotFoundError, self).__init__( | |||||
| error=DebuggerErrors.MODULE_NOT_FOUND_ERROR, | |||||
| message=DebuggerErrorMsg.MODULE_NOT_FOUND_ERROR.value.format(msg), | |||||
| http_code=500 | |||||
| ) | |||||
| class DebuggerSessionNumOverBoundError(MindInsightException): | |||||
| """The condition unavailable error in debugger module.""" | |||||
| def __init__(self): | |||||
| super(DebuggerSessionNumOverBoundError, self).__init__( | |||||
| error=DebuggerErrors.DEBUGGER_SESSION_OVER_BOUND_ERROR, | |||||
| message=DebuggerErrorMsg.DEBUGGER_SESSION_OVER_BOUND_ERROR.value, | |||||
| http_code=400 | |||||
| ) | |||||
| class DebuggerSessionNotFoundError(MindInsightException): | |||||
| """The condition unavailable error in debugger module.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerSessionNotFoundError, self).__init__( | |||||
| error=DebuggerErrors.DEBUGGER_SESSION_NOT_FOUND_ERROR, | |||||
| message=DebuggerErrorMsg.DEBUGGER_SESSION_NOT_FOUND_ERROR.value.format(msg), | |||||
| http_code=400 | |||||
| ) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -38,9 +38,12 @@ NUMPY_TYPE_MAP = { | |||||
| 'DT_FLOAT32': np.float32, | 'DT_FLOAT32': np.float32, | ||||
| 'DT_FLOAT64': np.float64, | 'DT_FLOAT64': np.float64, | ||||
| 'DT_STRING': np.str | |||||
| 'DT_STRING': np.str, | |||||
| 'DT_TYPE': np.str | |||||
| } | } | ||||
| MS_VERSION = '1.0.x' | |||||
| @enum.unique | @enum.unique | ||||
| class ReplyStates(enum.Enum): | class ReplyStates(enum.Enum): | ||||
| @@ -71,6 +74,7 @@ class Streams(enum.Enum): | |||||
| TENSOR = 'tensor' | TENSOR = 'tensor' | ||||
| WATCHPOINT = 'watchpoint' | WATCHPOINT = 'watchpoint' | ||||
| WATCHPOINT_HIT = 'watchpoint_hit' | WATCHPOINT_HIT = 'watchpoint_hit' | ||||
| DEVICE = 'device' | |||||
| class RunLevel(enum.Enum): | class RunLevel(enum.Enum): | ||||
| @@ -152,3 +156,26 @@ def is_scope_type(node_type): | |||||
| def is_cst_type(node_type): | def is_cst_type(node_type): | ||||
| """Judge whether the type is const type.""" | """Judge whether the type is const type.""" | ||||
| return node_type == NodeTypeEnum.CONST.value | return node_type == NodeTypeEnum.CONST.value | ||||
| def version_match(ms_version, mi_version): | |||||
| """Judge if the version of Mindinsight and Mindspore is matched.""" | |||||
| if not ms_version: | |||||
| ms_version = MS_VERSION | |||||
| mi_major, mi_minor = mi_version.split('.')[:2] | |||||
| ms_major, ms_minor = ms_version.split('.')[:2] | |||||
| return mi_major == ms_major and mi_minor == ms_minor | |||||
| @enum.unique | |||||
| class DebuggerServerMode(enum.Enum): | |||||
| """Debugger Server Mode.""" | |||||
| ONLINE = 'online' | |||||
| OFFLINE = 'offline' | |||||
| class DumpSettings(enum.Enum): | |||||
| """Dump settings.""" | |||||
| E2E_DUMP_SETTINGS = 'e2e_dump_settings' | |||||
| COMMON_DUMP_SETTINGS = 'common_dump_settings' | |||||
| ASYNC_DUMP_SETTINGS = 'async_dump_settings' | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -64,13 +64,13 @@ class _ConditionParameterValue: | |||||
| return self.parameter.name | return self.parameter.name | ||||
| def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_context): | |||||
| def recommend_watchpoints(condition_mgr: ConditionMgr, multi_card_graph_stream, condition_context): | |||||
| """ | """ | ||||
| Recommend watchpoints. | Recommend watchpoints. | ||||
| Args: | Args: | ||||
| condition_mgr (ConditionMgr): Condition manager instance. | condition_mgr (ConditionMgr): Condition manager instance. | ||||
| graph_stream (GraphHandler): Graph handler instance. | |||||
| multi_card_graph_stream (GraphHandler): Multi card graph handler instance. | |||||
| condition_context (ConditionContext): Context for condition. | condition_context (ConditionContext): Context for condition. | ||||
| Returns: | Returns: | ||||
| @@ -78,7 +78,7 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c | |||||
| """ | """ | ||||
| watch_points = [] | watch_points = [] | ||||
| if not graph_stream.graph: | |||||
| if not multi_card_graph_stream.has_graph: | |||||
| logger.warning("Given graph is None.") | logger.warning("Given graph is None.") | ||||
| return watch_points | return watch_points | ||||
| @@ -86,7 +86,7 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c | |||||
| return watch_points | return watch_points | ||||
| # add weight watch points | # add weight watch points | ||||
| merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, graph_stream) | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, multi_card_graph_stream) | |||||
| _recommend_weight_initialization(merged_info, condition_mgr, watch_points, condition_context) | _recommend_weight_initialization(merged_info, condition_mgr, watch_points, condition_context) | ||||
| _recommend_weight_change_too_large(merged_info, condition_mgr, watch_points, condition_context) | _recommend_weight_change_too_large(merged_info, condition_mgr, watch_points, condition_context) | ||||
| @@ -97,25 +97,27 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c | |||||
| _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context) | _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context) | ||||
| # add gradient watch points | # add gradient watch points | ||||
| merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, graph_stream) | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, multi_card_graph_stream) | |||||
| _recommend_gradient_vanishing(merged_info, condition_mgr, watch_points, condition_context) | _recommend_gradient_vanishing(merged_info, condition_mgr, watch_points, condition_context) | ||||
| # add tensor watch points | # add tensor watch points | ||||
| merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, graph_stream) | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, multi_card_graph_stream) | |||||
| _recommend_operator_overflow(merged_info, condition_mgr, watch_points, condition_context) | _recommend_operator_overflow(merged_info, condition_mgr, watch_points, condition_context) | ||||
| _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context) | _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context) | ||||
| _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context) | _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context) | ||||
| # add activation watch points | # add activation watch points | ||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.TANH.value) | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, | |||||
| ActivationFuncEnum.TANH.value) | |||||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | ||||
| ActivationFuncEnum.TANH.value) | ActivationFuncEnum.TANH.value) | ||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.SIGMOID.value) | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, | |||||
| ActivationFuncEnum.SIGMOID.value) | |||||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | ||||
| ActivationFuncEnum.SIGMOID.value) | ActivationFuncEnum.SIGMOID.value) | ||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, | |||||
| [ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value]) | [ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value]) | ||||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | ||||
| ActivationFuncEnum.RELU.value) | ActivationFuncEnum.RELU.value) | ||||
| @@ -318,12 +320,21 @@ def _recommend_activation_range(basic_info_nodes, condition_mgr, watch_points, c | |||||
| watch_points.append(activation_range_watchpoint) | watch_points.append(activation_range_watchpoint) | ||||
| def get_basic_node_info(node_category, graph_stream, activation_func=None): | |||||
| def get_basic_node_info(node_category, multi_card_graph_stream, activation_func=None): | |||||
| """Get node merged info.""" | """Get node merged info.""" | ||||
| basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func) | |||||
| merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph) | |||||
| merged_info = _add_graph_name(merged_info, graph_stream) | |||||
| return merged_info | |||||
| nodes_for_devices = {} | |||||
| has_node = False | |||||
| for rank_id, graph_stream in multi_card_graph_stream.graph_handlers.items(): | |||||
| basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func) | |||||
| merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph) | |||||
| merged_info = _add_graph_name(merged_info, graph_stream) | |||||
| nodes_for_devices[rank_id] = merged_info | |||||
| has_node = has_node or merged_info | |||||
| if has_node: | |||||
| return nodes_for_devices | |||||
| return {} | |||||
| def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None): | def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None): | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -17,17 +17,19 @@ import sys | |||||
| from mindinsight.debugger.common.log import LOGGER as log | from mindinsight.debugger.common.log import LOGGER as log | ||||
| from mindinsight.debugger.common.utils import Streams | from mindinsight.debugger.common.utils import Streams | ||||
| from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \ | |||||
| TensorHandler, WatchpointHandler, WatchpointHitHandler | |||||
| from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, MultiCardGraphHandler, \ | |||||
| MultiCardTensorHandler, WatchpointHandler, MultiCardWatchpointHitHandler | |||||
| from mindinsight.debugger.stream_handler.device_handler import DeviceHandler | |||||
| STREAM_HANDLER_MAP = { | STREAM_HANDLER_MAP = { | ||||
| Streams.COMMAND.value: EventHandler, | Streams.COMMAND.value: EventHandler, | ||||
| Streams.DATA.value: EventHandler, | Streams.DATA.value: EventHandler, | ||||
| Streams.METADATA.value: MetadataHandler, | Streams.METADATA.value: MetadataHandler, | ||||
| Streams.GRAPH.value: GraphHandler, | |||||
| Streams.TENSOR.value: TensorHandler, | |||||
| Streams.GRAPH.value: MultiCardGraphHandler, | |||||
| Streams.TENSOR.value: MultiCardTensorHandler, | |||||
| Streams.WATCHPOINT.value: WatchpointHandler, | Streams.WATCHPOINT.value: WatchpointHandler, | ||||
| Streams.WATCHPOINT_HIT.value: WatchpointHitHandler | |||||
| Streams.WATCHPOINT_HIT.value: MultiCardWatchpointHitHandler, | |||||
| Streams.DEVICE.value: DeviceHandler | |||||
| } | } | ||||
| @@ -40,10 +42,8 @@ class DebuggerCache: | |||||
| def initialize(self): | def initialize(self): | ||||
| """Initialize the stream handlers.""" | """Initialize the stream handlers.""" | ||||
| self._stream_handler = {} | self._stream_handler = {} | ||||
| for stream in Streams: | |||||
| mode = stream.value | |||||
| stream_handler = STREAM_HANDLER_MAP.get(mode) | |||||
| self._stream_handler[mode] = stream_handler() | |||||
| for mode, stream_class in STREAM_HANDLER_MAP.items(): | |||||
| self._stream_handler[mode] = stream_class() | |||||
| def clean(self): | def clean(self): | ||||
| """Clean cache for all stream.""" | """Clean cache for all stream.""" | ||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Debugger train job register.""" | |||||
| import os | |||||
| from mindinsight.utils.folder_analyzer import FolderAnalyzer | |||||
| from mindinsight.datavisual.common.log import logger | |||||
| class DebuggerFolderAnalyzer(FolderAnalyzer): | |||||
| """Debugger train job register.""" | |||||
| def analyze(self, entry, summary_base_dir, relative_path): | |||||
| """Check dir by debugger register.""" | |||||
| update_info = {} | |||||
| if entry.is_dir(): | |||||
| sub_relative_path = os.path.join(relative_path, entry.name) | |||||
| entry_path = os.path.join(summary_base_dir, sub_relative_path) | |||||
| try: | |||||
| subdir_entries = os.scandir(entry_path) | |||||
| except PermissionError: | |||||
| logger.warning('Path of %s under summary base directory is not accessible.', entry.name) | |||||
| return update_info | |||||
| subdir_entries = [subdir_entry for subdir_entry in subdir_entries if not subdir_entry.is_symlink()] | |||||
| subdir_entries = sorted(subdir_entries, key=lambda x: x.stat().st_mtime) | |||||
| for subdir_entry in subdir_entries: | |||||
| if subdir_entry.is_dir() and subdir_entry.name.startswith(".metadata"): | |||||
| update_info = {'dump_dir': sub_relative_path} | |||||
| return update_info | |||||
| return update_info | |||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Debugger server module.""" | |||||
| @@ -19,7 +19,7 @@ from functools import wraps | |||||
| import mindinsight | import mindinsight | ||||
| from mindinsight.debugger.common.log import LOGGER as log | from mindinsight.debugger.common.log import LOGGER as log | ||||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | ||||
| Streams, RunLevel | |||||
| Streams, RunLevel, version_match | |||||
| from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum, ParamNameEnum | from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum, ParamNameEnum | ||||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | ||||
| from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto | from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto | ||||
| @@ -117,9 +117,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| # clean cache data at the beginning of new step or node has been changed. | # clean cache data at the beginning of new step or node has been changed. | ||||
| if is_new_step or is_new_node: | if is_new_step or is_new_node: | ||||
| self._cache_store.clean_data() | self._cache_store.clean_data() | ||||
| self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).clean_tensors( | |||||
| request.cur_step) | |||||
| if is_new_step: | if is_new_step: | ||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean() | |||||
| # receive graph at the beginning of the training | # receive graph at the beginning of the training | ||||
| if self._status == ServerStatus.RECEIVE_GRAPH: | if self._status == ServerStatus.RECEIVE_GRAPH: | ||||
| self._send_graph_flag(metadata_stream) | self._send_graph_flag(metadata_stream) | ||||
| @@ -141,7 +142,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| self._status = ServerStatus.WAITING | self._status = ServerStatus.WAITING | ||||
| metadata_stream.state = ServerStatus.WAITING.value | metadata_stream.state = ServerStatus.WAITING.value | ||||
| metadata = metadata_stream.get() | metadata = metadata_stream.get() | ||||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get() | |||||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get() | |||||
| res.update(metadata) | res.update(metadata) | ||||
| self._cache_store.put_data(res) | self._cache_store.put_data(res) | ||||
| log.debug("Put graph into data queue.") | log.debug("Put graph into data queue.") | ||||
| @@ -157,7 +158,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| # put new metadata into cache | # put new metadata into cache | ||||
| metadata_stream.put(metadata_proto) | metadata_stream.put(metadata_proto) | ||||
| # update current node name and graph name | # update current node name and graph name | ||||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0) | |||||
| full_name = metadata_proto.cur_node | full_name = metadata_proto.cur_node | ||||
| graph_name = graph_stream.get_graph_id_by_full_name( | graph_name = graph_stream.get_graph_id_by_full_name( | ||||
| full_name) if full_name else metadata_stream.graph_name | full_name) if full_name else metadata_stream.graph_name | ||||
| @@ -182,7 +183,8 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| def _send_watchpoint_hit_flag(self): | def _send_watchpoint_hit_flag(self): | ||||
| """Send Watchpoint hit flag.""" | """Send Watchpoint hit flag.""" | ||||
| watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id( | |||||
| 0) | |||||
| if not self._received_hit: | if not self._received_hit: | ||||
| return | return | ||||
| watchpoint_hits = self._received_hit | watchpoint_hits = self._received_hit | ||||
| @@ -344,7 +346,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| run_cmd.node_name = '' | run_cmd.node_name = '' | ||||
| # clean watchpoint hit cache | # clean watchpoint hit cache | ||||
| if run_cmd.run_level == RunLevel.RECHECK.value: | if run_cmd.run_level == RunLevel.RECHECK.value: | ||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean() | |||||
| log.debug("Receive RunCMD. Clean watchpoint hit cache.") | log.debug("Receive RunCMD. Clean watchpoint hit cache.") | ||||
| # update metadata state from sending to running | # update metadata state from sending to running | ||||
| metadata_stream.state = ServerStatus.RUNNING.value | metadata_stream.state = ServerStatus.RUNNING.value | ||||
| @@ -365,8 +367,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| log.info("The training from %s has finished.", client_ip) | log.info("The training from %s has finished.", client_ip) | ||||
| else: | else: | ||||
| ms_version = request.ms_version | ms_version = request.ms_version | ||||
| if not ms_version: | |||||
| ms_version = '1.0.x' | |||||
| if version_match(ms_version, mindinsight.__version__) is False: | if version_match(ms_version, mindinsight.__version__) is False: | ||||
| log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s", | log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s", | ||||
| ms_version, mindinsight.__version__) | ms_version, mindinsight.__version__) | ||||
| @@ -403,8 +403,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| graph = GraphProto.FromString(serial_graph) | graph = GraphProto.FromString(serial_graph) | ||||
| log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node)) | log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node)) | ||||
| graph_dict = {graph.name: graph} | graph_dict = {graph.name: graph} | ||||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) | |||||
| self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict) | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals( | |||||
| graph.const_vals) | |||||
| self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name | self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name | ||||
| self._record_parameter_names() | self._record_parameter_names() | ||||
| self._status = ServerStatus.RECEIVE_GRAPH | self._status = ServerStatus.RECEIVE_GRAPH | ||||
| @@ -429,10 +430,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name, | log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name, | ||||
| len(sub_graph.node)) | len(sub_graph.node)) | ||||
| serial_graph = b"" | serial_graph = b"" | ||||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals( | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals( | |||||
| sub_graph.const_vals) | sub_graph.const_vals) | ||||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | |||||
| self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict) | |||||
| self._record_parameter_names() | self._record_parameter_names() | ||||
| self._status = ServerStatus.RECEIVE_GRAPH | self._status = ServerStatus.RECEIVE_GRAPH | ||||
| log.debug("Send the reply for graph.") | log.debug("Send the reply for graph.") | ||||
| @@ -440,9 +441,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| def _record_parameter_names(self): | def _record_parameter_names(self): | ||||
| """Record parameter full names in tensor handler.""" | """Record parameter full names in tensor handler.""" | ||||
| parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).search_in_graph( | |||||
| pattern={'node_category': TargetTypeEnum.PARAMETER.value}) | |||||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||||
| parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)\ | |||||
| .search_in_graph(pattern={'node_category': TargetTypeEnum.PARAMETER.value}) | |||||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0) | |||||
| for node in parameter_nodes: | for node in parameter_nodes: | ||||
| tensor_name = [node.full_name + ':0'] | tensor_name = [node.full_name + ':0'] | ||||
| tensor_stream.record_parameter_names(tensor_name) | tensor_stream.record_parameter_names(tensor_name) | ||||
| @@ -452,7 +453,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| """Send tensors into DebuggerCache.""" | """Send tensors into DebuggerCache.""" | ||||
| log.info("Received tensor.") | log.info("Received tensor.") | ||||
| tensor_contents = [] | tensor_contents = [] | ||||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0) | |||||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | ||||
| step = metadata_stream.step | step = metadata_stream.step | ||||
| for tensor in request_iterator: | for tensor in request_iterator: | ||||
| @@ -482,7 +483,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| # save the watchpoint_hits data | # save the watchpoint_hits data | ||||
| watchpoint_hits = [] | watchpoint_hits = [] | ||||
| watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) | watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) | ||||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0) | |||||
| for watchpoint_hit_proto in request_iterator: | for watchpoint_hit_proto in request_iterator: | ||||
| node_full_name = watchpoint_hit_proto.tensor.node_name | node_full_name = watchpoint_hit_proto.tensor.node_name | ||||
| graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) | graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) | ||||
| @@ -517,10 +518,3 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| self._received_hit = watchpoint_hits | self._received_hit = watchpoint_hits | ||||
| reply = get_ack_reply() | reply = get_ack_reply() | ||||
| return reply | return reply | ||||
| def version_match(mi_version, ms_version): | |||||
| """Judge if the version of Mindinsight and Mindspore is matched""" | |||||
| mi_major, mi_minor = mi_version.split('.')[:2] | |||||
| ms_major, ms_minor = ms_version.split('.')[:2] | |||||
| return mi_major == ms_major and mi_minor == ms_minor | |||||
| @@ -0,0 +1,613 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Debugger Offline server.""" | |||||
| import copy | |||||
| from collections import defaultdict | |||||
| from importlib import import_module | |||||
| from threading import Event | |||||
| from multiprocessing import Process, Manager | |||||
| import mindinsight | |||||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerModuleNotFoundError | |||||
| from mindinsight.debugger.common.log import LOGGER as log | |||||
| from mindinsight.debugger.common.utils import Streams, ServerStatus, version_match, DebuggerServerMode, get_ack_reply, \ | |||||
| RunLevel | |||||
| from mindinsight.debugger.conditionmgr.condition import ParamNameEnum | |||||
| from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase, debugger_server_wrap | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply | |||||
| from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto | |||||
| from mindinsight.debugger.stream_cache.data_loader import DataLoader | |||||
| from mindinsight.utils.exceptions import MindInsightException | |||||
| class DebuggerOfflineServer(DebuggerServerBase): | |||||
| """Debugger Offline Server.""" | |||||
| _MAX_TRY_EXCEPT_COUNT = 500 | |||||
| def __init__(self, cache_store, context): | |||||
| super(DebuggerOfflineServer, self).__init__(cache_store, context) | |||||
| self._offline_server_manager = DebuggerOfflineManager(cache_store, context.dbg_dir) | |||||
| self._running = Event() | |||||
| self._running.clear() | |||||
| def run(self): | |||||
| """Start the debugger offline server.""" | |||||
| log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) | |||||
| self._offline_server_manager.initialize() | |||||
| self._running.set() | |||||
| log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) | |||||
| try_count = 0 | |||||
| while self._running.is_set() and try_count < self._MAX_TRY_EXCEPT_COUNT: | |||||
| try: | |||||
| self._offline_server_manager.wait_for_termination() | |||||
| if not self._offline_server_manager.is_runnable(): | |||||
| break | |||||
| except MindInsightException as err: | |||||
| log.exception(err) | |||||
| log.warning("Error happens during listening on user commands. Restart listening again.") | |||||
| finally: | |||||
| try_count += 1 | |||||
| # protect server from too much failure commands. | |||||
| if try_count == self._MAX_TRY_EXCEPT_COUNT: | |||||
| self._cache_store.clean() | |||||
| metadata = self._cache_store.get_stream_handler(Streams.METADATA).get() | |||||
| self._cache_store.put_data(metadata) | |||||
| log.warning("Exception exceed %d times, stop server.", try_count) | |||||
| def stop(self): | |||||
| """Stop offline debugger server.""" | |||||
| log.debug("Start to wait for thread started.") | |||||
| self._running.wait() | |||||
| log.info("Start to stop offline debugger server.") | |||||
| self._running.clear() | |||||
| self._offline_server_manager.stop() | |||||
| self.join() | |||||
| class DebuggerOfflineManager: | |||||
| """Debugger offline manager which is used to handle user commands.""" | |||||
| def __init__(self, cache_store, dbg_dir): | |||||
| cache_store.initialize() | |||||
| self._cache_store = cache_store | |||||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | |||||
| self._dbg_dir = dbg_dir | |||||
| self._dbg_services_module = self._get_dbg_service_module() | |||||
| self._dbg_service = None | |||||
| self._command_listener = CommandListener(cache_store) | |||||
| self._data_loader = DataLoader(dbg_dir) | |||||
| self._is_running_flag = False | |||||
| self._old_run_cmd = {} | |||||
| def stop(self): | |||||
| """Stop server.""" | |||||
| self._is_running_flag = False | |||||
| self._command_listener.stop() | |||||
| self._cache_store.clean() | |||||
| event = get_ack_reply() | |||||
| event.exit = True | |||||
| self._cache_store.put_command(event) | |||||
| log.info("Stop debugger offline manager.") | |||||
| def is_runnable(self): | |||||
| """Check if the offline manager is runnable.""" | |||||
| state = self._metadata_stream.state | |||||
| flag = self._is_running_flag and state not in [ServerStatus.MISMATCH.value, ServerStatus.PENDING.value] | |||||
| if not flag: | |||||
| log.debug("The offline manager is not runnable, is_running_flag: %s, metadata state: %s", | |||||
| self._is_running_flag, state) | |||||
| return flag | |||||
| @staticmethod | |||||
| def _get_dbg_service_module(): | |||||
| """Get dbg service module from MindSpore.""" | |||||
| try: | |||||
| dbg_services_module = import_module('mindspore.offline_debug.dbg_services') | |||||
| except ModuleNotFoundError as err: | |||||
| log.error("Failed to find module dbg_services. %s", err) | |||||
| raise DebuggerModuleNotFoundError("dbg_services") | |||||
| return dbg_services_module | |||||
| @debugger_server_wrap | |||||
| def initialize(self): | |||||
| """Start to load offline debugger data.""" | |||||
| self._data_loader.initialize() | |||||
| is_sync = self._data_loader.get_sync_flag() | |||||
| net_name = self._data_loader.get_net_name() | |||||
| net_dir = self._data_loader.get_net_dir() | |||||
| self._dbg_service = self._dbg_services_module.DbgServices(net_dir) | |||||
| self._dbg_service.initialize(net_name=net_name, is_sync_mode=is_sync) | |||||
| self._cache_store.clean() | |||||
| self._command_listener.start() | |||||
| self._is_running_flag = True | |||||
| self._check_version() | |||||
| if self._metadata_stream.state == ServerStatus.MISMATCH.value: | |||||
| log.info("The MindSpore and MindInsight version are mismatched. Failed to initialize offline server.") | |||||
| return | |||||
| self._load_metadata() | |||||
| self._load_graphs() | |||||
| log.info("Success initialize offline server for %s", self._dbg_dir) | |||||
| def _check_version(self): | |||||
| """Check version.""" | |||||
| ms_version = self._dbg_services_module.get_version() | |||||
| mi_version = mindinsight.__version__ | |||||
| self._metadata_stream.debugger_version = {'ms': ms_version, 'mi': mi_version} | |||||
| if version_match(ms_version, mi_version) is False: | |||||
| log.info("Version is mismatched, dbg_services is: %s, mindinsight is: %s", | |||||
| ms_version, mi_version) | |||||
| self._metadata_stream.state = ServerStatus.MISMATCH.value | |||||
| metadata = self._metadata_stream.get(['state', 'debugger_version']) | |||||
| self._cache_store.put_data(metadata) | |||||
| def _load_metadata(self): | |||||
| """Load metadata.""" | |||||
| self._metadata_stream.debugger_type = DebuggerServerMode.OFFLINE.value | |||||
| device_info = self._data_loader.load_device_info() | |||||
| # The backend referred to the running environment on which the offline debugger | |||||
| # data was generated. | |||||
| # Currently supported options: `GPU`, `Ascend` | |||||
| backend = device_info.get('device_target', 'Ascend') | |||||
| self._metadata_stream.backend = backend | |||||
| device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) | |||||
| device_stream.put(device_info.get('server_list')) | |||||
| rank_id = 0 | |||||
| rank_0_info = device_stream.get(rank_id)['devices'][0] | |||||
| self._metadata_stream.client_ip = rank_0_info.get('server_id') | |||||
| # get step number per device. dict(device_id, step_num), may be increased with time goes by | |||||
| step_num_per_device = self._data_loader.load_step_number() | |||||
| device_stream.add_step_num_info(step_num_per_device) | |||||
| self._metadata_stream.max_step_num = max(step_num_per_device.values()) | |||||
| def _load_graphs(self): | |||||
| """Load graphs.""" | |||||
| # the format of graphs is a list of {'device_id': int, 'graph_protos': [GraphProto]}} | |||||
| graphs = self._data_loader.load_graphs() | |||||
| device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) | |||||
| graph_per_rank = {} | |||||
| for graph in graphs: | |||||
| device_id = int(graph.get('device_id')) | |||||
| rank_id = device_stream.get_rank_id_by_device_id(device_id) | |||||
| graph_per_rank[rank_id] = {} | |||||
| tensor_stream_per_rank = self._cache_store.get_stream_handler(Streams.TENSOR).\ | |||||
| get_tensor_handler_by_rank_id(rank_id, create_if_not_exit=True) | |||||
| for graph_proto in graph.get('graph_protos'): | |||||
| graph_per_rank[rank_id][graph_proto.name] = graph_proto | |||||
| tensor_stream_per_rank.put_const_vals(graph_proto.const_vals) | |||||
| # the graph_per_rank is format like: Dict[<rank_id>, Dict[<graph_name>, <GraphProto>]] | |||||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_per_rank) | |||||
| device_stream.add_graph_name_info(graph_per_rank) | |||||
| self._metadata_stream.state = ServerStatus.RECEIVE_GRAPH.value | |||||
| @debugger_server_wrap | |||||
| def wait_for_termination(self): | |||||
| """Begin to listen on command event.""" | |||||
| log.info("Begin to listen for user commands.") | |||||
| self._send_graph() | |||||
| while self.is_runnable(): | |||||
| if not self._command_listener.has_new_command() and self._old_run_cmd: | |||||
| self._deal_with_old_run_cmd() | |||||
| continue | |||||
| cmd = self._command_listener.get_next_command() | |||||
| self.deal_with_cmd(cmd) | |||||
| def _send_graph(self): | |||||
| """Put graph and metadata info into data queue.""" | |||||
| if not self.is_runnable(): | |||||
| return | |||||
| self._metadata_stream.state = ServerStatus.WAITING.value | |||||
| metadata = self._metadata_stream.get() | |||||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get() | |||||
| res.update(metadata) | |||||
| self._cache_store.put_data(res) | |||||
| def _deal_with_old_run_cmd(self): | |||||
| """Deal with old run command.""" | |||||
| left_step_count = self._old_run_cmd.get('left_step_count') | |||||
| if left_step_count: | |||||
| self._execute_one_step() | |||||
| # if old_run_cmd is not cleared due to hit. | |||||
| if self._old_run_cmd: | |||||
| self._old_run_cmd['left_step_count'] = left_step_count - 1 if left_step_count > 0 else -1 | |||||
| if not self._old_run_cmd.get('left_step_count'): | |||||
| self._old_run_cmd.clear() | |||||
| def deal_with_cmd(self, cmd): | |||||
| """Deal with command.""" | |||||
| if cmd is None: | |||||
| return | |||||
| if isinstance(cmd, dict): | |||||
| self._deal_with_view_cmd(cmd) | |||||
| elif isinstance(cmd, EventReply): | |||||
| self._on_event(cmd) | |||||
| def _on_event(self, event): | |||||
| """ | |||||
| Deal with different command event. | |||||
| Args: | |||||
| event (EventReply): Command Event. | |||||
| """ | |||||
| if event.HasField('run_cmd'): | |||||
| self._deal_with_run_cmd(event) | |||||
| elif event.HasField('exit'): | |||||
| self._cache_store.clean() | |||||
| self._update_state(ServerStatus.PENDING) | |||||
| log.debug("Clean cache for exit cmd.") | |||||
| else: | |||||
| self._deal_with_set_cmd(event) | |||||
| log.debug("Deal with set cmd.") | |||||
| def _deal_with_view_cmd(self, event): | |||||
| """ | |||||
| Deal with view cmd. | |||||
| Args: | |||||
| event (dict): View command params. | |||||
| - view_cmd (EventReply): EventReply with view command. | |||||
| - node_name (str): The center node name for view command. | |||||
| - tensor_name (str): The center tensor name for view command. | |||||
| - graph_name (str): The graph name of center node. | |||||
| - rank_id (int): The device id of the tensor. | |||||
| """ | |||||
| view_cmd = event.pop('view_cmd', None).view_cmd | |||||
| node_info = event | |||||
| log.debug("Receive view cmd for node: %s.", event) | |||||
| if not (view_cmd and node_info): | |||||
| log.info("Invalid view command. Ignore it.") | |||||
| return | |||||
| # read tensor value by dbg_service | |||||
| rank_id = node_info.get('rank_id', 0) | |||||
| device_id = self._cache_store.get_stream_handler(Streams.DEVICE).get_device_id_by_rank_id(rank_id) | |||||
| cur_step = self._metadata_stream.step | |||||
| tensor_protos = view_cmd.tensors | |||||
| root_graph_id = self.get_root_graph_id() | |||||
| tensor_infos = [ | |||||
| self._dbg_services_module.TensorInfo( | |||||
| node_name=tensor_proto.node_name, | |||||
| slot=int(tensor_proto.slot), | |||||
| iteration=cur_step - 1 if tensor_proto.iter == 'prev' else cur_step, | |||||
| device_id=device_id, | |||||
| is_parameter=tensor_proto.truncate, | |||||
| root_graph_id=root_graph_id | |||||
| ) for tensor_proto in tensor_protos] | |||||
| res = self._dbg_service.read_tensors(tensor_infos) | |||||
| # put tensor into cache | |||||
| for tensor_proto, tensor_data in zip(tensor_protos, res): | |||||
| log.debug("Tensor name: %s:%s, tensor type: %s, tensor size: %s", tensor_proto.node_name, tensor_proto.slot, | |||||
| tensor_data.dtype, tensor_data.data_size) | |||||
| tensor_proto.tensor_content = tensor_data.data_ptr | |||||
| tensor_proto.ClearField('dims') | |||||
| tensor_proto.dims.extend(tensor_data.shape) | |||||
| tensor_proto.data_type = tensor_data.dtype | |||||
| self._put_tensor_value_into_cache(cur_step, node_info, rank_id, tensor_protos) | |||||
| log.info("Put tensor value into cache.") | |||||
| def get_root_graph_id(self): | |||||
| """Get root graph id.""" | |||||
| is_sync = self._data_loader.get_sync_flag() | |||||
| graph_id = 0 if is_sync else 1 | |||||
| return graph_id | |||||
| def _put_tensor_value_into_cache(self, cur_step, node_info, rank_id, tensor_protos): | |||||
| """Put tensor value into tensor cache.""" | |||||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR). \ | |||||
| get_tensor_handler_by_rank_id(rank_id) | |||||
| update_data_flag = False | |||||
| for tensor_proto in tensor_protos: | |||||
| if not tensor_proto.tensor_content: | |||||
| log.warning("Tensor %s:%s is empty.", | |||||
| tensor_proto.node_name, tensor_proto.slot) | |||||
| try: | |||||
| has_update = tensor_stream.put({ | |||||
| 'step': cur_step, | |||||
| 'tensor_proto': tensor_proto, | |||||
| 'tensor_contents': [tensor_proto.tensor_content] | |||||
| }) | |||||
| except ValueError as err: | |||||
| log.warning("Failed to put %s:%s into cache. Ignore it. %s", | |||||
| tensor_proto.node_name, tensor_proto.slot, str(err)) | |||||
| continue | |||||
| if has_update: | |||||
| update_data_flag = True | |||||
| if update_data_flag: | |||||
| # send message to frontend | |||||
| metadata = self._metadata_stream.get(['step', 'state']) | |||||
| ret = {'receive_tensor': node_info.copy()} | |||||
| ret.update(metadata) | |||||
| self._cache_store.put_data(ret) | |||||
| def _deal_with_run_cmd(self, event): | |||||
| """Deal with run cmd.""" | |||||
| run_cmd = event.run_cmd | |||||
| parsed_run_cmd = self._get_parsed_run_cmd(run_cmd) | |||||
| if parsed_run_cmd.run_steps > 0: | |||||
| self._execute_one_step() | |||||
| elif run_cmd.run_level == RunLevel.RECHECK.value: | |||||
| log.info("Deal with recheck command.") | |||||
| self._check_watchpoint(self._metadata_stream.step) | |||||
| def _execute_one_step(self): | |||||
| """Execute on step.""" | |||||
| new_step = self._metadata_stream.step + 1 | |||||
| if new_step > self._metadata_stream.max_step_num: | |||||
| self._old_run_cmd.clear() | |||||
| log.info("The server is already at the last step. %s", self._metadata_stream.max_step_num) | |||||
| return | |||||
| log.info("Go to next step: %s.", new_step) | |||||
| self._check_watchpoint(new_step) | |||||
| self._metadata_stream.step = new_step | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).set_step(new_step) | |||||
| self._cache_store.put_data(self._metadata_stream.get('step')) | |||||
| def _get_parsed_run_cmd(self, run_cmd): | |||||
| """Get parsed run command.""" | |||||
| if run_cmd.run_level == RunLevel.STEP.value: | |||||
| # receive pause cmd | |||||
| if not run_cmd.run_steps: | |||||
| log.debug("Pause training and wait for next command.") | |||||
| self._old_run_cmd.clear() | |||||
| # update metadata state from sending to waiting | |||||
| self._update_state(ServerStatus.WAITING) | |||||
| return run_cmd | |||||
| # receive step cmd | |||||
| left_steps = run_cmd.run_steps - 1 | |||||
| run_cmd.run_steps = 1 | |||||
| if left_steps: | |||||
| self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1 | |||||
| elif run_cmd.node_name: | |||||
| self._old_run_cmd['node_name'] = run_cmd.node_name | |||||
| run_cmd.node_name = '' | |||||
| return run_cmd | |||||
| def _check_watchpoint(self, step): | |||||
| """Save watchpoint hits into cache.""" | |||||
| self._update_state(ServerStatus.RUNNING) | |||||
| # Clean watchpoint_hits in cache | |||||
| multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| multi_card_hit_streams.clean() | |||||
| hits = Manager().list() | |||||
| check_watchpoints_process = Process(target=self._check_watchpoint_work, args=(hits, step,)) | |||||
| check_watchpoints_process.start() | |||||
| check_watchpoints_process.join() | |||||
| log.info("finish check watchpoint of %s", step) | |||||
| if hits: | |||||
| log.info("Received WatchpointHits. Left run cmd %s change to empty.", self._old_run_cmd) | |||||
| self._old_run_cmd.clear() | |||||
| self._update_state(ServerStatus.WAITING) | |||||
| self._save_watchpoint_hits(hits) | |||||
| def _save_watchpoint_hits(self, hits): | |||||
| """Save watchpoint hits.""" | |||||
| multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| multi_card_graph_streams = self._cache_store.get_stream_handler(Streams.GRAPH) | |||||
| device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) | |||||
| watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) | |||||
| watchpoint_hits = defaultdict(list) | |||||
| for hit in hits: | |||||
| log.info("Received hit\n: " | |||||
| "name:%s, slot:%s, condition:%s, " | |||||
| "watchpoint_id:%s" | |||||
| "error_code:%s, device_id:%s", | |||||
| hit['name'], hit['slot'], hit['condition'], | |||||
| hit['watchpoint_id'], hit['error_code'], hit['device_id']) | |||||
| rank_id = device_stream.get_rank_id_by_device_id(hit['device_id']) | |||||
| watchpoint_hit = {} | |||||
| self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit) | |||||
| if not watchpoint_hit: | |||||
| continue | |||||
| self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit) | |||||
| watchpoint_hit['error_code'] = hit['error_code'] | |||||
| watchpoint_hits[rank_id].append(watchpoint_hit) | |||||
| # save hit info into cache | |||||
| multi_card_hit_streams.put(watchpoint_hits) | |||||
| self._cache_store.put_data({'receive_watchpoint_hits': True}) | |||||
| log.debug("Send the watchpoint hits to DataQueue.") | |||||
| @staticmethod | |||||
| def _add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit): | |||||
| """Add hit node info.""" | |||||
| graph_stream = multi_card_graph_streams.get_graph_handler_by_rank_id(rank_id) | |||||
| node_full_name = hit['name'] | |||||
| graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) | |||||
| if not graph_name: | |||||
| log.warning("Cannot find node %s in graph. Skip it.", node_full_name) | |||||
| return | |||||
| ui_node_name = graph_stream.get_node_name_by_full_name(node_full_name, graph_name) | |||||
| log.debug("Receive watch point hit: %s:%s", node_full_name, hit['slot']) | |||||
| if not ui_node_name: | |||||
| log.info("Not support to show %s on graph.", node_full_name) | |||||
| return | |||||
| watchpoint_hit.update({ | |||||
| 'tensor_proto': TensorProto(node_name=node_full_name, slot=str(hit['slot'])), | |||||
| 'node_name': ui_node_name, | |||||
| 'graph_name': graph_name | |||||
| }) | |||||
| @staticmethod | |||||
| def _add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit): | |||||
| """Add watchpoint hit info.""" | |||||
| watchpoint = copy.deepcopy(watchpoint_stream.get_watchpoint_by_id(hit['watchpoint_id'])) | |||||
| hit_params = {} | |||||
| # get hit actual value | |||||
| for param in hit['parameters']: | |||||
| if param['name'] not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value, | |||||
| ParamNameEnum.RANGE_END_INCLUSIVE.value) \ | |||||
| and hit['error_code'] == 0: | |||||
| hit_params[param['name']] = param['actual_value'] | |||||
| # update actual value into watchpoint | |||||
| watchpoint_condition_params = watchpoint.condition['params'] | |||||
| for i, param in enumerate(watchpoint_condition_params): | |||||
| name = param['name'] | |||||
| if name in hit_params.keys(): | |||||
| watchpoint_condition_params[i]['actual_value'] = hit_params[name] | |||||
| else: | |||||
| watchpoint_condition_params[i]['actual_value'] = None | |||||
| watchpoint_hit['watchpoint'] = watchpoint | |||||
| def _deal_with_set_cmd(self, event): | |||||
| """ | |||||
| Deal with set cmd. | |||||
| Args: | |||||
| event (EventReply): User command event including set_cmd. | |||||
| """ | |||||
| set_cmd = event.set_cmd | |||||
| set_cmd_id = set_cmd.id | |||||
| delete = set_cmd.delete | |||||
| if not delete: | |||||
| log.info("Add watchpoint by using dbg_server.") | |||||
| watch_condition = set_cmd.watch_condition | |||||
| param_list = [] | |||||
| for param in watch_condition.params: | |||||
| param_list.append( | |||||
| self._dbg_services_module.Parameter(param.name, param.disabled, param.value)) | |||||
| watch_nodes = set_cmd.watch_nodes | |||||
| check_nodes = self._get_check_nodes(watch_nodes) | |||||
| log.debug("Watchpoint %s, condition: %s, watch nodes: %s", | |||||
| set_cmd_id, watch_condition.condition, check_nodes) | |||||
| self._dbg_service.add_watchpoint(set_cmd_id, watch_condition.condition, check_nodes, param_list) | |||||
| else: | |||||
| log.info("Remove watchpoint by using dbg_server.") | |||||
| self._dbg_service.remove_watchpoint(set_cmd_id) | |||||
| def _get_check_nodes(self, watch_nodes): | |||||
| """Get check nodes format""" | |||||
| check_nodes = {} | |||||
| device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) | |||||
| root_graph_id = self.get_root_graph_id() | |||||
| for watch_node in watch_nodes: | |||||
| node_name = watch_node.node_name | |||||
| rank_id = watch_node.rank_id | |||||
| device_id = device_stream.get_device_id_by_rank_id(rank_id) | |||||
| if node_name not in check_nodes: | |||||
| is_parameter = bool(watch_node.node_type == NodeTypeEnum.PARAMETER.value) | |||||
| check_nodes[node_name] = { | |||||
| "device_id": [device_id], | |||||
| "is_parameter": is_parameter, | |||||
| "root_graph_id": [root_graph_id] | |||||
| } | |||||
| else: | |||||
| check_nodes[node_name]["device_id"].append(device_id) | |||||
| return check_nodes | |||||
| def _update_state(self, server_status): | |||||
| """ | |||||
| Update state in metadata stream. | |||||
| Args: | |||||
| server_status (ServerStatus): The enum value in ServerStatus. | |||||
| """ | |||||
| if self._metadata_stream.state != server_status.value: | |||||
| self._metadata_stream.state = server_status.value | |||||
| self._cache_store.put_data(self._metadata_stream.get()) | |||||
| def _check_watchpoint_work(self, hits, step): | |||||
| """The check WatchPoint function work in another process.""" | |||||
| log.info("Start checking WatchPointHit process.") | |||||
| res = self._dbg_service.check_watchpoints(step) | |||||
| for watchpoint_hit in res: | |||||
| hit_dict = convert_watchpointhit(watchpoint_hit) | |||||
| hits.append(hit_dict) | |||||
| log.info("Checking WatchPointHit process is finished.") | |||||
| class CommandListener: | |||||
| """Event listener.""" | |||||
| def __init__(self, cache_store): | |||||
| self._cache_store = cache_store | |||||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | |||||
| # the next position of command queue to be queried | |||||
| self._pos = '0' | |||||
| self._is_waiting = Event() | |||||
| def start(self): | |||||
| """Start event listener.""" | |||||
| self._pos = '0' | |||||
| self._is_waiting.set() | |||||
| def stop(self): | |||||
| """Stop event listener.""" | |||||
| # stop waiting for new user commands but can still get old commands. | |||||
| self._is_waiting.clear() | |||||
| def has_new_command(self): | |||||
| """Check if there is new command in command queue.""" | |||||
| return self._cache_store.has_command(self._pos) | |||||
| def get_next_command(self): | |||||
| """Get next command.""" | |||||
| event = None | |||||
| while event is None and self.has_new_command(): | |||||
| self._pos, event = self._cache_store.get_command(self._pos) | |||||
| log.debug("Deal with old %s-th command:\n%s.", self._pos, event) | |||||
| if event is None: | |||||
| event = self._wait_for_next_command() | |||||
| return event | |||||
| def _wait_for_next_command(self): | |||||
| """ | |||||
| Wait for next command. | |||||
| Returns: | |||||
| EventReply, the command event. | |||||
| """ | |||||
| if not self._is_waiting.is_set(): | |||||
| self._metadata_stream.state = ServerStatus.PENDING.value | |||||
| return None | |||||
| log.info("Start to wait for command.") | |||||
| if self._metadata_stream.state != ServerStatus.WAITING.value: | |||||
| self._metadata_stream.state = ServerStatus.WAITING.value | |||||
| self._cache_store.put_data(self._metadata_stream.get()) | |||||
| log.debug("Wait for %s-th command", self._pos) | |||||
| event = None | |||||
| while event is None and self._is_waiting.is_set(): | |||||
| self._pos, event = self._cache_store.get_command(self._pos) | |||||
| return event | |||||
| def convert_watchpointhit(watchpointhit): | |||||
| """Convert watchpointhit object to dict.""" | |||||
| parameters = watchpointhit.parameters | |||||
| param_list = [] | |||||
| for param in parameters: | |||||
| param_dict = convert_param(param) | |||||
| param_list.append(param_dict) | |||||
| watchpointhit_dict = {'condition': watchpointhit.condition, | |||||
| 'device_id': watchpointhit.device_id, | |||||
| 'error_code': watchpointhit.error_code, | |||||
| 'name': watchpointhit.name, | |||||
| 'parameters': param_list, | |||||
| 'slot': watchpointhit.slot, | |||||
| 'watchpoint_id': watchpointhit.watchpoint_id} | |||||
| return watchpointhit_dict | |||||
| def convert_param(param): | |||||
| """Convert parameter object to dict""" | |||||
| param_dict = {'actual_value': param.actual_value, | |||||
| 'disabled': param.disabled, | |||||
| 'hit': param.hit, | |||||
| 'name': param.name, | |||||
| 'value': param.value} | |||||
| return param_dict | |||||
| @@ -0,0 +1,58 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Debugger Online server.""" | |||||
| from concurrent import futures | |||||
| import grpc | |||||
| from mindinsight.debugger.common.log import LOGGER as log | |||||
| from mindinsight.conf import settings | |||||
| from mindinsight.debugger.debugger_services.debugger_grpc_server import DebuggerGrpcServer | |||||
| from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase | |||||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||||
| def get_debugger_hostname(): | |||||
| """Get hostname for online debugger server.""" | |||||
| grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051 | |||||
| host = settings.HOST if hasattr(settings, 'HOST') else '[::]' | |||||
| hostname = "{}:{}".format(host, grpc_port) | |||||
| return hostname | |||||
| class DebuggerOnlineServer(DebuggerServerBase): | |||||
| """Debugger Online Server.""" | |||||
| def __init__(self, cache_store, context): | |||||
| super(DebuggerOnlineServer, self).__init__(cache_store, context) | |||||
| self._grpc_server_manager = self.get_grpc_server_manager() | |||||
| def run(self): | |||||
| self._grpc_server_manager.start() | |||||
| log.info("Start grpc server %s", self._context.hostname) | |||||
| self._grpc_server_manager.wait_for_termination() | |||||
| def get_grpc_server_manager(self): | |||||
| """Get grpc server instance according to hostname.""" | |||||
| if self._context.hostname is None: | |||||
| self._context.hostname = get_debugger_hostname() | |||||
| grpc_server = DebuggerGrpcServer(self._cache_store, None) | |||||
| grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |||||
| grpc_server_base.add_EventListenerServicer_to_server(grpc_server, grpc_server_manager) | |||||
| grpc_server_manager.add_insecure_port(self._context.hostname) | |||||
| return grpc_server_manager | |||||
| def stop(self): | |||||
| self._grpc_server_manager.stop(grace=None) | |||||
| self.join() | |||||
| @@ -0,0 +1,58 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """DebuggerServerBase.""" | |||||
| import threading | |||||
| from abc import abstractmethod | |||||
| from functools import wraps | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerServerRunningError | |||||
| from mindinsight.debugger.common.log import LOGGER as log | |||||
| def debugger_server_wrap(func): | |||||
| """Wrapper for catch exception.""" | |||||
| @wraps(func) | |||||
| def record_log(*args, **kwargs): | |||||
| try: | |||||
| return func(*args, **kwargs) | |||||
| except Exception as err: | |||||
| log.exception(err) | |||||
| raise DebuggerServerRunningError(str(err)) | |||||
| return record_log | |||||
| class DebuggerServerBase(threading.Thread): | |||||
| """ | |||||
| Debugger Server Base. | |||||
| Args: | |||||
| cache_store (DebuggerCacheStore): Cache store for debugger server. | |||||
| context (DebuggerServerContext): Context for initialize debugger server. | |||||
| """ | |||||
| def __init__(self, cache_store, context): | |||||
| super(DebuggerServerBase, self).__init__() | |||||
| self._cache_store = cache_store | |||||
| self._context = context | |||||
| @abstractmethod | |||||
| @debugger_server_wrap | |||||
| def run(self): | |||||
| """Function that should be called when thread started.""" | |||||
| @abstractmethod | |||||
| @debugger_server_wrap | |||||
| def stop(self): | |||||
| """Stop debugger server.""" | |||||
| @@ -0,0 +1,92 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Debugger server factory.""" | |||||
| import threading | |||||
| from mindinsight.debugger.common.utils import DebuggerServerMode | |||||
| from mindinsight.debugger.debugger_services.debugger_offline_server import DebuggerOfflineServer | |||||
| from mindinsight.debugger.debugger_services.debugger_online_server import DebuggerOnlineServer | |||||
| class DebuggerServerFactory: | |||||
| """Create debugger server according to debugger mode.""" | |||||
| _lock = threading.Lock() | |||||
| _instance = None | |||||
| def __init__(self): | |||||
| self._server_map = { | |||||
| DebuggerServerMode.ONLINE.value: DebuggerOnlineServer, | |||||
| DebuggerServerMode.OFFLINE.value: DebuggerOfflineServer | |||||
| } | |||||
| def __new__(cls, *args, **kwargs): | |||||
| if cls._instance is None: | |||||
| with cls._lock: | |||||
| if cls._instance is None: | |||||
| cls._instance = super().__new__(cls, *args, **kwargs) | |||||
| return cls._instance | |||||
| def get_debugger_server(self, cache_store, context): | |||||
| """ | |||||
| Get debugger server according to debugger_context and cache_store. | |||||
| Args: | |||||
| cache_store (DebuggerCacheStore): Cache store for debugger server. | |||||
| context (DebuggerServerContext): Context for initialize debugger server. | |||||
| Returns: | |||||
| DebuggerServerBase, Debugger server object. | |||||
| """ | |||||
| dbg_server = None | |||||
| dbg_server_class = self._server_map.get(context.dbg_mode) | |||||
| if dbg_server_class: | |||||
| dbg_server = dbg_server_class(cache_store, context) | |||||
| return dbg_server | |||||
| class DebuggerServerContext: | |||||
| """ | |||||
| Debugger server context. | |||||
| Args: | |||||
| dbg_mode (str): The debugger mode. Optional: `online` or `offline`. | |||||
| train_job (str): The relative directory of debugger dump data for one training. | |||||
| Used only when dbg_mode is `offline.` | |||||
| dbg_dir (str): The base directory of debugger dump data for one training. | |||||
| Used only when dbg_mode is `offline.` | |||||
| hostname (str): The hostname used for online debugger server. | |||||
| Used only when dbg_mode is `online.` | |||||
| """ | |||||
| def __init__(self, dbg_mode, train_job=None, dbg_dir=None, hostname=None): | |||||
| self._dbg_mode = dbg_mode | |||||
| self._train_job = train_job | |||||
| self._dbg_dir = dbg_dir | |||||
| self.hostname = hostname | |||||
| @property | |||||
| def dbg_mode(self): | |||||
| """Property of debugger mode.""" | |||||
| return self._dbg_mode | |||||
| @property | |||||
| def dbg_dir(self): | |||||
| """Property of debugger mode.""" | |||||
| return self._dbg_dir | |||||
| @property | |||||
| def train_job(self): | |||||
| """The property of train job.""" | |||||
| return self._train_job | |||||
| @@ -13,17 +13,8 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Implement the debugger server.""" | """Implement the debugger server.""" | ||||
| import signal | |||||
| from concurrent import futures | |||||
| from functools import wraps | from functools import wraps | ||||
| from threading import Thread | |||||
| import grpc | |||||
| from mindinsight.debugger.conditionmgr.condition import ConditionContext | |||||
| from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | |||||
| from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints | |||||
| from mindinsight.conf import settings | |||||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | ||||
| from mindinsight.datavisual.utils.tools import to_float | from mindinsight.datavisual.utils.tools import to_float | ||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | ||||
| @@ -32,9 +23,11 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue | |||||
| from mindinsight.debugger.common.log import LOGGER as log | from mindinsight.debugger.common.log import LOGGER as log | ||||
| from mindinsight.debugger.common.utils import ServerStatus, \ | from mindinsight.debugger.common.utils import ServerStatus, \ | ||||
| create_view_event_from_tensor_basic_info, Streams | create_view_event_from_tensor_basic_info, Streams | ||||
| from mindinsight.debugger.conditionmgr.condition import ConditionContext | |||||
| from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | |||||
| from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints | |||||
| from mindinsight.debugger.debugger_cache import DebuggerCache | from mindinsight.debugger.debugger_cache import DebuggerCache | ||||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||||
| from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerFactory | |||||
| from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo | from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo | ||||
| from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator | from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator | ||||
| from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator | from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator | ||||
| @@ -57,25 +50,29 @@ def try_except(func): | |||||
| return send_latest_metadata | return send_latest_metadata | ||||
| class DebuggerServer: | |||||
| class DebuggerSession: | |||||
| """The server manager of debugger.""" | """The server manager of debugger.""" | ||||
| def __init__(self): | |||||
| def __init__(self, context): | |||||
| self.condition_mgr = ConditionMgr() | self.condition_mgr = ConditionMgr() | ||||
| self.cache_store = DebuggerCache() | self.cache_store = DebuggerCache() | ||||
| self.grpc_server = DebuggerGrpcServer(self.cache_store, self.condition_mgr) | |||||
| self.grpc_server_manager = None | |||||
| self.back_server = None | |||||
| self.context = context | |||||
| self.back_server = DebuggerServerFactory().get_debugger_server(self.cache_store, context) | |||||
| @property | |||||
| def train_job(self): | |||||
| """The property of train job.""" | |||||
| return self.context.train_job | |||||
| def get_condition_collections(self, train_id): | |||||
| def get_condition_collections(self, train_id=""): | |||||
| """Get default condition_collections""" | """Get default condition_collections""" | ||||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | ||||
| condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) | condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) | ||||
| log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) | log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) | ||||
| return self.condition_mgr.get_all_collections(condition_context) | return self.condition_mgr.get_all_collections(condition_context) | ||||
| def set_recommended_watch_points(self, set_recommended, train_id): | |||||
| """set recommended watch points.""" | |||||
| def set_recommended_watch_points(self, set_recommended, train_id=""): | |||||
| """Set recommended watch points.""" | |||||
| if not isinstance(set_recommended, bool): | if not isinstance(set_recommended, bool): | ||||
| log.error("Bool param should be given for set_recommended") | log.error("Bool param should be given for set_recommended") | ||||
| raise DebuggerParamValueError("Bool param should be given.") | raise DebuggerParamValueError("Bool param should be given.") | ||||
| @@ -97,38 +94,28 @@ class DebuggerServer: | |||||
| def _add_recommended_watchpoints(self, condition_context): | def _add_recommended_watchpoints(self, condition_context): | ||||
| """Add predefined watchpoints.""" | """Add predefined watchpoints.""" | ||||
| log.debug("Add predefined watchpoints.") | log.debug("Add predefined watchpoints.") | ||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| watchpoints = recommend_watchpoints(self.condition_mgr, graph_stream, condition_context) | |||||
| multi_card_graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| watchpoints = recommend_watchpoints(self.condition_mgr, multi_card_graph_stream, condition_context) | |||||
| watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | ||||
| device_stream = self.cache_store.get_stream_handler(Streams.DEVICE) | |||||
| watch_points_ids = [] | watch_points_ids = [] | ||||
| for watchpoint in watchpoints: | for watchpoint in watchpoints: | ||||
| watch_points_id = watch_point_stream_handler.create_watchpoint( | watch_points_id = watch_point_stream_handler.create_watchpoint( | ||||
| watch_condition=watchpoint.get_watch_condition_dict(), | watch_condition=watchpoint.get_watch_condition_dict(), | ||||
| watch_nodes=watchpoint.watch_nodes, | watch_nodes=watchpoint.watch_nodes, | ||||
| name=watchpoint.name, | name=watchpoint.name, | ||||
| condition_mgr=self.condition_mgr | |||||
| condition_mgr=self.condition_mgr, | |||||
| device_amount=device_stream.device_amount | |||||
| ) | ) | ||||
| watch_points_ids.append(watch_points_id) | watch_points_ids.append(watch_points_id) | ||||
| return watch_points_ids | return watch_points_ids | ||||
| def start(self): | def start(self): | ||||
| """Start server.""" | """Start server.""" | ||||
| grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051 | |||||
| host = settings.HOST if hasattr(settings, 'HOST') else '[::]' | |||||
| hostname = "{}:{}".format(host, grpc_port) | |||||
| # initialize a grpc server | |||||
| grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |||||
| grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager) | |||||
| grpc_server_manager.add_insecure_port(hostname) | |||||
| grpc_server_manager.start() | |||||
| my_server_thread = Thread(target=grpc_server_manager.wait_for_termination) | |||||
| # start grpc server | |||||
| my_server_thread.start() | |||||
| self.back_server = my_server_thread | |||||
| self.grpc_server_manager = grpc_server_manager | |||||
| self.back_server.start() | |||||
| # register stop server handler | # register stop server handler | ||||
| signal.signal(signal.SIGINT, self._stop_handler) | |||||
| log.info("Start grpc server %s", hostname) | |||||
| #signal.signal(signal.SIGINT, self._stop_handler) | |||||
| log.info("Start debugger backend server.") | |||||
| def _stop_handler(self, signum, frame): | def _stop_handler(self, signum, frame): | ||||
| """Register stop server handler.""" | """Register stop server handler.""" | ||||
| @@ -139,8 +126,7 @@ class DebuggerServer: | |||||
| """Stop debugger server.""" | """Stop debugger server.""" | ||||
| log.info("Send terminate info to client.") | log.info("Send terminate info to client.") | ||||
| self.control({'mode': 'terminate'}) | self.control({'mode': 'terminate'}) | ||||
| self.grpc_server_manager.stop(grace=None) | |||||
| self.back_server.join() | |||||
| self.back_server.stop() | |||||
| log.info("Stop debugger server.") | log.info("Stop debugger server.") | ||||
| def poll_data(self, pos): | def poll_data(self, pos): | ||||
| @@ -172,6 +158,7 @@ class DebuggerServer: | |||||
| - graph_name (str): The graph name. | - graph_name (str): The graph name. | ||||
| - watch_point_id (int): The id of watchpoint. Default: 0. | - watch_point_id (int): The id of watchpoint. Default: 0. | ||||
| - node_category (str): The node_category. Default: None | - node_category (str): The node_category. Default: None | ||||
| - rank_id (int): The id of rank. Default: 0. | |||||
| Returns: | Returns: | ||||
| dict, the searched nodes. | dict, the searched nodes. | ||||
| @@ -179,19 +166,20 @@ class DebuggerServer: | |||||
| log.info("receive search request with filter_condition: %s", filter_condition) | log.info("receive search request with filter_condition: %s", filter_condition) | ||||
| # validate watchpoint id | # validate watchpoint id | ||||
| watch_point_id = filter_condition.pop('watch_point_id', 0) | watch_point_id = filter_condition.pop('watch_point_id', 0) | ||||
| rank_id = filter_condition.pop('rank_id', 0) | |||||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | ||||
| watchpoint_stream.validate_watchpoint_id(watch_point_id) | watchpoint_stream.validate_watchpoint_id(watch_point_id) | ||||
| # validate and update graph name | # validate and update graph name | ||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||||
| graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) | graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) | ||||
| filter_condition['graph_name'] = graph_name | filter_condition['graph_name'] = graph_name | ||||
| # get searched graph | # get searched graph | ||||
| graph = graph_stream.search_nodes(filter_condition) | graph = graph_stream.search_nodes(filter_condition) | ||||
| # add watched label to graph | # add watched label to graph | ||||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name) | |||||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name, rank_id) | |||||
| return graph | return graph | ||||
| def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): | |||||
| def tensor_comparisons(self, name, shape, detail='data', tolerance='0', rank_id=0): | |||||
| """ | """ | ||||
| Get tensor comparisons data for given name, detail, shape and tolerance. | Get tensor comparisons data for given name, detail, shape and tolerance. | ||||
| @@ -202,6 +190,7 @@ class DebuggerServer: | |||||
| shape (str): Specify concrete dimensions of shape. | shape (str): Specify concrete dimensions of shape. | ||||
| tolerance (str): Specify tolerance of difference between current step tensor and previous | tolerance (str): Specify tolerance of difference between current step tensor and previous | ||||
| step tensor. Default value is 0. | step tensor. Default value is 0. | ||||
| rank_id (int): The id of rank. Default: 0. | |||||
| Raises: | Raises: | ||||
| DebuggerParamValueError, If node type is not parameter or value of detail is not support. | DebuggerParamValueError, If node type is not parameter or value of detail is not support. | ||||
| @@ -220,9 +209,10 @@ class DebuggerServer: | |||||
| parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) | parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) | ||||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | ||||
| tolerance = to_float(tolerance, 'tolerance') | tolerance = to_float(tolerance, 'tolerance') | ||||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id) | |||||
| cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step | |||||
| if node_type == NodeTypeEnum.PARAMETER.value: | if node_type == NodeTypeEnum.PARAMETER.value: | ||||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) | |||||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance, cur_step) | |||||
| else: | else: | ||||
| raise DebuggerParamValueError( | raise DebuggerParamValueError( | ||||
| "The node type must be parameter, but got {}.".format(node_type)) | "The node type must be parameter, but got {}.".format(node_type)) | ||||
| @@ -270,10 +260,18 @@ class DebuggerServer: | |||||
| self.cache_store.clean_data() | self.cache_store.clean_data() | ||||
| log.info("Clean data queue cache when retrieve all request.") | log.info("Clean data queue cache when retrieve all request.") | ||||
| result = {} | result = {} | ||||
| for stream in [Streams.METADATA, Streams.GRAPH]: | |||||
| for stream in [Streams.METADATA, Streams.GRAPH, Streams.DEVICE]: | |||||
| sub_res = self.cache_store.get_stream_handler(stream).get() | sub_res = self.cache_store.get_stream_handler(stream).get() | ||||
| result.update(sub_res) | result.update(sub_res) | ||||
| devices = result['devices'] | |||||
| if not devices: | |||||
| graph = result['graph'] | |||||
| metadata = result['metadata'] | |||||
| device = {'rank_id': 0, 'server_ip': metadata.get('ip', 'localhost'), | |||||
| 'device_id': metadata.get('device_name', ''), | |||||
| 'graph_names': graph.get('graph_names', [])} | |||||
| devices.append(device) | |||||
| sub_res = self._hide_parameters_for_ui() | sub_res = self._hide_parameters_for_ui() | ||||
| result.update(sub_res) | result.update(sub_res) | ||||
| @@ -298,7 +296,8 @@ class DebuggerServer: | |||||
| log.debug("Retrieve node %s.", filter_condition) | log.debug("Retrieve node %s.", filter_condition) | ||||
| # validate node name | # validate node name | ||||
| node_name = filter_condition.get('name') | node_name = filter_condition.get('name') | ||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| rank_id = filter_condition.get('rank_id', 0) | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||||
| graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) | graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) | ||||
| if node_name: | if node_name: | ||||
| # validate node name | # validate node name | ||||
| @@ -325,24 +324,27 @@ class DebuggerServer: | |||||
| dict, reply with graph. | dict, reply with graph. | ||||
| """ | """ | ||||
| # validate watch_point_id | # validate watch_point_id | ||||
| rank_id = filter_condition.get('rank_id', 0) | |||||
| watch_point_id = filter_condition.get('watch_point_id', 0) | watch_point_id = filter_condition.get('watch_point_id', 0) | ||||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | ||||
| watchpoint_stream.validate_watchpoint_id(watch_point_id) | watchpoint_stream.validate_watchpoint_id(watch_point_id) | ||||
| # get graph | # get graph | ||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||||
| reply = graph_stream.get(filter_condition) | reply = graph_stream.get(filter_condition) | ||||
| graph = reply.get('graph') | graph = reply.get('graph') | ||||
| # add watched label to graph | # add watched label to graph | ||||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name')) | |||||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name'), | |||||
| rank_id) | |||||
| return reply | return reply | ||||
| def retrieve_tensor_history(self, node_name, graph_name=None): | |||||
| def retrieve_tensor_history(self, node_name, graph_name=None, rank_id=0): | |||||
| """ | """ | ||||
| Retrieve tensor history for leaf node. | Retrieve tensor history for leaf node. | ||||
| Args: | Args: | ||||
| node_name (str): The name of leaf node. | node_name (str): The name of leaf node. | ||||
| graph_name (str): The graph name. Default: None. | graph_name (str): The graph name. Default: None. | ||||
| rank_id (int): The id of rank. Default: 0. | |||||
| Returns: | Returns: | ||||
| dict, the tensor history and metadata. | dict, the tensor history and metadata. | ||||
| @@ -352,34 +354,34 @@ class DebuggerServer: | |||||
| if metadata_stream.state == ServerStatus.PENDING.value: | if metadata_stream.state == ServerStatus.PENDING.value: | ||||
| log.info("The backend is in pending status.") | log.info("The backend is in pending status.") | ||||
| return metadata_stream.get(['state', 'step']) | return metadata_stream.get(['state', 'step']) | ||||
| res = self._get_tensor_history(node_name, graph_name) | |||||
| res = self._get_tensor_history(node_name, graph_name, rank_id) | |||||
| return res | return res | ||||
| def _get_tensor_history(self, node_name, graph_name=None): | |||||
| def _get_tensor_history(self, node_name, graph_name=None, rank_id=0): | |||||
| """ | """ | ||||
| Get tensor history for single node. | Get tensor history for single node. | ||||
| Args: | Args: | ||||
| node_name (str): The name of leaf node. | node_name (str): The name of leaf node. | ||||
| graph_name (str): The graph name. Default: None. | graph_name (str): The graph name. Default: None. | ||||
| rank_id (int): The id of rank. Default: 0. | |||||
| Returns: | Returns: | ||||
| dict, the tensor history and metadata. | dict, the tensor history and metadata. | ||||
| """ | """ | ||||
| # get basic tensor history | # get basic tensor history | ||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||||
| tensor_history = graph_stream.get_tensor_history(node_name, graph_name) | tensor_history = graph_stream.get_tensor_history(node_name, graph_name) | ||||
| # add tensor value for tensor history | # add tensor value for tensor history | ||||
| self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name) | |||||
| self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name, rank_id) | |||||
| # add hit label for tensor history | # add hit label for tensor history | ||||
| watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| watchpoint_hit_stream.update_tensor_history(tensor_history) | |||||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).update_tensor_history(tensor_history, rank_id) | |||||
| # add metadata | # add metadata | ||||
| metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(['step']) | metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(['step']) | ||||
| tensor_history.update(metadata) | tensor_history.update(metadata) | ||||
| return tensor_history | return tensor_history | ||||
| def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name): | |||||
| def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name, rank_id): | |||||
| """ | """ | ||||
| Add tensor value for_tensor_history and send ViewCMD if tensor value missed. | Add tensor value for_tensor_history and send ViewCMD if tensor value missed. | ||||
| @@ -387,48 +389,53 @@ class DebuggerServer: | |||||
| tensor_history (list[dict]): A list of tensor info, including name and type. | tensor_history (list[dict]): A list of tensor info, including name and type. | ||||
| node_name (str): The UI node name. | node_name (str): The UI node name. | ||||
| graph_name (str): The graph name. Default: None. | graph_name (str): The graph name. Default: None. | ||||
| rank_id (int): The id of rank. Default: 0. | |||||
| Returns: | Returns: | ||||
| dict, the tensor info. | dict, the tensor info. | ||||
| """ | """ | ||||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||||
| missed_tensors = tensor_stream.update_tensor_history(tensor_history) | |||||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id) | |||||
| cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step | |||||
| missed_tensors = tensor_stream.update_tensor_history(tensor_history, cur_step) | |||||
| if missed_tensors: | if missed_tensors: | ||||
| view_cmd = create_view_event_from_tensor_basic_info(missed_tensors) | view_cmd = create_view_event_from_tensor_basic_info(missed_tensors) | ||||
| self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name}) | |||||
| self.cache_store.put_command( | |||||
| {'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name, 'rank_id': rank_id}) | |||||
| log.debug("Send view cmd.") | log.debug("Send view cmd.") | ||||
| def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False): | |||||
| def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False, rank_id=0): | |||||
| """Retrieve the tensor value.""" | """Retrieve the tensor value.""" | ||||
| log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) | log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) | ||||
| self.validate_tensor_param(name, detail) | self.validate_tensor_param(name, detail) | ||||
| # Limit to query max two dimensions for tensor in table view. | # Limit to query max two dimensions for tensor in table view. | ||||
| parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) | parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) | ||||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name) | |||||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name, rank_id) | |||||
| reply = self.cache_store.get_stream_handler(Streams.TENSOR).get( | reply = self.cache_store.get_stream_handler(Streams.TENSOR).get( | ||||
| {'name': tensor_name, | {'name': tensor_name, | ||||
| 'node_type': node_type, | 'node_type': node_type, | ||||
| 'shape': parsed_shape, | 'shape': parsed_shape, | ||||
| 'prev': prev} | |||||
| 'prev': prev}, | |||||
| rank_id | |||||
| ) | ) | ||||
| reply['tensor_value']['name'] = name | reply['tensor_value']['name'] = name | ||||
| return reply | return reply | ||||
| def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None): | |||||
| def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None, rank_id=0): | |||||
| """ | """ | ||||
| Get inner tensor name and type by UI name. | Get inner tensor name and type by UI name. | ||||
| Args: | Args: | ||||
| name (str): Node name shown in UI. | name (str): Node name shown in UI. | ||||
| graph_name (Union[str, None]): The graph name, default is: None. | graph_name (Union[str, None]): The graph name, default is: None. | ||||
| rank_id (int): The id of rank. Default: 0. | |||||
| Returns: | Returns: | ||||
| str, full name of tensor. | str, full name of tensor. | ||||
| str, node type of tensor. | str, node type of tensor. | ||||
| """ | """ | ||||
| node_name, slot = name.rsplit(':', 1) | node_name, slot = name.rsplit(':', 1) | ||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||||
| graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(node_name) | graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(node_name) | ||||
| node_type = graph_stream.get_node_type(node_name, graph_name) | node_type = graph_stream.get_node_type(node_name, graph_name) | ||||
| full_name = graph_stream.get_full_name(node_name, graph_name) | full_name = graph_stream.get_full_name(node_name, graph_name) | ||||
| @@ -483,6 +490,7 @@ class DebuggerServer: | |||||
| - offset (int): The offset of current page. | - offset (int): The offset of current page. | ||||
| - node_name (str): The retrieved node name. | - node_name (str): The retrieved node name. | ||||
| - graph_name (str): The retrieved graph name. | - graph_name (str): The retrieved graph name. | ||||
| - rank_id (int): The rank id. | |||||
| Returns: | Returns: | ||||
| dict, watch point list or relative graph. | dict, watch point list or relative graph. | ||||
| @@ -496,7 +504,13 @@ class DebuggerServer: | |||||
| log.info("The backend is in pending status.") | log.info("The backend is in pending status.") | ||||
| return metadata_stream.get() | return metadata_stream.get() | ||||
| reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).group_by(group_condition) | |||||
| rank_id = group_condition.pop('rank_id', 0) | |||||
| reply = {} | |||||
| multi_watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| if multi_watchpoint_hit_stream.check_rank_id(rank_id): | |||||
| watchpoint_hit_stream = multi_watchpoint_hit_stream.get_hit_handler_by_rank_id(rank_id) | |||||
| reply = watchpoint_hit_stream.group_by(group_condition) | |||||
| reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() | reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() | ||||
| return reply | return reply | ||||
| @@ -591,40 +605,6 @@ class DebuggerServer: | |||||
| training_controller.validate_mode(mode) | training_controller.validate_mode(mode) | ||||
| return training_controller.control(mode, params) | return training_controller.control(mode, params) | ||||
| def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False): | |||||
| """ | |||||
| Get the graph of the next node according to node_name. | |||||
| Args: | |||||
| node_name (str): The name of current chosen leaf node. | |||||
| graph_name (str): The graph name. | |||||
| ascend (bool): If True, traverse the input nodes; | |||||
| If False, traverse the output nodes. Default is True. | |||||
| Returns: | |||||
| dict, the next node information. | |||||
| """ | |||||
| log.info("Retrieve node <%s> by bfs, `ascend` is :%s", | |||||
| node_name, ascend) | |||||
| reply = {} | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| graph_name = graph_stream.validate_graph_name(graph_name) | |||||
| next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) | |||||
| # no next node | |||||
| if next_node_name is None: | |||||
| return reply | |||||
| # add graph and tensor history for next node | |||||
| filter_condition = { | |||||
| 'name': next_node_name, | |||||
| 'graph_name': graph_name, | |||||
| 'single_node': True | |||||
| } | |||||
| search_graph = self._get_nodes_info(filter_condition) | |||||
| reply = {'name': next_node_name} | |||||
| reply.update(search_graph) | |||||
| return reply | |||||
| @try_except | @try_except | ||||
| def recheck(self): | def recheck(self): | ||||
| """ | """ | ||||
| @@ -635,13 +615,14 @@ class DebuggerServer: | |||||
| """ | """ | ||||
| return TrainingControlOperator(self.cache_store).recheck() | return TrainingControlOperator(self.cache_store).recheck() | ||||
| def retrieve_tensor_graph(self, tensor_name, graph_name): | |||||
| def retrieve_tensor_graph(self, tensor_name, graph_name, rank_id=0): | |||||
| """ | """ | ||||
| Retrieve tensor graph. | Retrieve tensor graph. | ||||
| Args: | Args: | ||||
| tensor_name (str): The tensor name from UI. | tensor_name (str): The tensor name from UI. | ||||
| graph_name (str): The graph name. | graph_name (str): The graph name. | ||||
| rank_id (int): The id of rank. Default: 0. | |||||
| Returns: | Returns: | ||||
| dict, tensor graph object. | dict, tensor graph object. | ||||
| @@ -650,16 +631,17 @@ class DebuggerServer: | |||||
| log.error("Failed to get tensor graph the MindSpore is not in waiting state.") | log.error("Failed to get tensor graph the MindSpore is not in waiting state.") | ||||
| raise DebuggerTensorGraphError | raise DebuggerTensorGraphError | ||||
| log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name) | log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name) | ||||
| tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name) | |||||
| tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name, rank_id) | |||||
| return tensor_graph_ops | return tensor_graph_ops | ||||
| def retrieve_tensor_hits(self, tensor_name, graph_name): | |||||
| def retrieve_tensor_hits(self, tensor_name, graph_name, rank_id=0): | |||||
| """ | """ | ||||
| Retrieve tensor hit information. | Retrieve tensor hit information. | ||||
| Args: | Args: | ||||
| tensor_name (str): The tensor name from UI. | tensor_name (str): The tensor name from UI. | ||||
| graph_name (str): The graph name. | graph_name (str): The graph name. | ||||
| rank_id (int): The id of rank. Default: 0. | |||||
| Returns: | Returns: | ||||
| dict, tensor hit info. | dict, tensor hit info. | ||||
| @@ -668,7 +650,7 @@ class DebuggerServer: | |||||
| log.error("Failed to get tensor hits as the MindSpore is not in waiting state.") | log.error("Failed to get tensor hits as the MindSpore is not in waiting state.") | ||||
| raise DebuggerTensorHitError | raise DebuggerTensorHitError | ||||
| log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name) | log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name) | ||||
| watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name) | |||||
| watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name, rank_id) | |||||
| return {'watch_points': watch_points} | return {'watch_points': watch_points} | ||||
| def _hide_parameters_for_ui(self): | def _hide_parameters_for_ui(self): | ||||
| @@ -122,6 +122,9 @@ message WatchCondition { | |||||
| message WatchNode { | message WatchNode { | ||||
| string node_name = 1; | string node_name = 1; | ||||
| string node_type = 2; | string node_type = 2; | ||||
| string graph_name = 3; | |||||
| int32 rank_id = 4; | |||||
| int32 device_id = 5; | |||||
| } | } | ||||
| message WatchpointHit { | message WatchpointHit { | ||||
| @@ -2,8 +2,6 @@ | |||||
| # Generated by the protocol buffer compiler. DO NOT EDIT! | # Generated by the protocol buffer compiler. DO NOT EDIT! | ||||
| # source: mindinsight/debugger/proto/debug_grpc.proto | # source: mindinsight/debugger/proto/debug_grpc.proto | ||||
| import sys | |||||
| _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) | |||||
| from google.protobuf import descriptor as _descriptor | from google.protobuf import descriptor as _descriptor | ||||
| from google.protobuf import message as _message | from google.protobuf import message as _message | ||||
| from google.protobuf import reflection as _reflection | from google.protobuf import reflection as _reflection | ||||
| @@ -21,7 +19,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( | |||||
| package='debugger', | package='debugger', | ||||
| syntax='proto3', | syntax='proto3', | ||||
| serialized_options=None, | serialized_options=None, | ||||
| serialized_pb=_b('\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\x81\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x95\x02\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x0c\n\x08overflow\x10\x02\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3') | |||||
| serialized_pb=b'\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\x81\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x95\x02\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x0c\n\x08overflow\x10\x02\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"i\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\x12\x12\n\ngraph_name\x18\x03 \x01(\t\x12\x0f\n\x07rank_id\x18\x04 \x01(\x05\x12\x11\n\tdevice_id\x18\x05 \x01(\x05\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3' | |||||
| , | , | ||||
| dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) | dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) | ||||
| @@ -130,7 +128,7 @@ _METADATA = _descriptor.Descriptor( | |||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='device_name', full_name='debugger.Metadata.device_name', index=0, | name='device_name', full_name='debugger.Metadata.device_name', index=0, | ||||
| number=1, type=9, cpp_type=9, label=1, | number=1, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| @@ -144,14 +142,14 @@ _METADATA = _descriptor.Descriptor( | |||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='backend', full_name='debugger.Metadata.backend', index=2, | name='backend', full_name='debugger.Metadata.backend', index=2, | ||||
| number=3, type=9, cpp_type=9, label=1, | number=3, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='cur_node', full_name='debugger.Metadata.cur_node', index=3, | name='cur_node', full_name='debugger.Metadata.cur_node', index=3, | ||||
| number=4, type=9, cpp_type=9, label=1, | number=4, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| @@ -172,7 +170,7 @@ _METADATA = _descriptor.Descriptor( | |||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='ms_version', full_name='debugger.Metadata.ms_version', index=6, | name='ms_version', full_name='debugger.Metadata.ms_version', index=6, | ||||
| number=7, type=9, cpp_type=9, label=1, | number=7, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| @@ -203,7 +201,7 @@ _CHUNK = _descriptor.Descriptor( | |||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='buffer', full_name='debugger.Chunk.buffer', index=0, | name='buffer', full_name='debugger.Chunk.buffer', index=0, | ||||
| number=1, type=12, cpp_type=9, label=1, | number=1, type=12, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b(""), | |||||
| has_default_value=False, default_value=b"", | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| @@ -311,7 +309,7 @@ _RUNCMD = _descriptor.Descriptor( | |||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='run_level', full_name='debugger.RunCMD.run_level', index=0, | name='run_level', full_name='debugger.RunCMD.run_level', index=0, | ||||
| number=1, type=9, cpp_type=9, label=1, | number=1, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| @@ -325,7 +323,7 @@ _RUNCMD = _descriptor.Descriptor( | |||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='node_name', full_name='debugger.RunCMD.node_name', index=2, | name='node_name', full_name='debugger.RunCMD.node_name', index=2, | ||||
| number=3, type=9, cpp_type=9, label=1, | number=3, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| @@ -442,7 +440,7 @@ _WATCHCONDITION_PARAMETER = _descriptor.Descriptor( | |||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='name', full_name='debugger.WatchCondition.Parameter.name', index=0, | name='name', full_name='debugger.WatchCondition.Parameter.name', index=0, | ||||
| number=1, type=9, cpp_type=9, label=1, | number=1, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| @@ -546,14 +544,35 @@ _WATCHNODE = _descriptor.Descriptor( | |||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='node_name', full_name='debugger.WatchNode.node_name', index=0, | name='node_name', full_name='debugger.WatchNode.node_name', index=0, | ||||
| number=1, type=9, cpp_type=9, label=1, | number=1, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| _descriptor.FieldDescriptor( | _descriptor.FieldDescriptor( | ||||
| name='node_type', full_name='debugger.WatchNode.node_type', index=1, | name='node_type', full_name='debugger.WatchNode.node_type', index=1, | ||||
| number=2, type=9, cpp_type=9, label=1, | number=2, type=9, cpp_type=9, label=1, | ||||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='graph_name', full_name='debugger.WatchNode.graph_name', index=2, | |||||
| number=3, type=9, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='rank_id', full_name='debugger.WatchNode.rank_id', index=3, | |||||
| number=4, type=5, cpp_type=1, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='device_id', full_name='debugger.WatchNode.device_id', index=4, | |||||
| number=5, type=5, cpp_type=1, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | message_type=None, enum_type=None, containing_type=None, | ||||
| is_extension=False, extension_scope=None, | is_extension=False, extension_scope=None, | ||||
| serialized_options=None, file=DESCRIPTOR), | serialized_options=None, file=DESCRIPTOR), | ||||
| @@ -570,7 +589,7 @@ _WATCHNODE = _descriptor.Descriptor( | |||||
| oneofs=[ | oneofs=[ | ||||
| ], | ], | ||||
| serialized_start=1335, | serialized_start=1335, | ||||
| serialized_end=1384, | |||||
| serialized_end=1440, | |||||
| ) | ) | ||||
| @@ -621,8 +640,8 @@ _WATCHPOINTHIT = _descriptor.Descriptor( | |||||
| extension_ranges=[], | extension_ranges=[], | ||||
| oneofs=[ | oneofs=[ | ||||
| ], | ], | ||||
| serialized_start=1387, | |||||
| serialized_end=1524, | |||||
| serialized_start=1443, | |||||
| serialized_end=1580, | |||||
| ) | ) | ||||
| _EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS | _EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS | ||||
| @@ -750,8 +769,8 @@ _EVENTLISTENER = _descriptor.ServiceDescriptor( | |||||
| file=DESCRIPTOR, | file=DESCRIPTOR, | ||||
| index=0, | index=0, | ||||
| serialized_options=None, | serialized_options=None, | ||||
| serialized_start=1527, | |||||
| serialized_end=1912, | |||||
| serialized_start=1583, | |||||
| serialized_end=1968, | |||||
| methods=[ | methods=[ | ||||
| _descriptor.MethodDescriptor( | _descriptor.MethodDescriptor( | ||||
| name='WaitCMD', | name='WaitCMD', | ||||
| @@ -1,5 +1,4 @@ | |||||
| # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | ||||
| """Client and server classes corresponding to protobuf-defined services.""" | |||||
| import grpc | import grpc | ||||
| from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2 | from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2 | ||||
| @@ -7,7 +6,7 @@ from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_ | |||||
| class EventListenerStub(object): | class EventListenerStub(object): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| def __init__(self, channel): | def __init__(self, channel): | ||||
| """Constructor. | """Constructor. | ||||
| @@ -48,40 +47,40 @@ class EventListenerStub(object): | |||||
| class EventListenerServicer(object): | class EventListenerServicer(object): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| def WaitCMD(self, request, context): | def WaitCMD(self, request, context): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||||
| context.set_details('Method not implemented!') | context.set_details('Method not implemented!') | ||||
| raise NotImplementedError('Method not implemented!') | raise NotImplementedError('Method not implemented!') | ||||
| def SendMetadata(self, request, context): | def SendMetadata(self, request, context): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||||
| context.set_details('Method not implemented!') | context.set_details('Method not implemented!') | ||||
| raise NotImplementedError('Method not implemented!') | raise NotImplementedError('Method not implemented!') | ||||
| def SendGraph(self, request_iterator, context): | def SendGraph(self, request_iterator, context): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||||
| context.set_details('Method not implemented!') | context.set_details('Method not implemented!') | ||||
| raise NotImplementedError('Method not implemented!') | raise NotImplementedError('Method not implemented!') | ||||
| def SendTensors(self, request_iterator, context): | def SendTensors(self, request_iterator, context): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||||
| context.set_details('Method not implemented!') | context.set_details('Method not implemented!') | ||||
| raise NotImplementedError('Method not implemented!') | raise NotImplementedError('Method not implemented!') | ||||
| def SendWatchpointHits(self, request_iterator, context): | def SendWatchpointHits(self, request_iterator, context): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||||
| context.set_details('Method not implemented!') | context.set_details('Method not implemented!') | ||||
| raise NotImplementedError('Method not implemented!') | raise NotImplementedError('Method not implemented!') | ||||
| def SendMultiGraphs(self, request_iterator, context): | def SendMultiGraphs(self, request_iterator, context): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | context.set_code(grpc.StatusCode.UNIMPLEMENTED) | ||||
| context.set_details('Method not implemented!') | context.set_details('Method not implemented!') | ||||
| raise NotImplementedError('Method not implemented!') | raise NotImplementedError('Method not implemented!') | ||||
| @@ -127,7 +126,7 @@ def add_EventListenerServicer_to_server(servicer, server): | |||||
| # This class is part of an EXPERIMENTAL API. | # This class is part of an EXPERIMENTAL API. | ||||
| class EventListener(object): | class EventListener(object): | ||||
| """Missing associated documentation comment in .proto file.""" | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| @staticmethod | @staticmethod | ||||
| def WaitCMD(request, | def WaitCMD(request, | ||||
| @@ -135,7 +134,6 @@ class EventListener(object): | |||||
| options=(), | options=(), | ||||
| channel_credentials=None, | channel_credentials=None, | ||||
| call_credentials=None, | call_credentials=None, | ||||
| insecure=False, | |||||
| compression=None, | compression=None, | ||||
| wait_for_ready=None, | wait_for_ready=None, | ||||
| timeout=None, | timeout=None, | ||||
| @@ -144,7 +142,7 @@ class EventListener(object): | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | ||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | ||||
| options, channel_credentials, | options, channel_credentials, | ||||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | @staticmethod | ||||
| def SendMetadata(request, | def SendMetadata(request, | ||||
| @@ -152,7 +150,6 @@ class EventListener(object): | |||||
| options=(), | options=(), | ||||
| channel_credentials=None, | channel_credentials=None, | ||||
| call_credentials=None, | call_credentials=None, | ||||
| insecure=False, | |||||
| compression=None, | compression=None, | ||||
| wait_for_ready=None, | wait_for_ready=None, | ||||
| timeout=None, | timeout=None, | ||||
| @@ -161,7 +158,7 @@ class EventListener(object): | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | ||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | ||||
| options, channel_credentials, | options, channel_credentials, | ||||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | @staticmethod | ||||
| def SendGraph(request_iterator, | def SendGraph(request_iterator, | ||||
| @@ -169,7 +166,6 @@ class EventListener(object): | |||||
| options=(), | options=(), | ||||
| channel_credentials=None, | channel_credentials=None, | ||||
| call_credentials=None, | call_credentials=None, | ||||
| insecure=False, | |||||
| compression=None, | compression=None, | ||||
| wait_for_ready=None, | wait_for_ready=None, | ||||
| timeout=None, | timeout=None, | ||||
| @@ -178,7 +174,7 @@ class EventListener(object): | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | ||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | ||||
| options, channel_credentials, | options, channel_credentials, | ||||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | @staticmethod | ||||
| def SendTensors(request_iterator, | def SendTensors(request_iterator, | ||||
| @@ -186,7 +182,6 @@ class EventListener(object): | |||||
| options=(), | options=(), | ||||
| channel_credentials=None, | channel_credentials=None, | ||||
| call_credentials=None, | call_credentials=None, | ||||
| insecure=False, | |||||
| compression=None, | compression=None, | ||||
| wait_for_ready=None, | wait_for_ready=None, | ||||
| timeout=None, | timeout=None, | ||||
| @@ -195,7 +190,7 @@ class EventListener(object): | |||||
| mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, | mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, | ||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | ||||
| options, channel_credentials, | options, channel_credentials, | ||||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | @staticmethod | ||||
| def SendWatchpointHits(request_iterator, | def SendWatchpointHits(request_iterator, | ||||
| @@ -203,7 +198,6 @@ class EventListener(object): | |||||
| options=(), | options=(), | ||||
| channel_credentials=None, | channel_credentials=None, | ||||
| call_credentials=None, | call_credentials=None, | ||||
| insecure=False, | |||||
| compression=None, | compression=None, | ||||
| wait_for_ready=None, | wait_for_ready=None, | ||||
| timeout=None, | timeout=None, | ||||
| @@ -212,7 +206,7 @@ class EventListener(object): | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, | ||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | ||||
| options, channel_credentials, | options, channel_credentials, | ||||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | @staticmethod | ||||
| def SendMultiGraphs(request_iterator, | def SendMultiGraphs(request_iterator, | ||||
| @@ -220,7 +214,6 @@ class EventListener(object): | |||||
| options=(), | options=(), | ||||
| channel_credentials=None, | channel_credentials=None, | ||||
| call_credentials=None, | call_credentials=None, | ||||
| insecure=False, | |||||
| compression=None, | compression=None, | ||||
| wait_for_ready=None, | wait_for_ready=None, | ||||
| timeout=None, | timeout=None, | ||||
| @@ -229,4 +222,4 @@ class EventListener(object): | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | ||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | ||||
| options, channel_credentials, | options, channel_credentials, | ||||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @@ -229,6 +229,9 @@ message NodeProto { | |||||
| // full name with scope | // full name with scope | ||||
| optional string full_name = 8; | optional string full_name = 8; | ||||
| // The corresponding source code for this node. | |||||
| optional string source_address = 9; | |||||
| } | } | ||||
| // Models | // Models | ||||
| @@ -0,0 +1,172 @@ | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Implement the session manager.""" | |||||
| import os | |||||
| import threading | |||||
| from urllib.parse import unquote | |||||
| import _thread | |||||
| from mindinsight.conf import settings | |||||
| from mindinsight.debugger.common.log import LOGGER as logger | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \ | |||||
| DebuggerSessionNotFoundError | |||||
| from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext | |||||
| from mindinsight.debugger.debugger_session import DebuggerSession | |||||
| class SessionManager: | |||||
| """The server manager of debugger.""" | |||||
| ONLINE_TYPE = "ONLINE" | |||||
| MAX_SESSION_NUM = 2 | |||||
| ONLINE_SESSION_ID = "0" | |||||
| _instance = None | |||||
| _cls_lock = threading.Lock() | |||||
| def __init__(self): | |||||
| self.train_jobs = {} | |||||
| self.sessions = {} | |||||
| self.session_id = 1 | |||||
| self.online_session = None | |||||
| self._lock = threading.Lock() | |||||
| self._exiting = False | |||||
| enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | |||||
| if enable_debugger: | |||||
| self.creat_session(self.ONLINE_TYPE) | |||||
| @classmethod | |||||
| def get_instance(cls): | |||||
| """Get the singleton instance.""" | |||||
| with cls._cls_lock: | |||||
| if cls._instance is None: | |||||
| cls._instance = SessionManager() | |||||
| return cls._instance | |||||
| def exit(self): | |||||
| """ | |||||
| Called when the gunicorn worker process is exiting. | |||||
| """ | |||||
| with self._lock: | |||||
| logger.info("Start to exit sessions.") | |||||
| self._exiting = True | |||||
| for session in self.sessions: | |||||
| session.stop() | |||||
| self.online_session.stop() | |||||
| logger.info("Exited.") | |||||
| def get_session(self, session_id): | |||||
| """ | |||||
| Get session by session id or get all session info. | |||||
| Args: | |||||
| session_id (Union[None, str]: The id of session. | |||||
| Returns: | |||||
| DebuggerSession, debugger session object. | |||||
| """ | |||||
| with self._lock: | |||||
| if session_id == self.ONLINE_SESSION_ID and self.online_session is not None: | |||||
| return self.online_session | |||||
| if session_id in self.sessions: | |||||
| return self.sessions.get(session_id) | |||||
| raise DebuggerSessionNotFoundError("{}".format(session_id)) | |||||
| def creat_session(self, session_type, train_job=None): | |||||
| """ | |||||
| Create session by the train job info. | |||||
| Args: | |||||
| session_type (str): The session_type. | |||||
| train_job (str): The train job info. | |||||
| Returns: | |||||
| str, session id. | |||||
| """ | |||||
| with self._lock: | |||||
| if self._exiting: | |||||
| logger.info( | |||||
| "System is exiting, will terminate the thread.") | |||||
| _thread.exit() | |||||
| if session_type == self.ONLINE_TYPE: | |||||
| if self.online_session is None: | |||||
| context = DebuggerServerContext(dbg_mode='online') | |||||
| self.online_session = DebuggerSession(context) | |||||
| self.online_session.start() | |||||
| return self.ONLINE_SESSION_ID | |||||
| if train_job in self.train_jobs: | |||||
| return self.train_jobs.get(train_job) | |||||
| self._check_session_num() | |||||
| summary_base_dir = settings.SUMMARY_BASE_DIR | |||||
| unquote_path = unquote(train_job, errors='strict') | |||||
| whole_path = os.path.join(summary_base_dir, unquote_path) | |||||
| normalized_path = validate_and_normalize_path(whole_path) | |||||
| context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) | |||||
| session = DebuggerSession(context) | |||||
| session.start() | |||||
| session_id = str(self.session_id) | |||||
| self.sessions[session_id] = session | |||||
| self.train_jobs[train_job] = session_id | |||||
| self.session_id += 1 | |||||
| return session_id | |||||
| def delete_session(self, session_id): | |||||
| """Delete session by session id.""" | |||||
| with self._lock: | |||||
| if session_id == self.ONLINE_SESSION_ID: | |||||
| self.online_session.stop() | |||||
| self.online_session = None | |||||
| return | |||||
| if session_id not in self.sessions: | |||||
| raise DebuggerSessionNotFoundError("session id {}".format(session_id)) | |||||
| session = self.sessions.get(session_id) | |||||
| session.stop() | |||||
| self.sessions.pop(session_id) | |||||
| self.train_jobs.pop(session.train_job) | |||||
| return | |||||
| def get_sessions(self): | |||||
| """get all sessions""" | |||||
| return {"train_jobs": self.train_jobs} | |||||
| def _check_session_num(self): | |||||
| """Check the amount of sessions.""" | |||||
| if len(self.sessions) >= self.MAX_SESSION_NUM: | |||||
| raise DebuggerSessionNumOverBoundError() | |||||
| def validate_and_normalize_path(path): | |||||
| """Validate and normalize_path""" | |||||
| if not path: | |||||
| raise ValueError("The path is invalid!") | |||||
| path_str = str(path) | |||||
| if not path_str.startswith("/"): | |||||
| raise ValueError("The path is invalid!") | |||||
| try: | |||||
| normalized_path = os.path.realpath(path) | |||||
| except ValueError: | |||||
| raise ValueError("The path is invalid!") | |||||
| return normalized_path | |||||
| @@ -0,0 +1,210 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """This file is used to define the DataLoader.""" | |||||
| import os | |||||
| import json | |||||
| from mindinsight.debugger.proto.ms_graph_pb2 import ModelProto | |||||
| from mindinsight.debugger.common.log import LOGGER as log | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| from mindinsight.debugger.common.utils import DumpSettings | |||||
| class DataLoader: | |||||
| """The DataLoader object provides interface to load graphs and device information from base_dir.""" | |||||
| def __init__(self, base_dir): | |||||
| self._debugger_base_dir = base_dir | |||||
| self._graph_protos = [] | |||||
| self._device_info = {} | |||||
| self._step_num = {} | |||||
| # flag for whether the data is from sync dump or async dump, True for sync dump, False for async dump. | |||||
| self._is_sync = None | |||||
| self._net_dir = "" | |||||
| self._net_name = "" | |||||
| self.initialize() | |||||
| def initialize(self): | |||||
| """Initialize the data_mode and net_dir of DataLoader.""" | |||||
| dump_config_file = os.path.join(self._debugger_base_dir, os.path.join(".metadata", "data_dump.json")) | |||||
| with open(dump_config_file, 'r') as load_f: | |||||
| dump_config = json.load(load_f) | |||||
| common_settings = dump_config.get(DumpSettings.COMMON_DUMP_SETTINGS.value) | |||||
| if not common_settings: | |||||
| raise ParamValueError('common_dump_settings not found in dump_config file.') | |||||
| self._net_name = common_settings['net_name'] | |||||
| if dump_config.get(DumpSettings.E2E_DUMP_SETTINGS.value) and \ | |||||
| dump_config[DumpSettings.E2E_DUMP_SETTINGS.value]['enable']: | |||||
| self._is_sync = True | |||||
| self._net_dir = os.path.join(self._debugger_base_dir, self._net_name) | |||||
| elif dump_config.get(DumpSettings.ASYNC_DUMP_SETTINGS.value) and \ | |||||
| dump_config[DumpSettings.ASYNC_DUMP_SETTINGS.value]['enable']: | |||||
| self._is_sync = False | |||||
| self._net_dir = self._debugger_base_dir | |||||
| else: | |||||
| raise ParamValueError('The data must be generated from sync dump or async dump.') | |||||
| def load_graphs(self): | |||||
| """Load graphs from the debugger_base_dir.""" | |||||
| files = os.listdir(self._net_dir) | |||||
| for file in files: | |||||
| if not self.is_device_dir(file): | |||||
| continue | |||||
| device_id, device_dir = self.get_device_id_and_dir(file) | |||||
| graphs_dir = os.path.join(device_dir, 'graphs') | |||||
| if not os.path.exists(graphs_dir) or not os.path.isdir(graphs_dir): | |||||
| log.debug("Directory '%s' not exist.", graphs_dir) | |||||
| self._graph_protos.append({'device_id': device_id, 'graph_protos': []}) | |||||
| continue | |||||
| graph_protos = get_graph_protos_from_dir(graphs_dir) | |||||
| self._graph_protos.append({'device_id': device_id, 'graph_protos': graph_protos}) | |||||
| return self._graph_protos | |||||
| def load_device_info(self): | |||||
| """Load device_info from file""" | |||||
| hccl_json_file = os.path.join(self._debugger_base_dir, '.metadata/hccl.json') | |||||
| if not os.path.isfile(hccl_json_file): | |||||
| device = [] | |||||
| device_ids = self.get_all_device_id() | |||||
| device_ids.sort() | |||||
| for i, device_id in enumerate(device_ids): | |||||
| rank_id = i | |||||
| device.append({'device_id': str(device_id), 'rank_id': str(rank_id)}) | |||||
| device_target = 'Ascend' | |||||
| self._device_info = {'device_target': device_target, | |||||
| 'server_list': [{'server_id': 'localhost', 'device': device}]} | |||||
| else: | |||||
| with open(hccl_json_file, 'r') as load_f: | |||||
| load_dict = json.load(load_f) | |||||
| self._device_info = {'device_target': 'Ascend', 'server_list': load_dict['server_list']} | |||||
| return self._device_info | |||||
| def load_step_number(self): | |||||
| """Load step number in the directory""" | |||||
| files = os.listdir(self._net_dir) | |||||
| for file in files: | |||||
| if not self.is_device_dir(file): | |||||
| continue | |||||
| device_id, device_dir = self.get_device_id_and_dir(file) | |||||
| max_step = 0 | |||||
| files_in_device = os.listdir(device_dir) | |||||
| if self._is_sync: | |||||
| for file_in_device in files_in_device: | |||||
| abs_file_in_device = os.path.join(device_dir, file_in_device) | |||||
| if os.path.isdir(abs_file_in_device) and file_in_device.startswith("iteration_"): | |||||
| step_id_str = file_in_device.split('_')[-1] | |||||
| max_step = update_max_step(step_id_str, max_step) | |||||
| self._step_num[str(device_id)] = max_step | |||||
| else: | |||||
| net_graph_dir = [] | |||||
| for file_in_device in files_in_device: | |||||
| abs_file_in_device = os.path.join(device_dir, file_in_device) | |||||
| if os.path.isdir(abs_file_in_device) and file_in_device.startswith(self._net_name): | |||||
| net_graph_dir.append(abs_file_in_device) | |||||
| if len(net_graph_dir) > 1: | |||||
| log.warning("There are more than one graph directory in device_dir: %s. " | |||||
| "OfflineDebugger use data in %s.", device_dir, net_graph_dir[0]) | |||||
| net_graph_dir_to_use = net_graph_dir[0] | |||||
| graph_id = net_graph_dir_to_use.split('_')[-1] | |||||
| graph_id_dir = os.path.join(net_graph_dir_to_use, graph_id) | |||||
| step_ids = os.listdir(graph_id_dir) | |||||
| for step_id_str in step_ids: | |||||
| max_step = update_max_step(step_id_str, max_step) | |||||
| self._step_num[str(device_id)] = max_step | |||||
| return self._step_num | |||||
| def is_device_dir(self, file_name): | |||||
| """Judge if the file_name is a sub directory named 'device_x'.""" | |||||
| if not file_name.startswith("device_"): | |||||
| return False | |||||
| id_str = file_name.split("_")[-1] | |||||
| if not id_str.isdigit(): | |||||
| return False | |||||
| device_dir = os.path.join(self._net_dir, file_name) | |||||
| if not os.path.isdir(device_dir): | |||||
| return False | |||||
| return True | |||||
| def get_device_id_and_dir(self, file_name): | |||||
| """Get device_id and absolute directory of file_name.""" | |||||
| id_str = file_name.split("_")[-1] | |||||
| device_id = int(id_str) | |||||
| device_dir = os.path.join(self._net_dir, file_name) | |||||
| return device_id, device_dir | |||||
| def get_all_device_id(self): | |||||
| """Get all device_id int the debugger_base_dir""" | |||||
| device_ids = [] | |||||
| files = os.listdir(self._net_dir) | |||||
| for file in files: | |||||
| if not self.is_device_dir(file): | |||||
| continue | |||||
| id_str = file.split("_")[-1] | |||||
| device_id = int(id_str) | |||||
| device_ids.append(device_id) | |||||
| return device_ids | |||||
| def get_net_dir(self): | |||||
| """Get graph_name directory of the data.""" | |||||
| return self._net_dir | |||||
| def get_sync_flag(self): | |||||
| """Get the sync flag of the data.""" | |||||
| return self._is_sync | |||||
| def get_net_name(self): | |||||
| """Get net_name of the data.""" | |||||
| return self._net_name | |||||
| def load_graph_from_file(graph_file_name): | |||||
| """Load graph from file.""" | |||||
| with open(graph_file_name, 'rb') as file_handler: | |||||
| model_bytes = file_handler.read() | |||||
| model = ModelProto.FromString(model_bytes) | |||||
| graph = model.graph | |||||
| return graph | |||||
| def get_graph_protos_from_dir(graphs_dir): | |||||
| """ | |||||
| Get graph from graph directory. | |||||
| Args: | |||||
| graph_dir (str): The absolute directory of graph files. | |||||
| Returns: | |||||
| list, list of 'GraphProto' object. | |||||
| """ | |||||
| files_in_graph_dir = os.listdir(graphs_dir) | |||||
| graph_protos = [] | |||||
| pre_file_name = "ms_output_trace_code_graph_" | |||||
| for file_in_device in files_in_graph_dir: | |||||
| if file_in_device.startswith(pre_file_name) and file_in_device.endswith(".pb"): | |||||
| abs_graph_file = os.path.join(graphs_dir, file_in_device) | |||||
| graph_proto = load_graph_from_file(abs_graph_file) | |||||
| graph_protos.append(graph_proto) | |||||
| return graph_protos | |||||
| def update_max_step(step_id_str, max_step): | |||||
| """Update max_step by compare step_id_str and max_step.""" | |||||
| res = max_step | |||||
| if step_id_str.isdigit(): | |||||
| step_id = int(step_id_str) | |||||
| if step_id > max_step: | |||||
| res = step_id | |||||
| return res | |||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """The definition of tensor stream.""" | """The definition of tensor stream.""" | ||||
| from abc import abstractmethod, ABC | from abc import abstractmethod, ABC | ||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | ||||
| @@ -149,7 +148,10 @@ class OpTensor(BaseTensor): | |||||
| @property | @property | ||||
| def shape(self): | def shape(self): | ||||
| """The property of tensor shape.""" | """The property of tensor shape.""" | ||||
| return list(self._tensor_proto.dims) | |||||
| dims = list(self._tensor_proto.dims) | |||||
| if dims == [0]: | |||||
| dims = [] | |||||
| return dims | |||||
| @property | @property | ||||
| def value(self): | def value(self): | ||||
| @@ -254,12 +256,13 @@ class OpTensor(BaseTensor): | |||||
| class ConstTensor(BaseTensor): | class ConstTensor(BaseTensor): | ||||
| """Tensor data structure for Const Node.""" | """Tensor data structure for Const Node.""" | ||||
| _STRING_TYPE = 'DT_STRING' | _STRING_TYPE = 'DT_STRING' | ||||
| _DT_TYPE = 'DT_TYPE' | |||||
| def __init__(self, const_proto): | def __init__(self, const_proto): | ||||
| # the type of const_proto is NamedValueProto | # the type of const_proto is NamedValueProto | ||||
| super(ConstTensor, self).__init__() | super(ConstTensor, self).__init__() | ||||
| self._const_proto = const_proto | self._const_proto = const_proto | ||||
| self._value = self.generate_value_from_proto(const_proto) | |||||
| self._value = self.generate_value_from_proto(const_proto.value) | |||||
| def set_step(self, step): | def set_step(self, step): | ||||
| """Set step value.""" | """Set step value.""" | ||||
| @@ -295,16 +298,25 @@ class ConstTensor(BaseTensor): | |||||
| Returns: | Returns: | ||||
| Union[None, str, np.ndarray], the value of the tensor. | Union[None, str, np.ndarray], the value of the tensor. | ||||
| """ | """ | ||||
| fields = tensor_proto.value.ListFields() | |||||
| fields = tensor_proto.ListFields() | |||||
| if len(fields) != 2: | if len(fields) != 2: | ||||
| log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto) | log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto) | ||||
| tensor_value = None | tensor_value = None | ||||
| for field_obj, field_value in fields: | for field_obj, field_value in fields: | ||||
| if field_obj.name != 'dtype': | if field_obj.name != 'dtype': | ||||
| tensor_value = field_value | |||||
| if tensor_proto.dtype == DataType.DT_TUPLE: | |||||
| tensor_values = [] | |||||
| for field_value_element in field_value: | |||||
| value_element = self.generate_value_from_proto(field_value_element) | |||||
| tensor_values.append(value_element) | |||||
| tensor_value = tensor_values | |||||
| elif tensor_proto.dtype == DataType.DT_TYPE: | |||||
| tensor_value = DataType.Name(field_value.data_type) | |||||
| else: | |||||
| tensor_value = field_value | |||||
| break | break | ||||
| if tensor_value is not None and self.dtype != self._STRING_TYPE: | |||||
| tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(self.dtype)) | |||||
| if tensor_value is not None and tensor_proto.dtype != self._STRING_TYPE: | |||||
| tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(tensor_proto.dtype)) | |||||
| return tensor_value | return tensor_value | ||||
| def get_tensor_value_by_shape(self, shape=None): | def get_tensor_value_by_shape(self, shape=None): | ||||
| @@ -328,7 +340,8 @@ class ConstTensor(BaseTensor): | |||||
| Returns: | Returns: | ||||
| dict, overall statistics. | dict, overall statistics. | ||||
| """ | """ | ||||
| if self.empty or self.dtype == self._STRING_TYPE: | |||||
| if self.empty or self.dtype == self._STRING_TYPE or self.dtype == self._DT_TYPE: | |||||
| log.debug("The tensor dtype is: %s, skip getting statistics.", self.dtype) | |||||
| return {} | return {} | ||||
| stats = TensorUtils.get_statistics_from_tensor(self.value) | stats = TensorUtils.get_statistics_from_tensor(self.value) | ||||
| statistics = TensorUtils.get_overall_statistic_dict(stats) | statistics = TensorUtils.get_overall_statistic_dict(stats) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -184,7 +184,7 @@ class Watchpoint: | |||||
| def __init__(self, watchpoint_id, watch_condition, name=None): | def __init__(self, watchpoint_id, watch_condition, name=None): | ||||
| self._id = watchpoint_id | self._id = watchpoint_id | ||||
| self._condition = watch_condition | self._condition = watch_condition | ||||
| self._watch_node = WatchNodeTree() | |||||
| self._watch_node = {0: WatchNodeTree()} | |||||
| self.name = name | self.name = name | ||||
| @property | @property | ||||
| @@ -214,32 +214,36 @@ class Watchpoint: | |||||
| else: | else: | ||||
| self._watch_node = other_watchpoint.nodes | self._watch_node = other_watchpoint.nodes | ||||
| def add_nodes(self, nodes): | |||||
| def add_nodes(self, nodes, rank_id): | |||||
| """Add node into watchpoint.""" | """Add node into watchpoint.""" | ||||
| if not nodes: | if not nodes: | ||||
| log.warning("Add empty nodes.") | log.warning("Add empty nodes.") | ||||
| return | return | ||||
| if rank_id not in self._watch_node: | |||||
| self._watch_node[rank_id] = WatchNodeTree() | |||||
| if not isinstance(nodes, list): | if not isinstance(nodes, list): | ||||
| nodes = [nodes] | nodes = [nodes] | ||||
| for node in nodes: | for node in nodes: | ||||
| self._watch_node.add_node(node.name, node.type, node.full_name) | |||||
| watch_node = self._watch_node.get(rank_id) | |||||
| watch_node.add_node(node.name, node.type, node.full_name) | |||||
| def remove_nodes(self, nodes): | |||||
| def remove_nodes(self, nodes, rank_id): | |||||
| """Remove nodes from watchpoint.""" | """Remove nodes from watchpoint.""" | ||||
| if not nodes: | if not nodes: | ||||
| return | return | ||||
| self.validate_rank_id(rank_id) | |||||
| if not isinstance(nodes, list): | if not isinstance(nodes, list): | ||||
| nodes = [nodes] | nodes = [nodes] | ||||
| for node in nodes: | for node in nodes: | ||||
| self._watch_node.remove_node(node.name) | |||||
| self._watch_node.get(rank_id).remove_node(node.name) | |||||
| def get_node_status(self, node_name, node_type, full_name): | |||||
| def get_node_status(self, node_name, node_type, full_name, rank_id): | |||||
| """Judge if the node is in watch nodes.""" | """Judge if the node is in watch nodes.""" | ||||
| if is_cst_type(node_type): | if is_cst_type(node_type): | ||||
| return WatchNodeTree.INVALID | return WatchNodeTree.INVALID | ||||
| scope_names = node_name.split('/') | scope_names = node_name.split('/') | ||||
| cur_node = self._watch_node | |||||
| self.validate_rank_id(rank_id) | |||||
| cur_node = self._watch_node.get(rank_id) | |||||
| status = 1 | status = 1 | ||||
| for scope_name in scope_names: | for scope_name in scope_names: | ||||
| cur_node = cur_node.get(scope_name) | cur_node = cur_node.get(scope_name) | ||||
| @@ -250,7 +254,7 @@ class Watchpoint: | |||||
| status = WatchNodeTree.TOTAL_WATCH | status = WatchNodeTree.TOTAL_WATCH | ||||
| break | break | ||||
| if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name: | if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name: | ||||
| self._watch_node.add_node(node_name, node_type, full_name) | |||||
| self._watch_node.get(rank_id).add_node(node_name, node_type, full_name) | |||||
| return status | return status | ||||
| @@ -278,11 +282,14 @@ class Watchpoint: | |||||
| Returns: | Returns: | ||||
| list[NodeBasicInfo], the list of watch node basic infos. | list[NodeBasicInfo], the list of watch node basic infos. | ||||
| """ | """ | ||||
| watch_nodes = [] | |||||
| self._get_watch_node(self._watch_node, watch_nodes) | |||||
| return watch_nodes | |||||
| def get_pending_cmd(self, watch_nodes): | |||||
| watch_nodes_for_devices = {} | |||||
| for rank_id, watch_node_tree in self._watch_node.items(): | |||||
| watch_nodes = [] | |||||
| self._get_watch_node(watch_node_tree, watch_nodes) | |||||
| watch_nodes_for_devices[rank_id] = watch_nodes | |||||
| return watch_nodes_for_devices | |||||
| def get_pending_cmd(self, watch_nodes_for_devices): | |||||
| """Return the watchpoint in proto format.""" | """Return the watchpoint in proto format.""" | ||||
| # construct SetCMD | # construct SetCMD | ||||
| condition_id = self._condition.get('id') | condition_id = self._condition.get('id') | ||||
| @@ -309,10 +316,12 @@ class Watchpoint: | |||||
| param_proto.name = param_name | param_proto.name = param_name | ||||
| param_proto.disabled = True | param_proto.disabled = True | ||||
| for watch_node in watch_nodes: | |||||
| event_node = set_cmd.watch_nodes.add() | |||||
| event_node.node_name = watch_node.full_name | |||||
| event_node.node_type = watch_node.type | |||||
| for rank_id, watch_nodes in watch_nodes_for_devices.items(): | |||||
| for watch_node in watch_nodes: | |||||
| event_node = set_cmd.watch_nodes.add() | |||||
| event_node.node_name = watch_node.full_name | |||||
| event_node.node_type = watch_node.type | |||||
| event_node.rank_id = rank_id | |||||
| return set_cmd | return set_cmd | ||||
| def get_watch_condition_info(self): | def get_watch_condition_info(self): | ||||
| @@ -325,6 +334,11 @@ class Watchpoint: | |||||
| watchpoint_info['name'] = self.name | watchpoint_info['name'] = self.name | ||||
| return watchpoint_info | return watchpoint_info | ||||
| def validate_rank_id(self, rank_id): | |||||
| if rank_id not in self._watch_node: | |||||
| log.warning("Rank_id not exist") | |||||
| return | |||||
| class WatchpointHit: | class WatchpointHit: | ||||
| """The watchpoint hit structure.""" | """The watchpoint hit structure.""" | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -15,9 +15,10 @@ | |||||
| """Import the streams handlers.""" | """Import the streams handlers.""" | ||||
| from .event_handler import EventHandler | from .event_handler import EventHandler | ||||
| from .metadata_handler import MetadataHandler | from .metadata_handler import MetadataHandler | ||||
| from .graph_handler import GraphHandler | |||||
| from .tensor_handler import TensorHandler | |||||
| from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler | |||||
| from .graph_handler import GraphHandler, MultiCardGraphHandler | |||||
| from .tensor_handler import TensorHandler, MultiCardTensorHandler | |||||
| from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler, MultiCardWatchpointHitHandler | |||||
| __all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', | |||||
| 'WatchpointHandler', 'WatchpointHitHandler'] | |||||
| __all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', 'WatchpointHitHandler', | |||||
| 'MultiCardGraphHandler', 'MultiCardTensorHandler', | |||||
| 'WatchpointHandler', 'MultiCardWatchpointHitHandler'] | |||||
| @@ -0,0 +1,198 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the Device stream handler.""" | |||||
| from collections import defaultdict | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, DeviceIdUnregistered, \ | |||||
| DebuggerParamTypeError | |||||
| from mindinsight.debugger.common.log import LOGGER as log | |||||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||||
| class DeviceHandler(StreamHandlerBase): | |||||
| """Metadata Handler.""" | |||||
| def __init__(self): | |||||
| # contains all device infos, the format is like Dict[int(<device_id>, <device_info>)] | |||||
| self._rank_info = defaultdict(DeviceInfo) | |||||
| self._device_rank_map = {} | |||||
| @property | |||||
| def rank_ids(self): | |||||
| """The rank ids.""" | |||||
| return list(self._rank_info) | |||||
| @property | |||||
| def device_amount(self): | |||||
| """The rank ids.""" | |||||
| return len(self._rank_info) | |||||
| def put(self, value): | |||||
| """ | |||||
| Put value into device info cache. | |||||
| Args: | |||||
| value (list): The list of server info. Each item is format like: | |||||
| { | |||||
| "server_id": str, | |||||
| "device": list[<Device Info>] | |||||
| }, | |||||
| The format of <Device Info> is like: | |||||
| { | |||||
| "device_id": str, | |||||
| "device_ip": str, | |||||
| "rank_id": str | |||||
| }. | |||||
| """ | |||||
| if not isinstance(value, list): | |||||
| log.error("Invalid input type. list object is expected.") | |||||
| raise DebuggerParamTypeError("List object is expected.") | |||||
| try: | |||||
| self._extract_rank_info(value) | |||||
| except TypeError as err: | |||||
| log.exception(err) | |||||
| log.error("Invalid Device info.") | |||||
| raise DebuggerParamValueError("Invalid device info.") | |||||
| log.debug("Put Device into cache") | |||||
| def _extract_rank_info(self, value): | |||||
| """Extract rank info and save.""" | |||||
| for server_info in value: | |||||
| server_ip = server_info.get('server_id') | |||||
| for device_info in server_info.get('device', []): | |||||
| rank_id = int(device_info.get('rank_id')) | |||||
| if rank_id in self._rank_info: | |||||
| log.error("Repeated rank info for rank_id: %d", rank_id) | |||||
| raise DebuggerParamValueError("Repeated rank info.") | |||||
| device_info_obj = self._rank_info[rank_id] | |||||
| device_info_obj.rank_id = rank_id | |||||
| device_info_obj.server_ip = server_ip | |||||
| device_info_obj.device_id = int(device_info.get('device_id')) | |||||
| device_info_obj.device_ip = device_info.get('device_ip') | |||||
| self._device_rank_map[device_info_obj.device_id] = rank_id | |||||
| def add_step_num_info(self, step_info): | |||||
| """ | |||||
| Add step number information for each device. | |||||
| Args: | |||||
| step_info (dict): Step info per device. The key is the device id, the value | |||||
| is the relative step number. | |||||
| """ | |||||
| if not step_info: | |||||
| log.warning("No step number information.") | |||||
| return | |||||
| if len(step_info) == 1 and not self._rank_info: | |||||
| device_id = int(list(step_info)[0]) | |||||
| log.info("Default registered device %d as rank 0.", device_id) | |||||
| self._rank_info[0].device_id = device_id | |||||
| if len(step_info) > 1 and not self._rank_info: | |||||
| log.error("Missing device info for multi-card training.") | |||||
| raise DeviceIdUnregistered("all") | |||||
| for device_id, step_num in step_info.items(): | |||||
| device_id = int(device_id) | |||||
| rank_id = self.get_rank_id_by_device_id(device_id) | |||||
| self._rank_info[rank_id].step_num = step_num | |||||
| def add_graph_name_info(self, graphs): | |||||
| """ | |||||
| Add graph name per device. | |||||
| Args: | |||||
| graphs (dict): Graph infos of all rank id. Each item is format like | |||||
| """ | |||||
| for rank_id, graph_info in graphs.items(): | |||||
| graph_names = list(graph_info) | |||||
| self._rank_info[rank_id].graph_names = graph_names | |||||
| def get(self, filter_condition=None): | |||||
| """ | |||||
| Get device information according to filter_condition. | |||||
| Args: | |||||
| filter_condition (list): The rank id. | |||||
| Returns: | |||||
| dict, the device info. | |||||
| """ | |||||
| if filter_condition is None: | |||||
| filter_condition = self.rank_ids | |||||
| if not isinstance(filter_condition, list): | |||||
| filter_condition = [filter_condition] | |||||
| device_infos = [] | |||||
| for rank_id in filter_condition: | |||||
| device_info = self._rank_info.get(rank_id) | |||||
| if device_info is None: | |||||
| log.error("Invalid rank id.") | |||||
| raise DeviceIdUnregistered(rank_id) | |||||
| device_infos.append(device_info.to_dict()) | |||||
| return {'devices': device_infos} | |||||
| def get_rank_id_by_device_id(self, device_id): | |||||
| """ | |||||
| Get rank id by device id. | |||||
| Args: | |||||
| device_id (int): The device id. | |||||
| Returns: | |||||
| int, the rank id. | |||||
| """ | |||||
| rank_id = self._device_rank_map.get(device_id) | |||||
| if rank_id is None: | |||||
| log.error("Failed to find rank_id for device_id %s", device_id) | |||||
| raise DeviceIdUnregistered(device_id) | |||||
| return rank_id | |||||
| def get_device_id_by_rank_id(self, rank_id): | |||||
| """ | |||||
| Get device id by rank id. | |||||
| Args: | |||||
| rank_id (int): The rank id. | |||||
| Returns: | |||||
| int, the device id. | |||||
| """ | |||||
| device_info = self._rank_info.get(rank_id) | |||||
| if device_info: | |||||
| return device_info.device_id | |||||
| log.error("Failed to find device id according to rank_id %s", rank_id) | |||||
| raise DeviceIdUnregistered(rank_id) | |||||
| class DeviceInfo: | |||||
| """Device info object.""" | |||||
| def __init__(self): | |||||
| self.rank_id = 0 | |||||
| self.device_id = 0 | |||||
| self.server_ip = '' | |||||
| self.graph_names = [] | |||||
| self.device_ip = '' | |||||
| self.step_num = 0 | |||||
| def to_dict(self): | |||||
| """Convert device info to dict.""" | |||||
| res = { | |||||
| 'rank_id': self.rank_id, | |||||
| 'server_ip': self.server_ip, | |||||
| 'device_id': self.device_id, | |||||
| 'device_ip': self.device_ip, | |||||
| 'graph_names': self.graph_names, | |||||
| 'total_step_num': self.step_num | |||||
| } | |||||
| return res | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -24,6 +24,55 @@ from mindinsight.debugger.stream_cache.debugger_multigraph import DebuggerMultiG | |||||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | ||||
| class MultiCardGraphHandler: | |||||
| """Multi-card Graph Handler.""" | |||||
| def __init__(self): | |||||
| self._graph_handlers = {0: GraphHandler()} | |||||
| @property | |||||
| def graph_handlers(self): | |||||
| """The property of whole_graph.""" | |||||
| return self._graph_handlers | |||||
| def get_graph_handler_by_rank_id(self, rank_id=0): | |||||
| """Get handler by rank id""" | |||||
| if rank_id in self._graph_handlers: | |||||
| return self._graph_handlers.get(rank_id) | |||||
| log.error("There is no rank id %d.", rank_id) | |||||
| raise ValueError | |||||
| def put(self, value): | |||||
| """put graphs into graph_handlers""" | |||||
| for rank_id, graph in value.items(): | |||||
| if rank_id not in self._graph_handlers: | |||||
| self._graph_handlers[rank_id] = GraphHandler() | |||||
| self._graph_handlers[rank_id].put(graph) | |||||
| def get(self, filter_condition=None, rank_id=0): | |||||
| """Get the graph of specific node for specific device.""" | |||||
| if rank_id in self._graph_handlers: | |||||
| return self._graph_handlers.get(rank_id).get(filter_condition) | |||||
| log.error("There is no rank id %d.", rank_id) | |||||
| raise ValueError | |||||
| def has_graph(self): | |||||
| """check if has graph""" | |||||
| res = False | |||||
| for graph_handler in self._graph_handlers: | |||||
| res = res or graph_handler.graph | |||||
| return res | |||||
| def register_graph_handler(self, rank_id, graph_handler): | |||||
| """Register graph handler.""" | |||||
| self._graph_handlers[rank_id] = graph_handler | |||||
| def clean(self): | |||||
| """Clean cache.""" | |||||
| self.__init__() | |||||
| class GraphHandler(StreamHandlerBase): | class GraphHandler(StreamHandlerBase): | ||||
| """Metadata Handler.""" | """Metadata Handler.""" | ||||
| @@ -68,7 +117,7 @@ class GraphHandler(StreamHandlerBase): | |||||
| Put value into graph cache. Called by grpc server. | Put value into graph cache. Called by grpc server. | ||||
| Args: | Args: | ||||
| value (GraphProto): The Graph proto message. | |||||
| value (dict): The Graph proto message. Each item is format like (<graph_name>, GraphProto). | |||||
| """ | """ | ||||
| log.info("Put graph into cache.") | log.info("Put graph into cache.") | ||||
| sorted_value_list = self._sort_graph(value) | sorted_value_list = self._sort_graph(value) | ||||
| @@ -430,8 +479,8 @@ class GraphHandler(StreamHandlerBase): | |||||
| graph_name, node_name = self._parse_node_name(scope_name, graph_name) | graph_name, node_name = self._parse_node_name(scope_name, graph_name) | ||||
| graph = self._get_graph(graph_name) | graph = self._get_graph(graph_name) | ||||
| # to make sure fully match the scope name | # to make sure fully match the scope name | ||||
| node_name = node_name + '/' if not node_name.endswith('/') else node_name | |||||
| nodes = graph.search_leaf_nodes_by_pattern(node_name) | |||||
| node_name = node_name + '/' if node_name and not node_name.endswith('/') else node_name | |||||
| nodes = graph.search_leaf_nodes_by_pattern(node_name, True) | |||||
| res = [self.construct_node_basic_info(full_name=node.full_name, | res = [self.construct_node_basic_info(full_name=node.full_name, | ||||
| graph_name=graph_name, | graph_name=graph_name, | ||||
| node_name=node.name, | node_name=node.name, | ||||
| @@ -448,45 +497,6 @@ class GraphHandler(StreamHandlerBase): | |||||
| log.debug("Get empty full name.") | log.debug("Get empty full name.") | ||||
| return node_name | return node_name | ||||
| def get_node_by_bfs_order(self, node_name=None, ascend=True): | |||||
| """ | |||||
| Traverse the graph in order of breath-first search by given node. | |||||
| Args: | |||||
| node_name (str): The name of current chosen leaf node. | |||||
| ascend (bool): If True, traverse the input nodes; | |||||
| If False, traverse the output nodes. Default is True. | |||||
| Returns: | |||||
| Union[None, dict], the next node object in dict type or None. | |||||
| """ | |||||
| bfs_order = self.bfs_order | |||||
| length = len(bfs_order) | |||||
| if not bfs_order: | |||||
| log.error('Cannot get the BFS order of the graph!') | |||||
| msg = 'Cannot get the BFS order of the graph!' | |||||
| raise DebuggerParamValueError(msg) | |||||
| if node_name is None: | |||||
| if ascend is False: | |||||
| next_node = None | |||||
| else: | |||||
| next_node = bfs_order[0] | |||||
| else: | |||||
| try: | |||||
| index = bfs_order.index(node_name) | |||||
| log.debug("The index of the node in BFS list is: %d", index) | |||||
| except ValueError as err: | |||||
| log.error('Cannot find the node: %s. Please check ' | |||||
| 'the node name: %s', node_name, err) | |||||
| msg = f'Cannot find the node: {node_name}. ' \ | |||||
| f'Please check the node name {err}.' | |||||
| raise DebuggerParamValueError(msg) | |||||
| next_node = self._get_next_node_in_bfs(index, length, ascend) | |||||
| return next_node | |||||
| def _get_next_node_in_bfs(self, index, length, ascend): | def _get_next_node_in_bfs(self, index, length, ascend): | ||||
| """ | """ | ||||
| Get the next node in bfs order. | Get the next node in bfs order. | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -13,8 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Define the metadata stream handler.""" | """Define the metadata stream handler.""" | ||||
| from mindinsight.debugger.common.log import LOGGER as log | from mindinsight.debugger.common.log import LOGGER as log | ||||
| from mindinsight.debugger.common.utils import ServerStatus | |||||
| from mindinsight.debugger.common.utils import ServerStatus, DebuggerServerMode | |||||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | ||||
| @@ -24,28 +25,36 @@ class MetadataHandler(StreamHandlerBase): | |||||
| def __init__(self): | def __init__(self): | ||||
| self._state = ServerStatus.PENDING | self._state = ServerStatus.PENDING | ||||
| self._device_name = "" | self._device_name = "" | ||||
| self._step = 0 | |||||
| self.step = 0 | |||||
| self._client_ip = "" | self._client_ip = "" | ||||
| self._cur_node_name = "" | self._cur_node_name = "" | ||||
| self._cur_full_name = "" | self._cur_full_name = "" | ||||
| self._backend = "" | |||||
| self.backend = "" | |||||
| self._enable_recheck = False | self._enable_recheck = False | ||||
| self._cur_graph_name = "" | self._cur_graph_name = "" | ||||
| # If recommendation_confirmed is true, it only means the user has answered yes or no to the question, | # If recommendation_confirmed is true, it only means the user has answered yes or no to the question, | ||||
| # it does not necessarily mean that the user will use the recommended watch points. | # it does not necessarily mean that the user will use the recommended watch points. | ||||
| self._recommendation_confirmed = False | self._recommendation_confirmed = False | ||||
| self._debugger_version = {} | self._debugger_version = {} | ||||
| # maximum step number among all devices | |||||
| self._max_step_num = 0 | |||||
| self._debugger_type = DebuggerServerMode.ONLINE.value | |||||
| @property | |||||
| def debugger_type(self): | |||||
| """The property of debugger_type.""" | |||||
| return self._debugger_type | |||||
| @debugger_type.setter | |||||
| def debugger_type(self, debugger_type): | |||||
| """The property of debugger_type.""" | |||||
| self._debugger_type = debugger_type | |||||
| @property | @property | ||||
| def device_name(self): | def device_name(self): | ||||
| """The property of device name.""" | """The property of device name.""" | ||||
| return self._device_name | return self._device_name | ||||
| @property | |||||
| def step(self): | |||||
| """The property of current step.""" | |||||
| return self._step | |||||
| @property | @property | ||||
| def node_name(self): | def node_name(self): | ||||
| """The property of current node name.""" | """The property of current node name.""" | ||||
| @@ -71,11 +80,6 @@ class MetadataHandler(StreamHandlerBase): | |||||
| """The property of current node name.""" | """The property of current node name.""" | ||||
| return self._cur_full_name | return self._cur_full_name | ||||
| @property | |||||
| def backend(self): | |||||
| """The property of current backend.""" | |||||
| return self._backend | |||||
| @property | @property | ||||
| def state(self): | def state(self): | ||||
| """The property of state.""" | """The property of state.""" | ||||
| @@ -152,6 +156,16 @@ class MetadataHandler(StreamHandlerBase): | |||||
| """ | """ | ||||
| self._debugger_version = value | self._debugger_version = value | ||||
| @property | |||||
| def max_step_num(self): | |||||
| """The property of max_step_num.""" | |||||
| return self._max_step_num | |||||
| @max_step_num.setter | |||||
| def max_step_num(self, max_step_num): | |||||
| """Set the property of max_step_num.""" | |||||
| self._max_step_num = max_step_num | |||||
| def put(self, value): | def put(self, value): | ||||
| """ | """ | ||||
| Put value into metadata cache. Called by grpc server. | Put value into metadata cache. Called by grpc server. | ||||
| @@ -160,10 +174,10 @@ class MetadataHandler(StreamHandlerBase): | |||||
| value (MetadataProto): The Metadata proto message. | value (MetadataProto): The Metadata proto message. | ||||
| """ | """ | ||||
| self._device_name = value.device_name.split(':')[0] | self._device_name = value.device_name.split(':')[0] | ||||
| self._step = value.cur_step | |||||
| self.step = value.cur_step | |||||
| self._cur_full_name = value.cur_node | self._cur_full_name = value.cur_node | ||||
| self._backend = value.backend if value.backend else "Ascend" | |||||
| log.debug("Put metadata into cache at the %d-th step.", self._step) | |||||
| self.backend = value.backend if value.backend else "Ascend" | |||||
| log.debug("Put metadata into cache at the %d-th step.", self.step) | |||||
| def get(self, filter_condition=None): | def get(self, filter_condition=None): | ||||
| """ | """ | ||||
| @@ -190,6 +204,8 @@ class MetadataHandler(StreamHandlerBase): | |||||
| 'recommendation_confirmed': self._recommendation_confirmed, | 'recommendation_confirmed': self._recommendation_confirmed, | ||||
| 'debugger_version': self.debugger_version | 'debugger_version': self.debugger_version | ||||
| } | } | ||||
| if self.debugger_type == 'offline': | |||||
| metadata['total_step_num'] = self.max_step_num | |||||
| else: | else: | ||||
| if not isinstance(filter_condition, list): | if not isinstance(filter_condition, list): | ||||
| filter_condition = [filter_condition] | filter_condition = [filter_condition] | ||||
| @@ -28,6 +28,46 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison | |||||
| TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) | TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) | ||||
| class MultiCardTensorHandler: | |||||
| """Multi-card Tensor Handler.""" | |||||
| def __init__(self): | |||||
| self.tensor_handlers = {0: TensorHandler()} | |||||
| def set_step(self, step_id): | |||||
| """Set step id.""" | |||||
| for tensor_handler in self.tensor_handlers.values(): | |||||
| tensor_handler.cur_step = step_id | |||||
| def get_tensor_handler_by_rank_id(self, rank_id=0, create_if_not_exit=False): | |||||
| """get handler by rank id""" | |||||
| if rank_id in self.tensor_handlers: | |||||
| return self.tensor_handlers.get(rank_id) | |||||
| if create_if_not_exit: | |||||
| tensor_handler = TensorHandler() | |||||
| self.tensor_handlers[rank_id] = tensor_handler | |||||
| return tensor_handler | |||||
| log.error("There is no rank id %d in MultiCardTensorHandler.", rank_id) | |||||
| raise ValueError | |||||
| def put(self, value): | |||||
| """put graphs into graph_handlers""" | |||||
| for rank_id, tensor in value: | |||||
| if rank_id not in self.tensor_handlers: | |||||
| self.tensor_handlers[rank_id] = TensorHandler() | |||||
| self.tensor_handlers[rank_id].put(tensor) | |||||
| def get(self, filter_condition=None, rank_id=0): | |||||
| """Get the graph of specific node for specific device.""" | |||||
| if rank_id in self.tensor_handlers: | |||||
| return self.tensor_handlers.get(rank_id).get(filter_condition) | |||||
| log.error("There is no rank id %d.", rank_id) | |||||
| raise ValueError | |||||
| def clean(self): | |||||
| """Clean cache.""" | |||||
| self.__init__() | |||||
| class TensorHandler(StreamHandlerBase): | class TensorHandler(StreamHandlerBase): | ||||
| """Metadata Handler.""" | """Metadata Handler.""" | ||||
| @@ -46,6 +86,11 @@ class TensorHandler(StreamHandlerBase): | |||||
| """The property of current step.""" | """The property of current step.""" | ||||
| return self._cur_step | return self._cur_step | ||||
| @cur_step.setter | |||||
| def cur_step(self, step_id): | |||||
| """The property of current step.""" | |||||
| self._cur_step = step_id | |||||
| @property | @property | ||||
| def prev_step(self): | def prev_step(self): | ||||
| """The property of previous step.""" | """The property of previous step.""" | ||||
| @@ -172,7 +217,7 @@ class TensorHandler(StreamHandlerBase): | |||||
| log.error("No tensor named %s at the step %s", name, step) | log.error("No tensor named %s at the step %s", name, step) | ||||
| raise DebuggerParamValueError("No tensor named {}".format(name)) | raise DebuggerParamValueError("No tensor named {}".format(name)) | ||||
| tensor_info = tensor.get_full_info(shape) | tensor_info = tensor.get_full_info(shape) | ||||
| self._update_has_prev_step_field(tensor_info, name, node_type) | |||||
| self._update_has_prev_step_field(tensor_info, name, node_type, self.cur_step) | |||||
| return {'tensor_value': tensor_info} | return {'tensor_value': tensor_info} | ||||
| def _get_tensor(self, tensor_name, node_type=None, step=None): | def _get_tensor(self, tensor_name, node_type=None, step=None): | ||||
| @@ -198,20 +243,21 @@ class TensorHandler(StreamHandlerBase): | |||||
| return tensor | return tensor | ||||
| def _get_basic_info(self, tensor_name, node_type=None): | |||||
| def _get_basic_info(self, tensor_name, node_type, step): | |||||
| """Get the latest basic tensor info by tensor name.""" | """Get the latest basic tensor info by tensor name.""" | ||||
| tensor = self._get_tensor(tensor_name, node_type) | |||||
| tensor = self._get_tensor(tensor_name, node_type, step) | |||||
| if tensor: | if tensor: | ||||
| return tensor.get_basic_info() | return tensor.get_basic_info() | ||||
| return None | return None | ||||
| def update_tensor_history(self, tensor_history): | |||||
| def update_tensor_history(self, tensor_history, step=None): | |||||
| """ | """ | ||||
| Add tensor basic info in tensor_history. | Add tensor basic info in tensor_history. | ||||
| Args: | Args: | ||||
| tensor_history (dict): Tensor history, including a list of tensor name and type. | tensor_history (dict): Tensor history, including a list of tensor name and type. | ||||
| step (int): The step of tensor info. Default: None. | |||||
| Returns: | Returns: | ||||
| list[dict], the list of tensor basic info cache. | list[dict], the list of tensor basic info cache. | ||||
| @@ -220,9 +266,9 @@ class TensorHandler(StreamHandlerBase): | |||||
| for tensor_info in tensor_history.get('tensor_history'): | for tensor_info in tensor_history.get('tensor_history'): | ||||
| tensor_name = tensor_info.get('full_name') | tensor_name = tensor_info.get('full_name') | ||||
| node_type = tensor_info.get('node_type') | node_type = tensor_info.get('node_type') | ||||
| basic_info = self._get_basic_info(tensor_name, node_type) | |||||
| basic_info = self._get_basic_info(tensor_name, node_type, step) | |||||
| # add `has_prev_step` field to tensor basic info. | # add `has_prev_step` field to tensor basic info. | ||||
| missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type) | |||||
| missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type, step) | |||||
| if basic_info: | if basic_info: | ||||
| tensor_info.update(basic_info) | tensor_info.update(basic_info) | ||||
| if missing_tensors_info: | if missing_tensors_info: | ||||
| @@ -230,14 +276,14 @@ class TensorHandler(StreamHandlerBase): | |||||
| return missed_tensors | return missed_tensors | ||||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): | |||||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type, step=None): | |||||
| """Update has_prev_step field in tensor info.""" | """Update has_prev_step field in tensor info.""" | ||||
| missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type) | |||||
| if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0: | |||||
| missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type, step) | |||||
| if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and step > 0: | |||||
| tensor_info['has_prev_step'] = True | tensor_info['has_prev_step'] = True | ||||
| return missing_tensors_info | return missing_tensors_info | ||||
| def _get_missing_tensor_info(self, tensor_name, node_type): | |||||
| def _get_missing_tensor_info(self, tensor_name, node_type, step): | |||||
| """ | """ | ||||
| Get missing tensor infos. | Get missing tensor infos. | ||||
| @@ -248,7 +294,6 @@ class TensorHandler(StreamHandlerBase): | |||||
| Returns: | Returns: | ||||
| list, list of missing tensor basic information. | list, list of missing tensor basic information. | ||||
| """ | """ | ||||
| step = self.cur_step | |||||
| missing_tensors_info = [] | missing_tensors_info = [] | ||||
| # check the current step value is missing | # check the current step value is missing | ||||
| if self._is_tensor_value_missing(tensor_name, step): | if self._is_tensor_value_missing(tensor_name, step): | ||||
| @@ -278,13 +323,13 @@ class TensorHandler(StreamHandlerBase): | |||||
| tensor = self._get_tensor(tensor_name, step=step) | tensor = self._get_tensor(tensor_name, step=step) | ||||
| return bool(not tensor or tensor.empty) | return bool(not tensor or tensor.empty) | ||||
| def get_valid_tensor_by_name(self, tensor_name, prev=False): | |||||
| def get_valid_tensor_by_name(self, tensor_name, step, prev=False): | |||||
| """Get tensor value by name in numpy type.""" | """Get tensor value by name in numpy type.""" | ||||
| step = self.prev_step if prev else self.cur_step | |||||
| if step < 0: | |||||
| log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name) | |||||
| target_step = step - 1 if prev else step | |||||
| if target_step < 0: | |||||
| log.warning("Step %d has no previous value for tensor: %s", target_step, tensor_name) | |||||
| return None | return None | ||||
| tensor = self._get_tensor(tensor_name, step=step) | |||||
| tensor = self._get_tensor(tensor_name, step=target_step) | |||||
| if tensor and tensor.empty: | if tensor and tensor.empty: | ||||
| log.warning("%s has empty value.", tensor_name) | log.warning("%s has empty value.", tensor_name) | ||||
| return None | return None | ||||
| @@ -316,9 +361,9 @@ class TensorHandler(StreamHandlerBase): | |||||
| self._tensors.pop(param) | self._tensors.pop(param) | ||||
| log.debug("Clean param %s in cache.", param) | log.debug("Clean param %s in cache.", param) | ||||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0): | |||||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0, step=None): | |||||
| """ | """ | ||||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||||
| Args: | Args: | ||||
| tensor_name (str): The name of tensor for cache. | tensor_name (str): The name of tensor for cache. | ||||
| @@ -329,6 +374,7 @@ class TensorHandler(StreamHandlerBase): | |||||
| calculate the min value and max value of the result of the current step tensor subtract | calculate the min value and max value of the result of the current step tensor subtract | ||||
| the previous step tensor. If the absolute value of result is less than or equal to | the previous step tensor. If the absolute value of result is less than or equal to | ||||
| boundary value, the result will set to be zero. | boundary value, the result will set to be zero. | ||||
| step (int): The step of the tensor. Default: None. | |||||
| Raises: | Raises: | ||||
| DebuggerParamValueError, If get current step node and previous step node failed or | DebuggerParamValueError, If get current step node and previous step node failed or | ||||
| @@ -337,8 +383,8 @@ class TensorHandler(StreamHandlerBase): | |||||
| Returns: | Returns: | ||||
| dict, the retrieved data. | dict, the retrieved data. | ||||
| """ | """ | ||||
| curr_tensor = self.get_valid_tensor_by_name(tensor_name) | |||||
| prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True) | |||||
| curr_tensor = self.get_valid_tensor_by_name(tensor_name, step=step) | |||||
| prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True, step=step) | |||||
| if not (curr_tensor and prev_tensor): | if not (curr_tensor and prev_tensor): | ||||
| log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) | log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) | ||||
| raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " | raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " | ||||
| @@ -386,22 +432,23 @@ class TensorHandler(StreamHandlerBase): | |||||
| stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats) | stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats) | ||||
| return stats_info | return stats_info | ||||
| def get_tensor_info_for_tensor_graph(self, tensor_name, node_type): | |||||
| def get_tensor_info_for_tensor_graph(self, tensor_name, node_type, step): | |||||
| """ | """ | ||||
| Get Tensor info for tensor graphs. | Get Tensor info for tensor graphs. | ||||
| Args: | Args: | ||||
| tensor_name (str): Tensor name, format like `node_name:slot`. | tensor_name (str): Tensor name, format like `node_name:slot`. | ||||
| node_type (str): Node type. | node_type (str): Node type. | ||||
| step (int): The step of tensor info. | |||||
| Returns: | Returns: | ||||
| dict, tensor infos, including overall statistics, tensor shape and has_prev_step info. | dict, tensor infos, including overall statistics, tensor shape and has_prev_step info. | ||||
| list, list of missing tensor basic information. | list, list of missing tensor basic information. | ||||
| """ | """ | ||||
| res = {} | res = {} | ||||
| tensor = self._get_tensor(tensor_name, node_type) | |||||
| tensor = self._get_tensor(tensor_name, node_type, step) | |||||
| if tensor and not tensor.empty: | if tensor and not tensor.empty: | ||||
| res['statistics'] = tensor.get_tensor_statistics() | res['statistics'] = tensor.get_tensor_statistics() | ||||
| res['shape'] = tensor.shape | res['shape'] = tensor.shape | ||||
| missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type) | |||||
| missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type, step) | |||||
| return res, missing_tensors | return res, missing_tensors | ||||
| @@ -105,12 +105,12 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| return {'watch_points': reply} | return {'watch_points': reply} | ||||
| def get_pending_commands(self, graph_stream): | |||||
| def get_pending_commands(self, multi_card_graph_stream): | |||||
| """ | """ | ||||
| Get all watchpoint in SetCMD proto format. | Get all watchpoint in SetCMD proto format. | ||||
| Args: | Args: | ||||
| graph_stream (GraphHandler): Graph handler. | |||||
| multi_card_graph_stream (MultiCardGraphHandler): Multi card graph handler. | |||||
| Returns: | Returns: | ||||
| list[SetCMD], updated watchpoint to be sent to MindSpore. | list[SetCMD], updated watchpoint to be sent to MindSpore. | ||||
| @@ -118,9 +118,13 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| newly_set_cmds = [] | newly_set_cmds = [] | ||||
| for _, watchpoint in self._updated_watchpoints.items(): | for _, watchpoint in self._updated_watchpoints.items(): | ||||
| # construct set command with leaf nodes | # construct set command with leaf nodes | ||||
| watch_nodes = watchpoint.get_watch_nodes() | |||||
| leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes) | |||||
| newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes)) | |||||
| watch_nodes_for_devices = watchpoint.get_watch_nodes() | |||||
| leaf_watch_nodes_for_devices = {} | |||||
| for rank_id, watch_nodes in watch_nodes_for_devices.items(): | |||||
| graph_stream = multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id) | |||||
| leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes) | |||||
| leaf_watch_nodes_for_devices[rank_id] = leaf_watch_nodes | |||||
| newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes_for_devices)) | |||||
| newly_set_cmds.extend(self._deleted_watchpoints) | newly_set_cmds.extend(self._deleted_watchpoints) | ||||
| self.sync_set_cmd(newly_set_cmds) | self.sync_set_cmd(newly_set_cmds) | ||||
| @@ -161,7 +165,7 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| """ | """ | ||||
| return self._outdated | return self._outdated | ||||
| def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None): | |||||
| def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None, rank_id=0): | |||||
| """ | """ | ||||
| set watch nodes for graph. | set watch nodes for graph. | ||||
| @@ -170,23 +174,24 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| graph_stream (GraphHandler): The graph handler. | graph_stream (GraphHandler): The graph handler. | ||||
| watch_point_id (int): The id of watchpoint. | watch_point_id (int): The id of watchpoint. | ||||
| graph_name (str): The graph name. | graph_name (str): The graph name. | ||||
| rank_id (int): The rank id. | |||||
| """ | """ | ||||
| if not (watch_point_id and graph): | if not (watch_point_id and graph): | ||||
| return | return | ||||
| log.debug("add watch flags") | log.debug("add watch flags") | ||||
| watchpoint = self._watchpoints.get(watch_point_id) | watchpoint = self._watchpoints.get(watch_point_id) | ||||
| self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name) | |||||
| self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name, rank_id) | |||||
| def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None): | |||||
| def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None, rank_id=0): | |||||
| """Set watch status to graph.""" | """Set watch status to graph.""" | ||||
| if graph.get('children'): | if graph.get('children'): | ||||
| self._set_watch_status_recursively( | self._set_watch_status_recursively( | ||||
| graph.get('children'), graph_stream, watchpoint, graph_name) | |||||
| graph.get('children'), graph_stream, watchpoint, graph_name, rank_id=0) | |||||
| if graph.get('nodes'): | if graph.get('nodes'): | ||||
| _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name) | |||||
| _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name, rank_id) | |||||
| def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name): | |||||
| def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name, rank_id=0): | |||||
| """ | """ | ||||
| Set watch state for nodes. | Set watch state for nodes. | ||||
| @@ -204,11 +209,11 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| node_name = node.get('name') | node_name = node.get('name') | ||||
| # search result could have `nodes` in nodes object | # search result could have `nodes` in nodes object | ||||
| if node.get('nodes'): | if node.get('nodes'): | ||||
| flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name) | |||||
| flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name, rank_id) | |||||
| else: | else: | ||||
| full_name = graph_stream.get_full_name(node_name, graph_name) | full_name = graph_stream.get_full_name(node_name, graph_name) | ||||
| new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name]) | new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name]) | ||||
| flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name) | |||||
| flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name, rank_id) | |||||
| node['watched'] = flag | node['watched'] = flag | ||||
| if flag == WatchNodeTree.NOT_WATCH: | if flag == WatchNodeTree.NOT_WATCH: | ||||
| continue | continue | ||||
| @@ -224,7 +229,8 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| state = WatchNodeTree.TOTAL_WATCH | state = WatchNodeTree.TOTAL_WATCH | ||||
| return state | return state | ||||
| def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None): | |||||
| def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None, | |||||
| device_amount=8): | |||||
| """ | """ | ||||
| Create watchpoint. | Create watchpoint. | ||||
| Args: | Args: | ||||
| @@ -241,9 +247,10 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| } | } | ||||
| - id (str): Id of condition. | - id (str): Id of condition. | ||||
| - param (list[dict]): The list of param for this condition. | - param (list[dict]): The list of param for this condition. | ||||
| watch_nodes (list[NodeBasicInfo]): The list of node basic info. | |||||
| watch_nodes (dict[list[NodeBasicInfo]]): The list of node basic info. | |||||
| watch_point_id (int): The id of watchpoint. | watch_point_id (int): The id of watchpoint. | ||||
| name (str): The name of watchpoint. | name (str): The name of watchpoint. | ||||
| device_amount (int): The amount of devices. | |||||
| Returns: | Returns: | ||||
| int, the new id of watchpoint. | int, the new id of watchpoint. | ||||
| @@ -253,7 +260,9 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| new_id = self._latest_id + 1 | new_id = self._latest_id + 1 | ||||
| watchpoint = Watchpoint(new_id, watch_condition, name) | watchpoint = Watchpoint(new_id, watch_condition, name) | ||||
| if watch_nodes: | if watch_nodes: | ||||
| watchpoint.add_nodes(watch_nodes) | |||||
| for rank_id, watch_nodes_for_device in watch_nodes.items(): | |||||
| validate_rank_id(rank_id, device_amount) | |||||
| watchpoint.add_nodes(watch_nodes_for_device, rank_id) | |||||
| elif watch_point_id: | elif watch_point_id: | ||||
| self.validate_watchpoint_id(watch_point_id) | self.validate_watchpoint_id(watch_point_id) | ||||
| watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) | watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) | ||||
| @@ -261,7 +270,7 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| self._outdated = True | self._outdated = True | ||||
| return new_id | return new_id | ||||
| def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): | |||||
| def update_watchpoint(self, watch_point_id, watch_nodes, watched=False, rank_id=0): | |||||
| """ | """ | ||||
| Update watchpoint. | Update watchpoint. | ||||
| @@ -270,13 +279,14 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| watch_nodes (list[NodeBasicInfo]): The list of node basic info. | watch_nodes (list[NodeBasicInfo]): The list of node basic info. | ||||
| watched (bool): The update operator on nodes. If False, remove nodes from watch nodes. | watched (bool): The update operator on nodes. If False, remove nodes from watch nodes. | ||||
| If True, add nodes to watch nodes. Default: False. | If True, add nodes to watch nodes. Default: False. | ||||
| rank_id (int): The rank id. | |||||
| """ | """ | ||||
| self.validate_watchpoint_id(watch_point_id) | self.validate_watchpoint_id(watch_point_id) | ||||
| watchpoint = self._watchpoints.get(watch_point_id) | watchpoint = self._watchpoints.get(watch_point_id) | ||||
| if watched: | if watched: | ||||
| watchpoint.add_nodes(watch_nodes) | |||||
| watchpoint.add_nodes(watch_nodes, rank_id) | |||||
| else: | else: | ||||
| watchpoint.remove_nodes(watch_nodes) | |||||
| watchpoint.remove_nodes(watch_nodes, rank_id) | |||||
| self._updated_watchpoints[watch_point_id] = watchpoint | self._updated_watchpoints[watch_point_id] = watchpoint | ||||
| self._outdated = True | self._outdated = True | ||||
| log.debug("Update watchpoint %d in cache.", watch_point_id) | log.debug("Update watchpoint %d in cache.", watch_point_id) | ||||
| @@ -328,6 +338,58 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) | raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) | ||||
| class MultiCardWatchpointHitHandler: | |||||
| """Multi-card Watchpoint-hit Handler.""" | |||||
| def __init__(self): | |||||
| self.watchpoint_hit_handlers = {0: WatchpointHitHandler()} | |||||
| def get_hit_handler_by_rank_id(self, rank_id=0): | |||||
| """Get handler by rank id.""" | |||||
| if rank_id in self.watchpoint_hit_handlers: | |||||
| return self.watchpoint_hit_handlers.get(rank_id) | |||||
| log.error("There is no rank id %d.", rank_id) | |||||
| raise ValueError | |||||
| def put(self, value): | |||||
| """Put watchpoint hit into cache.""" | |||||
| for rank_id, tensor_hit_values in value.items(): | |||||
| if rank_id not in self.watchpoint_hit_handlers: | |||||
| self.watchpoint_hit_handlers[rank_id] = WatchpointHitHandler() | |||||
| cur_hit_handler = self.watchpoint_hit_handlers[rank_id] | |||||
| for tensor_hit_value in tensor_hit_values: | |||||
| cur_hit_handler.put(tensor_hit_value) | |||||
| def get(self, filter_condition=None, rank_id=0): | |||||
| """Get the graph of specific node for specific device.""" | |||||
| if rank_id in self.watchpoint_hit_handlers: | |||||
| return self.watchpoint_hit_handlers.get(rank_id).get(filter_condition) | |||||
| log.error("There is no rank id %d.", rank_id) | |||||
| raise ValueError | |||||
| def update_tensor_history(self, tensor_history, rank_id): | |||||
| """ | |||||
| Add hit flag to tensor history. | |||||
| Args: | |||||
| tensor_history (dict): The tensor history. | |||||
| rank_id (int): The rank id. | |||||
| """ | |||||
| if rank_id in self.watchpoint_hit_handlers: | |||||
| self.watchpoint_hit_handlers[rank_id].update_tensor_history(tensor_history) | |||||
| else: | |||||
| for tensor_info in tensor_history.get('tensor_history'): | |||||
| tensor_info['is_hit'] = False | |||||
| def check_rank_id(self, rank_id): | |||||
| """check if has the rank id.""" | |||||
| return rank_id in self.watchpoint_hit_handlers | |||||
| def clean(self): | |||||
| """Clean cache.""" | |||||
| self.__init__() | |||||
| class WatchpointHitHandler(StreamHandlerBase): | class WatchpointHitHandler(StreamHandlerBase): | ||||
| """Watchpoint hit handler.""" | """Watchpoint hit handler.""" | ||||
| @@ -743,3 +805,9 @@ def _get_error_list(error_code): | |||||
| error_list.append(error_str) | error_list.append(error_str) | ||||
| return error_list | return error_list | ||||
| def validate_rank_id(rank_id, device_amount): | |||||
| """validate rank id""" | |||||
| if rank_id >= device_amount: | |||||
| log.debug("The rank id %d over device amount.", rank_id) | |||||
| @@ -23,17 +23,19 @@ class TensorDetailInfo: | |||||
| def __init__(self, cache): | def __init__(self, cache): | ||||
| self._put_command = cache.put_command | self._put_command = cache.put_command | ||||
| self._tensor_stream = cache.get_stream_handler(Streams.TENSOR) | |||||
| self._graph_stream = cache.get_stream_handler(Streams.GRAPH) | |||||
| self._hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| self._metadata_stream = cache.get_stream_handler(Streams.METADATA) | |||||
| self._multi_card_tensor_stream = cache.get_stream_handler(Streams.TENSOR) | |||||
| self._multi_card_graph_stream = cache.get_stream_handler(Streams.GRAPH) | |||||
| self._multi_card_hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| def validate_tensor_name(self, tensor_name, graph_name): | |||||
| def validate_tensor_name(self, tensor_name, graph_name, rank_id): | |||||
| """ | """ | ||||
| Get the graph id of the tensor. | Get the graph id of the tensor. | ||||
| Args: | Args: | ||||
| tensor_name (str): The tensor name on UI. | tensor_name (str): The tensor name on UI. | ||||
| graph_name (str): The graph name. | graph_name (str): The graph name. | ||||
| rank_id (int): The rank id. | |||||
| """ | """ | ||||
| # validate tensor name format | # validate tensor name format | ||||
| if not isinstance(tensor_name, str) or ':' not in tensor_name: | if not isinstance(tensor_name, str) or ':' not in tensor_name: | ||||
| @@ -41,15 +43,17 @@ class TensorDetailInfo: | |||||
| raise DebuggerParamValueError("Invalid tensor name.") | raise DebuggerParamValueError("Invalid tensor name.") | ||||
| node_name, _ = tensor_name.rsplit(':', 1) | node_name, _ = tensor_name.rsplit(':', 1) | ||||
| # check if the node name is in graph | # check if the node name is in graph | ||||
| self._graph_stream.validate_node_name(node_name=node_name, graph_name=graph_name) | |||||
| self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).validate_node_name(node_name=node_name, | |||||
| graph_name=graph_name) | |||||
| def get_tensor_graph(self, tensor_name, graph_name): | |||||
| def get_tensor_graph(self, tensor_name, graph_name, rank_id=0): | |||||
| """ | """ | ||||
| Get the graph related to specific tensor. | Get the graph related to specific tensor. | ||||
| Args: | Args: | ||||
| tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. | tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. | ||||
| graph_name (str): The graph name. | graph_name (str): The graph name. | ||||
| rank_id (int): The rank id. | |||||
| Returns: | Returns: | ||||
| dict, tensor graph, format is {'nodes': [Node object]}. | dict, tensor graph, format is {'nodes': [Node object]}. | ||||
| @@ -68,8 +72,9 @@ class TensorDetailInfo: | |||||
| 'slot_mapping': list[pair<slot, slot>], | 'slot_mapping': list[pair<slot, slot>], | ||||
| }. | }. | ||||
| """ | """ | ||||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) | |||||
| graph = self._graph_stream.get_tensor_graph(tensor_name, graph_name) | |||||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name, rank_id=rank_id) | |||||
| graph = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_tensor_graph(tensor_name, | |||||
| graph_name) | |||||
| # add watchpoint hits info and statistics info for each tensor in tensor graph. | # add watchpoint hits info and statistics info for each tensor in tensor graph. | ||||
| # record missing tensor basic info | # record missing tensor basic info | ||||
| nodes = graph.get('graph', {}).get('nodes', []) | nodes = graph.get('graph', {}).get('nodes', []) | ||||
| @@ -77,13 +82,13 @@ class TensorDetailInfo: | |||||
| for node in nodes: | for node in nodes: | ||||
| node['graph_name'] = graph_name | node['graph_name'] = graph_name | ||||
| for slot_info in node.get('slots', []): | for slot_info in node.get('slots', []): | ||||
| self._add_watchpoint_hit_info(slot_info, node, graph_name) | |||||
| self._add_tensor_info(slot_info, node, missing_tensors) | |||||
| self._add_watchpoint_hit_info(slot_info, node, graph_name, rank_id) | |||||
| self._add_tensor_info(slot_info, node, missing_tensors, rank_id) | |||||
| # query missing tensor values from client | # query missing tensor values from client | ||||
| self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) | self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) | ||||
| return graph | return graph | ||||
| def _add_watchpoint_hit_info(self, slot_info, node, graph_name): | |||||
| def _add_watchpoint_hit_info(self, slot_info, node, graph_name, rank_id): | |||||
| """ | """ | ||||
| Add watchpoint hit info for the tensor. | Add watchpoint hit info for the tensor. | ||||
| @@ -93,9 +98,12 @@ class TensorDetailInfo: | |||||
| graph_name (str): Graph name. | graph_name (str): Graph name. | ||||
| """ | """ | ||||
| tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) | tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) | ||||
| slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name)) | |||||
| if self._multi_card_hit_stream.check_rank_id(rank_id=rank_id): | |||||
| slot_info.update( | |||||
| self._multi_card_hit_stream.get_hit_handler_by_rank_id(rank_id).get_tensor_hit_infos(tensor_name, | |||||
| graph_name)) | |||||
| def _add_tensor_info(self, slot_info, node, missing_tensors): | |||||
| def _add_tensor_info(self, slot_info, node, missing_tensors, rank_id): | |||||
| """ | """ | ||||
| Add the tensor info and query for missed tensors. | Add the tensor info and query for missed tensors. | ||||
| @@ -106,7 +114,8 @@ class TensorDetailInfo: | |||||
| """ | """ | ||||
| tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')]) | tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')]) | ||||
| node_type = node.get('type') | node_type = node.get('type') | ||||
| tensor_info, cur_missing_tensors = self._tensor_stream.get_tensor_info_for_tensor_graph(tensor_name, node_type) | |||||
| tensor_info, cur_missing_tensors = self._multi_card_tensor_stream.get_tensor_handler_by_rank_id( | |||||
| rank_id).get_tensor_info_for_tensor_graph(tensor_name, node_type, self._metadata_stream.step) | |||||
| slot_info.update(tensor_info) | slot_info.update(tensor_info) | ||||
| if cur_missing_tensors: | if cur_missing_tensors: | ||||
| log.debug("Get missing tensor basic infos for %s", tensor_name) | log.debug("Get missing tensor basic infos for %s", tensor_name) | ||||
| @@ -128,20 +137,24 @@ class TensorDetailInfo: | |||||
| self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name}) | self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name}) | ||||
| log.debug("Send view cmd for tensor-graphs.") | log.debug("Send view cmd for tensor-graphs.") | ||||
| def get_tensor_watch_points(self, tensor_name, graph_name): | |||||
| def get_tensor_watch_points(self, tensor_name, graph_name, rank_id=0): | |||||
| """ | """ | ||||
| Get all watchpoints that the tensor hit. | Get all watchpoints that the tensor hit. | ||||
| Args: | Args: | ||||
| tensor_name (str): Tensor name from UI. | tensor_name (str): Tensor name from UI. | ||||
| graph_name (str): The graph name. | graph_name (str): The graph name. | ||||
| rank_id (int): The rank id. | |||||
| Returns: | Returns: | ||||
| list, watchpoint hit infos. | list, watchpoint hit infos. | ||||
| """ | """ | ||||
| # validate tensor_name | # validate tensor_name | ||||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) | |||||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name, rank_id=rank_id) | |||||
| # get watchpoint info that the tensor hit | # get watchpoint info that the tensor hit | ||||
| tensor_hit_info = self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name) | |||||
| if not self._multi_card_hit_stream.check_rank_id(rank_id=rank_id): | |||||
| return [] | |||||
| tensor_hit_info = self._multi_card_hit_stream.get_hit_handler_by_rank_id(rank_id).get_tensor_hit_infos( | |||||
| tensor_name, graph_name) | |||||
| watch_points = tensor_hit_info.get('watch_points', []) | watch_points = tensor_hit_info.get('watch_points', []) | ||||
| return watch_points | return watch_points | ||||
| @@ -18,7 +18,8 @@ import enum | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \ | from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \ | ||||
| DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError | DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError | ||||
| from mindinsight.debugger.common.log import LOGGER as log | from mindinsight.debugger.common.log import LOGGER as log | ||||
| from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type | |||||
| from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type, \ | |||||
| DebuggerServerMode | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | ||||
| from mindinsight.utils.exceptions import MindInsightException | from mindinsight.utils.exceptions import MindInsightException | ||||
| @@ -29,6 +30,7 @@ class ControlTypeEnum(enum.Enum): | |||||
| CONTINUE = 'continue' # continue to run training | CONTINUE = 'continue' # continue to run training | ||||
| PAUSE = 'pause' # suspend training | PAUSE = 'pause' # suspend training | ||||
| TERMINATE = 'terminate' # terminate training | TERMINATE = 'terminate' # terminate training | ||||
| RESET = 'reset' # reset the step_id in offline debugger | |||||
| class TrainingControlOperator: | class TrainingControlOperator: | ||||
| @@ -39,7 +41,7 @@ class TrainingControlOperator: | |||||
| def __init__(self, cache_store): | def __init__(self, cache_store): | ||||
| self._cache_store = cache_store | self._cache_store = cache_store | ||||
| self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) | self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) | ||||
| self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||||
| self._multi_card_graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | ||||
| @staticmethod | @staticmethod | ||||
| @@ -71,6 +73,9 @@ class TrainingControlOperator: | |||||
| """ | """ | ||||
| if mode == ControlTypeEnum.CONTINUE.value: | if mode == ControlTypeEnum.CONTINUE.value: | ||||
| reply = self.continue_training(params) | reply = self.continue_training(params) | ||||
| elif mode == ControlTypeEnum.RESET.value: | |||||
| step_id = params['steps'] | |||||
| reply = self.reset_training_step(step_id) | |||||
| else: | else: | ||||
| mode_mapping = { | mode_mapping = { | ||||
| ControlTypeEnum.PAUSE.value: self.pause_training, | ControlTypeEnum.PAUSE.value: self.pause_training, | ||||
| @@ -150,13 +155,15 @@ class TrainingControlOperator: | |||||
| if level == RunLevel.NODE.value: | if level == RunLevel.NODE.value: | ||||
| node_name = params.get('name') | node_name = params.get('name') | ||||
| graph_name = params.get('graph_name') | graph_name = params.get('graph_name') | ||||
| self._validate_continue_node_name(node_name, graph_name) | |||||
| rank_id = params.get('rank_id', 0) | |||||
| self._validate_continue_node_name(node_name, graph_name, rank_id) | |||||
| def _validate_continue_node_name(self, node_name, graph_name): | |||||
| def _validate_continue_node_name(self, node_name, graph_name, rank_id): | |||||
| """Validate if the node is a leaf node.""" | """Validate if the node is a leaf node.""" | ||||
| if not node_name: | if not node_name: | ||||
| return | return | ||||
| node_type = self._graph_stream.get_node_type(node_name, graph_name) | |||||
| node_type = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_node_type(node_name, | |||||
| graph_name) | |||||
| if is_scope_type(node_type): | if is_scope_type(node_type): | ||||
| log.error("Scope type node has no tensor history.") | log.error("Scope type node has no tensor history.") | ||||
| raise DebuggerParamValueError("Invalid leaf node name.") | raise DebuggerParamValueError("Invalid leaf node name.") | ||||
| @@ -188,7 +195,9 @@ class TrainingControlOperator: | |||||
| name = params.get('name', '') | name = params.get('name', '') | ||||
| graph_name = params.get('graph_name') | graph_name = params.get('graph_name') | ||||
| if name: | if name: | ||||
| name = self._cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name) | |||||
| rank_id = params.get('rank_id', 0) | |||||
| name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_full_name(name, | |||||
| graph_name) | |||||
| run_cmd = RunCMD(run_level='node', node_name=name) | run_cmd = RunCMD(run_level='node', node_name=name) | ||||
| else: | else: | ||||
| run_cmd = RunCMD(run_level='recheck') | run_cmd = RunCMD(run_level='recheck') | ||||
| @@ -199,7 +208,7 @@ class TrainingControlOperator: | |||||
| def _send_watchpoints(self): | def _send_watchpoints(self): | ||||
| """Send watchpoints to client.""" | """Send watchpoints to client.""" | ||||
| set_commands = self._watchpoint_stream.get_pending_commands(self._graph_stream) | |||||
| set_commands = self._watchpoint_stream.get_pending_commands(self._multi_card_graph_stream) | |||||
| if not set_commands: | if not set_commands: | ||||
| return | return | ||||
| for set_cmd in set_commands: | for set_cmd in set_commands: | ||||
| @@ -274,3 +283,30 @@ class TrainingControlOperator: | |||||
| else: | else: | ||||
| log.debug("Send the recheck to command queue.") | log.debug("Send the recheck to command queue.") | ||||
| return metadata_stream.get(['state', 'enable_recheck']) | return metadata_stream.get(['state', 'enable_recheck']) | ||||
| def reset_training_step(self, step_id): | |||||
| """ | |||||
| Reset the training step. | |||||
| Args: | |||||
| step_id (int): The target step_id. | |||||
| Returns: | |||||
| dict, metadata info. | |||||
| """ | |||||
| metadata_stream = self._metadata_stream | |||||
| if metadata_stream.debugger_type == DebuggerServerMode.ONLINE.value: | |||||
| log.error("'step_id' can not be changed manually in online debugger.") | |||||
| return metadata_stream.get(['state', 'enable_recheck', 'step']) | |||||
| if step_id > metadata_stream.max_step_num: | |||||
| log.error("Invalid step_id, step_id should be less than %d.", metadata_stream.max_step_num) | |||||
| raise DebuggerParamValueError("Invalid step_id.") | |||||
| metadata_stream.state = ServerStatus.SENDING.value | |||||
| metadata_stream.step = step_id | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id) | |||||
| self._cache_store.clean_data() | |||||
| self._cache_store.clean_command() | |||||
| metadata_stream.enable_recheck = False | |||||
| metadata_stream.state = ServerStatus.WAITING.value | |||||
| log.debug("Send the Change_training_step CMD.") | |||||
| return metadata_stream.get(['state', 'enable_recheck', 'step']) | |||||
| @@ -31,8 +31,9 @@ class WatchpointOperator: | |||||
| def __init__(self, cache_store, condition_mgr): | def __init__(self, cache_store, condition_mgr): | ||||
| self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) | self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) | ||||
| self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||||
| self._multi_card_graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | ||||
| self._device_stream = cache_store.get_stream_handler(Streams.DEVICE) | |||||
| self._condition_mgr = condition_mgr | self._condition_mgr = condition_mgr | ||||
| def create_watchpoint(self, params): | def create_watchpoint(self, params): | ||||
| @@ -70,11 +71,6 @@ class WatchpointOperator: | |||||
| "Failed to create watchpoint as the MindSpore is not in waiting state.") | "Failed to create watchpoint as the MindSpore is not in waiting state.") | ||||
| self._validate_watch_condition(watch_condition) | self._validate_watch_condition(watch_condition) | ||||
| watch_nodes = self._get_watch_node_with_basic_info( | |||||
| node_names=params.get('watch_nodes'), | |||||
| search_pattern=params.get('search_pattern'), | |||||
| graph_name=params.get('graph_name')) | |||||
| validate_watch_condition(self._condition_mgr, watch_condition) | validate_watch_condition(self._condition_mgr, watch_condition) | ||||
| condition_id = watch_condition.get('id') | condition_id = watch_condition.get('id') | ||||
| condition = self._condition_mgr.get_condition(condition_id) | condition = self._condition_mgr.get_condition(condition_id) | ||||
| @@ -84,10 +80,11 @@ class WatchpointOperator: | |||||
| raise DebuggerConditionUnavailableError( | raise DebuggerConditionUnavailableError( | ||||
| "Failed to create watchpoint as the condition is not available.") | "Failed to create watchpoint as the condition is not available.") | ||||
| watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy() | |||||
| watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._multi_card_graph_stream) | |||||
| watchpoint_stream = self._watchpoint_stream | watchpoint_stream = self._watchpoint_stream | ||||
| watch_point_id = watchpoint_stream.create_watchpoint( | |||||
| self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) | |||||
| watch_point_id = watchpoint_stream.create_watchpoint(self._condition_mgr, watch_condition, watch_nodes, | |||||
| params.get('watch_point_id'), | |||||
| self._device_stream.device_amount) | |||||
| log.info("Create watchpoint %d", watch_point_id) | log.info("Create watchpoint %d", watch_point_id) | ||||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() | metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() | ||||
| @@ -115,6 +112,7 @@ class WatchpointOperator: | |||||
| 1 for add nodes to watch nodes. | 1 for add nodes to watch nodes. | ||||
| - search_pattern (dict): The search pattern. | - search_pattern (dict): The search pattern. | ||||
| - graph_name (str): The relative graph_name of the watched node. | - graph_name (str): The relative graph_name of the watched node. | ||||
| - rank_id (int): The rank id. | |||||
| Returns: | Returns: | ||||
| dict, the metadata info. | dict, the metadata info. | ||||
| @@ -137,13 +135,14 @@ class WatchpointOperator: | |||||
| watch_nodes = self._get_watch_node_with_basic_info( | watch_nodes = self._get_watch_node_with_basic_info( | ||||
| node_names=params.get('watch_nodes'), | node_names=params.get('watch_nodes'), | ||||
| search_pattern=params.get('search_pattern'), | search_pattern=params.get('search_pattern'), | ||||
| graph_name=params.get('graph_name')) | |||||
| watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode')) | |||||
| graph_name=params.get('graph_name'), | |||||
| rank_id=params.get('rank_id', 0)) | |||||
| watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode'), params.get('rank_id', 0)) | |||||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() | metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() | ||||
| log.info("Update watchpoint with id: %d", watch_point_id) | log.info("Update watchpoint with id: %d", watch_point_id) | ||||
| return metadata_stream.get(['state', 'enable_recheck']) | return metadata_stream.get(['state', 'enable_recheck']) | ||||
| def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None): | |||||
| def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None, rank_id=0): | |||||
| """ | """ | ||||
| Get watch node with basic info. | Get watch node with basic info. | ||||
| @@ -151,20 +150,21 @@ class WatchpointOperator: | |||||
| node_names (list[str]): A list of node names. | node_names (list[str]): A list of node names. | ||||
| search_pattern (dict): Get watch node with search pattern. Default: None | search_pattern (dict): Get watch node with search pattern. Default: None | ||||
| graph_name (str): The relative graph_name of the watched node. Default: None. | graph_name (str): The relative graph_name of the watched node. Default: None. | ||||
| rank_id (int): The rank id. | |||||
| Returns: | Returns: | ||||
| list[NodeBasicInfo], a list of node basic infos. | list[NodeBasicInfo], a list of node basic infos. | ||||
| """ | """ | ||||
| if not node_names: | if not node_names: | ||||
| return [] | return [] | ||||
| graph_name = self._graph_stream.validate_graph_name(graph_name) | |||||
| graph_name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).validate_graph_name(graph_name) | |||||
| if search_pattern is not None: | if search_pattern is not None: | ||||
| watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name) | |||||
| watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name, rank_id) | |||||
| else: | else: | ||||
| watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name) | |||||
| watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name, rank_id=rank_id) | |||||
| return watch_nodes | return watch_nodes | ||||
| def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name): | |||||
| def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name, rank_id): | |||||
| """ | """ | ||||
| Get watched leaf nodes by search name. | Get watched leaf nodes by search name. | ||||
| @@ -180,7 +180,7 @@ class WatchpointOperator: | |||||
| list[NodeBasicInfo], a list of node basic infos. | list[NodeBasicInfo], a list of node basic infos. | ||||
| """ | """ | ||||
| search_pattern['graph_name'] = graph_name | search_pattern['graph_name'] = graph_name | ||||
| search_nodes = self._graph_stream.search_nodes(search_pattern) | |||||
| search_nodes = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).search_nodes(search_pattern) | |||||
| watch_node_names = set() | watch_node_names = set() | ||||
| for name in node_names: | for name in node_names: | ||||
| names = self._get_watch_names_by_search(search_nodes, name) | names = self._get_watch_names_by_search(search_nodes, name) | ||||
| @@ -260,7 +260,7 @@ class WatchpointOperator: | |||||
| log.info("Delete watchpoint with id: %s", watch_point_id) | log.info("Delete watchpoint with id: %s", watch_point_id) | ||||
| return metadata_stream.get(['state', 'enable_recheck']) | return metadata_stream.get(['state', 'enable_recheck']) | ||||
| def _get_node_basic_infos(self, node_names, graph_name=None): | |||||
| def _get_node_basic_infos(self, node_names, graph_name=None, rank_id=0): | |||||
| """ | """ | ||||
| Get watch node info according to node names. | Get watch node info according to node names. | ||||
| @@ -273,7 +273,7 @@ class WatchpointOperator: | |||||
| """ | """ | ||||
| if not node_names: | if not node_names: | ||||
| return [] | return [] | ||||
| graph_stream = self._graph_stream | |||||
| graph_stream = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id) | |||||
| node_infos = [] | node_infos = [] | ||||
| for node_name in node_names: | for node_name in node_names: | ||||
| node_info = graph_stream.get_node_basic_info(node_name, graph_name) | node_info = graph_stream.get_node_basic_info(node_name, graph_name) | ||||
| @@ -26,7 +26,7 @@ limitations under the License. | |||||
| <div class="cl-center" | <div class="cl-center" | ||||
| :class="showWarmText ? 'cl-center-height' : ''"> | :class="showWarmText ? 'cl-center-height' : ''"> | ||||
| <router-view></router-view> | |||||
| <router-view :key="$route.fullPath"></router-view> | |||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| </template> | </template> | ||||
| @@ -362,8 +362,9 @@ export default { | |||||
| const params = { | const params = { | ||||
| tensor_name: this.curRowObj.name, | tensor_name: this.curRowObj.name, | ||||
| graph_name: this.curRowObj.graph_name, | graph_name: this.curRowObj.graph_name, | ||||
| rank_id: this.curRowObj.rank_id, | |||||
| }; | }; | ||||
| RequestService.getTensorGraphData(params).then( | |||||
| RequestService.getTensorGraphData(params, this.curRowObj.sessionId).then( | |||||
| (res) => { | (res) => { | ||||
| if (res && res.data && res.data.graph && res.data.graph.nodes && res.data.graph.nodes.length) { | if (res && res.data && res.data.graph && res.data.graph.nodes && res.data.graph.nodes.length) { | ||||
| this.graphShow = true; | this.graphShow = true; | ||||
| @@ -419,8 +420,9 @@ export default { | |||||
| const params = { | const params = { | ||||
| tensor_name: this.curRowObj.name, | tensor_name: this.curRowObj.name, | ||||
| graph_name: this.curRowObj.graph_name, | graph_name: this.curRowObj.graph_name, | ||||
| rank_id: this.curRowObj.rank_id, | |||||
| }; | }; | ||||
| RequestService.tensorHitsData(params).then( | |||||
| RequestService.tensorHitsData(params, this.curRowObj.sessionId).then( | |||||
| (res) => { | (res) => { | ||||
| if (res && res.data && res.data.watch_points && res.data.watch_points.length) { | if (res && res.data && res.data.watch_points && res.data.watch_points.length) { | ||||
| this.leftDataShow = true; | this.leftDataShow = true; | ||||
| @@ -995,11 +997,12 @@ export default { | |||||
| shape: encodeURIComponent(shape), | shape: encodeURIComponent(shape), | ||||
| tolerance: this.tolerance / 100, | tolerance: this.tolerance / 100, | ||||
| graph_name: row.graph_name, | graph_name: row.graph_name, | ||||
| rank_id: row.rank_id, | |||||
| }; | }; | ||||
| if (loadingFlag) { | if (loadingFlag) { | ||||
| this.loadingInstance = this.$loading(this.loadingOption); | this.loadingInstance = this.$loading(this.loadingOption); | ||||
| } | } | ||||
| RequestService.tensorComparisons(params).then( | |||||
| RequestService.tensorComparisons(params, row.sessionId).then( | |||||
| (res) => { | (res) => { | ||||
| if (res && res.data && res.data.tensor_value) { | if (res && res.data && res.data.tensor_value) { | ||||
| if (row.shape === '[]') { | if (row.shape === '[]') { | ||||
| @@ -1088,11 +1091,12 @@ export default { | |||||
| shape: encodeURIComponent(shape), | shape: encodeURIComponent(shape), | ||||
| graph_name: row.graph_name, | graph_name: row.graph_name, | ||||
| prev: this.gridType === 'preStep' ? true : false, | prev: this.gridType === 'preStep' ? true : false, | ||||
| rank_id: row.rank_id, | |||||
| }; | }; | ||||
| if (loadingFlag) { | if (loadingFlag) { | ||||
| this.loadingInstance = this.$loading(this.loadingOption); | this.loadingInstance = this.$loading(this.loadingOption); | ||||
| } | } | ||||
| RequestService.tensors(params).then( | |||||
| RequestService.tensors(params, row.sessionId).then( | |||||
| (res) => { | (res) => { | ||||
| if (row.shape === '[]') { | if (row.shape === '[]') { | ||||
| this.showFilterInput = false; | this.showFilterInput = false; | ||||
| @@ -24,7 +24,9 @@ | |||||
| "dataLoading": "Loading data...", | "dataLoading": "Loading data...", | ||||
| "notice": "Information", | "notice": "Information", | ||||
| "caseMode": "Not case sensitive", | "caseMode": "Not case sensitive", | ||||
| "all": "All" | |||||
| "all": "All", | |||||
| "details": "Details", | |||||
| "delete": "Delete" | |||||
| }, | }, | ||||
| "symbols": { | "symbols": { | ||||
| "leftbracket": "(", | "leftbracket": "(", | ||||
| @@ -52,12 +54,14 @@ | |||||
| "operation": "Operation", | "operation": "Operation", | ||||
| "viewDashboard": "Training Dashboard", | "viewDashboard": "Training Dashboard", | ||||
| "viewProfiler": "Profiling", | "viewProfiler": "Profiling", | ||||
| "viewOfflineDebugger": "Offline Debugger", | |||||
| "modelTraceback": "Model Lineage", | "modelTraceback": "Model Lineage", | ||||
| "dataTraceback": "Dataset Lineage", | "dataTraceback": "Dataset Lineage", | ||||
| "comparePlate": "Comparison Dashboard", | "comparePlate": "Comparison Dashboard", | ||||
| "disableProfilerTip": "Failed to view profiling because no profiler log is available.", | "disableProfilerTip": "Failed to view profiling because no profiler log is available.", | ||||
| "disableDashboardTip": "Failed to view training dashboard because no summary log or pb files are available.", | "disableDashboardTip": "Failed to view training dashboard because no summary log or pb files are available.", | ||||
| "disableParameterTip": "Failed to view parameter details because no lineage log is available.", | "disableParameterTip": "Failed to view parameter details because no lineage log is available.", | ||||
| "disableOfflineDebugger": "Failed to view offline debugger because no debugger log is available.", | |||||
| "openNewTab": "Open Link in New Tab", | "openNewTab": "Open Link in New Tab", | ||||
| "paramDetails": "Parameter Details", | "paramDetails": "Parameter Details", | ||||
| "trainingParamDetails": "Training Parameter Details", | "trainingParamDetails": "Training Parameter Details", | ||||
| @@ -80,7 +84,12 @@ | |||||
| "tensorUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#tensor-visualization", | "tensorUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#tensor-visualization", | ||||
| "graphUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#computational-graph-visualization", | "graphUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#computational-graph-visualization", | ||||
| "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#dataset-graph-visualization", | "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#dataset-graph-visualization", | ||||
| "imageUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#image-visualization" | |||||
| "imageUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#image-visualization", | |||||
| "sessionLimit": "The number of sessions of the offline debugger exceeds the number of online sessions", | |||||
| "sessionLimitNum": "At most 2 exist at the same time", | |||||
| "sessionLists": "List of currently existing sessions", | |||||
| "deleteSessionConfirm": "This operation will delete the current session, do you want to continue?", | |||||
| "deleteSessionSuccess": "Delete session successfully!" | |||||
| }, | }, | ||||
| "modelTraceback": { | "modelTraceback": { | ||||
| "summaryPath": "Summary Path", | "summaryPath": "Summary Path", | ||||
| @@ -561,7 +570,7 @@ | |||||
| "terminate": "TERMINATE", | "terminate": "TERMINATE", | ||||
| "selectCondition": "Select a condition", | "selectCondition": "Select a condition", | ||||
| "inputStep": "Enter a step value", | "inputStep": "Enter a step value", | ||||
| "inputTip": "A positive integer less than 2147483648", | |||||
| "inputTip": "A positive integer less than or equal to {total_step_num}", | |||||
| "curHitNode": "Watch Point Hit List", | "curHitNode": "Watch Point Hit List", | ||||
| "backstageStatus": "The backend running status is ", | "backstageStatus": "The backend running status is ", | ||||
| "view": "View", | "view": "View", | ||||
| @@ -830,7 +839,9 @@ | |||||
| "allPositive": "he parameter value must be greater than 0.", | "allPositive": "he parameter value must be greater than 0.", | ||||
| "watchOverflow": "The asynchronous full overflow watching function must be enabled before the training starts." | "watchOverflow": "The asynchronous full overflow watching function must be enabled before the training starts." | ||||
| }, | }, | ||||
| "paramValueTip": "Preset Value: {value}" | |||||
| "paramValueTip": "Preset Value: {value}", | |||||
| "logicCard": "Logic card", | |||||
| "inpStepTip": "Step:0~{total_step_num}" | |||||
| }, | }, | ||||
| "explain": { | "explain": { | ||||
| "explain": "Model Explanation", | "explain": "Model Explanation", | ||||
| @@ -952,6 +963,7 @@ | |||||
| "5054B183": "Backend training is in progress or has ended. Please try again later", | "5054B183": "Backend training is in progress or has ended. Please try again later", | ||||
| "5054B184": "The operation is too fast, the backend service has been suspended.", | "5054B184": "The operation is too fast, the backend service has been suspended.", | ||||
| "5054B189": "Do not set the value repeatedly.", | "5054B189": "Do not set the value repeatedly.", | ||||
| "5054B083": "Failed to create the watchpoint. Do not use invalid rules." | |||||
| "5054B083": "Failed to create the watchpoint. Do not use invalid rules.", | |||||
| "5054B202": "The debugger offline server module was not found" | |||||
| } | } | ||||
| } | } | ||||
| @@ -24,7 +24,9 @@ | |||||
| "dataLoading": "数据加载中", | "dataLoading": "数据加载中", | ||||
| "notice": "提示", | "notice": "提示", | ||||
| "caseMode": "不区分大小写", | "caseMode": "不区分大小写", | ||||
| "all": "全部" | |||||
| "all": "全部", | |||||
| "details": "详情", | |||||
| "delete": "删除" | |||||
| }, | }, | ||||
| "symbols": { | "symbols": { | ||||
| "leftbracket": "(", | "leftbracket": "(", | ||||
| @@ -52,12 +54,14 @@ | |||||
| "operation": "操作", | "operation": "操作", | ||||
| "viewDashboard": "训练看板", | "viewDashboard": "训练看板", | ||||
| "viewProfiler": "性能分析", | "viewProfiler": "性能分析", | ||||
| "viewOfflineDebugger": "离线调试器", | |||||
| "modelTraceback": "模型溯源", | "modelTraceback": "模型溯源", | ||||
| "dataTraceback": "数据溯源", | "dataTraceback": "数据溯源", | ||||
| "comparePlate": "对比看板", | "comparePlate": "对比看板", | ||||
| "disableProfilerTip": "无profiler日志,无法查看性能分析", | "disableProfilerTip": "无profiler日志,无法查看性能分析", | ||||
| "disableDashboardTip": "无summary日志或pb文件,无法查看训练看板", | "disableDashboardTip": "无summary日志或pb文件,无法查看训练看板", | ||||
| "disableParameterTip": "无lineage日志,无法查看参数详情", | "disableParameterTip": "无lineage日志,无法查看参数详情", | ||||
| "disableOfflineDebugger": "无Debugger日志,无法查看离线调试器", | |||||
| "openNewTab": "打开新页签", | "openNewTab": "打开新页签", | ||||
| "paramDetails": "参数详情", | "paramDetails": "参数详情", | ||||
| "trainingParamDetails": "训练参数详情", | "trainingParamDetails": "训练参数详情", | ||||
| @@ -80,7 +84,12 @@ | |||||
| "tensorUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id8", | "tensorUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id8", | ||||
| "graphUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id5", | "graphUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id5", | ||||
| "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id6", | "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id6", | ||||
| "imageUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id7" | |||||
| "imageUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id7", | |||||
| "sessionLimit": "离线调试器的session个数超过上线", | |||||
| "sessionLimitNum": "最多同时存在2个", | |||||
| "sessionLists": "目前存在的session列表", | |||||
| "deleteSessionConfirm": "此操作将删除当前session, 是否继续?", | |||||
| "deleteSessionSuccess": "删除session成功!" | |||||
| }, | }, | ||||
| "modelTraceback": { | "modelTraceback": { | ||||
| "summaryPath": "训练日志路径", | "summaryPath": "训练日志路径", | ||||
| @@ -560,7 +569,7 @@ | |||||
| "terminate": "结束", | "terminate": "结束", | ||||
| "selectCondition": "请选择条件", | "selectCondition": "请选择条件", | ||||
| "inputStep": "请输入轮次值", | "inputStep": "请输入轮次值", | ||||
| "inputTip": "小于2147483648的正整数", | |||||
| "inputTip": "小于等于{total_step_num}的正整数", | |||||
| "curHitNode": "命中的监测点", | "curHitNode": "命中的监测点", | ||||
| "backstageStatus": "后台运行状态是", | "backstageStatus": "后台运行状态是", | ||||
| "view": "查看", | "view": "查看", | ||||
| @@ -825,7 +834,9 @@ | |||||
| "allPositive": "此参数值必须大于0", | "allPositive": "此参数值必须大于0", | ||||
| "watchOverflow": "训练开始前需开启异步全量溢出监测功能" | "watchOverflow": "训练开始前需开启异步全量溢出监测功能" | ||||
| }, | }, | ||||
| "paramValueTip": "设置值为:{value}" | |||||
| "paramValueTip": "设置值为:{value}", | |||||
| "logicCard": "逻辑卡", | |||||
| "inpStepTip": "可输入当前轮次:0~{total_step_num}" | |||||
| }, | }, | ||||
| "explain": { | "explain": { | ||||
| "explain": "模型解释", | "explain": "模型解释", | ||||
| @@ -947,6 +958,7 @@ | |||||
| "5054B183": "后台训练运行中,请稍后重试", | "5054B183": "后台训练运行中,请稍后重试", | ||||
| "5054B184": "操作过快,后台服务已暂停。", | "5054B184": "操作过快,后台服务已暂停。", | ||||
| "5054B189": "请勿重复设置。", | "5054B189": "请勿重复设置。", | ||||
| "5054B083": "监测点创建失败,请勿使用已失效规则。" | |||||
| "5054B083": "监测点创建失败,请勿使用已失效规则。", | |||||
| "5054B202": "未找到调试器离线服务器模块" | |||||
| } | } | ||||
| } | } | ||||
| @@ -157,6 +157,10 @@ export default new Router({ | |||||
| path: '/debugger', | path: '/debugger', | ||||
| component: () => import('./views/debugger/debugger.vue'), | component: () => import('./views/debugger/debugger.vue'), | ||||
| }, | }, | ||||
| { | |||||
| path: '/offline-debugger', | |||||
| component: () => import('./views/debugger/debugger.vue'), | |||||
| }, | |||||
| { | { | ||||
| path: '/explain', | path: '/explain', | ||||
| component: () => import('./views/explain/summary-list.vue'), | component: () => import('./views/explain/summary-list.vue'), | ||||
| @@ -62,7 +62,14 @@ axios.interceptors.response.use( | |||||
| const errorData = i18n.messages[i18n.locale].error; | const errorData = i18n.messages[i18n.locale].error; | ||||
| const path = router.currentRoute.path; | const path = router.currentRoute.path; | ||||
| if (path === '/debugger') { | |||||
| if (path === '/debugger' || path === '/offline-debugger') { | |||||
| if ( | |||||
| error.response && | |||||
| error.response.data && | |||||
| error.response.data.error_code === '5054B281' | |||||
| ) { | |||||
| router.push('/'); | |||||
| } | |||||
| return Promise.reject(error); | return Promise.reject(error); | ||||
| } | } | ||||
| // error returned by backend | // error returned by backend | ||||
| @@ -309,55 +309,74 @@ export default { | |||||
| }); | }); | ||||
| }, | }, | ||||
| // debugger | // debugger | ||||
| pollData(params) { | |||||
| getSession(params) { | |||||
| return axios({ | |||||
| method: 'post', | |||||
| url: 'v1/mindinsight/debugger/sessions', | |||||
| data: params, | |||||
| }); | |||||
| }, | |||||
| deleteSession(sessionId) { | |||||
| return axios({ | |||||
| method: 'post', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/delete`, | |||||
| }); | |||||
| }, | |||||
| checkSessions() { | |||||
| return axios({ | return axios({ | ||||
| method: 'get', | method: 'get', | ||||
| url: 'v1/mindinsight/debugger/poll-data', | |||||
| url: `v1/mindinsight/debugger/sessions`, | |||||
| }); | |||||
| }, | |||||
| pollData(params, sessionId) { | |||||
| return axios({ | |||||
| method: 'get', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/poll-data`, | |||||
| params: params, | params: params, | ||||
| headers: { | headers: { | ||||
| ignoreError: true, | ignoreError: true, | ||||
| }, | }, | ||||
| }); | }); | ||||
| }, | }, | ||||
| retrieve(params) { | |||||
| retrieve(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: 'v1/mindinsight/debugger/retrieve', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/retrieve`, | |||||
| data: params, | data: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| createWatchpoint(params) { | |||||
| createWatchpoint(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: 'v1/mindinsight/debugger/create-watchpoint', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/create-watchpoint`, | |||||
| data: params, | data: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| updateWatchpoint(params) { | |||||
| updateWatchpoint(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: 'v1/mindinsight/debugger/update-watchpoint', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/update-watchpoint`, | |||||
| data: params, | data: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| deleteWatchpoint(params) { | |||||
| deleteWatchpoint(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: 'v1/mindinsight/debugger/delete-watchpoint', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/delete-watchpoint`, | |||||
| data: params, | data: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| control(params) { | |||||
| control(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: 'v1/mindinsight/debugger/control', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/control`, | |||||
| data: params, | data: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| search(params) { | |||||
| search(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'get', | method: 'get', | ||||
| url: 'v1/mindinsight/debugger/search', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/search`, | |||||
| params: params, | params: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| @@ -368,43 +387,43 @@ export default { | |||||
| params: params, | params: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| tensorComparisons(params) { | |||||
| tensorComparisons(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'get', | method: 'get', | ||||
| url: 'v1/mindinsight/debugger/tensor-comparisons', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-comparisons`, | |||||
| params: params, | params: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| tensors(params) { | |||||
| tensors(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'get', | method: 'get', | ||||
| url: 'v1/mindinsight/debugger/tensors', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensors`, | |||||
| params: params, | params: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| retrieveTensorHistory(params) { | |||||
| retrieveTensorHistory(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: 'v1/mindinsight/debugger/tensor-history', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-history`, | |||||
| data: params, | data: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| queryConditions(trainId) { | |||||
| queryConditions(sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'get', | method: 'get', | ||||
| url: `v1/mindinsight/conditionmgr/train-jobs/${trainId}/condition-collections`, | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/condition-collections`, | |||||
| }); | }); | ||||
| }, | }, | ||||
| recheckWatchPoints() { | |||||
| recheckWatchPoints(sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: `v1/mindinsight/debugger/recheck`, | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/recheck`, | |||||
| }); | }); | ||||
| }, | }, | ||||
| searchWatchpointHits(params) { | |||||
| searchWatchpointHits(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: `v1/mindinsight/debugger/search-watchpoint-hits`, | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/search-watchpoint-hits`, | |||||
| data: params, | data: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| @@ -447,33 +466,25 @@ export default { | |||||
| data: params, | data: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| tensorHitsData(params) { | |||||
| tensorHitsData(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'get', | method: 'get', | ||||
| url: 'v1/mindinsight/debugger/tensor-hits', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-hits`, | |||||
| params: params, | params: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| getTensorGraphData(params) { | |||||
| getTensorGraphData(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'get', | method: 'get', | ||||
| url: 'v1/mindinsight/debugger/tensor-graphs', | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-graphs`, | |||||
| params: params, | params: params, | ||||
| }); | }); | ||||
| }, | }, | ||||
| getCpuUtilization(params) { | |||||
| return axios({ | |||||
| method: 'post', | |||||
| url: 'v1/mindinsight/profile/minddata-cpu-utilization-summary', | |||||
| params: params.params, | |||||
| data: params.body, | |||||
| }); | |||||
| }, | |||||
| setRecommendWatchPoints(params) { | |||||
| setRecommendWatchPoints(params, sessionId) { | |||||
| return axios({ | return axios({ | ||||
| method: 'post', | method: 'post', | ||||
| url: `v1/mindinsight/conditionmgr/train-jobs/${params.trainId}/set-recommended-watch-points`, | |||||
| data: params.body, | |||||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/set-recommended-watch-points`, | |||||
| data: params, | |||||
| }); | }); | ||||
| }, | }, | ||||
| // memory-datail apis | // memory-datail apis | ||||
| @@ -46,6 +46,17 @@ limitations under the License. | |||||
| </div> | </div> | ||||
| <div class="content" | <div class="content" | ||||
| v-show="radio1==='tree'"> | v-show="radio1==='tree'"> | ||||
| <div class="node-type"> | |||||
| <div class="label">{{ $t('debugger.logicCard') }}</div> | |||||
| <el-select v-model="logicCard.value" | |||||
| @change="logicCardChange" | |||||
| :disabled="!trainId"> | |||||
| <el-option v-for="item in logicCard.options" | |||||
| :key="item" | |||||
| :value="item"> | |||||
| </el-option> | |||||
| </el-select> | |||||
| </div> | |||||
| <div class="node-type"> | <div class="node-type"> | ||||
| <div class="label">{{ $t('debugger.graphFile') }}</div> | <div class="label">{{ $t('debugger.graphFile') }}</div> | ||||
| <el-select v-model="graphFiles.value" | <el-select v-model="graphFiles.value" | ||||
| @@ -209,6 +220,17 @@ limitations under the License. | |||||
| </div> | </div> | ||||
| <div class="content" | <div class="content" | ||||
| v-show="radio1==='hit'"> | v-show="radio1==='hit'"> | ||||
| <div class="node-type"> | |||||
| <div class="label">{{ $t('debugger.logicCard') }}</div> | |||||
| <el-select v-model="logicCard.value" | |||||
| :disabled="!trainId" | |||||
| @change="logicCardChange();searchWatchpointHits(true);"> | |||||
| <el-option v-for="item in logicCard.options" | |||||
| :key="item" | |||||
| :value="item"> | |||||
| </el-option> | |||||
| </el-select> | |||||
| </div> | |||||
| <div class="hit-list-wrap"> | <div class="hit-list-wrap"> | ||||
| <el-table class="watchpoint-table" | <el-table class="watchpoint-table" | ||||
| :data="watchPointHits" | :data="watchPointHits" | ||||
| @@ -261,7 +283,7 @@ limitations under the License. | |||||
| <div class="step"> | <div class="step"> | ||||
| <el-tooltip class="item" | <el-tooltip class="item" | ||||
| effect="light" | effect="light" | ||||
| :content="$t('debugger.inputTip')" | |||||
| :content="$t('debugger.inputTip',{total_step_num:metadata.total_step_num})" | |||||
| placement="top-start"> | placement="top-start"> | ||||
| <el-input v-model="step" | <el-input v-model="step" | ||||
| :placeholder="$t('debugger.inputStep')" | :placeholder="$t('debugger.inputStep')" | ||||
| @@ -330,6 +352,25 @@ limitations under the License. | |||||
| v-show="metadata.state === state.sending"> | v-show="metadata.state === state.sending"> | ||||
| <i class="el-icon-time"></i> | <i class="el-icon-time"></i> | ||||
| </el-tooltip> | </el-tooltip> | ||||
| <i class="el-icon-edit" | |||||
| v-if="trainId && !isShowInp" | |||||
| :title="$t('debugger.inpStepTip',{total_step_num:metadata.total_step_num})" | |||||
| @click="editStep"></i> | |||||
| <el-tooltip class="item" | |||||
| effect="light" | |||||
| :content="$t('debugger.inputTip',{total_step_num:metadata.total_step_num})" | |||||
| placement="top-start" | |||||
| v-if="trainId && isShowInp"> | |||||
| <el-input v-model="newStep" | |||||
| type="text" | |||||
| @input="newStepChange"></el-input> | |||||
| </el-tooltip> | |||||
| <i class="el-icon-check" | |||||
| v-if="trainId && isShowInp" | |||||
| @click="saveStepValue"></i> | |||||
| <i class="el-icon-close" | |||||
| v-if="trainId && isShowInp" | |||||
| @click="isShowInp=false"></i> | |||||
| </div> | </div> | ||||
| <div class="svg-wrap" | <div class="svg-wrap" | ||||
| :class="{collapse: collapseTable}"> | :class="{collapse: collapseTable}"> | ||||
| @@ -505,7 +546,7 @@ limitations under the License. | |||||
| :close-on-click-modal="false" | :close-on-click-modal="false" | ||||
| :modal-append-to-body="false" | :modal-append-to-body="false" | ||||
| class="creat-watch-point-dialog" | class="creat-watch-point-dialog" | ||||
| width="890px"> | |||||
| width="930px"> | |||||
| <div class="conditions-container"> | <div class="conditions-container"> | ||||
| <div class="condition-item" | <div class="condition-item" | ||||
| @@ -787,6 +828,11 @@ export default { | |||||
| value: '', | value: '', | ||||
| graphs: {}, | graphs: {}, | ||||
| }, | }, | ||||
| logicCard: { | |||||
| options: [], | |||||
| value: '', | |||||
| }, | |||||
| devices: [], | |||||
| allGraphData: {}, // Graph Original input data | allGraphData: {}, // Graph Original input data | ||||
| firstFloorNodes: [], // ID array of the first layer node. | firstFloorNodes: [], // ID array of the first layer node. | ||||
| nodesCountLimit: 1500, // Maximum number of sub-nodes in a namespace. | nodesCountLimit: 1500, // Maximum number of sub-nodes in a namespace. | ||||
| @@ -830,7 +876,7 @@ export default { | |||||
| expandKeys: [], | expandKeys: [], | ||||
| isHitIntoView: true, | isHitIntoView: true, | ||||
| searchedWord: '', | searchedWord: '', | ||||
| trainId: '', | |||||
| trainId: this.$route.query.dir, | |||||
| recommendWatchPointDialog: false, | recommendWatchPointDialog: false, | ||||
| hitsOutdated: false, | hitsOutdated: false, | ||||
| conflictFlag: false, | conflictFlag: false, | ||||
| @@ -859,6 +905,9 @@ export default { | |||||
| }, | }, | ||||
| loadingInstance: null, | loadingInstance: null, | ||||
| paramErrorMsg: '', | paramErrorMsg: '', | ||||
| sessionId: this.$route.query.sessionId, | |||||
| isShowInp: false, | |||||
| newStep: '', | |||||
| }; | }; | ||||
| }, | }, | ||||
| components: {debuggerTensor, tree}, | components: {debuggerTensor, tree}, | ||||
| @@ -866,6 +915,12 @@ export default { | |||||
| mounted() { | mounted() { | ||||
| document.title = `${this.$t('debugger.debugger')}-MindInsight`; | document.title = `${this.$t('debugger.debugger')}-MindInsight`; | ||||
| this.nodeTypes.label = this.$t('debugger.nodeType'); | this.nodeTypes.label = this.$t('debugger.nodeType'); | ||||
| if (this.trainId) { | |||||
| document.title = `${this.trainId}-${this.$t('debugger.debugger')}-MindInsight`; | |||||
| this.retrieveAll(); | |||||
| } else { | |||||
| this.getSession(); | |||||
| } | |||||
| }, | }, | ||||
| watch: { | watch: { | ||||
| 'metadata.state': { | 'metadata.state': { | ||||
| @@ -896,7 +951,7 @@ export default { | |||||
| if (newValue === this.state.waiting) { | if (newValue === this.state.waiting) { | ||||
| if (this.oldState === this.state.pending || oldValue === this.state.pending) { | if (this.oldState === this.state.pending || oldValue === this.state.pending) { | ||||
| this.loadNode(this.node, this.resolve); | |||||
| this.retrieveAll(); | |||||
| } else if (this.oldState === this.state.running || oldValue === this.state.running) { | } else if (this.oldState === this.state.running || oldValue === this.state.running) { | ||||
| this.pagination.currentPage = 1; | this.pagination.currentPage = 1; | ||||
| this.watchPointHits = []; | this.watchPointHits = []; | ||||
| @@ -914,6 +969,8 @@ export default { | |||||
| this.curRowObj.type = type; | this.curRowObj.type = type; | ||||
| this.curRowObj.curFileName = this.graphFiles.value; | this.curRowObj.curFileName = this.graphFiles.value; | ||||
| this.curRowObj.step = this.metadata.step; | this.curRowObj.step = this.metadata.step; | ||||
| this.curRowObj.rank_id = this.logicCard.value; | |||||
| this.curRowObj.sessionId = this.sessionId; | |||||
| this.tensorCompareFlag = true; | this.tensorCompareFlag = true; | ||||
| }, | }, | ||||
| closeTensor(tensor, graphName) { | closeTensor(tensor, graphName) { | ||||
| @@ -922,6 +979,19 @@ export default { | |||||
| this.queryAllTreeData(tensor, true, graphName, true); | this.queryAllTreeData(tensor, true, graphName, true); | ||||
| } | } | ||||
| }, | }, | ||||
| logicCardChange() { | |||||
| this.graphFiles.options = JSON.parse( | |||||
| JSON.stringify(this.devices.find((val) => val.rank_id === this.logicCard.value).graph_names), | |||||
| ); | |||||
| if (this.graphFiles.options.length > 1) { | |||||
| this.graphFiles.options.unshift(this.$t('debugger.all')); | |||||
| } | |||||
| this.graphFiles.value = this.graphFiles.options[0]; | |||||
| const device = this.devices.find((val) => val.rank_id === this.logicCard.value); | |||||
| this.metadata.ip = device.server_ip; | |||||
| this.metadata.device_name = device.device_id; | |||||
| this.queryGraphByFile(); | |||||
| }, | |||||
| queryGraphByFile() { | queryGraphByFile() { | ||||
| this.searchWord = ''; | this.searchWord = ''; | ||||
| this.nodeTypes.value = 'all'; | this.nodeTypes.value = 'all'; | ||||
| @@ -931,12 +1001,13 @@ export default { | |||||
| params: { | params: { | ||||
| watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, | watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, | ||||
| graph_name: this.graphFiles.value, | graph_name: this.graphFiles.value, | ||||
| rank_id: this.logicCard.value, | |||||
| }, | }, | ||||
| }; | }; | ||||
| if (this.graphFiles.value === this.$t('debugger.all')) { | if (this.graphFiles.value === this.$t('debugger.all')) { | ||||
| delete params.params.graph_name; | delete params.params.graph_name; | ||||
| } | } | ||||
| RequestService.retrieve(params).then( | |||||
| RequestService.retrieve(params, this.sessionId).then( | |||||
| (res) => { | (res) => { | ||||
| if (res.data && res.data.metadata) { | if (res.data && res.data.metadata) { | ||||
| this.dealMetadata(res.data.metadata); | this.dealMetadata(res.data.metadata); | ||||
| @@ -975,6 +1046,7 @@ export default { | |||||
| d3.select('#graph svg').remove(); | d3.select('#graph svg').remove(); | ||||
| this.selectedNode.name = ''; | this.selectedNode.name = ''; | ||||
| this.dealGraphData(JSON.parse(JSON.stringify(graph.nodes))); | this.dealGraphData(JSON.parse(JSON.stringify(graph.nodes))); | ||||
| this.tableData = []; | |||||
| } | } | ||||
| }, | }, | ||||
| (err) => { | (err) => { | ||||
| @@ -1015,11 +1087,12 @@ export default { | |||||
| watch_nodes: watchNodes, | watch_nodes: watchNodes, | ||||
| mode: type ? 1 : 0, | mode: type ? 1 : 0, | ||||
| graph_name: this.graphFiles.value, | graph_name: this.graphFiles.value, | ||||
| rank_id: this.logicCard.value, | |||||
| }; | }; | ||||
| if (this.graphFiles.value === this.$t('debugger.all')) { | if (this.graphFiles.value === this.$t('debugger.all')) { | ||||
| delete params.graph_name; | delete params.graph_name; | ||||
| } | } | ||||
| RequestService.updateWatchpoint(params).then( | |||||
| RequestService.updateWatchpoint(params, this.sessionId).then( | |||||
| (res) => { | (res) => { | ||||
| this.defaultCheckedArr = this.$refs.tree.getCheckedKeys(); | this.defaultCheckedArr = this.$refs.tree.getCheckedKeys(); | ||||
| if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { | if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { | ||||
| @@ -1049,12 +1122,16 @@ export default { | |||||
| queryGraphByWatchpoint(id) { | queryGraphByWatchpoint(id) { | ||||
| const params = { | const params = { | ||||
| mode: 'watchpoint', | mode: 'watchpoint', | ||||
| params: {watch_point_id: id, graph_name: this.graphFiles.value}, | |||||
| params: { | |||||
| watch_point_id: id, | |||||
| graph_name: this.graphFiles.value, | |||||
| rank_id: this.logicCard.value, | |||||
| }, | |||||
| }; | }; | ||||
| if (this.graphFiles.value === this.$t('debugger.all')) { | if (this.graphFiles.value === this.$t('debugger.all')) { | ||||
| delete params.params.graph_name; | delete params.params.graph_name; | ||||
| } | } | ||||
| RequestService.retrieve(params).then( | |||||
| RequestService.retrieve(params, this.sessionId).then( | |||||
| (res) => { | (res) => { | ||||
| if (res.data && res.data.graph) { | if (res.data && res.data.graph) { | ||||
| const graph = res.data.graph; | const graph = res.data.graph; | ||||
| @@ -1306,11 +1383,12 @@ export default { | |||||
| level: 'node', | level: 'node', | ||||
| name: this.selectedNode.name.replace('_unfold', ''), | name: this.selectedNode.name.replace('_unfold', ''), | ||||
| graph_name: this.graphFiles.value, | graph_name: this.graphFiles.value, | ||||
| rank_id: this.logicCard.value, | |||||
| }; | }; | ||||
| if (this.graphFiles.value === this.$t('debugger.all')) { | if (this.graphFiles.value === this.$t('debugger.all')) { | ||||
| delete params.graph_name; | delete params.graph_name; | ||||
| } | } | ||||
| RequestService.control(params).then( | |||||
| RequestService.control(params, this.sessionId).then( | |||||
| (res) => { | (res) => { | ||||
| if (res && res.data) { | if (res && res.data) { | ||||
| } | } | ||||
| @@ -1387,12 +1465,13 @@ export default { | |||||
| node_type: type, | node_type: type, | ||||
| single_node: false, | single_node: false, | ||||
| graph_name: this.graphFiles.value, | graph_name: this.graphFiles.value, | ||||
| rank_id: this.logicCard.value, | |||||
| }; | }; | ||||
| if (this.graphFiles.value === this.$t('debugger.all')) { | if (this.graphFiles.value === this.$t('debugger.all')) { | ||||
| delete params.params.graph_name; | delete params.params.graph_name; | ||||
| } | } | ||||
| } | } | ||||
| RequestService.retrieve(params) | |||||
| RequestService.retrieve(params, this.sessionId) | |||||
| .then( | .then( | ||||
| (response) => { | (response) => { | ||||
| if (response && response.data && response.data.graph) { | if (response && response.data && response.data.graph) { | ||||
| @@ -1560,7 +1639,12 @@ export default { | |||||
| graphName = key.split('/')[0]; | graphName = key.split('/')[0]; | ||||
| key = key.replace(`${graphName}/`, ''); | key = key.replace(`${graphName}/`, ''); | ||||
| } | } | ||||
| const obj = {name: key, IOType: 'output', graph_name: graphName}; | |||||
| const obj = { | |||||
| name: key, | |||||
| IOType: 'output', | |||||
| graph_name: graphName, | |||||
| rank_id: this.logicCard.value, | |||||
| }; | |||||
| IOInfo.push(obj); | IOInfo.push(obj); | ||||
| this.selectedNode.outputNum++; | this.selectedNode.outputNum++; | ||||
| }); | }); | ||||
| @@ -1572,7 +1656,12 @@ export default { | |||||
| graphName = key.split('/')[0]; | graphName = key.split('/')[0]; | ||||
| key = key.replace(`${graphName}/`, ''); | key = key.replace(`${graphName}/`, ''); | ||||
| } | } | ||||
| const obj = {name: key, IOType: 'input', graph_name: graphName}; | |||||
| const obj = { | |||||
| name: key, | |||||
| IOType: 'input', | |||||
| graph_name: graphName, | |||||
| rank_id: this.logicCard.value, | |||||
| }; | |||||
| IOInfo.push(obj); | IOInfo.push(obj); | ||||
| this.selectedNode.inputNum++; | this.selectedNode.inputNum++; | ||||
| }); | }); | ||||
| @@ -1606,11 +1695,7 @@ export default { | |||||
| `translate(${this.graph.transform.x},` + `${this.graph.transform.y}) scale(${this.graph.transform.k})`, | `translate(${this.graph.transform.x},` + `${this.graph.transform.y}) scale(${this.graph.transform.k})`, | ||||
| ); | ); | ||||
| const transitionTime = Math.min( | |||||
| Math.abs(screenChange.x) * 2, | |||||
| Math.abs(screenChange.y) * 2, | |||||
| needDelay ? 800 : 0, | |||||
| ); | |||||
| const transitionTime = Math.min(Math.abs(screenChange.x) * 2, Math.abs(screenChange.y) * 2, needDelay ? 800 : 0); | |||||
| this.graph.dom.style.transition = `${transitionTime / 1000}s`; | this.graph.dom.style.transition = `${transitionTime / 1000}s`; | ||||
| this.graph.dom.style['transition-timing-function'] = 'linear'; | this.graph.dom.style['transition-timing-function'] = 'linear'; | ||||
| @@ -1829,8 +1914,8 @@ export default { | |||||
| height: calc(100% - 145px); | height: calc(100% - 145px); | ||||
| } | } | ||||
| .deb-wrap .left-wrap .left .content .node-type { | .deb-wrap .left-wrap .left .content .node-type { | ||||
| height: 50px; | |||||
| padding: 15px 15px 0 15px; | |||||
| height: 40px; | |||||
| padding: 10px 15px 0 15px; | |||||
| } | } | ||||
| .deb-wrap .left-wrap .left .content .node-type .label { | .deb-wrap .left-wrap .left .content .node-type .label { | ||||
| display: inline-block; | display: inline-block; | ||||
| @@ -1855,7 +1940,7 @@ export default { | |||||
| font-size: 12px; | font-size: 12px; | ||||
| } | } | ||||
| .deb-wrap .left-wrap .left .content .tree-wrap { | .deb-wrap .left-wrap .left .content .tree-wrap { | ||||
| height: calc(70% - 155px); | |||||
| height: calc(70% - 172px); | |||||
| overflow-y: auto; | overflow-y: auto; | ||||
| padding: 0 15px 15px; | padding: 0 15px 15px; | ||||
| position: relative; | position: relative; | ||||
| @@ -1973,12 +2058,13 @@ export default { | |||||
| color: red; | color: red; | ||||
| } | } | ||||
| .deb-wrap .left-wrap .left .content .hit-list-wrap { | .deb-wrap .left-wrap .left .content .hit-list-wrap { | ||||
| height: 100%; | |||||
| height: calc(100% - 40px); | |||||
| padding: 10px; | padding: 10px; | ||||
| } | } | ||||
| .deb-wrap .left-wrap .left .content .hit-list-wrap .watchpoint-table { | .deb-wrap .left-wrap .left .content .hit-list-wrap .watchpoint-table { | ||||
| max-height: calc(100% - 45px); | max-height: calc(100% - 45px); | ||||
| overflow: auto; | overflow: auto; | ||||
| margin-top: 10px; | |||||
| } | } | ||||
| .deb-wrap .left-wrap .left .content .hit-list-wrap .el-table::before { | .deb-wrap .left-wrap .left .content .hit-list-wrap .el-table::before { | ||||
| height: 0; | height: 0; | ||||
| @@ -2096,7 +2182,7 @@ export default { | |||||
| /* Opera */ | /* Opera */ | ||||
| } | } | ||||
| .deb-wrap .right .header { | .deb-wrap .right .header { | ||||
| padding: 15px; | |||||
| line-height: 51px; | |||||
| border-bottom: 1px solid #ebeef5; | border-bottom: 1px solid #ebeef5; | ||||
| position: relative; | position: relative; | ||||
| background: #fff; | background: #fff; | ||||
| @@ -2113,6 +2199,25 @@ export default { | |||||
| .deb-wrap .right .header .item + .item { | .deb-wrap .right .header .item + .item { | ||||
| margin-left: 15px; | margin-left: 15px; | ||||
| } | } | ||||
| .deb-wrap .right .header .el-icon-edit { | |||||
| margin-left: 5px; | |||||
| } | |||||
| .deb-wrap .right .header i { | |||||
| font-size: 18px; | |||||
| margin: 0 2px; | |||||
| color: #00a5a7; | |||||
| cursor: pointer; | |||||
| } | |||||
| .deb-wrap .right .header .el-icon-close { | |||||
| color: #f56c6c; | |||||
| } | |||||
| .deb-wrap .right .header .el-input { | |||||
| width: 45px; | |||||
| } | |||||
| .deb-wrap .right .header .el-input input { | |||||
| padding: 0; | |||||
| text-align: center; | |||||
| } | |||||
| .deb-wrap .right .header .tooltip { | .deb-wrap .right .header .tooltip { | ||||
| margin-left: 5px; | margin-left: 5px; | ||||
| cursor: pointer; | cursor: pointer; | ||||
| @@ -2343,13 +2448,13 @@ export default { | |||||
| display: none; | display: none; | ||||
| } | } | ||||
| .deb-wrap .creat-watch-point-dialog .conditions-container .collection { | .deb-wrap .creat-watch-point-dialog .conditions-container .collection { | ||||
| width: 200px; | |||||
| width: 210px; | |||||
| } | } | ||||
| .deb-wrap .creat-watch-point-dialog .conditions-container .condition, | .deb-wrap .creat-watch-point-dialog .conditions-container .condition, | ||||
| .deb-wrap .creat-watch-point-dialog .conditions-container .param, | .deb-wrap .creat-watch-point-dialog .conditions-container .param, | ||||
| .deb-wrap .creat-watch-point-dialog .conditions-container .param-value { | .deb-wrap .creat-watch-point-dialog .conditions-container .param-value { | ||||
| margin-left: 10px; | margin-left: 10px; | ||||
| width: 200px; | |||||
| width: 210px; | |||||
| } | } | ||||
| .deb-wrap .creat-watch-point-dialog .conditions-container .percent-sign { | .deb-wrap .creat-watch-point-dialog .conditions-container .percent-sign { | ||||
| display: inline-block; | display: inline-block; | ||||
| @@ -96,6 +96,16 @@ limitations under the License. | |||||
| :title="$t('summaryManage.disableProfilerTip')"> | :title="$t('summaryManage.disableProfilerTip')"> | ||||
| {{$t('summaryManage.viewProfiler')}} | {{$t('summaryManage.viewProfiler')}} | ||||
| </span> | </span> | ||||
| <span class="menu-item operate-btn" | |||||
| v-if="scope.row.viewOfflineDebugger" | |||||
| @contextmenu.prevent="rightClick(scope.row, $event, 2)" | |||||
| @click.stop="goToOfflineDebugger(scope.row)"> | |||||
| {{$t('summaryManage.viewOfflineDebugger')}} </span> | |||||
| <span class="menu-item operate-btn button-disable" | |||||
| v-else | |||||
| :title="$t('summaryManage.disableOfflineDebugger')"> | |||||
| {{$t('summaryManage.viewOfflineDebugger')}} | |||||
| </span> | |||||
| <span class="menu-item operate-btn" | <span class="menu-item operate-btn" | ||||
| v-if="scope.row.paramDetails" | v-if="scope.row.paramDetails" | ||||
| @click.stop="showModelDialog(scope.row)"> | @click.stop="showModelDialog(scope.row)"> | ||||
| @@ -157,6 +167,45 @@ limitations under the License. | |||||
| <li @click="doRightClick()">{{$t('summaryManage.openNewTab')}}</li> | <li @click="doRightClick()">{{$t('summaryManage.openNewTab')}}</li> | ||||
| </ul> | </ul> | ||||
| </div> | </div> | ||||
| <el-dialog :visible.sync="debuggerDialog.showDialogModel" | |||||
| width="50%" | |||||
| :close-on-click-modal="false" | |||||
| class="details-data-list"> | |||||
| <span slot="title"> | |||||
| <span class="sessionMsg">{{ debuggerDialog.title }}</span> | |||||
| <el-tooltip placement="right" | |||||
| effect="light" | |||||
| popper-class="legend-tip" | |||||
| :content="$t('summaryManage.sessionLimitNum')"> | |||||
| <i class="el-icon-info"></i> | |||||
| </el-tooltip> | |||||
| </span> | |||||
| <div class="session-title">{{ $t('summaryManage.sessionLists') }}</div> | |||||
| <el-table :data="debuggerDialog.trainJobs"> | |||||
| <el-table-column width="50" | |||||
| type=index | |||||
| :label="$t('summaryManage.sorting')"> | |||||
| </el-table-column> | |||||
| <el-table-column min-width="300" | |||||
| prop="relative_path" | |||||
| :label="$t('summaryManage.summaryPath')" | |||||
| show-overflow-tooltip> | |||||
| </el-table-column> | |||||
| <!-- operate --> | |||||
| <el-table-column prop="operate" | |||||
| :label="$t('summaryManage.operation')" | |||||
| class-name="operate-container"> | |||||
| <template slot-scope="scope"> | |||||
| <span class="menu-item operate-btn first-btn" | |||||
| @click="deleteSession(scope.row.session_id)"> | |||||
| {{$t('public.delete')}} </span> | |||||
| <span class="menu-item operate-btn first-btn" | |||||
| @click="viewSession(scope.row)"> | |||||
| {{$t('debugger.view')}} </span> | |||||
| </template> | |||||
| </el-table-column> | |||||
| </el-table> | |||||
| </el-dialog> | |||||
| </div> | </div> | ||||
| </template> | </template> | ||||
| @@ -223,7 +272,12 @@ export default { | |||||
| type: 0, | type: 0, | ||||
| }, | }, | ||||
| tableDom: null, | tableDom: null, | ||||
| operateWidth: localStorage.getItem('milang') === 'en-us' ? 400 : 290, | |||||
| operateWidth: localStorage.getItem('milang') === 'en-us' ? 550 : 400, | |||||
| debuggerDialog: { | |||||
| title: this.$t('summaryManage.sessionLimit'), | |||||
| showDialogModel: false, | |||||
| trainJobs: [], | |||||
| }, | |||||
| }; | }; | ||||
| }, | }, | ||||
| computed: {}, | computed: {}, | ||||
| @@ -286,6 +340,7 @@ export default { | |||||
| i.update_time = i.update_time ? i.update_time : '--'; | i.update_time = i.update_time ? i.update_time : '--'; | ||||
| i.viewProfiler = i.profiler_dir && i.profiler_dir.length; | i.viewProfiler = i.profiler_dir && i.profiler_dir.length; | ||||
| i.viewDashboard = i.summary_files || i.graph_files || i.lineage_files; | i.viewDashboard = i.summary_files || i.graph_files || i.lineage_files; | ||||
| i.viewOfflineDebugger = i.dump_dir; | |||||
| i.paramDetails = i.lineage_files; | i.paramDetails = i.lineage_files; | ||||
| }); | }); | ||||
| this.currentFolder = res.data.name ? res.data.name : '--'; | this.currentFolder = res.data.name ? res.data.name : '--'; | ||||
| @@ -363,7 +418,83 @@ export default { | |||||
| }, | }, | ||||
| }); | }); | ||||
| }, | }, | ||||
| /** | |||||
| * go to Offline Debugger | |||||
| * @param {Object} row select row | |||||
| */ | |||||
| goToOfflineDebugger(row) { | |||||
| this.contextMenu.show = false; | |||||
| const debuggerDir = row.dump_dir; | |||||
| const params = { | |||||
| session_type: 'OFFLINE', | |||||
| dump_dir: debuggerDir, | |||||
| }; | |||||
| this.getSessionId(params).then((value) => { | |||||
| if (value !== undefined) { | |||||
| this.$router.push({ | |||||
| path: '/offline-debugger', | |||||
| query: { | |||||
| dir: debuggerDir, | |||||
| sessionId: value, | |||||
| }, | |||||
| }); | |||||
| } | |||||
| }); | |||||
| }, | |||||
| getSessionId(params) { | |||||
| return RequestService.getSession(params).then( | |||||
| (res) => { | |||||
| if (res && res.data) { | |||||
| const sessionId = res.data; | |||||
| return sessionId; | |||||
| } | |||||
| }, | |||||
| (error) => { | |||||
| if (error && error.response && error.response.data && error.response.data.error_code === '5054B280') { | |||||
| this.checkSessions(); | |||||
| } | |||||
| }, | |||||
| ); | |||||
| }, | |||||
| deleteSession(sessionId) { | |||||
| this.$confirm(this.$t('summaryManage.deleteSessionConfirm'), this.$t('public.notice'), { | |||||
| confirmButtonText: this.$t('public.sure'), | |||||
| cancelButtonText: this.$t('public.cancel'), | |||||
| type: 'warning', | |||||
| }).then(() => { | |||||
| RequestService.deleteSession(sessionId).then((res) => { | |||||
| this.$message({ | |||||
| type: 'success', | |||||
| message: this.$t('summaryManage.deleteSessionSuccess'), | |||||
| }); | |||||
| this.checkSessions(); | |||||
| }); | |||||
| }); | |||||
| }, | |||||
| checkSessions() { | |||||
| RequestService.checkSessions().then((res) => { | |||||
| if (res && res.data && res.data.train_jobs) { | |||||
| const trainJobs = res.data.train_jobs; | |||||
| this.debuggerDialog.trainJobs = Object.keys(trainJobs).map((val) => { | |||||
| return { | |||||
| relative_path: decodeURIComponent(val), | |||||
| session_id: trainJobs[val], | |||||
| }; | |||||
| }); | |||||
| this.debuggerDialog.showDialogModel = true; | |||||
| } | |||||
| }); | |||||
| }, | |||||
| viewSession(row) { | |||||
| const dir = row.relative_path; | |||||
| this.$router.push({ | |||||
| path: '/offline-debugger', | |||||
| query: { | |||||
| dir, | |||||
| sessionId: row.session_id, | |||||
| }, | |||||
| }); | |||||
| }, | |||||
| rightClick(row, event, type) { | rightClick(row, event, type) { | ||||
| const maxWidth = 175; | const maxWidth = 175; | ||||
| this.contextMenu.data = row; | this.contextMenu.data = row; | ||||
| @@ -380,7 +511,28 @@ export default { | |||||
| if (!row) { | if (!row) { | ||||
| return; | return; | ||||
| } | } | ||||
| if (this.contextMenu.type) { | |||||
| if (this.contextMenu.type === 2) { | |||||
| // open offline debugger | |||||
| this.contextMenu.show = false; | |||||
| const debuggerDir = row.dump_dir; | |||||
| const params = { | |||||
| session_type: 'OFFLINE', | |||||
| dump_dir: debuggerDir, | |||||
| }; | |||||
| this.getSessionId(params).then((value) => { | |||||
| if (value !== undefined) { | |||||
| const routeUrl = this.$router.resolve({ | |||||
| path: '/offline-debugger', | |||||
| query: { | |||||
| dir: debuggerDir, | |||||
| sessionId: value, | |||||
| }, | |||||
| }); | |||||
| window.open(routeUrl.href, '_blank'); | |||||
| } | |||||
| }); | |||||
| } else if (this.contextMenu.type === 1) { | |||||
| // open profiling | |||||
| this.contextMenu.show = false; | this.contextMenu.show = false; | ||||
| const profilerDir = encodeURIComponent(row.profiler_dir); | const profilerDir = encodeURIComponent(row.profiler_dir); | ||||
| const trainId = encodeURIComponent(row.train_id); | const trainId = encodeURIComponent(row.train_id); | ||||
| @@ -400,7 +552,7 @@ export default { | |||||
| }, | }, | ||||
| }); | }); | ||||
| window.open(routeUrl.href, '_blank'); | window.open(routeUrl.href, '_blank'); | ||||
| } else { | |||||
| } else { // open training dashboard | |||||
| this.contextMenu.show = false; | this.contextMenu.show = false; | ||||
| const trainId = encodeURIComponent(row.train_id); | const trainId = encodeURIComponent(row.train_id); | ||||
| @@ -693,6 +845,16 @@ export default { | |||||
| #cl-summary-manage .details-data-list .el-dialog__body .details-data-title { | #cl-summary-manage .details-data-list .el-dialog__body .details-data-title { | ||||
| margin-bottom: 20px; | margin-bottom: 20px; | ||||
| } | } | ||||
| #cl-summary-manage .details-data-list .sessionMsg { | |||||
| color: #333; | |||||
| font-weight: bold; | |||||
| font-size: 16px; | |||||
| margin-right: 5px; | |||||
| } | |||||
| #cl-summary-manage .details-data-list .session-title { | |||||
| margin-bottom: 10px; | |||||
| color: #333; | |||||
| } | |||||
| #cl-summary-manage .is-disabled.custom-btn { | #cl-summary-manage .is-disabled.custom-btn { | ||||
| background-color: #f5f5f6; | background-color: #f5f5f6; | ||||
| border: 1px solid #dfe1e6 !important; | border: 1px solid #dfe1e6 !important; | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -12,15 +12,10 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Module init file.""" | |||||
| from mindinsight.backend.conditionmgr.conditionmgr_api import init_module as init_query_module | |||||
| """Train job register.""" | |||||
| def init_module(app): | |||||
| """ | |||||
| Init module entry. | |||||
| Args: | |||||
| app (Flask): A Flask instance. | |||||
| """ | |||||
| init_query_module(app) | |||||
| class FolderAnalyzer: | |||||
| """Train job register. The subclass should implement the analyze method and return update info.""" | |||||
| def analyze(self, entry, summary_base_dir, relative_path): | |||||
| """Analyze file.""" | |||||
| @@ -18,4 +18,6 @@ six>=1.12.0 | |||||
| Werkzeug>=1.0.0 | Werkzeug>=1.0.0 | ||||
| pandas>=1.0.4 | pandas>=1.0.4 | ||||
| yapf>=0.30.0 | yapf>=0.30.0 | ||||
| grpcio>=1.27.3 | |||||
| treelib>=1.6.1 | |||||
| grpcio>=1.27.3 | |||||
| XlsxWriter>=1.2.9 | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -25,13 +25,15 @@ from mindinsight.conf import settings | |||||
| from mindinsight.datavisual.utils import tools | from mindinsight.datavisual.utils import tools | ||||
| from mindinsight.debugger.proto import ms_graph_pb2 | from mindinsight.debugger.proto import ms_graph_pb2 | ||||
| from mindinsight.debugger.stream_handler.graph_handler import GraphHandler | from mindinsight.debugger.stream_handler.graph_handler import GraphHandler | ||||
| from mindinsight.debugger.session_manager import SessionManager | |||||
| GRAPH_PROTO_FILE = os.path.join( | GRAPH_PROTO_FILE = os.path.join( | ||||
| os.path.dirname(__file__), '../../../utils/resource/graph_pb/lenet.pb' | os.path.dirname(__file__), '../../../utils/resource/graph_pb/lenet.pb' | ||||
| ) | ) | ||||
| DEBUGGER_BASE_URL = '/v1/mindinsight/debugger' | |||||
| DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expect_results') | |||||
| DEBUGGER_BASE_URL = '/v1/mindinsight/debugger/sessions/0/' | |||||
| DEBUGGER_TEST_BASE_DIR = os.path.dirname(__file__) | |||||
| DEBUGGER_EXPECTED_RESULTS = os.path.join(DEBUGGER_TEST_BASE_DIR, 'expect_results') | |||||
| def init_graph_handler(): | def init_graph_handler(): | ||||
| @@ -51,14 +53,13 @@ def init_graph_handler(): | |||||
| @pytest.fixture(scope='session') | @pytest.fixture(scope='session') | ||||
| def app_client(): | def app_client(): | ||||
| """This fixture is flask server.""" | """This fixture is flask server.""" | ||||
| packages = ["mindinsight.backend.debugger", "mindinsight.backend.conditionmgr"] | |||||
| packages = ["mindinsight.backend.debugger"] | |||||
| settings.ENABLE_DEBUGGER = True | settings.ENABLE_DEBUGGER = True | ||||
| mock_obj = Mock(return_value=packages) | mock_obj = Mock(return_value=packages) | ||||
| tools.find_app_package = mock_obj | tools.find_app_package = mock_obj | ||||
| from mindinsight.backend.application import APP | from mindinsight.backend.application import APP | ||||
| from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER | |||||
| APP.response_class = Response | APP.response_class = Response | ||||
| client = APP.test_client() | client = APP.test_client() | ||||
| original_val = settings.ENABLE_RECOMMENDED_WATCHPOINTS | original_val = settings.ENABLE_RECOMMENDED_WATCHPOINTS | ||||
| @@ -67,4 +68,4 @@ def app_client(): | |||||
| yield client | yield client | ||||
| finally: | finally: | ||||
| settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_val | settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_val | ||||
| BACKEND_SERVER.stop() | |||||
| SessionManager.get_instance().online_session.stop() | |||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Function: | |||||
| Test debugger services. | |||||
| Usage: | |||||
| pytest tests/st/func/debugger | |||||
| """ | |||||
| @@ -0,0 +1,141 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| The module DbgServices provides offline debugger APIs. | |||||
| """ | |||||
| from unittest.mock import MagicMock | |||||
| import numpy as np | |||||
| import mindinsight | |||||
| from mindinsight.debugger.common.log import LOGGER as log | |||||
| def get_version(): | |||||
| """Get version.""" | |||||
| return mindinsight.__version__ | |||||
| class DbgServices: | |||||
| """ | |||||
| DbgServices. | |||||
| Args: | |||||
| dump_file_path (str): dir where the dump files are saved. | |||||
| """ | |||||
| def __init__(self, dump_file_path, verbose=True): | |||||
| self._verbose = verbose | |||||
| self.dump_file_path = dump_file_path | |||||
| self.dbg_instance = MagicMock() | |||||
| self._watchpoints = {} | |||||
| self.print_mes("in Python __init__, file path is {}".format(dump_file_path)) | |||||
| self.version = get_version() | |||||
| self.initialized = False | |||||
| self.is_sync = True | |||||
| def print_mes(self, mes): | |||||
| """Print message.""" | |||||
| if self._verbose: | |||||
| log.info(mes) | |||||
| def initialize(self, net_name, is_sync_mode): | |||||
| """Initialize.""" | |||||
| self.print_mes(" Python Initialize dump_file_path: {}, is_sync: {}".format(net_name, is_sync_mode)) | |||||
| self.initialized = True | |||||
| def add_watchpoint(self, watchpoint_id, watch_condition, check_node_list, parameter_list): | |||||
| """Add watchpoint.""" | |||||
| self.print_mes("Add watchpoint with watchpoint id: {}".format(watchpoint_id)) | |||||
| self._watchpoints[watchpoint_id] = {'watch_condition': watch_condition, | |||||
| 'check_nodes': check_node_list, | |||||
| 'parameter_list': parameter_list} | |||||
| return 0 | |||||
| def remove_watchpoint(self, watchpoint_id): | |||||
| """Remove watchpoints.""" | |||||
| self.print_mes("Remove watchpoint with watchpoint id: {}".format(watchpoint_id)) | |||||
| return self._watchpoints.pop(watchpoint_id) | |||||
| def check_watchpoints(self, iteration): | |||||
| """Check watchpoints.""" | |||||
| self.print_mes("Check watchpoint at iteration: {}".format(iteration)) | |||||
| watch_hits = [] | |||||
| for watchpoint_id, watchpoint in self._watchpoints.items(): | |||||
| # add param hit info | |||||
| for param in watchpoint.get('parameter_list'): | |||||
| param.hit = True | |||||
| param.value = 0.0 | |||||
| for watch_node_name, node_info in watchpoint.get('check_nodes'): | |||||
| for device_id in node_info.get('device_id'): | |||||
| hit = WatchpointHit(watch_node_name, | |||||
| 0, | |||||
| watchpoint.get('watch_condition'), | |||||
| watchpoint_id, | |||||
| watchpoint.get('parameter_list'), | |||||
| 0, | |||||
| device_id) | |||||
| watch_hits.append(hit) | |||||
| return watch_hits | |||||
| def read_tensors(self, info): | |||||
| """Read tensor values.""" | |||||
| value = np.asarray(list(range(12)), dtype=np.int32).tobytes() | |||||
| info_list_inst = [] | |||||
| for _ in range(info): | |||||
| tensor_data = TensorData(value, len(value), 4, [2, 2, 3]) | |||||
| info_list_inst.append(tensor_data) | |||||
| return info_list_inst | |||||
| class TensorInfo: | |||||
| """Tensor Information.""" | |||||
| def __init__(self, node_name, slot, iteration, device_id, is_parameter): | |||||
| self.node_name = node_name | |||||
| self.slot = slot | |||||
| self.iteration = iteration | |||||
| self.device_id = device_id | |||||
| self.is_parameter = is_parameter | |||||
| class TensorData: | |||||
| """Tensor data structure.""" | |||||
| def __init__(self, data_ptr, data_size, dtype, shape): | |||||
| self.data_ptr = data_ptr | |||||
| self.data_size = data_size | |||||
| self.dtype = dtype | |||||
| self.shape = shape | |||||
| class Parameter: | |||||
| """Parameter structure.""" | |||||
| def __init__(self, name, disabled, value, hit=False, actual_value=0.0): | |||||
| self.name = name | |||||
| self.disabled = disabled | |||||
| self.value = value | |||||
| self.hit = hit | |||||
| self.actual_value = actual_value | |||||
| class WatchpointHit: | |||||
| """Watchpoint hit structure.""" | |||||
| def __init__(self, name, slot, condition, watchpoint_id, parameters, error_code, device_id): | |||||
| self.name = name | |||||
| self.slot = slot | |||||
| self.condition = condition | |||||
| self.watchpoint_id = watchpoint_id | |||||
| self.parameters = parameters | |||||
| self.error_code = error_code | |||||
| self.device_id = device_id | |||||
| @@ -0,0 +1,77 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Function: | |||||
| Test query debugger services. | |||||
| Usage: | |||||
| pytest tests/st/func/debugger | |||||
| """ | |||||
| import os | |||||
| import shutil | |||||
| from unittest import mock | |||||
| import pytest | |||||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||||
| from mindinsight.debugger.debugger_services.debugger_server_factory import \ | |||||
| DebuggerServerFactory, DebuggerServerContext | |||||
| from tests.st.func.debugger.utils import build_dump_file_structure | |||||
| from tests.st.func.debugger.debugger_services import mock_dbg_services | |||||
| class TestDebuggerServerFactory: | |||||
| """Test debugger on Ascend backend.""" | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| """Setup class.""" | |||||
| cls.debugger_tmp_dir, cls.dump_files_dir = build_dump_file_structure() | |||||
| cls._dbg_dir = os.path.join(cls.dump_files_dir, 'Ascend/sync') | |||||
| cls._dbg_server_factory = DebuggerServerFactory() | |||||
| @classmethod | |||||
| def teardown_class(cls): | |||||
| """Run after test this class.""" | |||||
| if os.path.exists(cls.debugger_tmp_dir): | |||||
| shutil.rmtree(cls.debugger_tmp_dir) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_get_dbg_online_server(self): | |||||
| """Get debugger online server""" | |||||
| context = DebuggerServerContext(dbg_mode='online') | |||||
| server_obj = self._dbg_server_factory.get_debugger_server(DebuggerCache(), context) | |||||
| server_obj.start() | |||||
| server_obj.stop() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @mock.patch('mindinsight.debugger.debugger_services.debugger_offline_server.import_module') | |||||
| def test_get_dbg_offline_server(self, mock_import): | |||||
| """Get debugger offline server""" | |||||
| mock_import.return_value = mock_dbg_services | |||||
| context = DebuggerServerContext(dbg_mode='offline', dbg_dir=self._dbg_dir) | |||||
| server_obj = self._dbg_server_factory.get_debugger_server(DebuggerCache(), context) | |||||
| server_obj.start() | |||||
| server_obj.stop() | |||||
| @@ -0,0 +1,15 @@ | |||||
| { | |||||
| "common_dump_settings": { | |||||
| "dump_mode": 0, | |||||
| "path": "/absolute_path", | |||||
| "net_name": "Lenet", | |||||
| "iteration": 0, | |||||
| "input_output": 0, | |||||
| "kernels": ["Default/Conv-op12"], | |||||
| "support_device": [0,1,2,3,4,5,6,7] | |||||
| }, | |||||
| "async_dump_settings": { | |||||
| "enable": true, | |||||
| "op_debug_mode": 0 | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,23 @@ | |||||
| { | |||||
| "version": "1.0", | |||||
| "server_count": "1", | |||||
| "server_list": [ | |||||
| { | |||||
| "server_id": "0.0.0.0", | |||||
| "device": [ | |||||
| { | |||||
| "device_id": "0", | |||||
| "device_ip": "0.0.0.1", | |||||
| "rank_id": "0" | |||||
| }, | |||||
| { | |||||
| "device_id": "1", | |||||
| "device_ip": "0.0.0.2", | |||||
| "rank_id": "1" | |||||
| } | |||||
| ], | |||||
| "host_nic_ip": "reserve" | |||||
| } | |||||
| ], | |||||
| "status": "completed" | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| { | |||||
| "common_dump_settings": { | |||||
| "dump_mode": 0, | |||||
| "path": "/absolute_path", | |||||
| "net_name": "Lenet", | |||||
| "iteration": 0, | |||||
| "input_output": 0, | |||||
| "kernels": ["Default/Conv-op12"], | |||||
| "support_device": [0,1,2,3,4,5,6,7] | |||||
| }, | |||||
| "e2e_dump_settings": { | |||||
| "enable": true, | |||||
| "trans_flag": false | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,23 @@ | |||||
| { | |||||
| "version": "1.0", | |||||
| "server_count": "1", | |||||
| "server_list": [ | |||||
| { | |||||
| "server_id": "0.0.0.0", | |||||
| "device": [ | |||||
| { | |||||
| "device_id": "0", | |||||
| "device_ip": "0.0.0.1", | |||||
| "rank_id": "0" | |||||
| }, | |||||
| { | |||||
| "device_id": "1", | |||||
| "device_ip": "0.0.0.2", | |||||
| "rank_id": "1" | |||||
| } | |||||
| ], | |||||
| "host_nic_ip": "reserve" | |||||
| } | |||||
| ], | |||||
| "status": "completed" | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| { | |||||
| "common_dump_settings": { | |||||
| "dump_mode": 0, | |||||
| "path": "/absolute_path", | |||||
| "net_name": "Lenet", | |||||
| "iteration": 0, | |||||
| "input_output": 0, | |||||
| "kernels": ["Default/Conv-op12"], | |||||
| "support_device": [0,1,2,3,4,5,6,7] | |||||
| }, | |||||
| "e2e_dump_settings": { | |||||
| "enable": true, | |||||
| "trans_flag": false | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,21 @@ | |||||
| { | |||||
| "device_target": "Ascend", | |||||
| "server_list": [ | |||||
| { | |||||
| "server_id": "0.0.0.0", | |||||
| "device": [ | |||||
| { | |||||
| "device_id": "0", | |||||
| "device_ip": "0.0.0.1", | |||||
| "rank_id": "0" | |||||
| }, | |||||
| { | |||||
| "device_id": "1", | |||||
| "device_ip": "0.0.0.2", | |||||
| "rank_id": "1" | |||||
| } | |||||
| ], | |||||
| "host_nic_ip": "reserve" | |||||
| } | |||||
| ] | |||||
| } | |||||
| @@ -1 +1 @@ | |||||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} | |||||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0", "graph_1"]}], "watch_points": []} | |||||
| @@ -1 +1 @@ | |||||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} | |||||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0", "graph_1"]}], "watch_points": []} | |||||
| @@ -1,29 +1 @@ | |||||
| { | |||||
| "tensor_value": { | |||||
| "full_name": "Default/TransData-op99:0", | |||||
| "step": 1, | |||||
| "dtype": "DT_FLOAT32", | |||||
| "shape": [ | |||||
| 2, | |||||
| 3 | |||||
| ], | |||||
| "has_prev_step": false, | |||||
| "statistics": { | |||||
| "overall_max": 6.0, | |||||
| "overall_min": 1.0, | |||||
| "overall_avg": 3.5, | |||||
| "overall_count": 6, | |||||
| "overall_nan_count": 0, | |||||
| "overall_neg_inf_count": 0, | |||||
| "overall_pos_inf_count": 0, | |||||
| "overall_zero_count": 0.0, | |||||
| "overall_neg_zero_count": 0.0, | |||||
| "overall_pos_zero_count": 6.0 | |||||
| }, | |||||
| "value": [ | |||||
| 5.0, | |||||
| 6.0 | |||||
| ], | |||||
| "name": "Default/TransData-op99:0" | |||||
| } | |||||
| } | |||||
| {"tensor_value": {"full_name": "Default/TransData-op99:0", "step": 1, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": [5.0, 6.0], "statistics": {"overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, "overall_count": 6, "overall_nan_count": 0, "overall_neg_inf_count": 0, "overall_pos_inf_count": 0, "overall_zero_count": 0.0, "overall_neg_zero_count": 0.0, "overall_pos_zero_count": 6.0}, "name": "Default/TransData-op99:0"}} | |||||
| @@ -1 +1 @@ | |||||
| {"metadata": {"state": "mismatch", "step": 0, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.0.0"}}, "graph": {}, "watch_points": []} | |||||
| {"metadata": {"state": "mismatch", "step": 0, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": []}], "watch_points": []} | |||||
| @@ -0,0 +1,149 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Test DataLoader of offline debugger.""" | |||||
| import os | |||||
| import shutil | |||||
| import pytest | |||||
| from mindinsight.debugger.stream_cache.data_loader import DataLoader | |||||
| from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE | |||||
| from tests.st.func.debugger.utils import build_dump_file_structure | |||||
| from tests.utils.tools import compare_result_with_file, compare_result_with_binary_file | |||||
| class TestDataLoader: | |||||
| """Test DataLoader.""" | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| """Init TestDataLoader for DataLoader unittest.""" | |||||
| cls.debugger_tmp_dir, cls.dump_files_dir = build_dump_file_structure() | |||||
| cls.expected_results_dir = os.path.join(os.path.dirname(__file__), | |||||
| 'expect_results/offline_debugger') | |||||
| cls.dump_files_dir_ascend = os.path.join(cls.dump_files_dir, | |||||
| 'Ascend/sync') | |||||
| cls.data_loader_ascend = DataLoader(cls.dump_files_dir_ascend) | |||||
| cls.data_loader_ascend.initialize() | |||||
| cls.dump_files_dir_gpu = os.path.join(cls.dump_files_dir, | |||||
| 'GPU/sync') | |||||
| cls.data_loader_gpu = DataLoader(cls.dump_files_dir_gpu) | |||||
| cls.data_loader_gpu.initialize() | |||||
| cls.dump_files_dir_ascend_async = os.path.join(cls.dump_files_dir, | |||||
| 'Ascend/async') | |||||
| cls.data_loader_ascend_async = DataLoader(cls.dump_files_dir_ascend_async) | |||||
| cls.data_loader_ascend_async.initialize() | |||||
| @classmethod | |||||
| def teardown_class(cls): | |||||
| """Run after test this class.""" | |||||
| if os.path.exists(cls.debugger_tmp_dir): | |||||
| shutil.rmtree(cls.debugger_tmp_dir) | |||||
| @pytest.mark.level | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_load_graphs_ascend(self): | |||||
| """Test load_graphs function of offline-debugger.""" | |||||
| res = self.data_loader_ascend.load_graphs() | |||||
| expected_result0 = GRAPH_PROTO_FILE | |||||
| res0 = res[0]['graph_protos'][0].SerializeToString() | |||||
| compare_result_with_binary_file(res0, expected_result0) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_load_device_info_ascend(self): | |||||
| """Test load_device_info of ascend chip for offline-debugger.""" | |||||
| res = self.data_loader_ascend.load_device_info() | |||||
| expected_result = os.path.join(self.expected_results_dir, 'load_device_info_ascend.json') | |||||
| compare_result_with_file(res, expected_result) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_load_step_num_ascend(self): | |||||
| """Test load_step_num of ascend chip for offline-debugger.""" | |||||
| res = self.data_loader_ascend.load_step_number() | |||||
| expected_result = {"0": 4, "1": 4} | |||||
| assert res == expected_result | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_get_net_name_ascend(self): | |||||
| """Test get_net_name of ascend chip for offline-debugger.""" | |||||
| res = self.data_loader_ascend.get_net_name() | |||||
| assert res == 'Lenet' | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_get_sync_flag(self): | |||||
| """Test get_sync_flag of ascend chip for offline-debugger.""" | |||||
| res = self.data_loader_ascend.get_sync_flag() | |||||
| assert res | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_load_graphs_gpu(self): | |||||
| """Test load_graphs function of offline-debugger.""" | |||||
| res = self.data_loader_gpu.load_graphs() | |||||
| expected_result0 = GRAPH_PROTO_FILE | |||||
| res0 = res[0]['graph_protos'][0].SerializeToString() | |||||
| compare_result_with_binary_file(res0, expected_result0) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_load_step_num_gpu(self): | |||||
| """Test load_step_num of ascend chip for offline-debugger.""" | |||||
| res = self.data_loader_gpu.load_step_number() | |||||
| expected_result = {"0": 3, "1": 3} | |||||
| assert res == expected_result | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| def test_load_step_num_ascend_async(self): | |||||
| """Test load_step_num of ascend chip for offline-debugger.""" | |||||
| res = self.data_loader_ascend_async.load_step_number() | |||||
| expected_result = {"0": 3, "1": 3} | |||||
| assert res == expected_result | |||||
| @@ -84,7 +84,7 @@ class TestAscendDebugger: | |||||
| def test_get_conditions(self, app_client): | def test_get_conditions(self, app_client): | ||||
| """Test get conditions for ascend.""" | """Test get conditions for ascend.""" | ||||
| url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections' | |||||
| url = '/v1/mindinsight/debugger/sessions/0/condition-collections' | |||||
| body_data = {} | body_data = {} | ||||
| expect_file = 'get_conditions_for_ascend.json' | expect_file = 'get_conditions_for_ascend.json' | ||||
| with self._debugger_client.get_thread_instance(): | with self._debugger_client.get_thread_instance(): | ||||
| @@ -191,7 +191,7 @@ class TestAscendDebugger: | |||||
| check_state(app_client) | check_state(app_client) | ||||
| # prepare tensor value | # prepare tensor value | ||||
| url = 'tensor-history' | url = 'tensor-history' | ||||
| body_data = {'name': node_name} | |||||
| body_data = {'name': node_name, 'rank_id': 0} | |||||
| expect_file = 'retrieve_empty_tensor_history.json' | expect_file = 'retrieve_empty_tensor_history.json' | ||||
| send_and_compare_result(app_client, url, body_data, expect_file) | send_and_compare_result(app_client, url, body_data, expect_file) | ||||
| # check full tensor history from poll data | # check full tensor history from poll data | ||||
| @@ -229,7 +229,7 @@ class TestAscendDebugger: | |||||
| get_request_result(app_client, url, body_data) | get_request_result(app_client, url, body_data) | ||||
| check_state(app_client) | check_state(app_client) | ||||
| get_request_result( | get_request_result( | ||||
| app_client=app_client, url='tensor-history', body_data={'name': node_name}) | |||||
| app_client=app_client, url='tensor-history', body_data={'name': node_name, 'rank_id': 0}) | |||||
| res = get_request_result( | res = get_request_result( | ||||
| app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get') | app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get') | ||||
| assert res.get('receive_tensor', {}).get('node_name') == node_name | assert res.get('receive_tensor', {}).get('node_name') == node_name | ||||
| @@ -239,30 +239,12 @@ class TestAscendDebugger: | |||||
| 'name': node_name + ':0', | 'name': node_name + ':0', | ||||
| 'detail': 'data', | 'detail': 'data', | ||||
| 'shape': quote('[:, :]'), | 'shape': quote('[:, :]'), | ||||
| 'tolerance': 1 | |||||
| } | |||||
| 'tolerance': 1, | |||||
| 'rank_id': 0} | |||||
| expect_file = 'compare_tensors.json' | expect_file = 'compare_tensors.json' | ||||
| send_and_compare_result(app_client, url, body_data, expect_file, method='get') | send_and_compare_result(app_client, url, body_data, expect_file, method='get') | ||||
| send_terminate_cmd(app_client) | send_terminate_cmd(app_client) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.env_single | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.parametrize("body_data, expect_file", [ | |||||
| ({'ascend': True}, 'retrieve_node_by_bfs_ascend.json'), | |||||
| ({'name': 'Default/args0', 'ascend': False}, 'retrieve_node_by_bfs.json') | |||||
| ]) | |||||
| def test_retrieve_bfs_node(self, app_client, body_data, expect_file): | |||||
| """Test retrieve bfs node.""" | |||||
| with self._debugger_client.get_thread_instance(): | |||||
| check_state(app_client) | |||||
| # prepare tensor values | |||||
| url = 'retrieve_node_by_bfs' | |||||
| send_and_compare_result(app_client, url, body_data, expect_file, method='get') | |||||
| send_terminate_cmd(app_client) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.env_single | @pytest.mark.env_single | ||||
| @@ -441,7 +423,7 @@ class TestGPUDebugger: | |||||
| def test_get_conditions(self, app_client): | def test_get_conditions(self, app_client): | ||||
| """Test get conditions for gpu.""" | """Test get conditions for gpu.""" | ||||
| url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections' | |||||
| url = '/v1/mindinsight/debugger/sessions/0/condition-collections' | |||||
| body_data = {} | body_data = {} | ||||
| expect_file = 'get_conditions_for_gpu.json' | expect_file = 'get_conditions_for_gpu.json' | ||||
| with self._debugger_client.get_thread_instance(): | with self._debugger_client.get_thread_instance(): | ||||
| @@ -16,8 +16,10 @@ | |||||
| import json | import json | ||||
| import os | import os | ||||
| import time | import time | ||||
| from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL | |||||
| import shutil | |||||
| import tempfile | |||||
| from mindinsight.debugger.proto import ms_graph_pb2 | |||||
| from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL, GRAPH_PROTO_FILE | |||||
| from tests.utils.tools import compare_result_with_file, get_url | from tests.utils.tools import compare_result_with_file, get_url | ||||
| @@ -74,10 +76,57 @@ def send_and_save_result(app_client, url, body_data, file_path, method='post'): | |||||
| def delete_random_items(res): | def delete_random_items(res): | ||||
| """delete the random items in metadata.""" | """delete the random items in metadata.""" | ||||
| if isinstance(res, dict) and res.get('metadata'): | |||||
| if res['metadata'].get('ip'): | |||||
| res['metadata'].pop('ip') | |||||
| if res['metadata'].get('pos'): | |||||
| res['metadata'].pop('pos') | |||||
| if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'): | |||||
| res['metadata']['debugger_version'].pop('mi') | |||||
| if isinstance(res, dict): | |||||
| if res.get('metadata'): | |||||
| if res['metadata'].get('ip'): | |||||
| res['metadata'].pop('ip') | |||||
| if res['metadata'].get('pos'): | |||||
| res['metadata'].pop('pos') | |||||
| if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'): | |||||
| res['metadata']['debugger_version'].pop('mi') | |||||
| res['metadata']['debugger_version'].pop('ms') | |||||
| if res.get('devices'): | |||||
| for device in res.get('devices'): | |||||
| if device.get('server_ip'): | |||||
| device.pop('server_ip') | |||||
| def build_dump_file_structure(): | |||||
| """Build the dump file structure.""" | |||||
| async_file_structure = { | |||||
| "Ascend/async/device_0/Lenet_graph_1/1": 3, | |||||
| "Ascend/async/device_1/Lenet_graph_1/1": 3 | |||||
| } | |||||
| sync_file_structure = { | |||||
| "Ascend/sync/Lenet/device_0": 4, | |||||
| "Ascend/sync/Lenet/device_1": 4, | |||||
| "GPU/sync/Lenet/device_0": 3, | |||||
| "GPU/sync/Lenet/device_1": 3 | |||||
| } | |||||
| debugger_tmp_dir = tempfile.mkdtemp(suffix='debugger_tmp') | |||||
| dump_files_dir = os.path.join(debugger_tmp_dir, 'dump_files') | |||||
| shutil.copytree(os.path.join(os.path.dirname(__file__), 'dump_files'), dump_files_dir) | |||||
| for sub_dir, steps in async_file_structure.items(): | |||||
| for step in range(1, steps + 1): | |||||
| os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), str(step)), exist_ok=True) | |||||
| for sub_dir, steps in sync_file_structure.items(): | |||||
| for step in range(1, steps + 1): | |||||
| os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), 'iteration_' + str(step)), | |||||
| exist_ok=True) | |||||
| graph_dir_path = os.path.join(os.path.join(dump_files_dir, sub_dir), 'graphs') | |||||
| os.makedirs(graph_dir_path, exist_ok=True) | |||||
| graph_path = os.path.join(graph_dir_path, 'ms_output_trace_code_graph_0.pb') | |||||
| with open(GRAPH_PROTO_FILE, 'rb') as file_handler: | |||||
| content = file_handler.read() | |||||
| model = ms_graph_pb2.ModelProto() | |||||
| model.graph.ParseFromString(content) | |||||
| model_str = model.SerializeToString() | |||||
| with open(graph_path, 'wb') as file_handler: | |||||
| file_handler.write(model_str) | |||||
| return debugger_tmp_dir, dump_files_dir | |||||
| @@ -1 +1 @@ | |||||
| {"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "watch_points": []} | |||||
| {"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "devices": [{"rank_id": 0, "server_ip": "", "device_id": "", "graph_names": []}], "watch_points": []} | |||||
| @@ -1,5 +1,80 @@ | |||||
| [ | [ | ||||
| {"watchCondition": {"condition": "tensor_too_small", "value": 1.0, "params": [{"name": "abs_mean_lt", "disabled": true}, {"name": "max_lt", "value": 1.0}, {"name": "min_lt", "disabled": true}, {"name": "mean_lt", "disabled": true}]}, "id": 1, "watch_nodes_num": 0}, | |||||
| {"watchCondition": {"condition": "tensor_too_small", "value": 1.0, "params": [{"name": "abs_mean_lt", "disabled": true}, {"name": "max_lt", "disabled": true}, {"name": "min_lt", "value": 1.0}, {"name": "mean_lt", "disabled": true}]}, "id": 2, "watch_nodes_num": 172}, | |||||
| {"watchCondition": {"condition": "tensor_too_large", "value": 1.0, "params": [{"name": "abs_mean_gt", "disabled": true}, {"name": "max_gt", "value": 1.0}, {"name": "min_gt", "disabled": true}, {"name": "mean_gt", "disabled": true}]}, "id": 3, "watch_nodes_num": 1} | |||||
| { | |||||
| "watchCondition": { | |||||
| "condition": "tensor_too_small", | |||||
| "value": 1.0, | |||||
| "params": [ | |||||
| { | |||||
| "name": "abs_mean_lt", | |||||
| "disabled": true | |||||
| }, | |||||
| { | |||||
| "name": "max_lt", | |||||
| "value": 1.0 | |||||
| }, | |||||
| { | |||||
| "name": "min_lt", | |||||
| "disabled": true | |||||
| }, | |||||
| { | |||||
| "name": "mean_lt", | |||||
| "disabled": true | |||||
| } | |||||
| ] | |||||
| }, | |||||
| "id": 1, | |||||
| "watch_nodes_num": 0 | |||||
| }, | |||||
| { | |||||
| "watchCondition": { | |||||
| "condition": "tensor_too_small", | |||||
| "value": 1.0, | |||||
| "params": [ | |||||
| { | |||||
| "name": "abs_mean_lt", | |||||
| "disabled": true | |||||
| }, | |||||
| { | |||||
| "name": "max_lt", | |||||
| "disabled": true | |||||
| }, | |||||
| { | |||||
| "name": "min_lt", | |||||
| "value": 1.0 | |||||
| }, | |||||
| { | |||||
| "name": "mean_lt", | |||||
| "disabled": true | |||||
| } | |||||
| ] | |||||
| }, | |||||
| "id": 2, | |||||
| "watch_nodes_num": 142 | |||||
| }, | |||||
| { | |||||
| "watchCondition": { | |||||
| "condition": "tensor_too_large", | |||||
| "value": 1.0, | |||||
| "params": [ | |||||
| { | |||||
| "name": "abs_mean_gt", | |||||
| "disabled": true | |||||
| }, | |||||
| { | |||||
| "name": "max_gt", | |||||
| "value": 1.0 | |||||
| }, | |||||
| { | |||||
| "name": "min_gt", | |||||
| "disabled": true | |||||
| }, | |||||
| { | |||||
| "name": "mean_gt", | |||||
| "disabled": true | |||||
| } | |||||
| ] | |||||
| }, | |||||
| "id": 3, | |||||
| "watch_nodes_num": 1 | |||||
| } | |||||
| ] | ] | ||||
| @@ -111,20 +111,6 @@ class TestGraphHandler: | |||||
| node_name = self.graph_handler.get_node_name_by_full_name(full_name, 'kernel_graph_0') | node_name = self.graph_handler.get_node_name_by_full_name(full_name, 'kernel_graph_0') | ||||
| assert node_name == expect_node_name | assert node_name == expect_node_name | ||||
| @pytest.mark.parametrize("node_name, ascend, expect_next", [ | |||||
| (None, True, | |||||
| "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0"), | |||||
| (None, False, None), | |||||
| ("Default/tuple_getitem[10]_0/tuple_getitem-op206", True, | |||||
| "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89"), | |||||
| ("Default/tuple_getitem[10]_0/tuple_getitem-op206", False, | |||||
| "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205") | |||||
| ]) | |||||
| def test_get_node_by_bfs_order(self, node_name, ascend, expect_next): | |||||
| """Test get node by BFS order.""" | |||||
| next_node = self.graph_handler.get_node_by_bfs_order(node_name, ascend) | |||||
| assert next_node == expect_next | |||||
| @pytest.mark.parametrize("tensor_name, expect_file", [ | @pytest.mark.parametrize("tensor_name, expect_file", [ | ||||
| ("Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0:0", "get_tensor_graph-0.json"), | ("Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0:0", "get_tensor_graph-0.json"), | ||||
| ("Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89:1", "get_tensor_graph-1.json"), | ("Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89:1", "get_tensor_graph-1.json"), | ||||
| @@ -40,7 +40,7 @@ class TestTensorHandler: | |||||
| def test_get_tensor_value_by_name_none(self): | def test_get_tensor_value_by_name_none(self): | ||||
| """Test get_tensor_value_by_name.""" | """Test get_tensor_value_by_name.""" | ||||
| res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', True) | |||||
| res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', step=0, prev=True) | |||||
| assert res is None | assert res is None | ||||
| @mock.patch.object(log, "error") | @mock.patch.object(log, "error") | ||||
| @@ -49,5 +49,5 @@ class TestTensorHandler: | |||||
| """Test get_tensors_diff.""" | """Test get_tensors_diff.""" | ||||
| mock_error.return_value = None | mock_error.return_value = None | ||||
| with pytest.raises(DebuggerParamValueError) as ex: | with pytest.raises(DebuggerParamValueError) as ex: | ||||
| self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}) | |||||
| self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}, step=0) | |||||
| assert f"Get current step and previous step for this tensor name {tensor_name} failed." in str(ex.value) | assert f"Get current step and previous step for this tensor name {tensor_name} failed." in str(ex.value) | ||||
| @@ -30,6 +30,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue | |||||
| DebuggerParamTypeError | DebuggerParamTypeError | ||||
| from mindinsight.debugger.common.log import LOGGER as log | from mindinsight.debugger.common.log import LOGGER as log | ||||
| from mindinsight.debugger.stream_cache.watchpoint import Watchpoint | from mindinsight.debugger.stream_cache.watchpoint import Watchpoint | ||||
| from mindinsight.debugger.stream_handler import MultiCardGraphHandler | |||||
| from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHandler, \ | from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHandler, \ | ||||
| WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params | WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params | ||||
| from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \ | from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \ | ||||
| @@ -48,7 +49,9 @@ class TestWatchpointHandler: | |||||
| '../expected_results/watchpoint') | '../expected_results/watchpoint') | ||||
| cls.graph_results_dir = os.path.join(os.path.dirname(__file__), | cls.graph_results_dir = os.path.join(os.path.dirname(__file__), | ||||
| '../expected_results/graph') | '../expected_results/graph') | ||||
| cls.multi_graph_stream = MultiCardGraphHandler() | |||||
| cls.graph_stream = init_graph_handler() | cls.graph_stream = init_graph_handler() | ||||
| cls.multi_graph_stream.register_graph_handler(0, cls.graph_stream) | |||||
| cls.conditionmgr = None | cls.conditionmgr = None | ||||
| cls.handler = None | cls.handler = None | ||||
| @@ -69,7 +72,7 @@ class TestWatchpointHandler: | |||||
| ] | ] | ||||
| for watch_condition, watch_nodes, watch_point_id, expect_new_id in watchpoints: | for watch_condition, watch_nodes, watch_point_id, expect_new_id in watchpoints: | ||||
| watch_nodes = get_node_basic_infos(watch_nodes) | watch_nodes = get_node_basic_infos(watch_nodes) | ||||
| watch_point_id = self.handler.create_watchpoint(self.conditionmgr, watch_condition, watch_nodes, | |||||
| watch_point_id = self.handler.create_watchpoint(self.conditionmgr, watch_condition, {0: watch_nodes}, | |||||
| watch_point_id) | watch_point_id) | ||||
| assert watch_point_id == expect_new_id | assert watch_point_id == expect_new_id | ||||
| @@ -105,7 +108,7 @@ class TestWatchpointHandler: | |||||
| file_path = os.path.join(self.results_dir, result_file) | file_path = os.path.join(self.results_dir, result_file) | ||||
| with open(file_path, 'r') as file_handler: | with open(file_path, 'r') as file_handler: | ||||
| contents = json.load(file_handler) | contents = json.load(file_handler) | ||||
| protos = self.handler.get_pending_commands(self.graph_stream) | |||||
| protos = self.handler.get_pending_commands(self.multi_graph_stream) | |||||
| for proto in protos: | for proto in protos: | ||||
| msg_dict = json_format.MessageToDict(proto) | msg_dict = json_format.MessageToDict(proto) | ||||
| msg_dict['watch_nodes_num'] = len(msg_dict.pop('watchNodes', [])) | msg_dict['watch_nodes_num'] = len(msg_dict.pop('watchNodes', [])) | ||||
| @@ -48,7 +48,8 @@ class TestTrainingControlOperator: | |||||
| """Test validate leaf name.""" | """Test validate leaf name.""" | ||||
| args[0].return_value = 'name_scope' | args[0].return_value = 'name_scope' | ||||
| with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): | with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): | ||||
| self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') | |||||
| self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name', | |||||
| rank_id=0) | |||||
| @pytest.mark.parametrize('mode, cur_state, state', [ | @pytest.mark.parametrize('mode, cur_state, state', [ | ||||
| ('continue', 'waiting', 'sending'), | ('continue', 'waiting', 'sending'), | ||||
| @@ -64,3 +65,12 @@ class TestTrainingControlOperator: | |||||
| """Test construct run event.""" | """Test construct run event.""" | ||||
| res = self._server._construct_run_event({'level': 'node'}) | res = self._server._construct_run_event({'level': 'node'}) | ||||
| assert res.run_cmd == RunCMD(run_level='node', node_name='') | assert res.run_cmd == RunCMD(run_level='node', node_name='') | ||||
| @pytest.mark.parametrize('mode, state', [ | |||||
| ('reset', 'waiting')]) | |||||
| def test_control_reset_step(self, mode, state): | |||||
| """Test control request, in 'reset' mode.""" | |||||
| with mock.patch.object(MetadataHandler, 'max_step_num', 10), \ | |||||
| mock.patch.object(MetadataHandler, 'debugger_type', 'offline'): | |||||
| res = self._server.control(mode=mode, params={'steps': 9}) | |||||
| assert res == {'metadata': {'enable_recheck': False, 'state': state, 'step': 9}} | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -26,7 +26,7 @@ import numpy as np | |||||
| from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | ||||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus | from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus | ||||
| from mindinsight.debugger.debugger_cache import DebuggerCache | from mindinsight.debugger.debugger_cache import DebuggerCache | ||||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||||
| from mindinsight.debugger.debugger_services.debugger_grpc_server import DebuggerGrpcServer | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit | from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit | ||||
| from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType | from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType | ||||
| from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \ | from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \ | ||||
| @@ -30,11 +30,11 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue | |||||
| DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError | DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError | ||||
| from mindinsight.debugger.common.utils import Streams | from mindinsight.debugger.common.utils import Streams | ||||
| from mindinsight.debugger.debugger_cache import DebuggerCache | from mindinsight.debugger.debugger_cache import DebuggerCache | ||||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||||
| from mindinsight.debugger.debugger_server import grpc_server_base | |||||
| from mindinsight.debugger.stream_operator import watchpoint_operator | |||||
| from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext | |||||
| from mindinsight.debugger.debugger_session import DebuggerSession as DebuggerServer | |||||
| from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | ||||
| TensorHandler | TensorHandler | ||||
| from mindinsight.debugger.stream_operator import watchpoint_operator | |||||
| from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | ||||
| @@ -48,12 +48,12 @@ class TestDebuggerServer: | |||||
| def setup_method(self): | def setup_method(self): | ||||
| """Prepare debugger server object.""" | """Prepare debugger server object.""" | ||||
| self._server = DebuggerServer() | |||||
| context = DebuggerServerContext(dbg_mode='online') | |||||
| self._server = DebuggerServer(context) | |||||
| @mock.patch.object(signal, 'signal') | @mock.patch.object(signal, 'signal') | ||||
| @mock.patch.object(Thread, 'join') | @mock.patch.object(Thread, 'join') | ||||
| @mock.patch.object(Thread, 'start') | @mock.patch.object(Thread, 'start') | ||||
| @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server') | |||||
| @mock.patch.object(grpc, 'server') | @mock.patch.object(grpc, 'server') | ||||
| def test_stop_server(self, *args): | def test_stop_server(self, *args): | ||||
| """Test stop debugger server.""" | """Test stop debugger server.""" | ||||
| @@ -62,7 +62,6 @@ class TestDebuggerServer: | |||||
| self._server.start() | self._server.start() | ||||
| self._server._stop_handler(MagicMock(), MagicMock()) | self._server._stop_handler(MagicMock(), MagicMock()) | ||||
| assert self._server.back_server is not None | assert self._server.back_server is not None | ||||
| assert self._server.grpc_server_manager == mock_grpc_server_manager | |||||
| @mock.patch.object(DebuggerCache, 'get_data') | @mock.patch.object(DebuggerCache, 'get_data') | ||||
| def test_poll_data(self, *args): | def test_poll_data(self, *args): | ||||
| @@ -186,7 +185,6 @@ class TestDebuggerServer: | |||||
| self._server.create_watchpoint({'watch_condition': {'id': 'inf'}}) | self._server.create_watchpoint({'watch_condition': {'id': 'inf'}}) | ||||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | @mock.patch.object(MetadataHandler, 'state', 'waiting') | ||||
| @mock.patch.object(MetadataHandler, 'backend', 'GPU') | |||||
| @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) | @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) | ||||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') | @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') | ||||
| @mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock()) | @mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock()) | ||||
| @@ -194,6 +192,7 @@ class TestDebuggerServer: | |||||
| def test_create_watchpoint(self, *args): | def test_create_watchpoint(self, *args): | ||||
| """Test create watchpoint.""" | """Test create watchpoint.""" | ||||
| args[0].return_value = 1 | args[0].return_value = 1 | ||||
| self._server.cache_store.get_stream_handler((Streams.METADATA)).backend = 'GPU' | |||||
| res = self._server.create_watchpoint( | res = self._server.create_watchpoint( | ||||
| {'watch_condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, | {'watch_condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, | ||||
| 'watch_nodes': ['watch_node_name']}) | 'watch_nodes': ['watch_node_name']}) | ||||
| @@ -68,6 +68,13 @@ def compare_result_with_file(result, expected_file_path): | |||||
| assert result == expected_results | assert result == expected_results | ||||
| def compare_result_with_binary_file(result, expected_file_path): | |||||
| """Compare result with binary file which contain the expected results.""" | |||||
| with open(expected_file_path, 'rb') as file: | |||||
| expected_results = file.read() | |||||
| assert result == expected_results | |||||
| def deal_float_for_dict(res: dict, expected_res: dict, decimal_num): | def deal_float_for_dict(res: dict, expected_res: dict, decimal_num): | ||||
| """ | """ | ||||
| Deal float rounded to specified decimals in dict. | Deal float rounded to specified decimals in dict. | ||||