| @@ -1,61 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Conditionmgr restful api.""" | |||
| import json | |||
| from flask import Blueprint, request | |||
| from mindinsight.conf import settings | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER, _wrap_reply | |||
| BLUEPRINT = Blueprint("conditionmgr", __name__, | |||
| url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) | |||
| @BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/condition-collections", methods=["GET"]) | |||
| def get_condition_collections(train_id): | |||
| """get condition collections""" | |||
| reply = _wrap_reply(BACKEND_SERVER.get_condition_collections, train_id) | |||
| return reply | |||
| @BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/set-recommended-watch-points", methods=["POST"]) | |||
| def set_recommended_watch_points(train_id): | |||
| """set recommended watch points.""" | |||
| body = request.stream.read() | |||
| try: | |||
| body = json.loads(body if body else "{}") | |||
| except json.JSONDecodeError: | |||
| raise ParamValueError("Json data parse failed.") | |||
| request_body = body.get('requestBody') | |||
| if request_body is None: | |||
| raise ParamMissError('requestBody') | |||
| set_recommended = request_body.get('set_recommended') | |||
| reply = _wrap_reply(BACKEND_SERVER.set_recommended_watch_points, set_recommended, train_id) | |||
| return reply | |||
| def init_module(app): | |||
| """ | |||
| Init module entry. | |||
| Args: | |||
| app (Flask): The application obj. | |||
| """ | |||
| app.register_blueprint(BLUEPRINT) | |||
| @@ -26,6 +26,7 @@ import psutil | |||
| import gunicorn | |||
| from mindinsight.utils.computing_resource_mgr import terminate | |||
| from mindinsight.debugger.session_manager import SessionManager | |||
| gunicorn.SERVER_SOFTWARE = 'unknown' | |||
| @@ -110,4 +111,5 @@ def worker_int(worker): | |||
| global LISTEN_PROCESS | |||
| if LISTEN_PROCESS is not None: | |||
| LISTEN_PROCESS.terminate() | |||
| SessionManager.get_instance().exit() | |||
| worker.log.info("Worker int processed.") | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -19,6 +19,11 @@ from mindinsight.conf import settings | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER | |||
| from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater | |||
| from mindinsight.debugger.debugger_folder_analyzer import DebuggerFolderAnalyzer | |||
| ANALYZERS = { | |||
| "debugger_folder_analyzer": DebuggerFolderAnalyzer() | |||
| } | |||
| def init_module(app): | |||
| @@ -31,6 +36,8 @@ def init_module(app): | |||
| """ | |||
| # Just to suppress pylint warning about unused arg. | |||
| logger.debug("App: %s", type(app)) | |||
| for analyzer in ANALYZERS.values(): | |||
| DATA_MANAGER.register_folder_analyzer(analyzer) | |||
| DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater()) | |||
| # Let gunicorn load other modules first. | |||
| time.sleep(1) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -19,22 +19,13 @@ from urllib.parse import unquote | |||
| from flask import Blueprint, jsonify, request | |||
| from mindinsight.conf import settings | |||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.debugger.session_manager import SessionManager | |||
| from mindinsight.utils.exceptions import ParamMissError, ParamValueError | |||
| BLUEPRINT = Blueprint("debugger", __name__, | |||
| url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) | |||
| def _initialize_debugger_server(): | |||
| """Initialize a debugger server instance.""" | |||
| enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | |||
| server = None | |||
| if enable_debugger: | |||
| server = DebuggerServer() | |||
| return server | |||
| def _unquote_param(param): | |||
| """ | |||
| Decode parameter value. | |||
| @@ -77,8 +68,8 @@ def _wrap_reply(func, *args, **kwargs): | |||
| return jsonify(reply) | |||
| @BLUEPRINT.route("/debugger/poll-data", methods=["GET"]) | |||
| def poll_data(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/poll-data", methods=["GET"]) | |||
| def poll_data(session_id): | |||
| """ | |||
| Wait for data to be updated on UI. | |||
| @@ -88,17 +79,17 @@ def poll_data(): | |||
| str, the updated data. | |||
| Examples: | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/poll-data?pos=xx | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/poll-data?pos=xx | |||
| """ | |||
| pos = request.args.get('pos') | |||
| reply = _wrap_reply(BACKEND_SERVER.poll_data, pos) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).poll_data, pos) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/search", methods=["GET"]) | |||
| def search(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/search", methods=["GET"]) | |||
| def search(session_id): | |||
| """ | |||
| Search nodes in specified watchpoint. | |||
| @@ -106,42 +97,25 @@ def search(): | |||
| str, the required data. | |||
| Examples: | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/search?name=mock_name&watch_point_id=1 | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search?name=mock_name&watch_point_id=1 | |||
| """ | |||
| name = request.args.get('name') | |||
| graph_name = request.args.get('graph_name') | |||
| watch_point_id = int(request.args.get('watch_point_id', 0)) | |||
| node_category = request.args.get('node_category') | |||
| reply = _wrap_reply(BACKEND_SERVER.search, {'name': name, | |||
| 'graph_name': graph_name, | |||
| 'watch_point_id': watch_point_id, | |||
| 'node_category': node_category}) | |||
| rank_id = int(request.args.get('rank_id', 0)) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search, | |||
| {'name': name, | |||
| 'graph_name': graph_name, | |||
| 'watch_point_id': watch_point_id, | |||
| 'node_category': node_category, | |||
| 'rand_id': rank_id}) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"]) | |||
| def retrieve_node_by_bfs(): | |||
| """ | |||
| Search node by bfs. | |||
| Returns: | |||
| str, the required data. | |||
| Examples: | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true | |||
| """ | |||
| name = request.args.get('name') | |||
| graph_name = request.args.get('graph_name') | |||
| ascend = request.args.get('ascend', 'false') | |||
| ascend = ascend == 'true' | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, graph_name, ascend) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"]) | |||
| def tensor_comparisons(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-comparisons", methods=["GET"]) | |||
| def tensor_comparisons(session_id): | |||
| """ | |||
| Get tensor comparisons. | |||
| @@ -149,19 +123,21 @@ def tensor_comparisons(): | |||
| str, the required data. | |||
| Examples: | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-comparisons | |||
| """ | |||
| name = request.args.get('name') | |||
| detail = request.args.get('detail', 'data') | |||
| shape = _unquote_param(request.args.get('shape')) | |||
| tolerance = request.args.get('tolerance', '0') | |||
| reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) | |||
| rank_id = int(request.args.get('rank_id', 0)) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).tensor_comparisons, name, shape, detail, | |||
| tolerance, rank_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/retrieve", methods=["POST"]) | |||
| def retrieve(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/retrieve", methods=["POST"]) | |||
| def retrieve(session_id): | |||
| """ | |||
| Retrieve data according to mode and params. | |||
| @@ -169,17 +145,17 @@ def retrieve(): | |||
| str, the required data. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/retrieve | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/retrieve | |||
| """ | |||
| body = _read_post_request(request) | |||
| mode = body.get('mode') | |||
| params = body.get('params') | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve, mode, params) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/tensor-history", methods=["POST"]) | |||
| def retrieve_tensor_history(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-history", methods=["POST"]) | |||
| def retrieve_tensor_history(session_id): | |||
| """ | |||
| Retrieve data according to mode and params. | |||
| @@ -187,17 +163,19 @@ def retrieve_tensor_history(): | |||
| str, the required data. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/tensor-history | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-history | |||
| """ | |||
| body = _read_post_request(request) | |||
| name = body.get('name') | |||
| graph_name = body.get('graph_name') | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name, graph_name) | |||
| rank_id = body.get('rank_id') | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_history, name, graph_name, | |||
| rank_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/tensors", methods=["GET"]) | |||
| def retrieve_tensor_value(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensors", methods=["GET"]) | |||
| def retrieve_tensor_value(session_id): | |||
| """ | |||
| Retrieve tensor value according to name and shape. | |||
| @@ -205,20 +183,22 @@ def retrieve_tensor_value(): | |||
| str, the required data. | |||
| Examples: | |||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=tensor_name&detail=data&shape=[1,1,:,:] | |||
| >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensors?name=tensor_name&detail=data&shape=[1,1,:,:] | |||
| """ | |||
| name = request.args.get('name') | |||
| detail = request.args.get('detail') | |||
| shape = _unquote_param(request.args.get('shape')) | |||
| graph_name = request.args.get('graph_name') | |||
| prev = bool(request.args.get('prev') == 'true') | |||
| rank_id = int(request.args.get('rank_id', 0)) | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape, graph_name, prev) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_value, name, detail, | |||
| shape, graph_name, prev, rank_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/create-watchpoint", methods=["POST"]) | |||
| def create_watchpoint(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/create-watchpoint", methods=["POST"]) | |||
| def create_watchpoint(session_id): | |||
| """ | |||
| Create watchpoint. | |||
| @@ -229,16 +209,16 @@ def create_watchpoint(): | |||
| MindInsightException: If method fails to be called. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/create-watchpoint | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/create-watchpoint | |||
| """ | |||
| params = _read_post_request(request) | |||
| params['watch_condition'] = params.pop('condition', None) | |||
| reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, params) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).create_watchpoint, params) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/update-watchpoint", methods=["POST"]) | |||
| def update_watchpoint(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/update-watchpoint", methods=["POST"]) | |||
| def update_watchpoint(session_id): | |||
| """ | |||
| Update watchpoint. | |||
| @@ -249,17 +229,17 @@ def update_watchpoint(): | |||
| MindInsightException: If method fails to be called. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/update-watchpoint | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/update-watchpoint | |||
| """ | |||
| params = _read_post_request(request) | |||
| reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, params) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).update_watchpoint, params) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/delete-watchpoint", methods=["POST"]) | |||
| def delete_watchpoint(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/delete-watchpoint", methods=["POST"]) | |||
| def delete_watchpoint(session_id): | |||
| """ | |||
| delete watchpoint. | |||
| Delete watchpoint. | |||
| Returns: | |||
| str, reply message. | |||
| @@ -268,19 +248,19 @@ def delete_watchpoint(): | |||
| MindInsightException: If method fails to be called. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/delete-watchpoint | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/delete-watchpoint | |||
| """ | |||
| body = _read_post_request(request) | |||
| watch_point_id = body.get('watch_point_id') | |||
| reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).delete_watchpoint, watch_point_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/control", methods=["POST"]) | |||
| def control(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/control", methods=["POST"]) | |||
| def control(session_id): | |||
| """ | |||
| Control request. | |||
| @@ -291,16 +271,16 @@ def control(): | |||
| MindInsightException: If method fails to be called. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/control | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/control | |||
| """ | |||
| params = _read_post_request(request) | |||
| reply = _wrap_reply(BACKEND_SERVER.control, params) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).control, params) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/recheck", methods=["POST"]) | |||
| def recheck(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/recheck", methods=["POST"]) | |||
| def recheck(session_id): | |||
| """ | |||
| Recheck request. | |||
| @@ -311,15 +291,15 @@ def recheck(): | |||
| MindInsightException: If method fails to be called. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/recheck | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/recheck | |||
| """ | |||
| reply = _wrap_reply(BACKEND_SERVER.recheck) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).recheck) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"]) | |||
| def retrieve_tensor_graph(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-graphs", methods=["GET"]) | |||
| def retrieve_tensor_graph(session_id): | |||
| """ | |||
| Retrieve tensor value according to name and shape. | |||
| @@ -327,16 +307,18 @@ def retrieve_tensor_graph(): | |||
| str, the required data. | |||
| Examples: | |||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name | |||
| >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-graphs?tensor_name=xxx&graph_name=xxx | |||
| """ | |||
| tensor_name = request.args.get('tensor_name') | |||
| graph_name = request.args.get('graph_name') | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_graph, tensor_name, graph_name) | |||
| rank_id = int(request.args.get('rank_id', 0)) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_graph, tensor_name, | |||
| graph_name, rank_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"]) | |||
| def retrieve_tensor_hits(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-hits", methods=["GET"]) | |||
| def retrieve_tensor_hits(session_id): | |||
| """ | |||
| Retrieve tensor value according to name and shape. | |||
| @@ -344,16 +326,18 @@ def retrieve_tensor_hits(): | |||
| str, the required data. | |||
| Examples: | |||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name | |||
| >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-hits?tensor_name=xxx&graph_name=xxx | |||
| """ | |||
| tensor_name = request.args.get('tensor_name') | |||
| graph_name = request.args.get('graph_name') | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_hits, tensor_name, graph_name) | |||
| rank_id = int(request.args.get('rank_id', 0)) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_hits, tensor_name, | |||
| graph_name, rank_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/search-watchpoint-hits", methods=["POST"]) | |||
| def search_watchpoint_hits(): | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/search-watchpoint-hits", methods=["POST"]) | |||
| def search_watchpoint_hits(session_id): | |||
| """ | |||
| Search watchpoint hits by group condition. | |||
| @@ -361,15 +345,75 @@ def search_watchpoint_hits(): | |||
| str, the required data. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/search-watchpoint-hits | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search-watchpoint-hits | |||
| """ | |||
| body = _read_post_request(request) | |||
| group_condition = body.get('group_condition') | |||
| reply = _wrap_reply(BACKEND_SERVER.search_watchpoint_hits, group_condition) | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search_watchpoint_hits, group_condition) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/condition-collections", methods=["GET"]) | |||
| def get_condition_collections(session_id): | |||
| """Get condition collections.""" | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).get_condition_collections) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/set-recommended-watch-points", methods=["POST"]) | |||
| def set_recommended_watch_points(session_id): | |||
| """Set recommended watch points.""" | |||
| body = _read_post_request(request) | |||
| request_body = body.get('requestBody') | |||
| if request_body is None: | |||
| raise ParamMissError('requestBody') | |||
| set_recommended = request_body.get('set_recommended') | |||
| reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).set_recommended_watch_points, | |||
| set_recommended) | |||
| return reply | |||
| BACKEND_SERVER = _initialize_debugger_server() | |||
| @BLUEPRINT.route("/debugger/sessions", methods=["POST"]) | |||
| def creat_session(): | |||
| """ | |||
| Get session id if session exist, else create a session. | |||
| Returns: | |||
| str, session id. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/get-session | |||
| """ | |||
| body = _read_post_request(request) | |||
| summary_dir = body.get('dump_dir') | |||
| session_type = body.get('session_type') | |||
| reply = _wrap_reply(SessionManager.get_instance().creat_session, session_type, summary_dir) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/sessions", methods=["GET"]) | |||
| def get_sessions(): | |||
| """ | |||
| Check the cuurent active sessions. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/check-sessions | |||
| """ | |||
| reply = _wrap_reply(SessionManager.get_instance().get_sessions) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/sessions/<session_id>/delete", methods=["POST"]) | |||
| def delete_session(session_id): | |||
| """ | |||
| Delete session by session id. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/xxx/delete-session | |||
| """ | |||
| reply = _wrap_reply(SessionManager.get_instance().delete_session, session_id) | |||
| return reply | |||
| def init_module(app): | |||
| @@ -380,5 +424,3 @@ def init_module(app): | |||
| app (Flask): The application obj. | |||
| """ | |||
| app.register_blueprint(BLUEPRINT) | |||
| if BACKEND_SERVER: | |||
| BACKEND_SERVER.start() | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -112,6 +112,11 @@ class _BasicTrainJob: | |||
| """Get the lineage files count in the summary dir.""" | |||
| return self._entry['lineage_files'] | |||
| @property | |||
| def dump_dir(self): | |||
| """Get the dump file path in the summary dir.""" | |||
| return self._entry.get('dump_dir', None) | |||
| class CachedTrainJob: | |||
| """ | |||
| @@ -369,6 +374,10 @@ class _BaseCacheManager: | |||
| class _BriefCacheManager(_BaseCacheManager): | |||
| """A cache manager that holds all disk train jobs on disk.""" | |||
| def __init__(self, summary_base_dir): | |||
| super(_BriefCacheManager, self).__init__(summary_base_dir) | |||
| self._summary_watcher = SummaryWatcher() | |||
| def cache_train_job(self, train_id): | |||
| """ | |||
| Cache given train job. | |||
| @@ -386,7 +395,7 @@ class _BriefCacheManager(_BaseCacheManager): | |||
| def update_cache(self, executor): | |||
| """Update cache.""" | |||
| logger.info('Start to update BriefCacheManager.') | |||
| summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir) | |||
| summaries_info = self._summary_watcher.list_summary_directories(self._summary_base_dir) | |||
| basic_train_jobs = [] | |||
| for info in summaries_info: | |||
| @@ -425,6 +434,10 @@ class _BriefCacheManager(_BaseCacheManager): | |||
| return new_cache_items | |||
| def register_folder_analyzer(self, analyzer): | |||
| """Register folder analyzer.""" | |||
| self._summary_watcher.register_folder_analyzer(analyzer) | |||
| @property | |||
| def cache_items(self): | |||
| """Get cache items.""" | |||
| @@ -1028,6 +1041,10 @@ class DataManager: | |||
| """Register brief cache item updater for brief cache manager.""" | |||
| self._brief_cache.register_cache_item_updater(updater) | |||
| def register_folder_analyzer(self, analyzer): | |||
| """Register folder analyzer.""" | |||
| self._brief_cache.register_folder_analyzer(analyzer) | |||
| def get_brief_cache(self): | |||
| """Get brief cache.""" | |||
| return self._brief_cache | |||
| @@ -254,22 +254,24 @@ class MSGraph(Graph): | |||
| return searched_list | |||
| def search_leaf_nodes_by_pattern(self, pattern): | |||
| def search_leaf_nodes_by_pattern(self, pattern, scope_pattern=False): | |||
| """ | |||
| Search leaf node by a given pattern. | |||
| Args: | |||
| pattern (Union[str, None]): The pattern of the node to search, | |||
| if None, return all node names. | |||
| scope_pattern (bool): If true, return the children nodes of the scope. Default: False. | |||
| Returns: | |||
| list[Node], a list of nodes. | |||
| """ | |||
| is_match = lambda x, y: x.lower().startswith(y) if scope_pattern else y in x.lower() | |||
| if pattern is not None: | |||
| pattern = pattern.lower() | |||
| searched_nodes = [ | |||
| node for name, node in self._leaf_nodes.items() | |||
| if pattern in name.lower() | |||
| if is_match(name, pattern) | |||
| ] | |||
| else: | |||
| searched_nodes = [node for node in self._leaf_nodes.values()] | |||
| @@ -29,6 +29,7 @@ from mindinsight.utils.exceptions import FileSystemPermissionError | |||
| LINEAGE_SUMMARY_SUFFIX = '_lineage' | |||
| EXPLAIN_SUMMARY_SUFFIX = '_explain' | |||
| DUMP_FILE_PREFIX = 'dump_' | |||
| class SummaryWatcher: | |||
| @@ -45,6 +46,13 @@ class SummaryWatcher: | |||
| # to avoid long-time blocking | |||
| MAX_SCAN_COUNT = 20000 | |||
| def __init__(self): | |||
| self._analyzers = [] | |||
| def register_folder_analyzer(self, analyzer): | |||
| """Register folder analyzer.""" | |||
| self._analyzers.append(analyzer) | |||
| def list_summary_directories(self, summary_base_dir, overall=True, list_explain=False): | |||
| """ | |||
| List summary directories within base directory. | |||
| @@ -104,7 +112,7 @@ class SummaryWatcher: | |||
| elif entry.is_dir(): | |||
| self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry, list_explain) | |||
| entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) | |||
| self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry.name, counter, list_explain) | |||
| self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry, counter, list_explain) | |||
| directories = [] | |||
| for key, value in summary_dict.items(): | |||
| @@ -119,7 +127,7 @@ class SummaryWatcher: | |||
| return directories | |||
| def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry_name, counter, list_explain): | |||
| def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry, counter, list_explain): | |||
| """ | |||
| Scan subdir entries. | |||
| @@ -134,7 +142,7 @@ class SummaryWatcher: | |||
| try: | |||
| subdir_entries = os.scandir(entry_path) | |||
| except PermissionError: | |||
| logger.warning('Path of %s under summary base directory is not accessible.', entry_name) | |||
| logger.warning('Path of %s under summary base directory is not accessible.', entry.name) | |||
| return | |||
| # sort in ascending order according to modification time. | |||
| @@ -149,11 +157,14 @@ class SummaryWatcher: | |||
| logger.info('Stop further scanning due to overall is False and ' | |||
| 'number of scanned files exceeds upper limit.') | |||
| break | |||
| subdir_relative_path = os.path.join('.', entry_name) | |||
| subdir_relative_path = os.path.join('.', entry.name) | |||
| if subdir_entry.is_symlink(): | |||
| pass | |||
| self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry, list_explain) | |||
| relative_path = './' | |||
| self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) | |||
| def _is_valid_summary_directory(self, summary_base_dir, relative_path): | |||
| """ | |||
| Check if the given summary directory is valid. | |||
| @@ -198,13 +209,11 @@ class SummaryWatcher: | |||
| list_explain (bool): Indicates whether to list only the mindexplain folder. | |||
| """ | |||
| try: | |||
| stat = entry.stat() | |||
| ctime, mtime = self._get_stat_time(entry) | |||
| except FileNotFoundError: | |||
| logger.warning('File %s not found', entry.name) | |||
| return | |||
| ctime = datetime.datetime.fromtimestamp(stat.st_ctime).astimezone() | |||
| mtime = datetime.datetime.fromtimestamp(stat.st_mtime).astimezone() | |||
| if entry.is_file(): | |||
| summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) | |||
| pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) | |||
| @@ -238,7 +247,10 @@ class SummaryWatcher: | |||
| summary_dict[relative_path]['explain_files'] += 1 | |||
| else: | |||
| summary_dict[relative_path]['summary_files'] += 1 | |||
| self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) | |||
| elif entry.is_dir(): | |||
| self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict) | |||
| if list_explain: | |||
| return | |||
| @@ -261,6 +273,28 @@ class SummaryWatcher: | |||
| else: | |||
| summary_dict[relative_path] = _new_entry(ctime, mtime, profiler) | |||
| def _check_by_analyzers(self, entry, summary_base_dir, relative_path, summary_dict): | |||
| """Check by all analyzers.""" | |||
| try: | |||
| ctime, mtime = self._get_stat_time(entry) | |||
| except FileNotFoundError: | |||
| logger.warning('File %s not found', entry.name) | |||
| return | |||
| for analyzer in self._analyzers: | |||
| register_info = analyzer.analyze(entry, summary_base_dir, relative_path) | |||
| if register_info: | |||
| if relative_path not in summary_dict: | |||
| summary_dict[relative_path] = _new_entry(ctime, mtime) | |||
| summary_dict[relative_path].update(register_info) | |||
| def _get_stat_time(self, entry): | |||
| """Get ctime and mtime.""" | |||
| stat = entry.stat() | |||
| ctime = datetime.datetime.fromtimestamp(stat.st_ctime).astimezone() | |||
| mtime = datetime.datetime.fromtimestamp(stat.st_mtime).astimezone() | |||
| return ctime, mtime | |||
| def _find_profiler_dir(self, entry, summary_base_dir, relative_path): | |||
| """Find profiler dir by the given relative path.""" | |||
| profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name) | |||
| @@ -342,6 +376,9 @@ class SummaryWatcher: | |||
| if self._is_valid_profiler_directory(full_path)[0] or \ | |||
| self._is_valid_cluster_profiler_directory(full_path)[0]: | |||
| return True | |||
| if os.path.exists(os.path.join(summary_directory, os.path.join(entry.name, ".metadata"))): | |||
| return True | |||
| return False | |||
| def _is_valid_profiler_directory(self, directory): | |||
| @@ -515,7 +552,8 @@ def _new_entry(ctime, mtime, profiler=None): | |||
| 'lineage_files': 0, | |||
| 'explain_files': 0, | |||
| 'graph_files': 0, | |||
| 'profiler': profiler | |||
| 'profiler': profiler, | |||
| 'dump_dir': None | |||
| } | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -150,7 +150,8 @@ class TrainTaskManager(BaseProcessor): | |||
| profiler_type=basic_info.profiler_type, | |||
| summary_files=basic_info.summary_files, | |||
| graph_files=basic_info.graph_files, | |||
| lineage_files=basic_info.lineage_files | |||
| lineage_files=basic_info.lineage_files, | |||
| dump_dir=basic_info.dump_dir | |||
| ) | |||
| if train_job.cache_status != CacheStatus.NOT_IN_CACHE: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,6 +20,8 @@ from mindinsight.utils.constant import DebuggerErrors as DebuggerErrorCodes | |||
| _PARAM_ERROR_MASK = 0b00001 << 7 | |||
| _DEBUGGER_GRAPH_ERROR = 0b00010 << 7 | |||
| _DEBUGGER_RUNNING_ERROR = 0b00011 << 7 | |||
| _DEBUGGER_SERVER_ERROR = 0b00100 << 7 | |||
| _DEBUGGER_SESSION_ERROR = 0b00101 << 7 | |||
| @unique | |||
| @@ -44,6 +46,13 @@ class DebuggerErrors(DebuggerErrorCodes): | |||
| TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR | |||
| SET_RECOMMEND_WATCHPOINT_ERROR = 9 | _DEBUGGER_RUNNING_ERROR | |||
| DEBUGGER_SERVER_RUNNING_ERROR = 0 | _DEBUGGER_SERVER_ERROR | |||
| DEVICE_ID_UNREGISTERED = 1 | _DEBUGGER_SERVER_ERROR | |||
| MODULE_NOT_FOUND_ERROR = 2 | _DEBUGGER_SERVER_ERROR | |||
| DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR | |||
| DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR | |||
| @unique | |||
| class DebuggerErrorMsg(Enum): | |||
| @@ -63,3 +72,10 @@ class DebuggerErrorMsg(Enum): | |||
| TENSOR_GRAPH_ERROR = "Get tensor graphs failed." | |||
| TENSOR_HIT_ERROR = "Get tensor hits failed." | |||
| SET_RECOMMEND_WATCHPOINT_ERROR = "Set Recommend Watchpoints failed." | |||
| DEBUGGER_SERVER_RUNNING_ERROR = "Debugger server running error. {}" | |||
| DEVICE_ID_UNREGISTERED = "Device id unregistered. Device id: {}" | |||
| MODULE_NOT_FOUND_ERROR = "{} module not found." | |||
| DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation." | |||
| DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found." | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -190,3 +190,58 @@ class DebuggerConditionUnavailableError(MindInsightException): | |||
| message=DebuggerErrorMsg.DEBUGGER_CONDITION_UNAVAILABLE_ERROR.value.format(msg), | |||
| http_code=400 | |||
| ) | |||
| class DebuggerServerRunningError(MindInsightException): | |||
| """The condition unavailable error in debugger module.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerServerRunningError, self).__init__( | |||
| error=DebuggerErrors.DEBUGGER_SERVER_RUNNING_ERROR, | |||
| message=DebuggerErrorMsg.DEBUGGER_SERVER_RUNNING_ERROR.value.format(msg), | |||
| http_code=500 | |||
| ) | |||
| class DeviceIdUnregistered(MindInsightException): | |||
| """The condition unavailable error in debugger module.""" | |||
| def __init__(self, msg): | |||
| super(DeviceIdUnregistered, self).__init__( | |||
| error=DebuggerErrors.DEVICE_ID_UNREGISTERED, | |||
| message=DebuggerErrorMsg.DEVICE_ID_UNREGISTERED.value.format(msg), | |||
| http_code=400 | |||
| ) | |||
| class DebuggerModuleNotFoundError(MindInsightException): | |||
| """The condition unavailable error in debugger module.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerModuleNotFoundError, self).__init__( | |||
| error=DebuggerErrors.MODULE_NOT_FOUND_ERROR, | |||
| message=DebuggerErrorMsg.MODULE_NOT_FOUND_ERROR.value.format(msg), | |||
| http_code=500 | |||
| ) | |||
| class DebuggerSessionNumOverBoundError(MindInsightException): | |||
| """The condition unavailable error in debugger module.""" | |||
| def __init__(self): | |||
| super(DebuggerSessionNumOverBoundError, self).__init__( | |||
| error=DebuggerErrors.DEBUGGER_SESSION_OVER_BOUND_ERROR, | |||
| message=DebuggerErrorMsg.DEBUGGER_SESSION_OVER_BOUND_ERROR.value, | |||
| http_code=400 | |||
| ) | |||
| class DebuggerSessionNotFoundError(MindInsightException): | |||
| """The condition unavailable error in debugger module.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerSessionNotFoundError, self).__init__( | |||
| error=DebuggerErrors.DEBUGGER_SESSION_NOT_FOUND_ERROR, | |||
| message=DebuggerErrorMsg.DEBUGGER_SESSION_NOT_FOUND_ERROR.value.format(msg), | |||
| http_code=400 | |||
| ) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -38,9 +38,12 @@ NUMPY_TYPE_MAP = { | |||
| 'DT_FLOAT32': np.float32, | |||
| 'DT_FLOAT64': np.float64, | |||
| 'DT_STRING': np.str | |||
| 'DT_STRING': np.str, | |||
| 'DT_TYPE': np.str | |||
| } | |||
| MS_VERSION = '1.0.x' | |||
| @enum.unique | |||
| class ReplyStates(enum.Enum): | |||
| @@ -71,6 +74,7 @@ class Streams(enum.Enum): | |||
| TENSOR = 'tensor' | |||
| WATCHPOINT = 'watchpoint' | |||
| WATCHPOINT_HIT = 'watchpoint_hit' | |||
| DEVICE = 'device' | |||
| class RunLevel(enum.Enum): | |||
| @@ -152,3 +156,26 @@ def is_scope_type(node_type): | |||
| def is_cst_type(node_type): | |||
| """Judge whether the type is const type.""" | |||
| return node_type == NodeTypeEnum.CONST.value | |||
| def version_match(ms_version, mi_version): | |||
| """Judge if the version of Mindinsight and Mindspore is matched.""" | |||
| if not ms_version: | |||
| ms_version = MS_VERSION | |||
| mi_major, mi_minor = mi_version.split('.')[:2] | |||
| ms_major, ms_minor = ms_version.split('.')[:2] | |||
| return mi_major == ms_major and mi_minor == ms_minor | |||
| @enum.unique | |||
| class DebuggerServerMode(enum.Enum): | |||
| """Debugger Server Mode.""" | |||
| ONLINE = 'online' | |||
| OFFLINE = 'offline' | |||
| class DumpSettings(enum.Enum): | |||
| """Dump settings.""" | |||
| E2E_DUMP_SETTINGS = 'e2e_dump_settings' | |||
| COMMON_DUMP_SETTINGS = 'common_dump_settings' | |||
| ASYNC_DUMP_SETTINGS = 'async_dump_settings' | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -64,13 +64,13 @@ class _ConditionParameterValue: | |||
| return self.parameter.name | |||
| def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_context): | |||
| def recommend_watchpoints(condition_mgr: ConditionMgr, multi_card_graph_stream, condition_context): | |||
| """ | |||
| Recommend watchpoints. | |||
| Args: | |||
| condition_mgr (ConditionMgr): Condition manager instance. | |||
| graph_stream (GraphHandler): Graph handler instance. | |||
| multi_card_graph_stream (GraphHandler): Multi card graph handler instance. | |||
| condition_context (ConditionContext): Context for condition. | |||
| Returns: | |||
| @@ -78,7 +78,7 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c | |||
| """ | |||
| watch_points = [] | |||
| if not graph_stream.graph: | |||
| if not multi_card_graph_stream.has_graph: | |||
| logger.warning("Given graph is None.") | |||
| return watch_points | |||
| @@ -86,7 +86,7 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c | |||
| return watch_points | |||
| # add weight watch points | |||
| merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, graph_stream) | |||
| merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, multi_card_graph_stream) | |||
| _recommend_weight_initialization(merged_info, condition_mgr, watch_points, condition_context) | |||
| _recommend_weight_change_too_large(merged_info, condition_mgr, watch_points, condition_context) | |||
| @@ -97,25 +97,27 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c | |||
| _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context) | |||
| # add gradient watch points | |||
| merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, graph_stream) | |||
| merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, multi_card_graph_stream) | |||
| _recommend_gradient_vanishing(merged_info, condition_mgr, watch_points, condition_context) | |||
| # add tensor watch points | |||
| merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, graph_stream) | |||
| merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, multi_card_graph_stream) | |||
| _recommend_operator_overflow(merged_info, condition_mgr, watch_points, condition_context) | |||
| _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context) | |||
| _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context) | |||
| # add activation watch points | |||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.TANH.value) | |||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, | |||
| ActivationFuncEnum.TANH.value) | |||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | |||
| ActivationFuncEnum.TANH.value) | |||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.SIGMOID.value) | |||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, | |||
| ActivationFuncEnum.SIGMOID.value) | |||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | |||
| ActivationFuncEnum.SIGMOID.value) | |||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, | |||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream, | |||
| [ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value]) | |||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | |||
| ActivationFuncEnum.RELU.value) | |||
| @@ -318,12 +320,21 @@ def _recommend_activation_range(basic_info_nodes, condition_mgr, watch_points, c | |||
| watch_points.append(activation_range_watchpoint) | |||
| def get_basic_node_info(node_category, graph_stream, activation_func=None): | |||
| def get_basic_node_info(node_category, multi_card_graph_stream, activation_func=None): | |||
| """Get node merged info.""" | |||
| basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func) | |||
| merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph) | |||
| merged_info = _add_graph_name(merged_info, graph_stream) | |||
| return merged_info | |||
| nodes_for_devices = {} | |||
| has_node = False | |||
| for rank_id, graph_stream in multi_card_graph_stream.graph_handlers.items(): | |||
| basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func) | |||
| merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph) | |||
| merged_info = _add_graph_name(merged_info, graph_stream) | |||
| nodes_for_devices[rank_id] = merged_info | |||
| has_node = has_node or merged_info | |||
| if has_node: | |||
| return nodes_for_devices | |||
| return {} | |||
| def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None): | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -17,17 +17,19 @@ import sys | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import Streams | |||
| from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \ | |||
| TensorHandler, WatchpointHandler, WatchpointHitHandler | |||
| from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, MultiCardGraphHandler, \ | |||
| MultiCardTensorHandler, WatchpointHandler, MultiCardWatchpointHitHandler | |||
| from mindinsight.debugger.stream_handler.device_handler import DeviceHandler | |||
| STREAM_HANDLER_MAP = { | |||
| Streams.COMMAND.value: EventHandler, | |||
| Streams.DATA.value: EventHandler, | |||
| Streams.METADATA.value: MetadataHandler, | |||
| Streams.GRAPH.value: GraphHandler, | |||
| Streams.TENSOR.value: TensorHandler, | |||
| Streams.GRAPH.value: MultiCardGraphHandler, | |||
| Streams.TENSOR.value: MultiCardTensorHandler, | |||
| Streams.WATCHPOINT.value: WatchpointHandler, | |||
| Streams.WATCHPOINT_HIT.value: WatchpointHitHandler | |||
| Streams.WATCHPOINT_HIT.value: MultiCardWatchpointHitHandler, | |||
| Streams.DEVICE.value: DeviceHandler | |||
| } | |||
| @@ -40,10 +42,8 @@ class DebuggerCache: | |||
| def initialize(self): | |||
| """Initialize the stream handlers.""" | |||
| self._stream_handler = {} | |||
| for stream in Streams: | |||
| mode = stream.value | |||
| stream_handler = STREAM_HANDLER_MAP.get(mode) | |||
| self._stream_handler[mode] = stream_handler() | |||
| for mode, stream_class in STREAM_HANDLER_MAP.items(): | |||
| self._stream_handler[mode] = stream_class() | |||
| def clean(self): | |||
| """Clean cache for all stream.""" | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Debugger train job register.""" | |||
| import os | |||
| from mindinsight.utils.folder_analyzer import FolderAnalyzer | |||
| from mindinsight.datavisual.common.log import logger | |||
| class DebuggerFolderAnalyzer(FolderAnalyzer): | |||
| """Debugger train job register.""" | |||
| def analyze(self, entry, summary_base_dir, relative_path): | |||
| """Check dir by debugger register.""" | |||
| update_info = {} | |||
| if entry.is_dir(): | |||
| sub_relative_path = os.path.join(relative_path, entry.name) | |||
| entry_path = os.path.join(summary_base_dir, sub_relative_path) | |||
| try: | |||
| subdir_entries = os.scandir(entry_path) | |||
| except PermissionError: | |||
| logger.warning('Path of %s under summary base directory is not accessible.', entry.name) | |||
| return update_info | |||
| subdir_entries = [subdir_entry for subdir_entry in subdir_entries if not subdir_entry.is_symlink()] | |||
| subdir_entries = sorted(subdir_entries, key=lambda x: x.stat().st_mtime) | |||
| for subdir_entry in subdir_entries: | |||
| if subdir_entry.is_dir() and subdir_entry.name.startswith(".metadata"): | |||
| update_info = {'dump_dir': sub_relative_path} | |||
| return update_info | |||
| return update_info | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Debugger server module.""" | |||
| @@ -19,7 +19,7 @@ from functools import wraps | |||
| import mindinsight | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||
| Streams, RunLevel | |||
| Streams, RunLevel, version_match | |||
| from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum, ParamNameEnum | |||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto | |||
| @@ -117,9 +117,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| # clean cache data at the beginning of new step or node has been changed. | |||
| if is_new_step or is_new_node: | |||
| self._cache_store.clean_data() | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).clean_tensors( | |||
| request.cur_step) | |||
| if is_new_step: | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean() | |||
| # receive graph at the beginning of the training | |||
| if self._status == ServerStatus.RECEIVE_GRAPH: | |||
| self._send_graph_flag(metadata_stream) | |||
| @@ -141,7 +142,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._status = ServerStatus.WAITING | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| metadata = metadata_stream.get() | |||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get() | |||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get() | |||
| res.update(metadata) | |||
| self._cache_store.put_data(res) | |||
| log.debug("Put graph into data queue.") | |||
| @@ -157,7 +158,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| # put new metadata into cache | |||
| metadata_stream.put(metadata_proto) | |||
| # update current node name and graph name | |||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0) | |||
| full_name = metadata_proto.cur_node | |||
| graph_name = graph_stream.get_graph_id_by_full_name( | |||
| full_name) if full_name else metadata_stream.graph_name | |||
| @@ -182,7 +183,8 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def _send_watchpoint_hit_flag(self): | |||
| """Send Watchpoint hit flag.""" | |||
| watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id( | |||
| 0) | |||
| if not self._received_hit: | |||
| return | |||
| watchpoint_hits = self._received_hit | |||
| @@ -344,7 +346,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| run_cmd.node_name = '' | |||
| # clean watchpoint hit cache | |||
| if run_cmd.run_level == RunLevel.RECHECK.value: | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean() | |||
| log.debug("Receive RunCMD. Clean watchpoint hit cache.") | |||
| # update metadata state from sending to running | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| @@ -365,8 +367,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| log.info("The training from %s has finished.", client_ip) | |||
| else: | |||
| ms_version = request.ms_version | |||
| if not ms_version: | |||
| ms_version = '1.0.x' | |||
| if version_match(ms_version, mindinsight.__version__) is False: | |||
| log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s", | |||
| ms_version, mindinsight.__version__) | |||
| @@ -403,8 +403,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| graph = GraphProto.FromString(serial_graph) | |||
| log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node)) | |||
| graph_dict = {graph.name: graph} | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict) | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals( | |||
| graph.const_vals) | |||
| self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name | |||
| self._record_parameter_names() | |||
| self._status = ServerStatus.RECEIVE_GRAPH | |||
| @@ -429,10 +430,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name, | |||
| len(sub_graph.node)) | |||
| serial_graph = b"" | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals( | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals( | |||
| sub_graph.const_vals) | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict) | |||
| self._record_parameter_names() | |||
| self._status = ServerStatus.RECEIVE_GRAPH | |||
| log.debug("Send the reply for graph.") | |||
| @@ -440,9 +441,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def _record_parameter_names(self): | |||
| """Record parameter full names in tensor handler.""" | |||
| parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).search_in_graph( | |||
| pattern={'node_category': TargetTypeEnum.PARAMETER.value}) | |||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||
| parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)\ | |||
| .search_in_graph(pattern={'node_category': TargetTypeEnum.PARAMETER.value}) | |||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0) | |||
| for node in parameter_nodes: | |||
| tensor_name = [node.full_name + ':0'] | |||
| tensor_stream.record_parameter_names(tensor_name) | |||
| @@ -452,7 +453,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| """Send tensors into DebuggerCache.""" | |||
| log.info("Received tensor.") | |||
| tensor_contents = [] | |||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0) | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| step = metadata_stream.step | |||
| for tensor in request_iterator: | |||
| @@ -482,7 +483,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| # save the watchpoint_hits data | |||
| watchpoint_hits = [] | |||
| watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0) | |||
| for watchpoint_hit_proto in request_iterator: | |||
| node_full_name = watchpoint_hit_proto.tensor.node_name | |||
| graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) | |||
| @@ -517,10 +518,3 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._received_hit = watchpoint_hits | |||
| reply = get_ack_reply() | |||
| return reply | |||
| def version_match(mi_version, ms_version): | |||
| """Judge if the version of Mindinsight and Mindspore is matched""" | |||
| mi_major, mi_minor = mi_version.split('.')[:2] | |||
| ms_major, ms_minor = ms_version.split('.')[:2] | |||
| return mi_major == ms_major and mi_minor == ms_minor | |||
| @@ -0,0 +1,613 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Debugger Offline server.""" | |||
| import copy | |||
| from collections import defaultdict | |||
| from importlib import import_module | |||
| from threading import Event | |||
| from multiprocessing import Process, Manager | |||
| import mindinsight | |||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerModuleNotFoundError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import Streams, ServerStatus, version_match, DebuggerServerMode, get_ack_reply, \ | |||
| RunLevel | |||
| from mindinsight.debugger.conditionmgr.condition import ParamNameEnum | |||
| from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase, debugger_server_wrap | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto | |||
| from mindinsight.debugger.stream_cache.data_loader import DataLoader | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| class DebuggerOfflineServer(DebuggerServerBase): | |||
| """Debugger Offline Server.""" | |||
| _MAX_TRY_EXCEPT_COUNT = 500 | |||
| def __init__(self, cache_store, context): | |||
| super(DebuggerOfflineServer, self).__init__(cache_store, context) | |||
| self._offline_server_manager = DebuggerOfflineManager(cache_store, context.dbg_dir) | |||
| self._running = Event() | |||
| self._running.clear() | |||
| def run(self): | |||
| """Start the debugger offline server.""" | |||
| log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) | |||
| self._offline_server_manager.initialize() | |||
| self._running.set() | |||
| log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) | |||
| try_count = 0 | |||
| while self._running.is_set() and try_count < self._MAX_TRY_EXCEPT_COUNT: | |||
| try: | |||
| self._offline_server_manager.wait_for_termination() | |||
| if not self._offline_server_manager.is_runnable(): | |||
| break | |||
| except MindInsightException as err: | |||
| log.exception(err) | |||
| log.warning("Error happens during listening on user commands. Restart listening again.") | |||
| finally: | |||
| try_count += 1 | |||
| # protect server from too much failure commands. | |||
| if try_count == self._MAX_TRY_EXCEPT_COUNT: | |||
| self._cache_store.clean() | |||
| metadata = self._cache_store.get_stream_handler(Streams.METADATA).get() | |||
| self._cache_store.put_data(metadata) | |||
| log.warning("Exception exceed %d times, stop server.", try_count) | |||
| def stop(self): | |||
| """Stop offline debugger server.""" | |||
| log.debug("Start to wait for thread started.") | |||
| self._running.wait() | |||
| log.info("Start to stop offline debugger server.") | |||
| self._running.clear() | |||
| self._offline_server_manager.stop() | |||
| self.join() | |||
| class DebuggerOfflineManager: | |||
| """Debugger offline manager which is used to handle user commands.""" | |||
| def __init__(self, cache_store, dbg_dir): | |||
| cache_store.initialize() | |||
| self._cache_store = cache_store | |||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | |||
| self._dbg_dir = dbg_dir | |||
| self._dbg_services_module = self._get_dbg_service_module() | |||
| self._dbg_service = None | |||
| self._command_listener = CommandListener(cache_store) | |||
| self._data_loader = DataLoader(dbg_dir) | |||
| self._is_running_flag = False | |||
| self._old_run_cmd = {} | |||
| def stop(self): | |||
| """Stop server.""" | |||
| self._is_running_flag = False | |||
| self._command_listener.stop() | |||
| self._cache_store.clean() | |||
| event = get_ack_reply() | |||
| event.exit = True | |||
| self._cache_store.put_command(event) | |||
| log.info("Stop debugger offline manager.") | |||
| def is_runnable(self): | |||
| """Check if the offline manager is runnable.""" | |||
| state = self._metadata_stream.state | |||
| flag = self._is_running_flag and state not in [ServerStatus.MISMATCH.value, ServerStatus.PENDING.value] | |||
| if not flag: | |||
| log.debug("The offline manager is not runnable, is_running_flag: %s, metadata state: %s", | |||
| self._is_running_flag, state) | |||
| return flag | |||
| @staticmethod | |||
| def _get_dbg_service_module(): | |||
| """Get dbg service module from MindSpore.""" | |||
| try: | |||
| dbg_services_module = import_module('mindspore.offline_debug.dbg_services') | |||
| except ModuleNotFoundError as err: | |||
| log.error("Failed to find module dbg_services. %s", err) | |||
| raise DebuggerModuleNotFoundError("dbg_services") | |||
| return dbg_services_module | |||
| @debugger_server_wrap | |||
| def initialize(self): | |||
| """Start to load offline debugger data.""" | |||
| self._data_loader.initialize() | |||
| is_sync = self._data_loader.get_sync_flag() | |||
| net_name = self._data_loader.get_net_name() | |||
| net_dir = self._data_loader.get_net_dir() | |||
| self._dbg_service = self._dbg_services_module.DbgServices(net_dir) | |||
| self._dbg_service.initialize(net_name=net_name, is_sync_mode=is_sync) | |||
| self._cache_store.clean() | |||
| self._command_listener.start() | |||
| self._is_running_flag = True | |||
| self._check_version() | |||
| if self._metadata_stream.state == ServerStatus.MISMATCH.value: | |||
| log.info("The MindSpore and MindInsight version are mismatched. Failed to initialize offline server.") | |||
| return | |||
| self._load_metadata() | |||
| self._load_graphs() | |||
| log.info("Success initialize offline server for %s", self._dbg_dir) | |||
| def _check_version(self): | |||
| """Check version.""" | |||
| ms_version = self._dbg_services_module.get_version() | |||
| mi_version = mindinsight.__version__ | |||
| self._metadata_stream.debugger_version = {'ms': ms_version, 'mi': mi_version} | |||
| if version_match(ms_version, mi_version) is False: | |||
| log.info("Version is mismatched, dbg_services is: %s, mindinsight is: %s", | |||
| ms_version, mi_version) | |||
| self._metadata_stream.state = ServerStatus.MISMATCH.value | |||
| metadata = self._metadata_stream.get(['state', 'debugger_version']) | |||
| self._cache_store.put_data(metadata) | |||
| def _load_metadata(self): | |||
| """Load metadata.""" | |||
| self._metadata_stream.debugger_type = DebuggerServerMode.OFFLINE.value | |||
| device_info = self._data_loader.load_device_info() | |||
| # The backend referred to the running environment on which the offline debugger | |||
| # data was generated. | |||
| # Currently supported options: `GPU`, `Ascend` | |||
| backend = device_info.get('device_target', 'Ascend') | |||
| self._metadata_stream.backend = backend | |||
| device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) | |||
| device_stream.put(device_info.get('server_list')) | |||
| rank_id = 0 | |||
| rank_0_info = device_stream.get(rank_id)['devices'][0] | |||
| self._metadata_stream.client_ip = rank_0_info.get('server_id') | |||
| # get step number per device. dict(device_id, step_num), may be increased with time goes by | |||
| step_num_per_device = self._data_loader.load_step_number() | |||
| device_stream.add_step_num_info(step_num_per_device) | |||
| self._metadata_stream.max_step_num = max(step_num_per_device.values()) | |||
| def _load_graphs(self): | |||
| """Load graphs.""" | |||
| # the format of graphs is a list of {'device_id': int, 'graph_protos': [GraphProto]}} | |||
| graphs = self._data_loader.load_graphs() | |||
| device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) | |||
| graph_per_rank = {} | |||
| for graph in graphs: | |||
| device_id = int(graph.get('device_id')) | |||
| rank_id = device_stream.get_rank_id_by_device_id(device_id) | |||
| graph_per_rank[rank_id] = {} | |||
| tensor_stream_per_rank = self._cache_store.get_stream_handler(Streams.TENSOR).\ | |||
| get_tensor_handler_by_rank_id(rank_id, create_if_not_exit=True) | |||
| for graph_proto in graph.get('graph_protos'): | |||
| graph_per_rank[rank_id][graph_proto.name] = graph_proto | |||
| tensor_stream_per_rank.put_const_vals(graph_proto.const_vals) | |||
| # the graph_per_rank is format like: Dict[<rank_id>, Dict[<graph_name>, <GraphProto>]] | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_per_rank) | |||
| device_stream.add_graph_name_info(graph_per_rank) | |||
| self._metadata_stream.state = ServerStatus.RECEIVE_GRAPH.value | |||
| @debugger_server_wrap | |||
| def wait_for_termination(self): | |||
| """Begin to listen on command event.""" | |||
| log.info("Begin to listen for user commands.") | |||
| self._send_graph() | |||
| while self.is_runnable(): | |||
| if not self._command_listener.has_new_command() and self._old_run_cmd: | |||
| self._deal_with_old_run_cmd() | |||
| continue | |||
| cmd = self._command_listener.get_next_command() | |||
| self.deal_with_cmd(cmd) | |||
| def _send_graph(self): | |||
| """Put graph and metadata info into data queue.""" | |||
| if not self.is_runnable(): | |||
| return | |||
| self._metadata_stream.state = ServerStatus.WAITING.value | |||
| metadata = self._metadata_stream.get() | |||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get() | |||
| res.update(metadata) | |||
| self._cache_store.put_data(res) | |||
| def _deal_with_old_run_cmd(self): | |||
| """Deal with old run command.""" | |||
| left_step_count = self._old_run_cmd.get('left_step_count') | |||
| if left_step_count: | |||
| self._execute_one_step() | |||
| # if old_run_cmd is not cleared due to hit. | |||
| if self._old_run_cmd: | |||
| self._old_run_cmd['left_step_count'] = left_step_count - 1 if left_step_count > 0 else -1 | |||
| if not self._old_run_cmd.get('left_step_count'): | |||
| self._old_run_cmd.clear() | |||
| def deal_with_cmd(self, cmd): | |||
| """Deal with command.""" | |||
| if cmd is None: | |||
| return | |||
| if isinstance(cmd, dict): | |||
| self._deal_with_view_cmd(cmd) | |||
| elif isinstance(cmd, EventReply): | |||
| self._on_event(cmd) | |||
| def _on_event(self, event): | |||
| """ | |||
| Deal with different command event. | |||
| Args: | |||
| event (EventReply): Command Event. | |||
| """ | |||
| if event.HasField('run_cmd'): | |||
| self._deal_with_run_cmd(event) | |||
| elif event.HasField('exit'): | |||
| self._cache_store.clean() | |||
| self._update_state(ServerStatus.PENDING) | |||
| log.debug("Clean cache for exit cmd.") | |||
| else: | |||
| self._deal_with_set_cmd(event) | |||
| log.debug("Deal with set cmd.") | |||
| def _deal_with_view_cmd(self, event): | |||
| """ | |||
| Deal with view cmd. | |||
| Args: | |||
| event (dict): View command params. | |||
| - view_cmd (EventReply): EventReply with view command. | |||
| - node_name (str): The center node name for view command. | |||
| - tensor_name (str): The center tensor name for view command. | |||
| - graph_name (str): The graph name of center node. | |||
| - rank_id (int): The device id of the tensor. | |||
| """ | |||
| view_cmd = event.pop('view_cmd', None).view_cmd | |||
| node_info = event | |||
| log.debug("Receive view cmd for node: %s.", event) | |||
| if not (view_cmd and node_info): | |||
| log.info("Invalid view command. Ignore it.") | |||
| return | |||
| # read tensor value by dbg_service | |||
| rank_id = node_info.get('rank_id', 0) | |||
| device_id = self._cache_store.get_stream_handler(Streams.DEVICE).get_device_id_by_rank_id(rank_id) | |||
| cur_step = self._metadata_stream.step | |||
| tensor_protos = view_cmd.tensors | |||
| root_graph_id = self.get_root_graph_id() | |||
| tensor_infos = [ | |||
| self._dbg_services_module.TensorInfo( | |||
| node_name=tensor_proto.node_name, | |||
| slot=int(tensor_proto.slot), | |||
| iteration=cur_step - 1 if tensor_proto.iter == 'prev' else cur_step, | |||
| device_id=device_id, | |||
| is_parameter=tensor_proto.truncate, | |||
| root_graph_id=root_graph_id | |||
| ) for tensor_proto in tensor_protos] | |||
| res = self._dbg_service.read_tensors(tensor_infos) | |||
| # put tensor into cache | |||
| for tensor_proto, tensor_data in zip(tensor_protos, res): | |||
| log.debug("Tensor name: %s:%s, tensor type: %s, tensor size: %s", tensor_proto.node_name, tensor_proto.slot, | |||
| tensor_data.dtype, tensor_data.data_size) | |||
| tensor_proto.tensor_content = tensor_data.data_ptr | |||
| tensor_proto.ClearField('dims') | |||
| tensor_proto.dims.extend(tensor_data.shape) | |||
| tensor_proto.data_type = tensor_data.dtype | |||
| self._put_tensor_value_into_cache(cur_step, node_info, rank_id, tensor_protos) | |||
| log.info("Put tensor value into cache.") | |||
| def get_root_graph_id(self): | |||
| """Get root graph id.""" | |||
| is_sync = self._data_loader.get_sync_flag() | |||
| graph_id = 0 if is_sync else 1 | |||
| return graph_id | |||
| def _put_tensor_value_into_cache(self, cur_step, node_info, rank_id, tensor_protos): | |||
| """Put tensor value into tensor cache.""" | |||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR). \ | |||
| get_tensor_handler_by_rank_id(rank_id) | |||
| update_data_flag = False | |||
| for tensor_proto in tensor_protos: | |||
| if not tensor_proto.tensor_content: | |||
| log.warning("Tensor %s:%s is empty.", | |||
| tensor_proto.node_name, tensor_proto.slot) | |||
| try: | |||
| has_update = tensor_stream.put({ | |||
| 'step': cur_step, | |||
| 'tensor_proto': tensor_proto, | |||
| 'tensor_contents': [tensor_proto.tensor_content] | |||
| }) | |||
| except ValueError as err: | |||
| log.warning("Failed to put %s:%s into cache. Ignore it. %s", | |||
| tensor_proto.node_name, tensor_proto.slot, str(err)) | |||
| continue | |||
| if has_update: | |||
| update_data_flag = True | |||
| if update_data_flag: | |||
| # send message to frontend | |||
| metadata = self._metadata_stream.get(['step', 'state']) | |||
| ret = {'receive_tensor': node_info.copy()} | |||
| ret.update(metadata) | |||
| self._cache_store.put_data(ret) | |||
| def _deal_with_run_cmd(self, event): | |||
| """Deal with run cmd.""" | |||
| run_cmd = event.run_cmd | |||
| parsed_run_cmd = self._get_parsed_run_cmd(run_cmd) | |||
| if parsed_run_cmd.run_steps > 0: | |||
| self._execute_one_step() | |||
| elif run_cmd.run_level == RunLevel.RECHECK.value: | |||
| log.info("Deal with recheck command.") | |||
| self._check_watchpoint(self._metadata_stream.step) | |||
| def _execute_one_step(self): | |||
| """Execute on step.""" | |||
| new_step = self._metadata_stream.step + 1 | |||
| if new_step > self._metadata_stream.max_step_num: | |||
| self._old_run_cmd.clear() | |||
| log.info("The server is already at the last step. %s", self._metadata_stream.max_step_num) | |||
| return | |||
| log.info("Go to next step: %s.", new_step) | |||
| self._check_watchpoint(new_step) | |||
| self._metadata_stream.step = new_step | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).set_step(new_step) | |||
| self._cache_store.put_data(self._metadata_stream.get('step')) | |||
| def _get_parsed_run_cmd(self, run_cmd): | |||
| """Get parsed run command.""" | |||
| if run_cmd.run_level == RunLevel.STEP.value: | |||
| # receive pause cmd | |||
| if not run_cmd.run_steps: | |||
| log.debug("Pause training and wait for next command.") | |||
| self._old_run_cmd.clear() | |||
| # update metadata state from sending to waiting | |||
| self._update_state(ServerStatus.WAITING) | |||
| return run_cmd | |||
| # receive step cmd | |||
| left_steps = run_cmd.run_steps - 1 | |||
| run_cmd.run_steps = 1 | |||
| if left_steps: | |||
| self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1 | |||
| elif run_cmd.node_name: | |||
| self._old_run_cmd['node_name'] = run_cmd.node_name | |||
| run_cmd.node_name = '' | |||
| return run_cmd | |||
| def _check_watchpoint(self, step): | |||
| """Save watchpoint hits into cache.""" | |||
| self._update_state(ServerStatus.RUNNING) | |||
| # Clean watchpoint_hits in cache | |||
| multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| multi_card_hit_streams.clean() | |||
| hits = Manager().list() | |||
| check_watchpoints_process = Process(target=self._check_watchpoint_work, args=(hits, step,)) | |||
| check_watchpoints_process.start() | |||
| check_watchpoints_process.join() | |||
| log.info("finish check watchpoint of %s", step) | |||
| if hits: | |||
| log.info("Received WatchpointHits. Left run cmd %s change to empty.", self._old_run_cmd) | |||
| self._old_run_cmd.clear() | |||
| self._update_state(ServerStatus.WAITING) | |||
| self._save_watchpoint_hits(hits) | |||
| def _save_watchpoint_hits(self, hits): | |||
| """Save watchpoint hits.""" | |||
| multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| multi_card_graph_streams = self._cache_store.get_stream_handler(Streams.GRAPH) | |||
| device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) | |||
| watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| watchpoint_hits = defaultdict(list) | |||
| for hit in hits: | |||
| log.info("Received hit\n: " | |||
| "name:%s, slot:%s, condition:%s, " | |||
| "watchpoint_id:%s" | |||
| "error_code:%s, device_id:%s", | |||
| hit['name'], hit['slot'], hit['condition'], | |||
| hit['watchpoint_id'], hit['error_code'], hit['device_id']) | |||
| rank_id = device_stream.get_rank_id_by_device_id(hit['device_id']) | |||
| watchpoint_hit = {} | |||
| self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit) | |||
| if not watchpoint_hit: | |||
| continue | |||
| self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit) | |||
| watchpoint_hit['error_code'] = hit['error_code'] | |||
| watchpoint_hits[rank_id].append(watchpoint_hit) | |||
| # save hit info into cache | |||
| multi_card_hit_streams.put(watchpoint_hits) | |||
| self._cache_store.put_data({'receive_watchpoint_hits': True}) | |||
| log.debug("Send the watchpoint hits to DataQueue.") | |||
| @staticmethod | |||
| def _add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit): | |||
| """Add hit node info.""" | |||
| graph_stream = multi_card_graph_streams.get_graph_handler_by_rank_id(rank_id) | |||
| node_full_name = hit['name'] | |||
| graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) | |||
| if not graph_name: | |||
| log.warning("Cannot find node %s in graph. Skip it.", node_full_name) | |||
| return | |||
| ui_node_name = graph_stream.get_node_name_by_full_name(node_full_name, graph_name) | |||
| log.debug("Receive watch point hit: %s:%s", node_full_name, hit['slot']) | |||
| if not ui_node_name: | |||
| log.info("Not support to show %s on graph.", node_full_name) | |||
| return | |||
| watchpoint_hit.update({ | |||
| 'tensor_proto': TensorProto(node_name=node_full_name, slot=str(hit['slot'])), | |||
| 'node_name': ui_node_name, | |||
| 'graph_name': graph_name | |||
| }) | |||
| @staticmethod | |||
| def _add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit): | |||
| """Add watchpoint hit info.""" | |||
| watchpoint = copy.deepcopy(watchpoint_stream.get_watchpoint_by_id(hit['watchpoint_id'])) | |||
| hit_params = {} | |||
| # get hit actual value | |||
| for param in hit['parameters']: | |||
| if param['name'] not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value, | |||
| ParamNameEnum.RANGE_END_INCLUSIVE.value) \ | |||
| and hit['error_code'] == 0: | |||
| hit_params[param['name']] = param['actual_value'] | |||
| # update actual value into watchpoint | |||
| watchpoint_condition_params = watchpoint.condition['params'] | |||
| for i, param in enumerate(watchpoint_condition_params): | |||
| name = param['name'] | |||
| if name in hit_params.keys(): | |||
| watchpoint_condition_params[i]['actual_value'] = hit_params[name] | |||
| else: | |||
| watchpoint_condition_params[i]['actual_value'] = None | |||
| watchpoint_hit['watchpoint'] = watchpoint | |||
| def _deal_with_set_cmd(self, event): | |||
| """ | |||
| Deal with set cmd. | |||
| Args: | |||
| event (EventReply): User command event including set_cmd. | |||
| """ | |||
| set_cmd = event.set_cmd | |||
| set_cmd_id = set_cmd.id | |||
| delete = set_cmd.delete | |||
| if not delete: | |||
| log.info("Add watchpoint by using dbg_server.") | |||
| watch_condition = set_cmd.watch_condition | |||
| param_list = [] | |||
| for param in watch_condition.params: | |||
| param_list.append( | |||
| self._dbg_services_module.Parameter(param.name, param.disabled, param.value)) | |||
| watch_nodes = set_cmd.watch_nodes | |||
| check_nodes = self._get_check_nodes(watch_nodes) | |||
| log.debug("Watchpoint %s, condition: %s, watch nodes: %s", | |||
| set_cmd_id, watch_condition.condition, check_nodes) | |||
| self._dbg_service.add_watchpoint(set_cmd_id, watch_condition.condition, check_nodes, param_list) | |||
| else: | |||
| log.info("Remove watchpoint by using dbg_server.") | |||
| self._dbg_service.remove_watchpoint(set_cmd_id) | |||
| def _get_check_nodes(self, watch_nodes): | |||
| """Get check nodes format""" | |||
| check_nodes = {} | |||
| device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) | |||
| root_graph_id = self.get_root_graph_id() | |||
| for watch_node in watch_nodes: | |||
| node_name = watch_node.node_name | |||
| rank_id = watch_node.rank_id | |||
| device_id = device_stream.get_device_id_by_rank_id(rank_id) | |||
| if node_name not in check_nodes: | |||
| is_parameter = bool(watch_node.node_type == NodeTypeEnum.PARAMETER.value) | |||
| check_nodes[node_name] = { | |||
| "device_id": [device_id], | |||
| "is_parameter": is_parameter, | |||
| "root_graph_id": [root_graph_id] | |||
| } | |||
| else: | |||
| check_nodes[node_name]["device_id"].append(device_id) | |||
| return check_nodes | |||
| def _update_state(self, server_status): | |||
| """ | |||
| Update state in metadata stream. | |||
| Args: | |||
| server_status (ServerStatus): The enum value in ServerStatus. | |||
| """ | |||
| if self._metadata_stream.state != server_status.value: | |||
| self._metadata_stream.state = server_status.value | |||
| self._cache_store.put_data(self._metadata_stream.get()) | |||
| def _check_watchpoint_work(self, hits, step): | |||
| """The check WatchPoint function work in another process.""" | |||
| log.info("Start checking WatchPointHit process.") | |||
| res = self._dbg_service.check_watchpoints(step) | |||
| for watchpoint_hit in res: | |||
| hit_dict = convert_watchpointhit(watchpoint_hit) | |||
| hits.append(hit_dict) | |||
| log.info("Checking WatchPointHit process is finished.") | |||
| class CommandListener: | |||
| """Event listener.""" | |||
| def __init__(self, cache_store): | |||
| self._cache_store = cache_store | |||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | |||
| # the next position of command queue to be queried | |||
| self._pos = '0' | |||
| self._is_waiting = Event() | |||
| def start(self): | |||
| """Start event listener.""" | |||
| self._pos = '0' | |||
| self._is_waiting.set() | |||
| def stop(self): | |||
| """Stop event listener.""" | |||
| # stop waiting for new user commands but can still get old commands. | |||
| self._is_waiting.clear() | |||
| def has_new_command(self): | |||
| """Check if there is new command in command queue.""" | |||
| return self._cache_store.has_command(self._pos) | |||
| def get_next_command(self): | |||
| """Get next command.""" | |||
| event = None | |||
| while event is None and self.has_new_command(): | |||
| self._pos, event = self._cache_store.get_command(self._pos) | |||
| log.debug("Deal with old %s-th command:\n%s.", self._pos, event) | |||
| if event is None: | |||
| event = self._wait_for_next_command() | |||
| return event | |||
| def _wait_for_next_command(self): | |||
| """ | |||
| Wait for next command. | |||
| Returns: | |||
| EventReply, the command event. | |||
| """ | |||
| if not self._is_waiting.is_set(): | |||
| self._metadata_stream.state = ServerStatus.PENDING.value | |||
| return None | |||
| log.info("Start to wait for command.") | |||
| if self._metadata_stream.state != ServerStatus.WAITING.value: | |||
| self._metadata_stream.state = ServerStatus.WAITING.value | |||
| self._cache_store.put_data(self._metadata_stream.get()) | |||
| log.debug("Wait for %s-th command", self._pos) | |||
| event = None | |||
| while event is None and self._is_waiting.is_set(): | |||
| self._pos, event = self._cache_store.get_command(self._pos) | |||
| return event | |||
| def convert_watchpointhit(watchpointhit): | |||
| """Convert watchpointhit object to dict.""" | |||
| parameters = watchpointhit.parameters | |||
| param_list = [] | |||
| for param in parameters: | |||
| param_dict = convert_param(param) | |||
| param_list.append(param_dict) | |||
| watchpointhit_dict = {'condition': watchpointhit.condition, | |||
| 'device_id': watchpointhit.device_id, | |||
| 'error_code': watchpointhit.error_code, | |||
| 'name': watchpointhit.name, | |||
| 'parameters': param_list, | |||
| 'slot': watchpointhit.slot, | |||
| 'watchpoint_id': watchpointhit.watchpoint_id} | |||
| return watchpointhit_dict | |||
| def convert_param(param): | |||
| """Convert parameter object to dict""" | |||
| param_dict = {'actual_value': param.actual_value, | |||
| 'disabled': param.disabled, | |||
| 'hit': param.hit, | |||
| 'name': param.name, | |||
| 'value': param.value} | |||
| return param_dict | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Debugger Online server.""" | |||
| from concurrent import futures | |||
| import grpc | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.conf import settings | |||
| from mindinsight.debugger.debugger_services.debugger_grpc_server import DebuggerGrpcServer | |||
| from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase | |||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||
| def get_debugger_hostname(): | |||
| """Get hostname for online debugger server.""" | |||
| grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051 | |||
| host = settings.HOST if hasattr(settings, 'HOST') else '[::]' | |||
| hostname = "{}:{}".format(host, grpc_port) | |||
| return hostname | |||
| class DebuggerOnlineServer(DebuggerServerBase): | |||
| """Debugger Online Server.""" | |||
| def __init__(self, cache_store, context): | |||
| super(DebuggerOnlineServer, self).__init__(cache_store, context) | |||
| self._grpc_server_manager = self.get_grpc_server_manager() | |||
| def run(self): | |||
| self._grpc_server_manager.start() | |||
| log.info("Start grpc server %s", self._context.hostname) | |||
| self._grpc_server_manager.wait_for_termination() | |||
| def get_grpc_server_manager(self): | |||
| """Get grpc server instance according to hostname.""" | |||
| if self._context.hostname is None: | |||
| self._context.hostname = get_debugger_hostname() | |||
| grpc_server = DebuggerGrpcServer(self._cache_store, None) | |||
| grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |||
| grpc_server_base.add_EventListenerServicer_to_server(grpc_server, grpc_server_manager) | |||
| grpc_server_manager.add_insecure_port(self._context.hostname) | |||
| return grpc_server_manager | |||
| def stop(self): | |||
| self._grpc_server_manager.stop(grace=None) | |||
| self.join() | |||
| @@ -0,0 +1,58 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """DebuggerServerBase.""" | |||
| import threading | |||
| from abc import abstractmethod | |||
| from functools import wraps | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerServerRunningError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| def debugger_server_wrap(func): | |||
| """Wrapper for catch exception.""" | |||
| @wraps(func) | |||
| def record_log(*args, **kwargs): | |||
| try: | |||
| return func(*args, **kwargs) | |||
| except Exception as err: | |||
| log.exception(err) | |||
| raise DebuggerServerRunningError(str(err)) | |||
| return record_log | |||
| class DebuggerServerBase(threading.Thread): | |||
| """ | |||
| Debugger Server Base. | |||
| Args: | |||
| cache_store (DebuggerCacheStore): Cache store for debugger server. | |||
| context (DebuggerServerContext): Context for initialize debugger server. | |||
| """ | |||
| def __init__(self, cache_store, context): | |||
| super(DebuggerServerBase, self).__init__() | |||
| self._cache_store = cache_store | |||
| self._context = context | |||
| @abstractmethod | |||
| @debugger_server_wrap | |||
| def run(self): | |||
| """Function that should be called when thread started.""" | |||
| @abstractmethod | |||
| @debugger_server_wrap | |||
| def stop(self): | |||
| """Stop debugger server.""" | |||
| @@ -0,0 +1,92 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Debugger server factory.""" | |||
| import threading | |||
| from mindinsight.debugger.common.utils import DebuggerServerMode | |||
| from mindinsight.debugger.debugger_services.debugger_offline_server import DebuggerOfflineServer | |||
| from mindinsight.debugger.debugger_services.debugger_online_server import DebuggerOnlineServer | |||
| class DebuggerServerFactory: | |||
| """Create debugger server according to debugger mode.""" | |||
| _lock = threading.Lock() | |||
| _instance = None | |||
| def __init__(self): | |||
| self._server_map = { | |||
| DebuggerServerMode.ONLINE.value: DebuggerOnlineServer, | |||
| DebuggerServerMode.OFFLINE.value: DebuggerOfflineServer | |||
| } | |||
| def __new__(cls, *args, **kwargs): | |||
| if cls._instance is None: | |||
| with cls._lock: | |||
| if cls._instance is None: | |||
| cls._instance = super().__new__(cls, *args, **kwargs) | |||
| return cls._instance | |||
| def get_debugger_server(self, cache_store, context): | |||
| """ | |||
| Get debugger server according to debugger_context and cache_store. | |||
| Args: | |||
| cache_store (DebuggerCacheStore): Cache store for debugger server. | |||
| context (DebuggerServerContext): Context for initialize debugger server. | |||
| Returns: | |||
| DebuggerServerBase, Debugger server object. | |||
| """ | |||
| dbg_server = None | |||
| dbg_server_class = self._server_map.get(context.dbg_mode) | |||
| if dbg_server_class: | |||
| dbg_server = dbg_server_class(cache_store, context) | |||
| return dbg_server | |||
| class DebuggerServerContext: | |||
| """ | |||
| Debugger server context. | |||
| Args: | |||
| dbg_mode (str): The debugger mode. Optional: `online` or `offline`. | |||
| train_job (str): The relative directory of debugger dump data for one training. | |||
| Used only when dbg_mode is `offline.` | |||
| dbg_dir (str): The base directory of debugger dump data for one training. | |||
| Used only when dbg_mode is `offline.` | |||
| hostname (str): The hostname used for online debugger server. | |||
| Used only when dbg_mode is `online.` | |||
| """ | |||
| def __init__(self, dbg_mode, train_job=None, dbg_dir=None, hostname=None): | |||
| self._dbg_mode = dbg_mode | |||
| self._train_job = train_job | |||
| self._dbg_dir = dbg_dir | |||
| self.hostname = hostname | |||
| @property | |||
| def dbg_mode(self): | |||
| """Property of debugger mode.""" | |||
| return self._dbg_mode | |||
| @property | |||
| def dbg_dir(self): | |||
| """Property of debugger mode.""" | |||
| return self._dbg_dir | |||
| @property | |||
| def train_job(self): | |||
| """The property of train job.""" | |||
| return self._train_job | |||
| @@ -13,17 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Implement the debugger server.""" | |||
| import signal | |||
| from concurrent import futures | |||
| from functools import wraps | |||
| from threading import Thread | |||
| import grpc | |||
| from mindinsight.debugger.conditionmgr.condition import ConditionContext | |||
| from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | |||
| from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.datavisual.utils.tools import to_float | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| @@ -32,9 +23,11 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import ServerStatus, \ | |||
| create_view_event_from_tensor_basic_info, Streams | |||
| from mindinsight.debugger.conditionmgr.condition import ConditionContext | |||
| from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | |||
| from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||
| from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerFactory | |||
| from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo | |||
| from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator | |||
| from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator | |||
| @@ -57,25 +50,29 @@ def try_except(func): | |||
| return send_latest_metadata | |||
| class DebuggerServer: | |||
| class DebuggerSession: | |||
| """The server manager of debugger.""" | |||
| def __init__(self): | |||
| def __init__(self, context): | |||
| self.condition_mgr = ConditionMgr() | |||
| self.cache_store = DebuggerCache() | |||
| self.grpc_server = DebuggerGrpcServer(self.cache_store, self.condition_mgr) | |||
| self.grpc_server_manager = None | |||
| self.back_server = None | |||
| self.context = context | |||
| self.back_server = DebuggerServerFactory().get_debugger_server(self.cache_store, context) | |||
| @property | |||
| def train_job(self): | |||
| """The property of train job.""" | |||
| return self.context.train_job | |||
| def get_condition_collections(self, train_id): | |||
| def get_condition_collections(self, train_id=""): | |||
| """Get default condition_collections""" | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) | |||
| log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) | |||
| return self.condition_mgr.get_all_collections(condition_context) | |||
| def set_recommended_watch_points(self, set_recommended, train_id): | |||
| """set recommended watch points.""" | |||
| def set_recommended_watch_points(self, set_recommended, train_id=""): | |||
| """Set recommended watch points.""" | |||
| if not isinstance(set_recommended, bool): | |||
| log.error("Bool param should be given for set_recommended") | |||
| raise DebuggerParamValueError("Bool param should be given.") | |||
| @@ -97,38 +94,28 @@ class DebuggerServer: | |||
| def _add_recommended_watchpoints(self, condition_context): | |||
| """Add predefined watchpoints.""" | |||
| log.debug("Add predefined watchpoints.") | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| watchpoints = recommend_watchpoints(self.condition_mgr, graph_stream, condition_context) | |||
| multi_card_graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| watchpoints = recommend_watchpoints(self.condition_mgr, multi_card_graph_stream, condition_context) | |||
| watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| device_stream = self.cache_store.get_stream_handler(Streams.DEVICE) | |||
| watch_points_ids = [] | |||
| for watchpoint in watchpoints: | |||
| watch_points_id = watch_point_stream_handler.create_watchpoint( | |||
| watch_condition=watchpoint.get_watch_condition_dict(), | |||
| watch_nodes=watchpoint.watch_nodes, | |||
| name=watchpoint.name, | |||
| condition_mgr=self.condition_mgr | |||
| condition_mgr=self.condition_mgr, | |||
| device_amount=device_stream.device_amount | |||
| ) | |||
| watch_points_ids.append(watch_points_id) | |||
| return watch_points_ids | |||
| def start(self): | |||
| """Start server.""" | |||
| grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051 | |||
| host = settings.HOST if hasattr(settings, 'HOST') else '[::]' | |||
| hostname = "{}:{}".format(host, grpc_port) | |||
| # initialize a grpc server | |||
| grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |||
| grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager) | |||
| grpc_server_manager.add_insecure_port(hostname) | |||
| grpc_server_manager.start() | |||
| my_server_thread = Thread(target=grpc_server_manager.wait_for_termination) | |||
| # start grpc server | |||
| my_server_thread.start() | |||
| self.back_server = my_server_thread | |||
| self.grpc_server_manager = grpc_server_manager | |||
| self.back_server.start() | |||
| # register stop server handler | |||
| signal.signal(signal.SIGINT, self._stop_handler) | |||
| log.info("Start grpc server %s", hostname) | |||
| #signal.signal(signal.SIGINT, self._stop_handler) | |||
| log.info("Start debugger backend server.") | |||
| def _stop_handler(self, signum, frame): | |||
| """Register stop server handler.""" | |||
| @@ -139,8 +126,7 @@ class DebuggerServer: | |||
| """Stop debugger server.""" | |||
| log.info("Send terminate info to client.") | |||
| self.control({'mode': 'terminate'}) | |||
| self.grpc_server_manager.stop(grace=None) | |||
| self.back_server.join() | |||
| self.back_server.stop() | |||
| log.info("Stop debugger server.") | |||
| def poll_data(self, pos): | |||
| @@ -172,6 +158,7 @@ class DebuggerServer: | |||
| - graph_name (str): The graph name. | |||
| - watch_point_id (int): The id of watchpoint. Default: 0. | |||
| - node_category (str): The node_category. Default: None | |||
| - rank_id (int): The id of rank. Default: 0. | |||
| Returns: | |||
| dict, the searched nodes. | |||
| @@ -179,19 +166,20 @@ class DebuggerServer: | |||
| log.info("receive search request with filter_condition: %s", filter_condition) | |||
| # validate watchpoint id | |||
| watch_point_id = filter_condition.pop('watch_point_id', 0) | |||
| rank_id = filter_condition.pop('rank_id', 0) | |||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| watchpoint_stream.validate_watchpoint_id(watch_point_id) | |||
| # validate and update graph name | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||
| graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) | |||
| filter_condition['graph_name'] = graph_name | |||
| # get searched graph | |||
| graph = graph_stream.search_nodes(filter_condition) | |||
| # add watched label to graph | |||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name) | |||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name, rank_id) | |||
| return graph | |||
| def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): | |||
| def tensor_comparisons(self, name, shape, detail='data', tolerance='0', rank_id=0): | |||
| """ | |||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||
| @@ -202,6 +190,7 @@ class DebuggerServer: | |||
| shape (str): Specify concrete dimensions of shape. | |||
| tolerance (str): Specify tolerance of difference between current step tensor and previous | |||
| step tensor. Default value is 0. | |||
| rank_id (int): The id of rank. Default: 0. | |||
| Raises: | |||
| DebuggerParamValueError, If node type is not parameter or value of detail is not support. | |||
| @@ -220,9 +209,10 @@ class DebuggerServer: | |||
| parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) | |||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | |||
| tolerance = to_float(tolerance, 'tolerance') | |||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id) | |||
| cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step | |||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) | |||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance, cur_step) | |||
| else: | |||
| raise DebuggerParamValueError( | |||
| "The node type must be parameter, but got {}.".format(node_type)) | |||
| @@ -270,10 +260,18 @@ class DebuggerServer: | |||
| self.cache_store.clean_data() | |||
| log.info("Clean data queue cache when retrieve all request.") | |||
| result = {} | |||
| for stream in [Streams.METADATA, Streams.GRAPH]: | |||
| for stream in [Streams.METADATA, Streams.GRAPH, Streams.DEVICE]: | |||
| sub_res = self.cache_store.get_stream_handler(stream).get() | |||
| result.update(sub_res) | |||
| devices = result['devices'] | |||
| if not devices: | |||
| graph = result['graph'] | |||
| metadata = result['metadata'] | |||
| device = {'rank_id': 0, 'server_ip': metadata.get('ip', 'localhost'), | |||
| 'device_id': metadata.get('device_name', ''), | |||
| 'graph_names': graph.get('graph_names', [])} | |||
| devices.append(device) | |||
| sub_res = self._hide_parameters_for_ui() | |||
| result.update(sub_res) | |||
| @@ -298,7 +296,8 @@ class DebuggerServer: | |||
| log.debug("Retrieve node %s.", filter_condition) | |||
| # validate node name | |||
| node_name = filter_condition.get('name') | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| rank_id = filter_condition.get('rank_id', 0) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||
| graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) | |||
| if node_name: | |||
| # validate node name | |||
| @@ -325,24 +324,27 @@ class DebuggerServer: | |||
| dict, reply with graph. | |||
| """ | |||
| # validate watch_point_id | |||
| rank_id = filter_condition.get('rank_id', 0) | |||
| watch_point_id = filter_condition.get('watch_point_id', 0) | |||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| watchpoint_stream.validate_watchpoint_id(watch_point_id) | |||
| # get graph | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||
| reply = graph_stream.get(filter_condition) | |||
| graph = reply.get('graph') | |||
| # add watched label to graph | |||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name')) | |||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name'), | |||
| rank_id) | |||
| return reply | |||
| def retrieve_tensor_history(self, node_name, graph_name=None): | |||
| def retrieve_tensor_history(self, node_name, graph_name=None, rank_id=0): | |||
| """ | |||
| Retrieve tensor history for leaf node. | |||
| Args: | |||
| node_name (str): The name of leaf node. | |||
| graph_name (str): The graph name. Default: None. | |||
| rank_id (int): The id of rank. Default: 0. | |||
| Returns: | |||
| dict, the tensor history and metadata. | |||
| @@ -352,34 +354,34 @@ class DebuggerServer: | |||
| if metadata_stream.state == ServerStatus.PENDING.value: | |||
| log.info("The backend is in pending status.") | |||
| return metadata_stream.get(['state', 'step']) | |||
| res = self._get_tensor_history(node_name, graph_name) | |||
| res = self._get_tensor_history(node_name, graph_name, rank_id) | |||
| return res | |||
| def _get_tensor_history(self, node_name, graph_name=None): | |||
| def _get_tensor_history(self, node_name, graph_name=None, rank_id=0): | |||
| """ | |||
| Get tensor history for single node. | |||
| Args: | |||
| node_name (str): The name of leaf node. | |||
| graph_name (str): The graph name. Default: None. | |||
| rank_id (int): The id of rank. Default: 0. | |||
| Returns: | |||
| dict, the tensor history and metadata. | |||
| """ | |||
| # get basic tensor history | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||
| tensor_history = graph_stream.get_tensor_history(node_name, graph_name) | |||
| # add tensor value for tensor history | |||
| self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name) | |||
| self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name, rank_id) | |||
| # add hit label for tensor history | |||
| watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| watchpoint_hit_stream.update_tensor_history(tensor_history) | |||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).update_tensor_history(tensor_history, rank_id) | |||
| # add metadata | |||
| metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(['step']) | |||
| tensor_history.update(metadata) | |||
| return tensor_history | |||
| def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name): | |||
| def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name, rank_id): | |||
| """ | |||
| Add tensor value for_tensor_history and send ViewCMD if tensor value missed. | |||
| @@ -387,48 +389,53 @@ class DebuggerServer: | |||
| tensor_history (list[dict]): A list of tensor info, including name and type. | |||
| node_name (str): The UI node name. | |||
| graph_name (str): The graph name. Default: None. | |||
| rank_id (int): The id of rank. Default: 0. | |||
| Returns: | |||
| dict, the tensor info. | |||
| """ | |||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||
| missed_tensors = tensor_stream.update_tensor_history(tensor_history) | |||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id) | |||
| cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step | |||
| missed_tensors = tensor_stream.update_tensor_history(tensor_history, cur_step) | |||
| if missed_tensors: | |||
| view_cmd = create_view_event_from_tensor_basic_info(missed_tensors) | |||
| self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name}) | |||
| self.cache_store.put_command( | |||
| {'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name, 'rank_id': rank_id}) | |||
| log.debug("Send view cmd.") | |||
| def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False): | |||
| def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False, rank_id=0): | |||
| """Retrieve the tensor value.""" | |||
| log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) | |||
| self.validate_tensor_param(name, detail) | |||
| # Limit to query max two dimensions for tensor in table view. | |||
| parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) | |||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name) | |||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name, rank_id) | |||
| reply = self.cache_store.get_stream_handler(Streams.TENSOR).get( | |||
| {'name': tensor_name, | |||
| 'node_type': node_type, | |||
| 'shape': parsed_shape, | |||
| 'prev': prev} | |||
| 'prev': prev}, | |||
| rank_id | |||
| ) | |||
| reply['tensor_value']['name'] = name | |||
| return reply | |||
| def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None): | |||
| def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None, rank_id=0): | |||
| """ | |||
| Get inner tensor name and type by UI name. | |||
| Args: | |||
| name (str): Node name shown in UI. | |||
| graph_name (Union[str, None]): The graph name, default is: None. | |||
| rank_id (int): The id of rank. Default: 0. | |||
| Returns: | |||
| str, full name of tensor. | |||
| str, node type of tensor. | |||
| """ | |||
| node_name, slot = name.rsplit(':', 1) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id) | |||
| graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(node_name) | |||
| node_type = graph_stream.get_node_type(node_name, graph_name) | |||
| full_name = graph_stream.get_full_name(node_name, graph_name) | |||
| @@ -483,6 +490,7 @@ class DebuggerServer: | |||
| - offset (int): The offset of current page. | |||
| - node_name (str): The retrieved node name. | |||
| - graph_name (str): The retrieved graph name. | |||
| - rank_id (int): The rank id. | |||
| Returns: | |||
| dict, watch point list or relative graph. | |||
| @@ -496,7 +504,13 @@ class DebuggerServer: | |||
| log.info("The backend is in pending status.") | |||
| return metadata_stream.get() | |||
| reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).group_by(group_condition) | |||
| rank_id = group_condition.pop('rank_id', 0) | |||
| reply = {} | |||
| multi_watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| if multi_watchpoint_hit_stream.check_rank_id(rank_id): | |||
| watchpoint_hit_stream = multi_watchpoint_hit_stream.get_hit_handler_by_rank_id(rank_id) | |||
| reply = watchpoint_hit_stream.group_by(group_condition) | |||
| reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() | |||
| return reply | |||
| @@ -591,40 +605,6 @@ class DebuggerServer: | |||
| training_controller.validate_mode(mode) | |||
| return training_controller.control(mode, params) | |||
| def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False): | |||
| """ | |||
| Get the graph of the next node according to node_name. | |||
| Args: | |||
| node_name (str): The name of current chosen leaf node. | |||
| graph_name (str): The graph name. | |||
| ascend (bool): If True, traverse the input nodes; | |||
| If False, traverse the output nodes. Default is True. | |||
| Returns: | |||
| dict, the next node information. | |||
| """ | |||
| log.info("Retrieve node <%s> by bfs, `ascend` is :%s", | |||
| node_name, ascend) | |||
| reply = {} | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph_name = graph_stream.validate_graph_name(graph_name) | |||
| next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) | |||
| # no next node | |||
| if next_node_name is None: | |||
| return reply | |||
| # add graph and tensor history for next node | |||
| filter_condition = { | |||
| 'name': next_node_name, | |||
| 'graph_name': graph_name, | |||
| 'single_node': True | |||
| } | |||
| search_graph = self._get_nodes_info(filter_condition) | |||
| reply = {'name': next_node_name} | |||
| reply.update(search_graph) | |||
| return reply | |||
| @try_except | |||
| def recheck(self): | |||
| """ | |||
| @@ -635,13 +615,14 @@ class DebuggerServer: | |||
| """ | |||
| return TrainingControlOperator(self.cache_store).recheck() | |||
| def retrieve_tensor_graph(self, tensor_name, graph_name): | |||
| def retrieve_tensor_graph(self, tensor_name, graph_name, rank_id=0): | |||
| """ | |||
| Retrieve tensor graph. | |||
| Args: | |||
| tensor_name (str): The tensor name from UI. | |||
| graph_name (str): The graph name. | |||
| rank_id (int): The id of rank. Default: 0. | |||
| Returns: | |||
| dict, tensor graph object. | |||
| @@ -650,16 +631,17 @@ class DebuggerServer: | |||
| log.error("Failed to get tensor graph the MindSpore is not in waiting state.") | |||
| raise DebuggerTensorGraphError | |||
| log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name) | |||
| tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name) | |||
| tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name, rank_id) | |||
| return tensor_graph_ops | |||
| def retrieve_tensor_hits(self, tensor_name, graph_name): | |||
| def retrieve_tensor_hits(self, tensor_name, graph_name, rank_id=0): | |||
| """ | |||
| Retrieve tensor hit information. | |||
| Args: | |||
| tensor_name (str): The tensor name from UI. | |||
| graph_name (str): The graph name. | |||
| rank_id (int): The id of rank. Default: 0. | |||
| Returns: | |||
| dict, tensor hit info. | |||
| @@ -668,7 +650,7 @@ class DebuggerServer: | |||
| log.error("Failed to get tensor hits as the MindSpore is not in waiting state.") | |||
| raise DebuggerTensorHitError | |||
| log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name) | |||
| watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name) | |||
| watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name, rank_id) | |||
| return {'watch_points': watch_points} | |||
| def _hide_parameters_for_ui(self): | |||
| @@ -122,6 +122,9 @@ message WatchCondition { | |||
| message WatchNode { | |||
| string node_name = 1; | |||
| string node_type = 2; | |||
| string graph_name = 3; | |||
| int32 rank_id = 4; | |||
| int32 device_id = 5; | |||
| } | |||
| message WatchpointHit { | |||
| @@ -2,8 +2,6 @@ | |||
| # Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| # source: mindinsight/debugger/proto/debug_grpc.proto | |||
| import sys | |||
| _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) | |||
| from google.protobuf import descriptor as _descriptor | |||
| from google.protobuf import message as _message | |||
| from google.protobuf import reflection as _reflection | |||
| @@ -21,7 +19,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( | |||
| package='debugger', | |||
| syntax='proto3', | |||
| serialized_options=None, | |||
| serialized_pb=_b('\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\x81\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x95\x02\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x0c\n\x08overflow\x10\x02\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3') | |||
| serialized_pb=b'\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\x81\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x95\x02\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x0c\n\x08overflow\x10\x02\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"i\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\x12\x12\n\ngraph_name\x18\x03 \x01(\t\x12\x0f\n\x07rank_id\x18\x04 \x01(\x05\x12\x11\n\tdevice_id\x18\x05 \x01(\x05\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3' | |||
| , | |||
| dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) | |||
| @@ -130,7 +128,7 @@ _METADATA = _descriptor.Descriptor( | |||
| _descriptor.FieldDescriptor( | |||
| name='device_name', full_name='debugger.Metadata.device_name', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| @@ -144,14 +142,14 @@ _METADATA = _descriptor.Descriptor( | |||
| _descriptor.FieldDescriptor( | |||
| name='backend', full_name='debugger.Metadata.backend', index=2, | |||
| number=3, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='cur_node', full_name='debugger.Metadata.cur_node', index=3, | |||
| number=4, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| @@ -172,7 +170,7 @@ _METADATA = _descriptor.Descriptor( | |||
| _descriptor.FieldDescriptor( | |||
| name='ms_version', full_name='debugger.Metadata.ms_version', index=6, | |||
| number=7, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| @@ -203,7 +201,7 @@ _CHUNK = _descriptor.Descriptor( | |||
| _descriptor.FieldDescriptor( | |||
| name='buffer', full_name='debugger.Chunk.buffer', index=0, | |||
| number=1, type=12, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b(""), | |||
| has_default_value=False, default_value=b"", | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| @@ -311,7 +309,7 @@ _RUNCMD = _descriptor.Descriptor( | |||
| _descriptor.FieldDescriptor( | |||
| name='run_level', full_name='debugger.RunCMD.run_level', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| @@ -325,7 +323,7 @@ _RUNCMD = _descriptor.Descriptor( | |||
| _descriptor.FieldDescriptor( | |||
| name='node_name', full_name='debugger.RunCMD.node_name', index=2, | |||
| number=3, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| @@ -442,7 +440,7 @@ _WATCHCONDITION_PARAMETER = _descriptor.Descriptor( | |||
| _descriptor.FieldDescriptor( | |||
| name='name', full_name='debugger.WatchCondition.Parameter.name', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| @@ -546,14 +544,35 @@ _WATCHNODE = _descriptor.Descriptor( | |||
| _descriptor.FieldDescriptor( | |||
| name='node_name', full_name='debugger.WatchNode.node_name', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='node_type', full_name='debugger.WatchNode.node_type', index=1, | |||
| number=2, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=_b("").decode('utf-8'), | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='graph_name', full_name='debugger.WatchNode.graph_name', index=2, | |||
| number=3, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='rank_id', full_name='debugger.WatchNode.rank_id', index=3, | |||
| number=4, type=5, cpp_type=1, label=1, | |||
| has_default_value=False, default_value=0, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='device_id', full_name='debugger.WatchNode.device_id', index=4, | |||
| number=5, type=5, cpp_type=1, label=1, | |||
| has_default_value=False, default_value=0, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| @@ -570,7 +589,7 @@ _WATCHNODE = _descriptor.Descriptor( | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=1335, | |||
| serialized_end=1384, | |||
| serialized_end=1440, | |||
| ) | |||
| @@ -621,8 +640,8 @@ _WATCHPOINTHIT = _descriptor.Descriptor( | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=1387, | |||
| serialized_end=1524, | |||
| serialized_start=1443, | |||
| serialized_end=1580, | |||
| ) | |||
| _EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS | |||
| @@ -750,8 +769,8 @@ _EVENTLISTENER = _descriptor.ServiceDescriptor( | |||
| file=DESCRIPTOR, | |||
| index=0, | |||
| serialized_options=None, | |||
| serialized_start=1527, | |||
| serialized_end=1912, | |||
| serialized_start=1583, | |||
| serialized_end=1968, | |||
| methods=[ | |||
| _descriptor.MethodDescriptor( | |||
| name='WaitCMD', | |||
| @@ -1,5 +1,4 @@ | |||
| # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | |||
| """Client and server classes corresponding to protobuf-defined services.""" | |||
| import grpc | |||
| from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2 | |||
| @@ -7,7 +6,7 @@ from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_ | |||
| class EventListenerStub(object): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| def __init__(self, channel): | |||
| """Constructor. | |||
| @@ -48,40 +47,40 @@ class EventListenerStub(object): | |||
| class EventListenerServicer(object): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| def WaitCMD(self, request, context): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendMetadata(self, request, context): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendGraph(self, request_iterator, context): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendTensors(self, request_iterator, context): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendWatchpointHits(self, request_iterator, context): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendMultiGraphs(self, request_iterator, context): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| @@ -127,7 +126,7 @@ def add_EventListenerServicer_to_server(servicer, server): | |||
| # This class is part of an EXPERIMENTAL API. | |||
| class EventListener(object): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| """Missing associated documentation comment in .proto file""" | |||
| @staticmethod | |||
| def WaitCMD(request, | |||
| @@ -135,7 +134,6 @@ class EventListener(object): | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| insecure=False, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| @@ -144,7 +142,7 @@ class EventListener(object): | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendMetadata(request, | |||
| @@ -152,7 +150,6 @@ class EventListener(object): | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| insecure=False, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| @@ -161,7 +158,7 @@ class EventListener(object): | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendGraph(request_iterator, | |||
| @@ -169,7 +166,6 @@ class EventListener(object): | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| insecure=False, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| @@ -178,7 +174,7 @@ class EventListener(object): | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendTensors(request_iterator, | |||
| @@ -186,7 +182,6 @@ class EventListener(object): | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| insecure=False, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| @@ -195,7 +190,7 @@ class EventListener(object): | |||
| mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendWatchpointHits(request_iterator, | |||
| @@ -203,7 +198,6 @@ class EventListener(object): | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| insecure=False, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| @@ -212,7 +206,7 @@ class EventListener(object): | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendMultiGraphs(request_iterator, | |||
| @@ -220,7 +214,6 @@ class EventListener(object): | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| insecure=False, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| @@ -229,4 +222,4 @@ class EventListener(object): | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @@ -229,6 +229,9 @@ message NodeProto { | |||
| // full name with scope | |||
| optional string full_name = 8; | |||
| // The corresponding source code for this node. | |||
| optional string source_address = 9; | |||
| } | |||
| // Models | |||
| @@ -0,0 +1,172 @@ | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Implement the session manager.""" | |||
| import os | |||
| import threading | |||
| from urllib.parse import unquote | |||
| import _thread | |||
| from mindinsight.conf import settings | |||
| from mindinsight.debugger.common.log import LOGGER as logger | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \ | |||
| DebuggerSessionNotFoundError | |||
| from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext | |||
| from mindinsight.debugger.debugger_session import DebuggerSession | |||
| class SessionManager: | |||
| """The server manager of debugger.""" | |||
| ONLINE_TYPE = "ONLINE" | |||
| MAX_SESSION_NUM = 2 | |||
| ONLINE_SESSION_ID = "0" | |||
| _instance = None | |||
| _cls_lock = threading.Lock() | |||
| def __init__(self): | |||
| self.train_jobs = {} | |||
| self.sessions = {} | |||
| self.session_id = 1 | |||
| self.online_session = None | |||
| self._lock = threading.Lock() | |||
| self._exiting = False | |||
| enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | |||
| if enable_debugger: | |||
| self.creat_session(self.ONLINE_TYPE) | |||
| @classmethod | |||
| def get_instance(cls): | |||
| """Get the singleton instance.""" | |||
| with cls._cls_lock: | |||
| if cls._instance is None: | |||
| cls._instance = SessionManager() | |||
| return cls._instance | |||
| def exit(self): | |||
| """ | |||
| Called when the gunicorn worker process is exiting. | |||
| """ | |||
| with self._lock: | |||
| logger.info("Start to exit sessions.") | |||
| self._exiting = True | |||
| for session in self.sessions: | |||
| session.stop() | |||
| self.online_session.stop() | |||
| logger.info("Exited.") | |||
| def get_session(self, session_id): | |||
| """ | |||
| Get session by session id or get all session info. | |||
| Args: | |||
| session_id (Union[None, str]: The id of session. | |||
| Returns: | |||
| DebuggerSession, debugger session object. | |||
| """ | |||
| with self._lock: | |||
| if session_id == self.ONLINE_SESSION_ID and self.online_session is not None: | |||
| return self.online_session | |||
| if session_id in self.sessions: | |||
| return self.sessions.get(session_id) | |||
| raise DebuggerSessionNotFoundError("{}".format(session_id)) | |||
| def creat_session(self, session_type, train_job=None): | |||
| """ | |||
| Create session by the train job info. | |||
| Args: | |||
| session_type (str): The session_type. | |||
| train_job (str): The train job info. | |||
| Returns: | |||
| str, session id. | |||
| """ | |||
| with self._lock: | |||
| if self._exiting: | |||
| logger.info( | |||
| "System is exiting, will terminate the thread.") | |||
| _thread.exit() | |||
| if session_type == self.ONLINE_TYPE: | |||
| if self.online_session is None: | |||
| context = DebuggerServerContext(dbg_mode='online') | |||
| self.online_session = DebuggerSession(context) | |||
| self.online_session.start() | |||
| return self.ONLINE_SESSION_ID | |||
| if train_job in self.train_jobs: | |||
| return self.train_jobs.get(train_job) | |||
| self._check_session_num() | |||
| summary_base_dir = settings.SUMMARY_BASE_DIR | |||
| unquote_path = unquote(train_job, errors='strict') | |||
| whole_path = os.path.join(summary_base_dir, unquote_path) | |||
| normalized_path = validate_and_normalize_path(whole_path) | |||
| context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path) | |||
| session = DebuggerSession(context) | |||
| session.start() | |||
| session_id = str(self.session_id) | |||
| self.sessions[session_id] = session | |||
| self.train_jobs[train_job] = session_id | |||
| self.session_id += 1 | |||
| return session_id | |||
| def delete_session(self, session_id): | |||
| """Delete session by session id.""" | |||
| with self._lock: | |||
| if session_id == self.ONLINE_SESSION_ID: | |||
| self.online_session.stop() | |||
| self.online_session = None | |||
| return | |||
| if session_id not in self.sessions: | |||
| raise DebuggerSessionNotFoundError("session id {}".format(session_id)) | |||
| session = self.sessions.get(session_id) | |||
| session.stop() | |||
| self.sessions.pop(session_id) | |||
| self.train_jobs.pop(session.train_job) | |||
| return | |||
| def get_sessions(self): | |||
| """get all sessions""" | |||
| return {"train_jobs": self.train_jobs} | |||
| def _check_session_num(self): | |||
| """Check the amount of sessions.""" | |||
| if len(self.sessions) >= self.MAX_SESSION_NUM: | |||
| raise DebuggerSessionNumOverBoundError() | |||
| def validate_and_normalize_path(path): | |||
| """Validate and normalize_path""" | |||
| if not path: | |||
| raise ValueError("The path is invalid!") | |||
| path_str = str(path) | |||
| if not path_str.startswith("/"): | |||
| raise ValueError("The path is invalid!") | |||
| try: | |||
| normalized_path = os.path.realpath(path) | |||
| except ValueError: | |||
| raise ValueError("The path is invalid!") | |||
| return normalized_path | |||
| @@ -0,0 +1,210 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """This file is used to define the DataLoader.""" | |||
| import os | |||
| import json | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import ModelProto | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.debugger.common.utils import DumpSettings | |||
| class DataLoader: | |||
| """The DataLoader object provides interface to load graphs and device information from base_dir.""" | |||
| def __init__(self, base_dir): | |||
| self._debugger_base_dir = base_dir | |||
| self._graph_protos = [] | |||
| self._device_info = {} | |||
| self._step_num = {} | |||
| # flag for whether the data is from sync dump or async dump, True for sync dump, False for async dump. | |||
| self._is_sync = None | |||
| self._net_dir = "" | |||
| self._net_name = "" | |||
| self.initialize() | |||
| def initialize(self): | |||
| """Initialize the data_mode and net_dir of DataLoader.""" | |||
| dump_config_file = os.path.join(self._debugger_base_dir, os.path.join(".metadata", "data_dump.json")) | |||
| with open(dump_config_file, 'r') as load_f: | |||
| dump_config = json.load(load_f) | |||
| common_settings = dump_config.get(DumpSettings.COMMON_DUMP_SETTINGS.value) | |||
| if not common_settings: | |||
| raise ParamValueError('common_dump_settings not found in dump_config file.') | |||
| self._net_name = common_settings['net_name'] | |||
| if dump_config.get(DumpSettings.E2E_DUMP_SETTINGS.value) and \ | |||
| dump_config[DumpSettings.E2E_DUMP_SETTINGS.value]['enable']: | |||
| self._is_sync = True | |||
| self._net_dir = os.path.join(self._debugger_base_dir, self._net_name) | |||
| elif dump_config.get(DumpSettings.ASYNC_DUMP_SETTINGS.value) and \ | |||
| dump_config[DumpSettings.ASYNC_DUMP_SETTINGS.value]['enable']: | |||
| self._is_sync = False | |||
| self._net_dir = self._debugger_base_dir | |||
| else: | |||
| raise ParamValueError('The data must be generated from sync dump or async dump.') | |||
| def load_graphs(self): | |||
| """Load graphs from the debugger_base_dir.""" | |||
| files = os.listdir(self._net_dir) | |||
| for file in files: | |||
| if not self.is_device_dir(file): | |||
| continue | |||
| device_id, device_dir = self.get_device_id_and_dir(file) | |||
| graphs_dir = os.path.join(device_dir, 'graphs') | |||
| if not os.path.exists(graphs_dir) or not os.path.isdir(graphs_dir): | |||
| log.debug("Directory '%s' not exist.", graphs_dir) | |||
| self._graph_protos.append({'device_id': device_id, 'graph_protos': []}) | |||
| continue | |||
| graph_protos = get_graph_protos_from_dir(graphs_dir) | |||
| self._graph_protos.append({'device_id': device_id, 'graph_protos': graph_protos}) | |||
| return self._graph_protos | |||
| def load_device_info(self): | |||
| """Load device_info from file""" | |||
| hccl_json_file = os.path.join(self._debugger_base_dir, '.metadata/hccl.json') | |||
| if not os.path.isfile(hccl_json_file): | |||
| device = [] | |||
| device_ids = self.get_all_device_id() | |||
| device_ids.sort() | |||
| for i, device_id in enumerate(device_ids): | |||
| rank_id = i | |||
| device.append({'device_id': str(device_id), 'rank_id': str(rank_id)}) | |||
| device_target = 'Ascend' | |||
| self._device_info = {'device_target': device_target, | |||
| 'server_list': [{'server_id': 'localhost', 'device': device}]} | |||
| else: | |||
| with open(hccl_json_file, 'r') as load_f: | |||
| load_dict = json.load(load_f) | |||
| self._device_info = {'device_target': 'Ascend', 'server_list': load_dict['server_list']} | |||
| return self._device_info | |||
| def load_step_number(self): | |||
| """Load step number in the directory""" | |||
| files = os.listdir(self._net_dir) | |||
| for file in files: | |||
| if not self.is_device_dir(file): | |||
| continue | |||
| device_id, device_dir = self.get_device_id_and_dir(file) | |||
| max_step = 0 | |||
| files_in_device = os.listdir(device_dir) | |||
| if self._is_sync: | |||
| for file_in_device in files_in_device: | |||
| abs_file_in_device = os.path.join(device_dir, file_in_device) | |||
| if os.path.isdir(abs_file_in_device) and file_in_device.startswith("iteration_"): | |||
| step_id_str = file_in_device.split('_')[-1] | |||
| max_step = update_max_step(step_id_str, max_step) | |||
| self._step_num[str(device_id)] = max_step | |||
| else: | |||
| net_graph_dir = [] | |||
| for file_in_device in files_in_device: | |||
| abs_file_in_device = os.path.join(device_dir, file_in_device) | |||
| if os.path.isdir(abs_file_in_device) and file_in_device.startswith(self._net_name): | |||
| net_graph_dir.append(abs_file_in_device) | |||
| if len(net_graph_dir) > 1: | |||
| log.warning("There are more than one graph directory in device_dir: %s. " | |||
| "OfflineDebugger use data in %s.", device_dir, net_graph_dir[0]) | |||
| net_graph_dir_to_use = net_graph_dir[0] | |||
| graph_id = net_graph_dir_to_use.split('_')[-1] | |||
| graph_id_dir = os.path.join(net_graph_dir_to_use, graph_id) | |||
| step_ids = os.listdir(graph_id_dir) | |||
| for step_id_str in step_ids: | |||
| max_step = update_max_step(step_id_str, max_step) | |||
| self._step_num[str(device_id)] = max_step | |||
| return self._step_num | |||
| def is_device_dir(self, file_name): | |||
| """Judge if the file_name is a sub directory named 'device_x'.""" | |||
| if not file_name.startswith("device_"): | |||
| return False | |||
| id_str = file_name.split("_")[-1] | |||
| if not id_str.isdigit(): | |||
| return False | |||
| device_dir = os.path.join(self._net_dir, file_name) | |||
| if not os.path.isdir(device_dir): | |||
| return False | |||
| return True | |||
| def get_device_id_and_dir(self, file_name): | |||
| """Get device_id and absolute directory of file_name.""" | |||
| id_str = file_name.split("_")[-1] | |||
| device_id = int(id_str) | |||
| device_dir = os.path.join(self._net_dir, file_name) | |||
| return device_id, device_dir | |||
| def get_all_device_id(self): | |||
| """Get all device_id int the debugger_base_dir""" | |||
| device_ids = [] | |||
| files = os.listdir(self._net_dir) | |||
| for file in files: | |||
| if not self.is_device_dir(file): | |||
| continue | |||
| id_str = file.split("_")[-1] | |||
| device_id = int(id_str) | |||
| device_ids.append(device_id) | |||
| return device_ids | |||
| def get_net_dir(self): | |||
| """Get graph_name directory of the data.""" | |||
| return self._net_dir | |||
| def get_sync_flag(self): | |||
| """Get the sync flag of the data.""" | |||
| return self._is_sync | |||
| def get_net_name(self): | |||
| """Get net_name of the data.""" | |||
| return self._net_name | |||
| def load_graph_from_file(graph_file_name): | |||
| """Load graph from file.""" | |||
| with open(graph_file_name, 'rb') as file_handler: | |||
| model_bytes = file_handler.read() | |||
| model = ModelProto.FromString(model_bytes) | |||
| graph = model.graph | |||
| return graph | |||
| def get_graph_protos_from_dir(graphs_dir): | |||
| """ | |||
| Get graph from graph directory. | |||
| Args: | |||
| graph_dir (str): The absolute directory of graph files. | |||
| Returns: | |||
| list, list of 'GraphProto' object. | |||
| """ | |||
| files_in_graph_dir = os.listdir(graphs_dir) | |||
| graph_protos = [] | |||
| pre_file_name = "ms_output_trace_code_graph_" | |||
| for file_in_device in files_in_graph_dir: | |||
| if file_in_device.startswith(pre_file_name) and file_in_device.endswith(".pb"): | |||
| abs_graph_file = os.path.join(graphs_dir, file_in_device) | |||
| graph_proto = load_graph_from_file(abs_graph_file) | |||
| graph_protos.append(graph_proto) | |||
| return graph_protos | |||
| def update_max_step(step_id_str, max_step): | |||
| """Update max_step by compare step_id_str and max_step.""" | |||
| res = max_step | |||
| if step_id_str.isdigit(): | |||
| step_id = int(step_id_str) | |||
| if step_id > max_step: | |||
| res = step_id | |||
| return res | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """The definition of tensor stream.""" | |||
| from abc import abstractmethod, ABC | |||
| import numpy as np | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| @@ -149,7 +148,10 @@ class OpTensor(BaseTensor): | |||
| @property | |||
| def shape(self): | |||
| """The property of tensor shape.""" | |||
| return list(self._tensor_proto.dims) | |||
| dims = list(self._tensor_proto.dims) | |||
| if dims == [0]: | |||
| dims = [] | |||
| return dims | |||
| @property | |||
| def value(self): | |||
| @@ -254,12 +256,13 @@ class OpTensor(BaseTensor): | |||
| class ConstTensor(BaseTensor): | |||
| """Tensor data structure for Const Node.""" | |||
| _STRING_TYPE = 'DT_STRING' | |||
| _DT_TYPE = 'DT_TYPE' | |||
| def __init__(self, const_proto): | |||
| # the type of const_proto is NamedValueProto | |||
| super(ConstTensor, self).__init__() | |||
| self._const_proto = const_proto | |||
| self._value = self.generate_value_from_proto(const_proto) | |||
| self._value = self.generate_value_from_proto(const_proto.value) | |||
| def set_step(self, step): | |||
| """Set step value.""" | |||
| @@ -295,16 +298,25 @@ class ConstTensor(BaseTensor): | |||
| Returns: | |||
| Union[None, str, np.ndarray], the value of the tensor. | |||
| """ | |||
| fields = tensor_proto.value.ListFields() | |||
| fields = tensor_proto.ListFields() | |||
| if len(fields) != 2: | |||
| log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto) | |||
| tensor_value = None | |||
| for field_obj, field_value in fields: | |||
| if field_obj.name != 'dtype': | |||
| tensor_value = field_value | |||
| if tensor_proto.dtype == DataType.DT_TUPLE: | |||
| tensor_values = [] | |||
| for field_value_element in field_value: | |||
| value_element = self.generate_value_from_proto(field_value_element) | |||
| tensor_values.append(value_element) | |||
| tensor_value = tensor_values | |||
| elif tensor_proto.dtype == DataType.DT_TYPE: | |||
| tensor_value = DataType.Name(field_value.data_type) | |||
| else: | |||
| tensor_value = field_value | |||
| break | |||
| if tensor_value is not None and self.dtype != self._STRING_TYPE: | |||
| tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(self.dtype)) | |||
| if tensor_value is not None and tensor_proto.dtype != self._STRING_TYPE: | |||
| tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(tensor_proto.dtype)) | |||
| return tensor_value | |||
| def get_tensor_value_by_shape(self, shape=None): | |||
| @@ -328,7 +340,8 @@ class ConstTensor(BaseTensor): | |||
| Returns: | |||
| dict, overall statistics. | |||
| """ | |||
| if self.empty or self.dtype == self._STRING_TYPE: | |||
| if self.empty or self.dtype == self._STRING_TYPE or self.dtype == self._DT_TYPE: | |||
| log.debug("The tensor dtype is: %s, skip getting statistics.", self.dtype) | |||
| return {} | |||
| stats = TensorUtils.get_statistics_from_tensor(self.value) | |||
| statistics = TensorUtils.get_overall_statistic_dict(stats) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -184,7 +184,7 @@ class Watchpoint: | |||
| def __init__(self, watchpoint_id, watch_condition, name=None): | |||
| self._id = watchpoint_id | |||
| self._condition = watch_condition | |||
| self._watch_node = WatchNodeTree() | |||
| self._watch_node = {0: WatchNodeTree()} | |||
| self.name = name | |||
| @property | |||
| @@ -214,32 +214,36 @@ class Watchpoint: | |||
| else: | |||
| self._watch_node = other_watchpoint.nodes | |||
| def add_nodes(self, nodes): | |||
| def add_nodes(self, nodes, rank_id): | |||
| """Add node into watchpoint.""" | |||
| if not nodes: | |||
| log.warning("Add empty nodes.") | |||
| return | |||
| if rank_id not in self._watch_node: | |||
| self._watch_node[rank_id] = WatchNodeTree() | |||
| if not isinstance(nodes, list): | |||
| nodes = [nodes] | |||
| for node in nodes: | |||
| self._watch_node.add_node(node.name, node.type, node.full_name) | |||
| watch_node = self._watch_node.get(rank_id) | |||
| watch_node.add_node(node.name, node.type, node.full_name) | |||
| def remove_nodes(self, nodes): | |||
| def remove_nodes(self, nodes, rank_id): | |||
| """Remove nodes from watchpoint.""" | |||
| if not nodes: | |||
| return | |||
| self.validate_rank_id(rank_id) | |||
| if not isinstance(nodes, list): | |||
| nodes = [nodes] | |||
| for node in nodes: | |||
| self._watch_node.remove_node(node.name) | |||
| self._watch_node.get(rank_id).remove_node(node.name) | |||
| def get_node_status(self, node_name, node_type, full_name): | |||
| def get_node_status(self, node_name, node_type, full_name, rank_id): | |||
| """Judge if the node is in watch nodes.""" | |||
| if is_cst_type(node_type): | |||
| return WatchNodeTree.INVALID | |||
| scope_names = node_name.split('/') | |||
| cur_node = self._watch_node | |||
| self.validate_rank_id(rank_id) | |||
| cur_node = self._watch_node.get(rank_id) | |||
| status = 1 | |||
| for scope_name in scope_names: | |||
| cur_node = cur_node.get(scope_name) | |||
| @@ -250,7 +254,7 @@ class Watchpoint: | |||
| status = WatchNodeTree.TOTAL_WATCH | |||
| break | |||
| if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name: | |||
| self._watch_node.add_node(node_name, node_type, full_name) | |||
| self._watch_node.get(rank_id).add_node(node_name, node_type, full_name) | |||
| return status | |||
| @@ -278,11 +282,14 @@ class Watchpoint: | |||
| Returns: | |||
| list[NodeBasicInfo], the list of watch node basic infos. | |||
| """ | |||
| watch_nodes = [] | |||
| self._get_watch_node(self._watch_node, watch_nodes) | |||
| return watch_nodes | |||
| def get_pending_cmd(self, watch_nodes): | |||
| watch_nodes_for_devices = {} | |||
| for rank_id, watch_node_tree in self._watch_node.items(): | |||
| watch_nodes = [] | |||
| self._get_watch_node(watch_node_tree, watch_nodes) | |||
| watch_nodes_for_devices[rank_id] = watch_nodes | |||
| return watch_nodes_for_devices | |||
| def get_pending_cmd(self, watch_nodes_for_devices): | |||
| """Return the watchpoint in proto format.""" | |||
| # construct SetCMD | |||
| condition_id = self._condition.get('id') | |||
| @@ -309,10 +316,12 @@ class Watchpoint: | |||
| param_proto.name = param_name | |||
| param_proto.disabled = True | |||
| for watch_node in watch_nodes: | |||
| event_node = set_cmd.watch_nodes.add() | |||
| event_node.node_name = watch_node.full_name | |||
| event_node.node_type = watch_node.type | |||
| for rank_id, watch_nodes in watch_nodes_for_devices.items(): | |||
| for watch_node in watch_nodes: | |||
| event_node = set_cmd.watch_nodes.add() | |||
| event_node.node_name = watch_node.full_name | |||
| event_node.node_type = watch_node.type | |||
| event_node.rank_id = rank_id | |||
| return set_cmd | |||
| def get_watch_condition_info(self): | |||
| @@ -325,6 +334,11 @@ class Watchpoint: | |||
| watchpoint_info['name'] = self.name | |||
| return watchpoint_info | |||
| def validate_rank_id(self, rank_id): | |||
| if rank_id not in self._watch_node: | |||
| log.warning("Rank_id not exist") | |||
| return | |||
| class WatchpointHit: | |||
| """The watchpoint hit structure.""" | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -15,9 +15,10 @@ | |||
| """Import the streams handlers.""" | |||
| from .event_handler import EventHandler | |||
| from .metadata_handler import MetadataHandler | |||
| from .graph_handler import GraphHandler | |||
| from .tensor_handler import TensorHandler | |||
| from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler | |||
| from .graph_handler import GraphHandler, MultiCardGraphHandler | |||
| from .tensor_handler import TensorHandler, MultiCardTensorHandler | |||
| from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler, MultiCardWatchpointHitHandler | |||
| __all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', | |||
| 'WatchpointHandler', 'WatchpointHitHandler'] | |||
| __all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', 'WatchpointHitHandler', | |||
| 'MultiCardGraphHandler', 'MultiCardTensorHandler', | |||
| 'WatchpointHandler', 'MultiCardWatchpointHitHandler'] | |||
| @@ -0,0 +1,198 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the Device stream handler.""" | |||
| from collections import defaultdict | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, DeviceIdUnregistered, \ | |||
| DebuggerParamTypeError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| class DeviceHandler(StreamHandlerBase): | |||
| """Metadata Handler.""" | |||
| def __init__(self): | |||
| # contains all device infos, the format is like Dict[int(<device_id>, <device_info>)] | |||
| self._rank_info = defaultdict(DeviceInfo) | |||
| self._device_rank_map = {} | |||
| @property | |||
| def rank_ids(self): | |||
| """The rank ids.""" | |||
| return list(self._rank_info) | |||
| @property | |||
| def device_amount(self): | |||
| """The rank ids.""" | |||
| return len(self._rank_info) | |||
| def put(self, value): | |||
| """ | |||
| Put value into device info cache. | |||
| Args: | |||
| value (list): The list of server info. Each item is format like: | |||
| { | |||
| "server_id": str, | |||
| "device": list[<Device Info>] | |||
| }, | |||
| The format of <Device Info> is like: | |||
| { | |||
| "device_id": str, | |||
| "device_ip": str, | |||
| "rank_id": str | |||
| }. | |||
| """ | |||
| if not isinstance(value, list): | |||
| log.error("Invalid input type. list object is expected.") | |||
| raise DebuggerParamTypeError("List object is expected.") | |||
| try: | |||
| self._extract_rank_info(value) | |||
| except TypeError as err: | |||
| log.exception(err) | |||
| log.error("Invalid Device info.") | |||
| raise DebuggerParamValueError("Invalid device info.") | |||
| log.debug("Put Device into cache") | |||
| def _extract_rank_info(self, value): | |||
| """Extract rank info and save.""" | |||
| for server_info in value: | |||
| server_ip = server_info.get('server_id') | |||
| for device_info in server_info.get('device', []): | |||
| rank_id = int(device_info.get('rank_id')) | |||
| if rank_id in self._rank_info: | |||
| log.error("Repeated rank info for rank_id: %d", rank_id) | |||
| raise DebuggerParamValueError("Repeated rank info.") | |||
| device_info_obj = self._rank_info[rank_id] | |||
| device_info_obj.rank_id = rank_id | |||
| device_info_obj.server_ip = server_ip | |||
| device_info_obj.device_id = int(device_info.get('device_id')) | |||
| device_info_obj.device_ip = device_info.get('device_ip') | |||
| self._device_rank_map[device_info_obj.device_id] = rank_id | |||
| def add_step_num_info(self, step_info): | |||
| """ | |||
| Add step number information for each device. | |||
| Args: | |||
| step_info (dict): Step info per device. The key is the device id, the value | |||
| is the relative step number. | |||
| """ | |||
| if not step_info: | |||
| log.warning("No step number information.") | |||
| return | |||
| if len(step_info) == 1 and not self._rank_info: | |||
| device_id = int(list(step_info)[0]) | |||
| log.info("Default registered device %d as rank 0.", device_id) | |||
| self._rank_info[0].device_id = device_id | |||
| if len(step_info) > 1 and not self._rank_info: | |||
| log.error("Missing device info for multi-card training.") | |||
| raise DeviceIdUnregistered("all") | |||
| for device_id, step_num in step_info.items(): | |||
| device_id = int(device_id) | |||
| rank_id = self.get_rank_id_by_device_id(device_id) | |||
| self._rank_info[rank_id].step_num = step_num | |||
| def add_graph_name_info(self, graphs): | |||
| """ | |||
| Add graph name per device. | |||
| Args: | |||
| graphs (dict): Graph infos of all rank id. Each item is format like | |||
| """ | |||
| for rank_id, graph_info in graphs.items(): | |||
| graph_names = list(graph_info) | |||
| self._rank_info[rank_id].graph_names = graph_names | |||
| def get(self, filter_condition=None): | |||
| """ | |||
| Get device information according to filter_condition. | |||
| Args: | |||
| filter_condition (list): The rank id. | |||
| Returns: | |||
| dict, the device info. | |||
| """ | |||
| if filter_condition is None: | |||
| filter_condition = self.rank_ids | |||
| if not isinstance(filter_condition, list): | |||
| filter_condition = [filter_condition] | |||
| device_infos = [] | |||
| for rank_id in filter_condition: | |||
| device_info = self._rank_info.get(rank_id) | |||
| if device_info is None: | |||
| log.error("Invalid rank id.") | |||
| raise DeviceIdUnregistered(rank_id) | |||
| device_infos.append(device_info.to_dict()) | |||
| return {'devices': device_infos} | |||
| def get_rank_id_by_device_id(self, device_id): | |||
| """ | |||
| Get rank id by device id. | |||
| Args: | |||
| device_id (int): The device id. | |||
| Returns: | |||
| int, the rank id. | |||
| """ | |||
| rank_id = self._device_rank_map.get(device_id) | |||
| if rank_id is None: | |||
| log.error("Failed to find rank_id for device_id %s", device_id) | |||
| raise DeviceIdUnregistered(device_id) | |||
| return rank_id | |||
| def get_device_id_by_rank_id(self, rank_id): | |||
| """ | |||
| Get device id by rank id. | |||
| Args: | |||
| rank_id (int): The rank id. | |||
| Returns: | |||
| int, the device id. | |||
| """ | |||
| device_info = self._rank_info.get(rank_id) | |||
| if device_info: | |||
| return device_info.device_id | |||
| log.error("Failed to find device id according to rank_id %s", rank_id) | |||
| raise DeviceIdUnregistered(rank_id) | |||
| class DeviceInfo: | |||
| """Device info object.""" | |||
| def __init__(self): | |||
| self.rank_id = 0 | |||
| self.device_id = 0 | |||
| self.server_ip = '' | |||
| self.graph_names = [] | |||
| self.device_ip = '' | |||
| self.step_num = 0 | |||
| def to_dict(self): | |||
| """Convert device info to dict.""" | |||
| res = { | |||
| 'rank_id': self.rank_id, | |||
| 'server_ip': self.server_ip, | |||
| 'device_id': self.device_id, | |||
| 'device_ip': self.device_ip, | |||
| 'graph_names': self.graph_names, | |||
| 'total_step_num': self.step_num | |||
| } | |||
| return res | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -24,6 +24,55 @@ from mindinsight.debugger.stream_cache.debugger_multigraph import DebuggerMultiG | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| class MultiCardGraphHandler: | |||
| """Multi-card Graph Handler.""" | |||
| def __init__(self): | |||
| self._graph_handlers = {0: GraphHandler()} | |||
| @property | |||
| def graph_handlers(self): | |||
| """The property of whole_graph.""" | |||
| return self._graph_handlers | |||
| def get_graph_handler_by_rank_id(self, rank_id=0): | |||
| """Get handler by rank id""" | |||
| if rank_id in self._graph_handlers: | |||
| return self._graph_handlers.get(rank_id) | |||
| log.error("There is no rank id %d.", rank_id) | |||
| raise ValueError | |||
| def put(self, value): | |||
| """put graphs into graph_handlers""" | |||
| for rank_id, graph in value.items(): | |||
| if rank_id not in self._graph_handlers: | |||
| self._graph_handlers[rank_id] = GraphHandler() | |||
| self._graph_handlers[rank_id].put(graph) | |||
| def get(self, filter_condition=None, rank_id=0): | |||
| """Get the graph of specific node for specific device.""" | |||
| if rank_id in self._graph_handlers: | |||
| return self._graph_handlers.get(rank_id).get(filter_condition) | |||
| log.error("There is no rank id %d.", rank_id) | |||
| raise ValueError | |||
| def has_graph(self): | |||
| """check if has graph""" | |||
| res = False | |||
| for graph_handler in self._graph_handlers: | |||
| res = res or graph_handler.graph | |||
| return res | |||
| def register_graph_handler(self, rank_id, graph_handler): | |||
| """Register graph handler.""" | |||
| self._graph_handlers[rank_id] = graph_handler | |||
| def clean(self): | |||
| """Clean cache.""" | |||
| self.__init__() | |||
| class GraphHandler(StreamHandlerBase): | |||
| """Metadata Handler.""" | |||
| @@ -68,7 +117,7 @@ class GraphHandler(StreamHandlerBase): | |||
| Put value into graph cache. Called by grpc server. | |||
| Args: | |||
| value (GraphProto): The Graph proto message. | |||
| value (dict): The Graph proto message. Each item is format like (<graph_name>, GraphProto). | |||
| """ | |||
| log.info("Put graph into cache.") | |||
| sorted_value_list = self._sort_graph(value) | |||
| @@ -430,8 +479,8 @@ class GraphHandler(StreamHandlerBase): | |||
| graph_name, node_name = self._parse_node_name(scope_name, graph_name) | |||
| graph = self._get_graph(graph_name) | |||
| # to make sure fully match the scope name | |||
| node_name = node_name + '/' if not node_name.endswith('/') else node_name | |||
| nodes = graph.search_leaf_nodes_by_pattern(node_name) | |||
| node_name = node_name + '/' if node_name and not node_name.endswith('/') else node_name | |||
| nodes = graph.search_leaf_nodes_by_pattern(node_name, True) | |||
| res = [self.construct_node_basic_info(full_name=node.full_name, | |||
| graph_name=graph_name, | |||
| node_name=node.name, | |||
| @@ -448,45 +497,6 @@ class GraphHandler(StreamHandlerBase): | |||
| log.debug("Get empty full name.") | |||
| return node_name | |||
| def get_node_by_bfs_order(self, node_name=None, ascend=True): | |||
| """ | |||
| Traverse the graph in order of breath-first search by given node. | |||
| Args: | |||
| node_name (str): The name of current chosen leaf node. | |||
| ascend (bool): If True, traverse the input nodes; | |||
| If False, traverse the output nodes. Default is True. | |||
| Returns: | |||
| Union[None, dict], the next node object in dict type or None. | |||
| """ | |||
| bfs_order = self.bfs_order | |||
| length = len(bfs_order) | |||
| if not bfs_order: | |||
| log.error('Cannot get the BFS order of the graph!') | |||
| msg = 'Cannot get the BFS order of the graph!' | |||
| raise DebuggerParamValueError(msg) | |||
| if node_name is None: | |||
| if ascend is False: | |||
| next_node = None | |||
| else: | |||
| next_node = bfs_order[0] | |||
| else: | |||
| try: | |||
| index = bfs_order.index(node_name) | |||
| log.debug("The index of the node in BFS list is: %d", index) | |||
| except ValueError as err: | |||
| log.error('Cannot find the node: %s. Please check ' | |||
| 'the node name: %s', node_name, err) | |||
| msg = f'Cannot find the node: {node_name}. ' \ | |||
| f'Please check the node name {err}.' | |||
| raise DebuggerParamValueError(msg) | |||
| next_node = self._get_next_node_in_bfs(index, length, ascend) | |||
| return next_node | |||
| def _get_next_node_in_bfs(self, index, length, ascend): | |||
| """ | |||
| Get the next node in bfs order. | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,8 +13,9 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the metadata stream handler.""" | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import ServerStatus | |||
| from mindinsight.debugger.common.utils import ServerStatus, DebuggerServerMode | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| @@ -24,28 +25,36 @@ class MetadataHandler(StreamHandlerBase): | |||
| def __init__(self): | |||
| self._state = ServerStatus.PENDING | |||
| self._device_name = "" | |||
| self._step = 0 | |||
| self.step = 0 | |||
| self._client_ip = "" | |||
| self._cur_node_name = "" | |||
| self._cur_full_name = "" | |||
| self._backend = "" | |||
| self.backend = "" | |||
| self._enable_recheck = False | |||
| self._cur_graph_name = "" | |||
| # If recommendation_confirmed is true, it only means the user has answered yes or no to the question, | |||
| # it does not necessarily mean that the user will use the recommended watch points. | |||
| self._recommendation_confirmed = False | |||
| self._debugger_version = {} | |||
| # maximum step number among all devices | |||
| self._max_step_num = 0 | |||
| self._debugger_type = DebuggerServerMode.ONLINE.value | |||
| @property | |||
| def debugger_type(self): | |||
| """The property of debugger_type.""" | |||
| return self._debugger_type | |||
| @debugger_type.setter | |||
| def debugger_type(self, debugger_type): | |||
| """The property of debugger_type.""" | |||
| self._debugger_type = debugger_type | |||
| @property | |||
| def device_name(self): | |||
| """The property of device name.""" | |||
| return self._device_name | |||
| @property | |||
| def step(self): | |||
| """The property of current step.""" | |||
| return self._step | |||
| @property | |||
| def node_name(self): | |||
| """The property of current node name.""" | |||
| @@ -71,11 +80,6 @@ class MetadataHandler(StreamHandlerBase): | |||
| """The property of current node name.""" | |||
| return self._cur_full_name | |||
| @property | |||
| def backend(self): | |||
| """The property of current backend.""" | |||
| return self._backend | |||
| @property | |||
| def state(self): | |||
| """The property of state.""" | |||
| @@ -152,6 +156,16 @@ class MetadataHandler(StreamHandlerBase): | |||
| """ | |||
| self._debugger_version = value | |||
| @property | |||
| def max_step_num(self): | |||
| """The property of max_step_num.""" | |||
| return self._max_step_num | |||
| @max_step_num.setter | |||
| def max_step_num(self, max_step_num): | |||
| """Set the property of max_step_num.""" | |||
| self._max_step_num = max_step_num | |||
| def put(self, value): | |||
| """ | |||
| Put value into metadata cache. Called by grpc server. | |||
| @@ -160,10 +174,10 @@ class MetadataHandler(StreamHandlerBase): | |||
| value (MetadataProto): The Metadata proto message. | |||
| """ | |||
| self._device_name = value.device_name.split(':')[0] | |||
| self._step = value.cur_step | |||
| self.step = value.cur_step | |||
| self._cur_full_name = value.cur_node | |||
| self._backend = value.backend if value.backend else "Ascend" | |||
| log.debug("Put metadata into cache at the %d-th step.", self._step) | |||
| self.backend = value.backend if value.backend else "Ascend" | |||
| log.debug("Put metadata into cache at the %d-th step.", self.step) | |||
| def get(self, filter_condition=None): | |||
| """ | |||
| @@ -190,6 +204,8 @@ class MetadataHandler(StreamHandlerBase): | |||
| 'recommendation_confirmed': self._recommendation_confirmed, | |||
| 'debugger_version': self.debugger_version | |||
| } | |||
| if self.debugger_type == 'offline': | |||
| metadata['total_step_num'] = self.max_step_num | |||
| else: | |||
| if not isinstance(filter_condition, list): | |||
| filter_condition = [filter_condition] | |||
| @@ -28,6 +28,46 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison | |||
| TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) | |||
| class MultiCardTensorHandler: | |||
| """Multi-card Tensor Handler.""" | |||
| def __init__(self): | |||
| self.tensor_handlers = {0: TensorHandler()} | |||
| def set_step(self, step_id): | |||
| """Set step id.""" | |||
| for tensor_handler in self.tensor_handlers.values(): | |||
| tensor_handler.cur_step = step_id | |||
| def get_tensor_handler_by_rank_id(self, rank_id=0, create_if_not_exit=False): | |||
| """get handler by rank id""" | |||
| if rank_id in self.tensor_handlers: | |||
| return self.tensor_handlers.get(rank_id) | |||
| if create_if_not_exit: | |||
| tensor_handler = TensorHandler() | |||
| self.tensor_handlers[rank_id] = tensor_handler | |||
| return tensor_handler | |||
| log.error("There is no rank id %d in MultiCardTensorHandler.", rank_id) | |||
| raise ValueError | |||
| def put(self, value): | |||
| """put graphs into graph_handlers""" | |||
| for rank_id, tensor in value: | |||
| if rank_id not in self.tensor_handlers: | |||
| self.tensor_handlers[rank_id] = TensorHandler() | |||
| self.tensor_handlers[rank_id].put(tensor) | |||
| def get(self, filter_condition=None, rank_id=0): | |||
| """Get the graph of specific node for specific device.""" | |||
| if rank_id in self.tensor_handlers: | |||
| return self.tensor_handlers.get(rank_id).get(filter_condition) | |||
| log.error("There is no rank id %d.", rank_id) | |||
| raise ValueError | |||
| def clean(self): | |||
| """Clean cache.""" | |||
| self.__init__() | |||
| class TensorHandler(StreamHandlerBase): | |||
| """Metadata Handler.""" | |||
| @@ -46,6 +86,11 @@ class TensorHandler(StreamHandlerBase): | |||
| """The property of current step.""" | |||
| return self._cur_step | |||
| @cur_step.setter | |||
| def cur_step(self, step_id): | |||
| """The property of current step.""" | |||
| self._cur_step = step_id | |||
| @property | |||
| def prev_step(self): | |||
| """The property of previous step.""" | |||
| @@ -172,7 +217,7 @@ class TensorHandler(StreamHandlerBase): | |||
| log.error("No tensor named %s at the step %s", name, step) | |||
| raise DebuggerParamValueError("No tensor named {}".format(name)) | |||
| tensor_info = tensor.get_full_info(shape) | |||
| self._update_has_prev_step_field(tensor_info, name, node_type) | |||
| self._update_has_prev_step_field(tensor_info, name, node_type, self.cur_step) | |||
| return {'tensor_value': tensor_info} | |||
| def _get_tensor(self, tensor_name, node_type=None, step=None): | |||
| @@ -198,20 +243,21 @@ class TensorHandler(StreamHandlerBase): | |||
| return tensor | |||
| def _get_basic_info(self, tensor_name, node_type=None): | |||
| def _get_basic_info(self, tensor_name, node_type, step): | |||
| """Get the latest basic tensor info by tensor name.""" | |||
| tensor = self._get_tensor(tensor_name, node_type) | |||
| tensor = self._get_tensor(tensor_name, node_type, step) | |||
| if tensor: | |||
| return tensor.get_basic_info() | |||
| return None | |||
| def update_tensor_history(self, tensor_history): | |||
| def update_tensor_history(self, tensor_history, step=None): | |||
| """ | |||
| Add tensor basic info in tensor_history. | |||
| Args: | |||
| tensor_history (dict): Tensor history, including a list of tensor name and type. | |||
| step (int): The step of tensor info. Default: None. | |||
| Returns: | |||
| list[dict], the list of tensor basic info cache. | |||
| @@ -220,9 +266,9 @@ class TensorHandler(StreamHandlerBase): | |||
| for tensor_info in tensor_history.get('tensor_history'): | |||
| tensor_name = tensor_info.get('full_name') | |||
| node_type = tensor_info.get('node_type') | |||
| basic_info = self._get_basic_info(tensor_name, node_type) | |||
| basic_info = self._get_basic_info(tensor_name, node_type, step) | |||
| # add `has_prev_step` field to tensor basic info. | |||
| missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type) | |||
| missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type, step) | |||
| if basic_info: | |||
| tensor_info.update(basic_info) | |||
| if missing_tensors_info: | |||
| @@ -230,14 +276,14 @@ class TensorHandler(StreamHandlerBase): | |||
| return missed_tensors | |||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): | |||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type, step=None): | |||
| """Update has_prev_step field in tensor info.""" | |||
| missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type) | |||
| if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0: | |||
| missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type, step) | |||
| if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and step > 0: | |||
| tensor_info['has_prev_step'] = True | |||
| return missing_tensors_info | |||
| def _get_missing_tensor_info(self, tensor_name, node_type): | |||
| def _get_missing_tensor_info(self, tensor_name, node_type, step): | |||
| """ | |||
| Get missing tensor infos. | |||
| @@ -248,7 +294,6 @@ class TensorHandler(StreamHandlerBase): | |||
| Returns: | |||
| list, list of missing tensor basic information. | |||
| """ | |||
| step = self.cur_step | |||
| missing_tensors_info = [] | |||
| # check the current step value is missing | |||
| if self._is_tensor_value_missing(tensor_name, step): | |||
| @@ -278,13 +323,13 @@ class TensorHandler(StreamHandlerBase): | |||
| tensor = self._get_tensor(tensor_name, step=step) | |||
| return bool(not tensor or tensor.empty) | |||
| def get_valid_tensor_by_name(self, tensor_name, prev=False): | |||
| def get_valid_tensor_by_name(self, tensor_name, step, prev=False): | |||
| """Get tensor value by name in numpy type.""" | |||
| step = self.prev_step if prev else self.cur_step | |||
| if step < 0: | |||
| log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name) | |||
| target_step = step - 1 if prev else step | |||
| if target_step < 0: | |||
| log.warning("Step %d has no previous value for tensor: %s", target_step, tensor_name) | |||
| return None | |||
| tensor = self._get_tensor(tensor_name, step=step) | |||
| tensor = self._get_tensor(tensor_name, step=target_step) | |||
| if tensor and tensor.empty: | |||
| log.warning("%s has empty value.", tensor_name) | |||
| return None | |||
| @@ -316,9 +361,9 @@ class TensorHandler(StreamHandlerBase): | |||
| self._tensors.pop(param) | |||
| log.debug("Clean param %s in cache.", param) | |||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0): | |||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0, step=None): | |||
| """ | |||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||
| Args: | |||
| tensor_name (str): The name of tensor for cache. | |||
| @@ -329,6 +374,7 @@ class TensorHandler(StreamHandlerBase): | |||
| calculate the min value and max value of the result of the current step tensor subtract | |||
| the previous step tensor. If the absolute value of result is less than or equal to | |||
| boundary value, the result will set to be zero. | |||
| step (int): The step of the tensor. Default: None. | |||
| Raises: | |||
| DebuggerParamValueError, If get current step node and previous step node failed or | |||
| @@ -337,8 +383,8 @@ class TensorHandler(StreamHandlerBase): | |||
| Returns: | |||
| dict, the retrieved data. | |||
| """ | |||
| curr_tensor = self.get_valid_tensor_by_name(tensor_name) | |||
| prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True) | |||
| curr_tensor = self.get_valid_tensor_by_name(tensor_name, step=step) | |||
| prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True, step=step) | |||
| if not (curr_tensor and prev_tensor): | |||
| log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) | |||
| raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " | |||
| @@ -386,22 +432,23 @@ class TensorHandler(StreamHandlerBase): | |||
| stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats) | |||
| return stats_info | |||
| def get_tensor_info_for_tensor_graph(self, tensor_name, node_type): | |||
| def get_tensor_info_for_tensor_graph(self, tensor_name, node_type, step): | |||
| """ | |||
| Get Tensor info for tensor graphs. | |||
| Args: | |||
| tensor_name (str): Tensor name, format like `node_name:slot`. | |||
| node_type (str): Node type. | |||
| step (int): The step of tensor info. | |||
| Returns: | |||
| dict, tensor infos, including overall statistics, tensor shape and has_prev_step info. | |||
| list, list of missing tensor basic information. | |||
| """ | |||
| res = {} | |||
| tensor = self._get_tensor(tensor_name, node_type) | |||
| tensor = self._get_tensor(tensor_name, node_type, step) | |||
| if tensor and not tensor.empty: | |||
| res['statistics'] = tensor.get_tensor_statistics() | |||
| res['shape'] = tensor.shape | |||
| missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type) | |||
| missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type, step) | |||
| return res, missing_tensors | |||
| @@ -105,12 +105,12 @@ class WatchpointHandler(StreamHandlerBase): | |||
| return {'watch_points': reply} | |||
| def get_pending_commands(self, graph_stream): | |||
| def get_pending_commands(self, multi_card_graph_stream): | |||
| """ | |||
| Get all watchpoint in SetCMD proto format. | |||
| Args: | |||
| graph_stream (GraphHandler): Graph handler. | |||
| multi_card_graph_stream (MultiCardGraphHandler): Multi card graph handler. | |||
| Returns: | |||
| list[SetCMD], updated watchpoint to be sent to MindSpore. | |||
| @@ -118,9 +118,13 @@ class WatchpointHandler(StreamHandlerBase): | |||
| newly_set_cmds = [] | |||
| for _, watchpoint in self._updated_watchpoints.items(): | |||
| # construct set command with leaf nodes | |||
| watch_nodes = watchpoint.get_watch_nodes() | |||
| leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes) | |||
| newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes)) | |||
| watch_nodes_for_devices = watchpoint.get_watch_nodes() | |||
| leaf_watch_nodes_for_devices = {} | |||
| for rank_id, watch_nodes in watch_nodes_for_devices.items(): | |||
| graph_stream = multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id) | |||
| leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes) | |||
| leaf_watch_nodes_for_devices[rank_id] = leaf_watch_nodes | |||
| newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes_for_devices)) | |||
| newly_set_cmds.extend(self._deleted_watchpoints) | |||
| self.sync_set_cmd(newly_set_cmds) | |||
| @@ -161,7 +165,7 @@ class WatchpointHandler(StreamHandlerBase): | |||
| """ | |||
| return self._outdated | |||
| def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None): | |||
| def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None, rank_id=0): | |||
| """ | |||
| set watch nodes for graph. | |||
| @@ -170,23 +174,24 @@ class WatchpointHandler(StreamHandlerBase): | |||
| graph_stream (GraphHandler): The graph handler. | |||
| watch_point_id (int): The id of watchpoint. | |||
| graph_name (str): The graph name. | |||
| rank_id (int): The rank id. | |||
| """ | |||
| if not (watch_point_id and graph): | |||
| return | |||
| log.debug("add watch flags") | |||
| watchpoint = self._watchpoints.get(watch_point_id) | |||
| self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name) | |||
| self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name, rank_id) | |||
| def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None): | |||
| def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None, rank_id=0): | |||
| """Set watch status to graph.""" | |||
| if graph.get('children'): | |||
| self._set_watch_status_recursively( | |||
| graph.get('children'), graph_stream, watchpoint, graph_name) | |||
| graph.get('children'), graph_stream, watchpoint, graph_name, rank_id=0) | |||
| if graph.get('nodes'): | |||
| _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name) | |||
| _ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name, rank_id) | |||
| def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name): | |||
| def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name, rank_id=0): | |||
| """ | |||
| Set watch state for nodes. | |||
| @@ -204,11 +209,11 @@ class WatchpointHandler(StreamHandlerBase): | |||
| node_name = node.get('name') | |||
| # search result could have `nodes` in nodes object | |||
| if node.get('nodes'): | |||
| flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name) | |||
| flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name, rank_id) | |||
| else: | |||
| full_name = graph_stream.get_full_name(node_name, graph_name) | |||
| new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name]) | |||
| flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name) | |||
| flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name, rank_id) | |||
| node['watched'] = flag | |||
| if flag == WatchNodeTree.NOT_WATCH: | |||
| continue | |||
| @@ -224,7 +229,8 @@ class WatchpointHandler(StreamHandlerBase): | |||
| state = WatchNodeTree.TOTAL_WATCH | |||
| return state | |||
| def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None): | |||
| def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None, | |||
| device_amount=8): | |||
| """ | |||
| Create watchpoint. | |||
| Args: | |||
| @@ -241,9 +247,10 @@ class WatchpointHandler(StreamHandlerBase): | |||
| } | |||
| - id (str): Id of condition. | |||
| - param (list[dict]): The list of param for this condition. | |||
| watch_nodes (list[NodeBasicInfo]): The list of node basic info. | |||
| watch_nodes (dict[list[NodeBasicInfo]]): The list of node basic info. | |||
| watch_point_id (int): The id of watchpoint. | |||
| name (str): The name of watchpoint. | |||
| device_amount (int): The amount of devices. | |||
| Returns: | |||
| int, the new id of watchpoint. | |||
| @@ -253,7 +260,9 @@ class WatchpointHandler(StreamHandlerBase): | |||
| new_id = self._latest_id + 1 | |||
| watchpoint = Watchpoint(new_id, watch_condition, name) | |||
| if watch_nodes: | |||
| watchpoint.add_nodes(watch_nodes) | |||
| for rank_id, watch_nodes_for_device in watch_nodes.items(): | |||
| validate_rank_id(rank_id, device_amount) | |||
| watchpoint.add_nodes(watch_nodes_for_device, rank_id) | |||
| elif watch_point_id: | |||
| self.validate_watchpoint_id(watch_point_id) | |||
| watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) | |||
| @@ -261,7 +270,7 @@ class WatchpointHandler(StreamHandlerBase): | |||
| self._outdated = True | |||
| return new_id | |||
| def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): | |||
| def update_watchpoint(self, watch_point_id, watch_nodes, watched=False, rank_id=0): | |||
| """ | |||
| Update watchpoint. | |||
| @@ -270,13 +279,14 @@ class WatchpointHandler(StreamHandlerBase): | |||
| watch_nodes (list[NodeBasicInfo]): The list of node basic info. | |||
| watched (bool): The update operator on nodes. If False, remove nodes from watch nodes. | |||
| If True, add nodes to watch nodes. Default: False. | |||
| rank_id (int): The rank id. | |||
| """ | |||
| self.validate_watchpoint_id(watch_point_id) | |||
| watchpoint = self._watchpoints.get(watch_point_id) | |||
| if watched: | |||
| watchpoint.add_nodes(watch_nodes) | |||
| watchpoint.add_nodes(watch_nodes, rank_id) | |||
| else: | |||
| watchpoint.remove_nodes(watch_nodes) | |||
| watchpoint.remove_nodes(watch_nodes, rank_id) | |||
| self._updated_watchpoints[watch_point_id] = watchpoint | |||
| self._outdated = True | |||
| log.debug("Update watchpoint %d in cache.", watch_point_id) | |||
| @@ -328,6 +338,58 @@ class WatchpointHandler(StreamHandlerBase): | |||
| raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) | |||
| class MultiCardWatchpointHitHandler: | |||
| """Multi-card Watchpoint-hit Handler.""" | |||
| def __init__(self): | |||
| self.watchpoint_hit_handlers = {0: WatchpointHitHandler()} | |||
| def get_hit_handler_by_rank_id(self, rank_id=0): | |||
| """Get handler by rank id.""" | |||
| if rank_id in self.watchpoint_hit_handlers: | |||
| return self.watchpoint_hit_handlers.get(rank_id) | |||
| log.error("There is no rank id %d.", rank_id) | |||
| raise ValueError | |||
| def put(self, value): | |||
| """Put watchpoint hit into cache.""" | |||
| for rank_id, tensor_hit_values in value.items(): | |||
| if rank_id not in self.watchpoint_hit_handlers: | |||
| self.watchpoint_hit_handlers[rank_id] = WatchpointHitHandler() | |||
| cur_hit_handler = self.watchpoint_hit_handlers[rank_id] | |||
| for tensor_hit_value in tensor_hit_values: | |||
| cur_hit_handler.put(tensor_hit_value) | |||
| def get(self, filter_condition=None, rank_id=0): | |||
| """Get the graph of specific node for specific device.""" | |||
| if rank_id in self.watchpoint_hit_handlers: | |||
| return self.watchpoint_hit_handlers.get(rank_id).get(filter_condition) | |||
| log.error("There is no rank id %d.", rank_id) | |||
| raise ValueError | |||
| def update_tensor_history(self, tensor_history, rank_id): | |||
| """ | |||
| Add hit flag to tensor history. | |||
| Args: | |||
| tensor_history (dict): The tensor history. | |||
| rank_id (int): The rank id. | |||
| """ | |||
| if rank_id in self.watchpoint_hit_handlers: | |||
| self.watchpoint_hit_handlers[rank_id].update_tensor_history(tensor_history) | |||
| else: | |||
| for tensor_info in tensor_history.get('tensor_history'): | |||
| tensor_info['is_hit'] = False | |||
| def check_rank_id(self, rank_id): | |||
| """check if has the rank id.""" | |||
| return rank_id in self.watchpoint_hit_handlers | |||
| def clean(self): | |||
| """Clean cache.""" | |||
| self.__init__() | |||
| class WatchpointHitHandler(StreamHandlerBase): | |||
| """Watchpoint hit handler.""" | |||
| @@ -743,3 +805,9 @@ def _get_error_list(error_code): | |||
| error_list.append(error_str) | |||
| return error_list | |||
| def validate_rank_id(rank_id, device_amount): | |||
| """validate rank id""" | |||
| if rank_id >= device_amount: | |||
| log.debug("The rank id %d over device amount.", rank_id) | |||
| @@ -23,17 +23,19 @@ class TensorDetailInfo: | |||
| def __init__(self, cache): | |||
| self._put_command = cache.put_command | |||
| self._tensor_stream = cache.get_stream_handler(Streams.TENSOR) | |||
| self._graph_stream = cache.get_stream_handler(Streams.GRAPH) | |||
| self._hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| self._metadata_stream = cache.get_stream_handler(Streams.METADATA) | |||
| self._multi_card_tensor_stream = cache.get_stream_handler(Streams.TENSOR) | |||
| self._multi_card_graph_stream = cache.get_stream_handler(Streams.GRAPH) | |||
| self._multi_card_hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| def validate_tensor_name(self, tensor_name, graph_name): | |||
| def validate_tensor_name(self, tensor_name, graph_name, rank_id): | |||
| """ | |||
| Get the graph id of the tensor. | |||
| Args: | |||
| tensor_name (str): The tensor name on UI. | |||
| graph_name (str): The graph name. | |||
| rank_id (int): The rank id. | |||
| """ | |||
| # validate tensor name format | |||
| if not isinstance(tensor_name, str) or ':' not in tensor_name: | |||
| @@ -41,15 +43,17 @@ class TensorDetailInfo: | |||
| raise DebuggerParamValueError("Invalid tensor name.") | |||
| node_name, _ = tensor_name.rsplit(':', 1) | |||
| # check if the node name is in graph | |||
| self._graph_stream.validate_node_name(node_name=node_name, graph_name=graph_name) | |||
| self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).validate_node_name(node_name=node_name, | |||
| graph_name=graph_name) | |||
| def get_tensor_graph(self, tensor_name, graph_name): | |||
| def get_tensor_graph(self, tensor_name, graph_name, rank_id=0): | |||
| """ | |||
| Get the graph related to specific tensor. | |||
| Args: | |||
| tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. | |||
| graph_name (str): The graph name. | |||
| rank_id (int): The rank id. | |||
| Returns: | |||
| dict, tensor graph, format is {'nodes': [Node object]}. | |||
| @@ -68,8 +72,9 @@ class TensorDetailInfo: | |||
| 'slot_mapping': list[pair<slot, slot>], | |||
| }. | |||
| """ | |||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) | |||
| graph = self._graph_stream.get_tensor_graph(tensor_name, graph_name) | |||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name, rank_id=rank_id) | |||
| graph = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_tensor_graph(tensor_name, | |||
| graph_name) | |||
| # add watchpoint hits info and statistics info for each tensor in tensor graph. | |||
| # record missing tensor basic info | |||
| nodes = graph.get('graph', {}).get('nodes', []) | |||
| @@ -77,13 +82,13 @@ class TensorDetailInfo: | |||
| for node in nodes: | |||
| node['graph_name'] = graph_name | |||
| for slot_info in node.get('slots', []): | |||
| self._add_watchpoint_hit_info(slot_info, node, graph_name) | |||
| self._add_tensor_info(slot_info, node, missing_tensors) | |||
| self._add_watchpoint_hit_info(slot_info, node, graph_name, rank_id) | |||
| self._add_tensor_info(slot_info, node, missing_tensors, rank_id) | |||
| # query missing tensor values from client | |||
| self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) | |||
| return graph | |||
| def _add_watchpoint_hit_info(self, slot_info, node, graph_name): | |||
| def _add_watchpoint_hit_info(self, slot_info, node, graph_name, rank_id): | |||
| """ | |||
| Add watchpoint hit info for the tensor. | |||
| @@ -93,9 +98,12 @@ class TensorDetailInfo: | |||
| graph_name (str): Graph name. | |||
| """ | |||
| tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) | |||
| slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name)) | |||
| if self._multi_card_hit_stream.check_rank_id(rank_id=rank_id): | |||
| slot_info.update( | |||
| self._multi_card_hit_stream.get_hit_handler_by_rank_id(rank_id).get_tensor_hit_infos(tensor_name, | |||
| graph_name)) | |||
| def _add_tensor_info(self, slot_info, node, missing_tensors): | |||
| def _add_tensor_info(self, slot_info, node, missing_tensors, rank_id): | |||
| """ | |||
| Add the tensor info and query for missed tensors. | |||
| @@ -106,7 +114,8 @@ class TensorDetailInfo: | |||
| """ | |||
| tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')]) | |||
| node_type = node.get('type') | |||
| tensor_info, cur_missing_tensors = self._tensor_stream.get_tensor_info_for_tensor_graph(tensor_name, node_type) | |||
| tensor_info, cur_missing_tensors = self._multi_card_tensor_stream.get_tensor_handler_by_rank_id( | |||
| rank_id).get_tensor_info_for_tensor_graph(tensor_name, node_type, self._metadata_stream.step) | |||
| slot_info.update(tensor_info) | |||
| if cur_missing_tensors: | |||
| log.debug("Get missing tensor basic infos for %s", tensor_name) | |||
| @@ -128,20 +137,24 @@ class TensorDetailInfo: | |||
| self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name}) | |||
| log.debug("Send view cmd for tensor-graphs.") | |||
| def get_tensor_watch_points(self, tensor_name, graph_name): | |||
| def get_tensor_watch_points(self, tensor_name, graph_name, rank_id=0): | |||
| """ | |||
| Get all watchpoints that the tensor hit. | |||
| Args: | |||
| tensor_name (str): Tensor name from UI. | |||
| graph_name (str): The graph name. | |||
| rank_id (int): The rank id. | |||
| Returns: | |||
| list, watchpoint hit infos. | |||
| """ | |||
| # validate tensor_name | |||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name) | |||
| self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name, rank_id=rank_id) | |||
| # get watchpoint info that the tensor hit | |||
| tensor_hit_info = self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name) | |||
| if not self._multi_card_hit_stream.check_rank_id(rank_id=rank_id): | |||
| return [] | |||
| tensor_hit_info = self._multi_card_hit_stream.get_hit_handler_by_rank_id(rank_id).get_tensor_hit_infos( | |||
| tensor_name, graph_name) | |||
| watch_points = tensor_hit_info.get('watch_points', []) | |||
| return watch_points | |||
| @@ -18,7 +18,8 @@ import enum | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \ | |||
| DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type | |||
| from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type, \ | |||
| DebuggerServerMode | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| @@ -29,6 +30,7 @@ class ControlTypeEnum(enum.Enum): | |||
| CONTINUE = 'continue' # continue to run training | |||
| PAUSE = 'pause' # suspend training | |||
| TERMINATE = 'terminate' # terminate training | |||
| RESET = 'reset' # reset the step_id in offline debugger | |||
| class TrainingControlOperator: | |||
| @@ -39,7 +41,7 @@ class TrainingControlOperator: | |||
| def __init__(self, cache_store): | |||
| self._cache_store = cache_store | |||
| self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||
| self._multi_card_graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | |||
| @staticmethod | |||
| @@ -71,6 +73,9 @@ class TrainingControlOperator: | |||
| """ | |||
| if mode == ControlTypeEnum.CONTINUE.value: | |||
| reply = self.continue_training(params) | |||
| elif mode == ControlTypeEnum.RESET.value: | |||
| step_id = params['steps'] | |||
| reply = self.reset_training_step(step_id) | |||
| else: | |||
| mode_mapping = { | |||
| ControlTypeEnum.PAUSE.value: self.pause_training, | |||
| @@ -150,13 +155,15 @@ class TrainingControlOperator: | |||
| if level == RunLevel.NODE.value: | |||
| node_name = params.get('name') | |||
| graph_name = params.get('graph_name') | |||
| self._validate_continue_node_name(node_name, graph_name) | |||
| rank_id = params.get('rank_id', 0) | |||
| self._validate_continue_node_name(node_name, graph_name, rank_id) | |||
| def _validate_continue_node_name(self, node_name, graph_name): | |||
| def _validate_continue_node_name(self, node_name, graph_name, rank_id): | |||
| """Validate if the node is a leaf node.""" | |||
| if not node_name: | |||
| return | |||
| node_type = self._graph_stream.get_node_type(node_name, graph_name) | |||
| node_type = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_node_type(node_name, | |||
| graph_name) | |||
| if is_scope_type(node_type): | |||
| log.error("Scope type node has no tensor history.") | |||
| raise DebuggerParamValueError("Invalid leaf node name.") | |||
| @@ -188,7 +195,9 @@ class TrainingControlOperator: | |||
| name = params.get('name', '') | |||
| graph_name = params.get('graph_name') | |||
| if name: | |||
| name = self._cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name) | |||
| rank_id = params.get('rank_id', 0) | |||
| name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_full_name(name, | |||
| graph_name) | |||
| run_cmd = RunCMD(run_level='node', node_name=name) | |||
| else: | |||
| run_cmd = RunCMD(run_level='recheck') | |||
| @@ -199,7 +208,7 @@ class TrainingControlOperator: | |||
| def _send_watchpoints(self): | |||
| """Send watchpoints to client.""" | |||
| set_commands = self._watchpoint_stream.get_pending_commands(self._graph_stream) | |||
| set_commands = self._watchpoint_stream.get_pending_commands(self._multi_card_graph_stream) | |||
| if not set_commands: | |||
| return | |||
| for set_cmd in set_commands: | |||
| @@ -274,3 +283,30 @@ class TrainingControlOperator: | |||
| else: | |||
| log.debug("Send the recheck to command queue.") | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def reset_training_step(self, step_id): | |||
| """ | |||
| Reset the training step. | |||
| Args: | |||
| step_id (int): The target step_id. | |||
| Returns: | |||
| dict, metadata info. | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.debugger_type == DebuggerServerMode.ONLINE.value: | |||
| log.error("'step_id' can not be changed manually in online debugger.") | |||
| return metadata_stream.get(['state', 'enable_recheck', 'step']) | |||
| if step_id > metadata_stream.max_step_num: | |||
| log.error("Invalid step_id, step_id should be less than %d.", metadata_stream.max_step_num) | |||
| raise DebuggerParamValueError("Invalid step_id.") | |||
| metadata_stream.state = ServerStatus.SENDING.value | |||
| metadata_stream.step = step_id | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id) | |||
| self._cache_store.clean_data() | |||
| self._cache_store.clean_command() | |||
| metadata_stream.enable_recheck = False | |||
| metadata_stream.state = ServerStatus.WAITING.value | |||
| log.debug("Send the Change_training_step CMD.") | |||
| return metadata_stream.get(['state', 'enable_recheck', 'step']) | |||
| @@ -31,8 +31,9 @@ class WatchpointOperator: | |||
| def __init__(self, cache_store, condition_mgr): | |||
| self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||
| self._multi_card_graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | |||
| self._device_stream = cache_store.get_stream_handler(Streams.DEVICE) | |||
| self._condition_mgr = condition_mgr | |||
| def create_watchpoint(self, params): | |||
| @@ -70,11 +71,6 @@ class WatchpointOperator: | |||
| "Failed to create watchpoint as the MindSpore is not in waiting state.") | |||
| self._validate_watch_condition(watch_condition) | |||
| watch_nodes = self._get_watch_node_with_basic_info( | |||
| node_names=params.get('watch_nodes'), | |||
| search_pattern=params.get('search_pattern'), | |||
| graph_name=params.get('graph_name')) | |||
| validate_watch_condition(self._condition_mgr, watch_condition) | |||
| condition_id = watch_condition.get('id') | |||
| condition = self._condition_mgr.get_condition(condition_id) | |||
| @@ -84,10 +80,11 @@ class WatchpointOperator: | |||
| raise DebuggerConditionUnavailableError( | |||
| "Failed to create watchpoint as the condition is not available.") | |||
| watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy() | |||
| watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._multi_card_graph_stream) | |||
| watchpoint_stream = self._watchpoint_stream | |||
| watch_point_id = watchpoint_stream.create_watchpoint( | |||
| self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) | |||
| watch_point_id = watchpoint_stream.create_watchpoint(self._condition_mgr, watch_condition, watch_nodes, | |||
| params.get('watch_point_id'), | |||
| self._device_stream.device_amount) | |||
| log.info("Create watchpoint %d", watch_point_id) | |||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() | |||
| @@ -115,6 +112,7 @@ class WatchpointOperator: | |||
| 1 for add nodes to watch nodes. | |||
| - search_pattern (dict): The search pattern. | |||
| - graph_name (str): The relative graph_name of the watched node. | |||
| - rank_id (int): The rank id. | |||
| Returns: | |||
| dict, the metadata info. | |||
| @@ -137,13 +135,14 @@ class WatchpointOperator: | |||
| watch_nodes = self._get_watch_node_with_basic_info( | |||
| node_names=params.get('watch_nodes'), | |||
| search_pattern=params.get('search_pattern'), | |||
| graph_name=params.get('graph_name')) | |||
| watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode')) | |||
| graph_name=params.get('graph_name'), | |||
| rank_id=params.get('rank_id', 0)) | |||
| watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode'), params.get('rank_id', 0)) | |||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() | |||
| log.info("Update watchpoint with id: %d", watch_point_id) | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None): | |||
| def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None, rank_id=0): | |||
| """ | |||
| Get watch node with basic info. | |||
| @@ -151,20 +150,21 @@ class WatchpointOperator: | |||
| node_names (list[str]): A list of node names. | |||
| search_pattern (dict): Get watch node with search pattern. Default: None | |||
| graph_name (str): The relative graph_name of the watched node. Default: None. | |||
| rank_id (int): The rank id. | |||
| Returns: | |||
| list[NodeBasicInfo], a list of node basic infos. | |||
| """ | |||
| if not node_names: | |||
| return [] | |||
| graph_name = self._graph_stream.validate_graph_name(graph_name) | |||
| graph_name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).validate_graph_name(graph_name) | |||
| if search_pattern is not None: | |||
| watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name) | |||
| watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name, rank_id) | |||
| else: | |||
| watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name) | |||
| watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name, rank_id=rank_id) | |||
| return watch_nodes | |||
| def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name): | |||
| def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name, rank_id): | |||
| """ | |||
| Get watched leaf nodes by search name. | |||
| @@ -180,7 +180,7 @@ class WatchpointOperator: | |||
| list[NodeBasicInfo], a list of node basic infos. | |||
| """ | |||
| search_pattern['graph_name'] = graph_name | |||
| search_nodes = self._graph_stream.search_nodes(search_pattern) | |||
| search_nodes = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).search_nodes(search_pattern) | |||
| watch_node_names = set() | |||
| for name in node_names: | |||
| names = self._get_watch_names_by_search(search_nodes, name) | |||
| @@ -260,7 +260,7 @@ class WatchpointOperator: | |||
| log.info("Delete watchpoint with id: %s", watch_point_id) | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _get_node_basic_infos(self, node_names, graph_name=None): | |||
| def _get_node_basic_infos(self, node_names, graph_name=None, rank_id=0): | |||
| """ | |||
| Get watch node info according to node names. | |||
| @@ -273,7 +273,7 @@ class WatchpointOperator: | |||
| """ | |||
| if not node_names: | |||
| return [] | |||
| graph_stream = self._graph_stream | |||
| graph_stream = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id) | |||
| node_infos = [] | |||
| for node_name in node_names: | |||
| node_info = graph_stream.get_node_basic_info(node_name, graph_name) | |||
| @@ -26,7 +26,7 @@ limitations under the License. | |||
| <div class="cl-center" | |||
| :class="showWarmText ? 'cl-center-height' : ''"> | |||
| <router-view></router-view> | |||
| <router-view :key="$route.fullPath"></router-view> | |||
| </div> | |||
| </div> | |||
| </template> | |||
| @@ -362,8 +362,9 @@ export default { | |||
| const params = { | |||
| tensor_name: this.curRowObj.name, | |||
| graph_name: this.curRowObj.graph_name, | |||
| rank_id: this.curRowObj.rank_id, | |||
| }; | |||
| RequestService.getTensorGraphData(params).then( | |||
| RequestService.getTensorGraphData(params, this.curRowObj.sessionId).then( | |||
| (res) => { | |||
| if (res && res.data && res.data.graph && res.data.graph.nodes && res.data.graph.nodes.length) { | |||
| this.graphShow = true; | |||
| @@ -419,8 +420,9 @@ export default { | |||
| const params = { | |||
| tensor_name: this.curRowObj.name, | |||
| graph_name: this.curRowObj.graph_name, | |||
| rank_id: this.curRowObj.rank_id, | |||
| }; | |||
| RequestService.tensorHitsData(params).then( | |||
| RequestService.tensorHitsData(params, this.curRowObj.sessionId).then( | |||
| (res) => { | |||
| if (res && res.data && res.data.watch_points && res.data.watch_points.length) { | |||
| this.leftDataShow = true; | |||
| @@ -995,11 +997,12 @@ export default { | |||
| shape: encodeURIComponent(shape), | |||
| tolerance: this.tolerance / 100, | |||
| graph_name: row.graph_name, | |||
| rank_id: row.rank_id, | |||
| }; | |||
| if (loadingFlag) { | |||
| this.loadingInstance = this.$loading(this.loadingOption); | |||
| } | |||
| RequestService.tensorComparisons(params).then( | |||
| RequestService.tensorComparisons(params, row.sessionId).then( | |||
| (res) => { | |||
| if (res && res.data && res.data.tensor_value) { | |||
| if (row.shape === '[]') { | |||
| @@ -1088,11 +1091,12 @@ export default { | |||
| shape: encodeURIComponent(shape), | |||
| graph_name: row.graph_name, | |||
| prev: this.gridType === 'preStep' ? true : false, | |||
| rank_id: row.rank_id, | |||
| }; | |||
| if (loadingFlag) { | |||
| this.loadingInstance = this.$loading(this.loadingOption); | |||
| } | |||
| RequestService.tensors(params).then( | |||
| RequestService.tensors(params, row.sessionId).then( | |||
| (res) => { | |||
| if (row.shape === '[]') { | |||
| this.showFilterInput = false; | |||
| @@ -24,7 +24,9 @@ | |||
| "dataLoading": "Loading data...", | |||
| "notice": "Information", | |||
| "caseMode": "Not case sensitive", | |||
| "all": "All" | |||
| "all": "All", | |||
| "details": "Details", | |||
| "delete": "Delete" | |||
| }, | |||
| "symbols": { | |||
| "leftbracket": "(", | |||
| @@ -52,12 +54,14 @@ | |||
| "operation": "Operation", | |||
| "viewDashboard": "Training Dashboard", | |||
| "viewProfiler": "Profiling", | |||
| "viewOfflineDebugger": "Offline Debugger", | |||
| "modelTraceback": "Model Lineage", | |||
| "dataTraceback": "Dataset Lineage", | |||
| "comparePlate": "Comparison Dashboard", | |||
| "disableProfilerTip": "Failed to view profiling because no profiler log is available.", | |||
| "disableDashboardTip": "Failed to view training dashboard because no summary log or pb files are available.", | |||
| "disableParameterTip": "Failed to view parameter details because no lineage log is available.", | |||
| "disableOfflineDebugger": "Failed to view offline debugger because no debugger log is available.", | |||
| "openNewTab": "Open Link in New Tab", | |||
| "paramDetails": "Parameter Details", | |||
| "trainingParamDetails": "Training Parameter Details", | |||
| @@ -80,7 +84,12 @@ | |||
| "tensorUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#tensor-visualization", | |||
| "graphUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#computational-graph-visualization", | |||
| "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#dataset-graph-visualization", | |||
| "imageUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#image-visualization" | |||
| "imageUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#image-visualization", | |||
| "sessionLimit": "The number of sessions of the offline debugger exceeds the number of online sessions", | |||
| "sessionLimitNum": "At most 2 exist at the same time", | |||
| "sessionLists": "List of currently existing sessions", | |||
| "deleteSessionConfirm": "This operation will delete the current session, do you want to continue?", | |||
| "deleteSessionSuccess": "Delete session successfully!" | |||
| }, | |||
| "modelTraceback": { | |||
| "summaryPath": "Summary Path", | |||
| @@ -561,7 +570,7 @@ | |||
| "terminate": "TERMINATE", | |||
| "selectCondition": "Select a condition", | |||
| "inputStep": "Enter a step value", | |||
| "inputTip": "A positive integer less than 2147483648", | |||
| "inputTip": "A positive integer less than or equal to {total_step_num}", | |||
| "curHitNode": "Watch Point Hit List", | |||
| "backstageStatus": "The backend running status is ", | |||
| "view": "View", | |||
| @@ -830,7 +839,9 @@ | |||
| "allPositive": "he parameter value must be greater than 0.", | |||
| "watchOverflow": "The asynchronous full overflow watching function must be enabled before the training starts." | |||
| }, | |||
| "paramValueTip": "Preset Value: {value}" | |||
| "paramValueTip": "Preset Value: {value}", | |||
| "logicCard": "Logic card", | |||
| "inpStepTip": "Step:0~{total_step_num}" | |||
| }, | |||
| "explain": { | |||
| "explain": "Model Explanation", | |||
| @@ -952,6 +963,7 @@ | |||
| "5054B183": "Backend training is in progress or has ended. Please try again later", | |||
| "5054B184": "The operation is too fast, the backend service has been suspended.", | |||
| "5054B189": "Do not set the value repeatedly.", | |||
| "5054B083": "Failed to create the watchpoint. Do not use invalid rules." | |||
| "5054B083": "Failed to create the watchpoint. Do not use invalid rules.", | |||
| "5054B202": "The debugger offline server module was not found" | |||
| } | |||
| } | |||
| @@ -24,7 +24,9 @@ | |||
| "dataLoading": "数据加载中", | |||
| "notice": "提示", | |||
| "caseMode": "不区分大小写", | |||
| "all": "全部" | |||
| "all": "全部", | |||
| "details": "详情", | |||
| "delete": "删除" | |||
| }, | |||
| "symbols": { | |||
| "leftbracket": "(", | |||
| @@ -52,12 +54,14 @@ | |||
| "operation": "操作", | |||
| "viewDashboard": "训练看板", | |||
| "viewProfiler": "性能分析", | |||
| "viewOfflineDebugger": "离线调试器", | |||
| "modelTraceback": "模型溯源", | |||
| "dataTraceback": "数据溯源", | |||
| "comparePlate": "对比看板", | |||
| "disableProfilerTip": "无profiler日志,无法查看性能分析", | |||
| "disableDashboardTip": "无summary日志或pb文件,无法查看训练看板", | |||
| "disableParameterTip": "无lineage日志,无法查看参数详情", | |||
| "disableOfflineDebugger": "无Debugger日志,无法查看离线调试器", | |||
| "openNewTab": "打开新页签", | |||
| "paramDetails": "参数详情", | |||
| "trainingParamDetails": "训练参数详情", | |||
| @@ -80,7 +84,12 @@ | |||
| "tensorUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id8", | |||
| "graphUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id5", | |||
| "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id6", | |||
| "imageUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id7" | |||
| "imageUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id7", | |||
| "sessionLimit": "离线调试器的session个数超过上线", | |||
| "sessionLimitNum": "最多同时存在2个", | |||
| "sessionLists": "目前存在的session列表", | |||
| "deleteSessionConfirm": "此操作将删除当前session, 是否继续?", | |||
| "deleteSessionSuccess": "删除session成功!" | |||
| }, | |||
| "modelTraceback": { | |||
| "summaryPath": "训练日志路径", | |||
| @@ -560,7 +569,7 @@ | |||
| "terminate": "结束", | |||
| "selectCondition": "请选择条件", | |||
| "inputStep": "请输入轮次值", | |||
| "inputTip": "小于2147483648的正整数", | |||
| "inputTip": "小于等于{total_step_num}的正整数", | |||
| "curHitNode": "命中的监测点", | |||
| "backstageStatus": "后台运行状态是", | |||
| "view": "查看", | |||
| @@ -825,7 +834,9 @@ | |||
| "allPositive": "此参数值必须大于0", | |||
| "watchOverflow": "训练开始前需开启异步全量溢出监测功能" | |||
| }, | |||
| "paramValueTip": "设置值为:{value}" | |||
| "paramValueTip": "设置值为:{value}", | |||
| "logicCard": "逻辑卡", | |||
| "inpStepTip": "可输入当前轮次:0~{total_step_num}" | |||
| }, | |||
| "explain": { | |||
| "explain": "模型解释", | |||
| @@ -947,6 +958,7 @@ | |||
| "5054B183": "后台训练运行中,请稍后重试", | |||
| "5054B184": "操作过快,后台服务已暂停。", | |||
| "5054B189": "请勿重复设置。", | |||
| "5054B083": "监测点创建失败,请勿使用已失效规则。" | |||
| "5054B083": "监测点创建失败,请勿使用已失效规则。", | |||
| "5054B202": "未找到调试器离线服务器模块" | |||
| } | |||
| } | |||
| @@ -157,6 +157,10 @@ export default new Router({ | |||
| path: '/debugger', | |||
| component: () => import('./views/debugger/debugger.vue'), | |||
| }, | |||
| { | |||
| path: '/offline-debugger', | |||
| component: () => import('./views/debugger/debugger.vue'), | |||
| }, | |||
| { | |||
| path: '/explain', | |||
| component: () => import('./views/explain/summary-list.vue'), | |||
| @@ -62,7 +62,14 @@ axios.interceptors.response.use( | |||
| const errorData = i18n.messages[i18n.locale].error; | |||
| const path = router.currentRoute.path; | |||
| if (path === '/debugger') { | |||
| if (path === '/debugger' || path === '/offline-debugger') { | |||
| if ( | |||
| error.response && | |||
| error.response.data && | |||
| error.response.data.error_code === '5054B281' | |||
| ) { | |||
| router.push('/'); | |||
| } | |||
| return Promise.reject(error); | |||
| } | |||
| // error returned by backend | |||
| @@ -309,55 +309,74 @@ export default { | |||
| }); | |||
| }, | |||
| // debugger | |||
| pollData(params) { | |||
| getSession(params) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: 'v1/mindinsight/debugger/sessions', | |||
| data: params, | |||
| }); | |||
| }, | |||
| deleteSession(sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/delete`, | |||
| }); | |||
| }, | |||
| checkSessions() { | |||
| return axios({ | |||
| method: 'get', | |||
| url: 'v1/mindinsight/debugger/poll-data', | |||
| url: `v1/mindinsight/debugger/sessions`, | |||
| }); | |||
| }, | |||
| pollData(params, sessionId) { | |||
| return axios({ | |||
| method: 'get', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/poll-data`, | |||
| params: params, | |||
| headers: { | |||
| ignoreError: true, | |||
| }, | |||
| }); | |||
| }, | |||
| retrieve(params) { | |||
| retrieve(params, sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: 'v1/mindinsight/debugger/retrieve', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/retrieve`, | |||
| data: params, | |||
| }); | |||
| }, | |||
| createWatchpoint(params) { | |||
| createWatchpoint(params, sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: 'v1/mindinsight/debugger/create-watchpoint', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/create-watchpoint`, | |||
| data: params, | |||
| }); | |||
| }, | |||
| updateWatchpoint(params) { | |||
| updateWatchpoint(params, sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: 'v1/mindinsight/debugger/update-watchpoint', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/update-watchpoint`, | |||
| data: params, | |||
| }); | |||
| }, | |||
| deleteWatchpoint(params) { | |||
| deleteWatchpoint(params, sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: 'v1/mindinsight/debugger/delete-watchpoint', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/delete-watchpoint`, | |||
| data: params, | |||
| }); | |||
| }, | |||
| control(params) { | |||
| control(params, sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: 'v1/mindinsight/debugger/control', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/control`, | |||
| data: params, | |||
| }); | |||
| }, | |||
| search(params) { | |||
| search(params, sessionId) { | |||
| return axios({ | |||
| method: 'get', | |||
| url: 'v1/mindinsight/debugger/search', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/search`, | |||
| params: params, | |||
| }); | |||
| }, | |||
| @@ -368,43 +387,43 @@ export default { | |||
| params: params, | |||
| }); | |||
| }, | |||
| tensorComparisons(params) { | |||
| tensorComparisons(params, sessionId) { | |||
| return axios({ | |||
| method: 'get', | |||
| url: 'v1/mindinsight/debugger/tensor-comparisons', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-comparisons`, | |||
| params: params, | |||
| }); | |||
| }, | |||
| tensors(params) { | |||
| tensors(params, sessionId) { | |||
| return axios({ | |||
| method: 'get', | |||
| url: 'v1/mindinsight/debugger/tensors', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensors`, | |||
| params: params, | |||
| }); | |||
| }, | |||
| retrieveTensorHistory(params) { | |||
| retrieveTensorHistory(params, sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: 'v1/mindinsight/debugger/tensor-history', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-history`, | |||
| data: params, | |||
| }); | |||
| }, | |||
| queryConditions(trainId) { | |||
| queryConditions(sessionId) { | |||
| return axios({ | |||
| method: 'get', | |||
| url: `v1/mindinsight/conditionmgr/train-jobs/${trainId}/condition-collections`, | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/condition-collections`, | |||
| }); | |||
| }, | |||
| recheckWatchPoints() { | |||
| recheckWatchPoints(sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: `v1/mindinsight/debugger/recheck`, | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/recheck`, | |||
| }); | |||
| }, | |||
| searchWatchpointHits(params) { | |||
| searchWatchpointHits(params, sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: `v1/mindinsight/debugger/search-watchpoint-hits`, | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/search-watchpoint-hits`, | |||
| data: params, | |||
| }); | |||
| }, | |||
| @@ -447,33 +466,25 @@ export default { | |||
| data: params, | |||
| }); | |||
| }, | |||
| tensorHitsData(params) { | |||
| tensorHitsData(params, sessionId) { | |||
| return axios({ | |||
| method: 'get', | |||
| url: 'v1/mindinsight/debugger/tensor-hits', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-hits`, | |||
| params: params, | |||
| }); | |||
| }, | |||
| getTensorGraphData(params) { | |||
| getTensorGraphData(params, sessionId) { | |||
| return axios({ | |||
| method: 'get', | |||
| url: 'v1/mindinsight/debugger/tensor-graphs', | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-graphs`, | |||
| params: params, | |||
| }); | |||
| }, | |||
| getCpuUtilization(params) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: 'v1/mindinsight/profile/minddata-cpu-utilization-summary', | |||
| params: params.params, | |||
| data: params.body, | |||
| }); | |||
| }, | |||
| setRecommendWatchPoints(params) { | |||
| setRecommendWatchPoints(params, sessionId) { | |||
| return axios({ | |||
| method: 'post', | |||
| url: `v1/mindinsight/conditionmgr/train-jobs/${params.trainId}/set-recommended-watch-points`, | |||
| data: params.body, | |||
| url: `v1/mindinsight/debugger/sessions/${sessionId}/set-recommended-watch-points`, | |||
| data: params, | |||
| }); | |||
| }, | |||
| // memory-datail apis | |||
| @@ -46,6 +46,17 @@ limitations under the License. | |||
| </div> | |||
| <div class="content" | |||
| v-show="radio1==='tree'"> | |||
| <div class="node-type"> | |||
| <div class="label">{{ $t('debugger.logicCard') }}</div> | |||
| <el-select v-model="logicCard.value" | |||
| @change="logicCardChange" | |||
| :disabled="!trainId"> | |||
| <el-option v-for="item in logicCard.options" | |||
| :key="item" | |||
| :value="item"> | |||
| </el-option> | |||
| </el-select> | |||
| </div> | |||
| <div class="node-type"> | |||
| <div class="label">{{ $t('debugger.graphFile') }}</div> | |||
| <el-select v-model="graphFiles.value" | |||
| @@ -209,6 +220,17 @@ limitations under the License. | |||
| </div> | |||
| <div class="content" | |||
| v-show="radio1==='hit'"> | |||
| <div class="node-type"> | |||
| <div class="label">{{ $t('debugger.logicCard') }}</div> | |||
| <el-select v-model="logicCard.value" | |||
| :disabled="!trainId" | |||
| @change="logicCardChange();searchWatchpointHits(true);"> | |||
| <el-option v-for="item in logicCard.options" | |||
| :key="item" | |||
| :value="item"> | |||
| </el-option> | |||
| </el-select> | |||
| </div> | |||
| <div class="hit-list-wrap"> | |||
| <el-table class="watchpoint-table" | |||
| :data="watchPointHits" | |||
| @@ -261,7 +283,7 @@ limitations under the License. | |||
| <div class="step"> | |||
| <el-tooltip class="item" | |||
| effect="light" | |||
| :content="$t('debugger.inputTip')" | |||
| :content="$t('debugger.inputTip',{total_step_num:metadata.total_step_num})" | |||
| placement="top-start"> | |||
| <el-input v-model="step" | |||
| :placeholder="$t('debugger.inputStep')" | |||
| @@ -330,6 +352,25 @@ limitations under the License. | |||
| v-show="metadata.state === state.sending"> | |||
| <i class="el-icon-time"></i> | |||
| </el-tooltip> | |||
| <i class="el-icon-edit" | |||
| v-if="trainId && !isShowInp" | |||
| :title="$t('debugger.inpStepTip',{total_step_num:metadata.total_step_num})" | |||
| @click="editStep"></i> | |||
| <el-tooltip class="item" | |||
| effect="light" | |||
| :content="$t('debugger.inputTip',{total_step_num:metadata.total_step_num})" | |||
| placement="top-start" | |||
| v-if="trainId && isShowInp"> | |||
| <el-input v-model="newStep" | |||
| type="text" | |||
| @input="newStepChange"></el-input> | |||
| </el-tooltip> | |||
| <i class="el-icon-check" | |||
| v-if="trainId && isShowInp" | |||
| @click="saveStepValue"></i> | |||
| <i class="el-icon-close" | |||
| v-if="trainId && isShowInp" | |||
| @click="isShowInp=false"></i> | |||
| </div> | |||
| <div class="svg-wrap" | |||
| :class="{collapse: collapseTable}"> | |||
| @@ -505,7 +546,7 @@ limitations under the License. | |||
| :close-on-click-modal="false" | |||
| :modal-append-to-body="false" | |||
| class="creat-watch-point-dialog" | |||
| width="890px"> | |||
| width="930px"> | |||
| <div class="conditions-container"> | |||
| <div class="condition-item" | |||
| @@ -787,6 +828,11 @@ export default { | |||
| value: '', | |||
| graphs: {}, | |||
| }, | |||
| logicCard: { | |||
| options: [], | |||
| value: '', | |||
| }, | |||
| devices: [], | |||
| allGraphData: {}, // Graph Original input data | |||
| firstFloorNodes: [], // ID array of the first layer node. | |||
| nodesCountLimit: 1500, // Maximum number of sub-nodes in a namespace. | |||
| @@ -830,7 +876,7 @@ export default { | |||
| expandKeys: [], | |||
| isHitIntoView: true, | |||
| searchedWord: '', | |||
| trainId: '', | |||
| trainId: this.$route.query.dir, | |||
| recommendWatchPointDialog: false, | |||
| hitsOutdated: false, | |||
| conflictFlag: false, | |||
| @@ -859,6 +905,9 @@ export default { | |||
| }, | |||
| loadingInstance: null, | |||
| paramErrorMsg: '', | |||
| sessionId: this.$route.query.sessionId, | |||
| isShowInp: false, | |||
| newStep: '', | |||
| }; | |||
| }, | |||
| components: {debuggerTensor, tree}, | |||
| @@ -866,6 +915,12 @@ export default { | |||
| mounted() { | |||
| document.title = `${this.$t('debugger.debugger')}-MindInsight`; | |||
| this.nodeTypes.label = this.$t('debugger.nodeType'); | |||
| if (this.trainId) { | |||
| document.title = `${this.trainId}-${this.$t('debugger.debugger')}-MindInsight`; | |||
| this.retrieveAll(); | |||
| } else { | |||
| this.getSession(); | |||
| } | |||
| }, | |||
| watch: { | |||
| 'metadata.state': { | |||
| @@ -896,7 +951,7 @@ export default { | |||
| if (newValue === this.state.waiting) { | |||
| if (this.oldState === this.state.pending || oldValue === this.state.pending) { | |||
| this.loadNode(this.node, this.resolve); | |||
| this.retrieveAll(); | |||
| } else if (this.oldState === this.state.running || oldValue === this.state.running) { | |||
| this.pagination.currentPage = 1; | |||
| this.watchPointHits = []; | |||
| @@ -914,6 +969,8 @@ export default { | |||
| this.curRowObj.type = type; | |||
| this.curRowObj.curFileName = this.graphFiles.value; | |||
| this.curRowObj.step = this.metadata.step; | |||
| this.curRowObj.rank_id = this.logicCard.value; | |||
| this.curRowObj.sessionId = this.sessionId; | |||
| this.tensorCompareFlag = true; | |||
| }, | |||
| closeTensor(tensor, graphName) { | |||
| @@ -922,6 +979,19 @@ export default { | |||
| this.queryAllTreeData(tensor, true, graphName, true); | |||
| } | |||
| }, | |||
| logicCardChange() { | |||
| this.graphFiles.options = JSON.parse( | |||
| JSON.stringify(this.devices.find((val) => val.rank_id === this.logicCard.value).graph_names), | |||
| ); | |||
| if (this.graphFiles.options.length > 1) { | |||
| this.graphFiles.options.unshift(this.$t('debugger.all')); | |||
| } | |||
| this.graphFiles.value = this.graphFiles.options[0]; | |||
| const device = this.devices.find((val) => val.rank_id === this.logicCard.value); | |||
| this.metadata.ip = device.server_ip; | |||
| this.metadata.device_name = device.device_id; | |||
| this.queryGraphByFile(); | |||
| }, | |||
| queryGraphByFile() { | |||
| this.searchWord = ''; | |||
| this.nodeTypes.value = 'all'; | |||
| @@ -931,12 +1001,13 @@ export default { | |||
| params: { | |||
| watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, | |||
| graph_name: this.graphFiles.value, | |||
| rank_id: this.logicCard.value, | |||
| }, | |||
| }; | |||
| if (this.graphFiles.value === this.$t('debugger.all')) { | |||
| delete params.params.graph_name; | |||
| } | |||
| RequestService.retrieve(params).then( | |||
| RequestService.retrieve(params, this.sessionId).then( | |||
| (res) => { | |||
| if (res.data && res.data.metadata) { | |||
| this.dealMetadata(res.data.metadata); | |||
| @@ -975,6 +1046,7 @@ export default { | |||
| d3.select('#graph svg').remove(); | |||
| this.selectedNode.name = ''; | |||
| this.dealGraphData(JSON.parse(JSON.stringify(graph.nodes))); | |||
| this.tableData = []; | |||
| } | |||
| }, | |||
| (err) => { | |||
| @@ -1015,11 +1087,12 @@ export default { | |||
| watch_nodes: watchNodes, | |||
| mode: type ? 1 : 0, | |||
| graph_name: this.graphFiles.value, | |||
| rank_id: this.logicCard.value, | |||
| }; | |||
| if (this.graphFiles.value === this.$t('debugger.all')) { | |||
| delete params.graph_name; | |||
| } | |||
| RequestService.updateWatchpoint(params).then( | |||
| RequestService.updateWatchpoint(params, this.sessionId).then( | |||
| (res) => { | |||
| this.defaultCheckedArr = this.$refs.tree.getCheckedKeys(); | |||
| if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { | |||
| @@ -1049,12 +1122,16 @@ export default { | |||
| queryGraphByWatchpoint(id) { | |||
| const params = { | |||
| mode: 'watchpoint', | |||
| params: {watch_point_id: id, graph_name: this.graphFiles.value}, | |||
| params: { | |||
| watch_point_id: id, | |||
| graph_name: this.graphFiles.value, | |||
| rank_id: this.logicCard.value, | |||
| }, | |||
| }; | |||
| if (this.graphFiles.value === this.$t('debugger.all')) { | |||
| delete params.params.graph_name; | |||
| } | |||
| RequestService.retrieve(params).then( | |||
| RequestService.retrieve(params, this.sessionId).then( | |||
| (res) => { | |||
| if (res.data && res.data.graph) { | |||
| const graph = res.data.graph; | |||
| @@ -1306,11 +1383,12 @@ export default { | |||
| level: 'node', | |||
| name: this.selectedNode.name.replace('_unfold', ''), | |||
| graph_name: this.graphFiles.value, | |||
| rank_id: this.logicCard.value, | |||
| }; | |||
| if (this.graphFiles.value === this.$t('debugger.all')) { | |||
| delete params.graph_name; | |||
| } | |||
| RequestService.control(params).then( | |||
| RequestService.control(params, this.sessionId).then( | |||
| (res) => { | |||
| if (res && res.data) { | |||
| } | |||
| @@ -1387,12 +1465,13 @@ export default { | |||
| node_type: type, | |||
| single_node: false, | |||
| graph_name: this.graphFiles.value, | |||
| rank_id: this.logicCard.value, | |||
| }; | |||
| if (this.graphFiles.value === this.$t('debugger.all')) { | |||
| delete params.params.graph_name; | |||
| } | |||
| } | |||
| RequestService.retrieve(params) | |||
| RequestService.retrieve(params, this.sessionId) | |||
| .then( | |||
| (response) => { | |||
| if (response && response.data && response.data.graph) { | |||
| @@ -1560,7 +1639,12 @@ export default { | |||
| graphName = key.split('/')[0]; | |||
| key = key.replace(`${graphName}/`, ''); | |||
| } | |||
| const obj = {name: key, IOType: 'output', graph_name: graphName}; | |||
| const obj = { | |||
| name: key, | |||
| IOType: 'output', | |||
| graph_name: graphName, | |||
| rank_id: this.logicCard.value, | |||
| }; | |||
| IOInfo.push(obj); | |||
| this.selectedNode.outputNum++; | |||
| }); | |||
| @@ -1572,7 +1656,12 @@ export default { | |||
| graphName = key.split('/')[0]; | |||
| key = key.replace(`${graphName}/`, ''); | |||
| } | |||
| const obj = {name: key, IOType: 'input', graph_name: graphName}; | |||
| const obj = { | |||
| name: key, | |||
| IOType: 'input', | |||
| graph_name: graphName, | |||
| rank_id: this.logicCard.value, | |||
| }; | |||
| IOInfo.push(obj); | |||
| this.selectedNode.inputNum++; | |||
| }); | |||
| @@ -1606,11 +1695,7 @@ export default { | |||
| `translate(${this.graph.transform.x},` + `${this.graph.transform.y}) scale(${this.graph.transform.k})`, | |||
| ); | |||
| const transitionTime = Math.min( | |||
| Math.abs(screenChange.x) * 2, | |||
| Math.abs(screenChange.y) * 2, | |||
| needDelay ? 800 : 0, | |||
| ); | |||
| const transitionTime = Math.min(Math.abs(screenChange.x) * 2, Math.abs(screenChange.y) * 2, needDelay ? 800 : 0); | |||
| this.graph.dom.style.transition = `${transitionTime / 1000}s`; | |||
| this.graph.dom.style['transition-timing-function'] = 'linear'; | |||
| @@ -1829,8 +1914,8 @@ export default { | |||
| height: calc(100% - 145px); | |||
| } | |||
| .deb-wrap .left-wrap .left .content .node-type { | |||
| height: 50px; | |||
| padding: 15px 15px 0 15px; | |||
| height: 40px; | |||
| padding: 10px 15px 0 15px; | |||
| } | |||
| .deb-wrap .left-wrap .left .content .node-type .label { | |||
| display: inline-block; | |||
| @@ -1855,7 +1940,7 @@ export default { | |||
| font-size: 12px; | |||
| } | |||
| .deb-wrap .left-wrap .left .content .tree-wrap { | |||
| height: calc(70% - 155px); | |||
| height: calc(70% - 172px); | |||
| overflow-y: auto; | |||
| padding: 0 15px 15px; | |||
| position: relative; | |||
| @@ -1973,12 +2058,13 @@ export default { | |||
| color: red; | |||
| } | |||
| .deb-wrap .left-wrap .left .content .hit-list-wrap { | |||
| height: 100%; | |||
| height: calc(100% - 40px); | |||
| padding: 10px; | |||
| } | |||
| .deb-wrap .left-wrap .left .content .hit-list-wrap .watchpoint-table { | |||
| max-height: calc(100% - 45px); | |||
| overflow: auto; | |||
| margin-top: 10px; | |||
| } | |||
| .deb-wrap .left-wrap .left .content .hit-list-wrap .el-table::before { | |||
| height: 0; | |||
| @@ -2096,7 +2182,7 @@ export default { | |||
| /* Opera */ | |||
| } | |||
| .deb-wrap .right .header { | |||
| padding: 15px; | |||
| line-height: 51px; | |||
| border-bottom: 1px solid #ebeef5; | |||
| position: relative; | |||
| background: #fff; | |||
| @@ -2113,6 +2199,25 @@ export default { | |||
| .deb-wrap .right .header .item + .item { | |||
| margin-left: 15px; | |||
| } | |||
| .deb-wrap .right .header .el-icon-edit { | |||
| margin-left: 5px; | |||
| } | |||
| .deb-wrap .right .header i { | |||
| font-size: 18px; | |||
| margin: 0 2px; | |||
| color: #00a5a7; | |||
| cursor: pointer; | |||
| } | |||
| .deb-wrap .right .header .el-icon-close { | |||
| color: #f56c6c; | |||
| } | |||
| .deb-wrap .right .header .el-input { | |||
| width: 45px; | |||
| } | |||
| .deb-wrap .right .header .el-input input { | |||
| padding: 0; | |||
| text-align: center; | |||
| } | |||
| .deb-wrap .right .header .tooltip { | |||
| margin-left: 5px; | |||
| cursor: pointer; | |||
| @@ -2343,13 +2448,13 @@ export default { | |||
| display: none; | |||
| } | |||
| .deb-wrap .creat-watch-point-dialog .conditions-container .collection { | |||
| width: 200px; | |||
| width: 210px; | |||
| } | |||
| .deb-wrap .creat-watch-point-dialog .conditions-container .condition, | |||
| .deb-wrap .creat-watch-point-dialog .conditions-container .param, | |||
| .deb-wrap .creat-watch-point-dialog .conditions-container .param-value { | |||
| margin-left: 10px; | |||
| width: 200px; | |||
| width: 210px; | |||
| } | |||
| .deb-wrap .creat-watch-point-dialog .conditions-container .percent-sign { | |||
| display: inline-block; | |||
| @@ -96,6 +96,16 @@ limitations under the License. | |||
| :title="$t('summaryManage.disableProfilerTip')"> | |||
| {{$t('summaryManage.viewProfiler')}} | |||
| </span> | |||
| <span class="menu-item operate-btn" | |||
| v-if="scope.row.viewOfflineDebugger" | |||
| @contextmenu.prevent="rightClick(scope.row, $event, 2)" | |||
| @click.stop="goToOfflineDebugger(scope.row)"> | |||
| {{$t('summaryManage.viewOfflineDebugger')}} </span> | |||
| <span class="menu-item operate-btn button-disable" | |||
| v-else | |||
| :title="$t('summaryManage.disableOfflineDebugger')"> | |||
| {{$t('summaryManage.viewOfflineDebugger')}} | |||
| </span> | |||
| <span class="menu-item operate-btn" | |||
| v-if="scope.row.paramDetails" | |||
| @click.stop="showModelDialog(scope.row)"> | |||
| @@ -157,6 +167,45 @@ limitations under the License. | |||
| <li @click="doRightClick()">{{$t('summaryManage.openNewTab')}}</li> | |||
| </ul> | |||
| </div> | |||
| <el-dialog :visible.sync="debuggerDialog.showDialogModel" | |||
| width="50%" | |||
| :close-on-click-modal="false" | |||
| class="details-data-list"> | |||
| <span slot="title"> | |||
| <span class="sessionMsg">{{ debuggerDialog.title }}</span> | |||
| <el-tooltip placement="right" | |||
| effect="light" | |||
| popper-class="legend-tip" | |||
| :content="$t('summaryManage.sessionLimitNum')"> | |||
| <i class="el-icon-info"></i> | |||
| </el-tooltip> | |||
| </span> | |||
| <div class="session-title">{{ $t('summaryManage.sessionLists') }}</div> | |||
| <el-table :data="debuggerDialog.trainJobs"> | |||
| <el-table-column width="50" | |||
| type=index | |||
| :label="$t('summaryManage.sorting')"> | |||
| </el-table-column> | |||
| <el-table-column min-width="300" | |||
| prop="relative_path" | |||
| :label="$t('summaryManage.summaryPath')" | |||
| show-overflow-tooltip> | |||
| </el-table-column> | |||
| <!-- operate --> | |||
| <el-table-column prop="operate" | |||
| :label="$t('summaryManage.operation')" | |||
| class-name="operate-container"> | |||
| <template slot-scope="scope"> | |||
| <span class="menu-item operate-btn first-btn" | |||
| @click="deleteSession(scope.row.session_id)"> | |||
| {{$t('public.delete')}} </span> | |||
| <span class="menu-item operate-btn first-btn" | |||
| @click="viewSession(scope.row)"> | |||
| {{$t('debugger.view')}} </span> | |||
| </template> | |||
| </el-table-column> | |||
| </el-table> | |||
| </el-dialog> | |||
| </div> | |||
| </template> | |||
| @@ -223,7 +272,12 @@ export default { | |||
| type: 0, | |||
| }, | |||
| tableDom: null, | |||
| operateWidth: localStorage.getItem('milang') === 'en-us' ? 400 : 290, | |||
| operateWidth: localStorage.getItem('milang') === 'en-us' ? 550 : 400, | |||
| debuggerDialog: { | |||
| title: this.$t('summaryManage.sessionLimit'), | |||
| showDialogModel: false, | |||
| trainJobs: [], | |||
| }, | |||
| }; | |||
| }, | |||
| computed: {}, | |||
| @@ -286,6 +340,7 @@ export default { | |||
| i.update_time = i.update_time ? i.update_time : '--'; | |||
| i.viewProfiler = i.profiler_dir && i.profiler_dir.length; | |||
| i.viewDashboard = i.summary_files || i.graph_files || i.lineage_files; | |||
| i.viewOfflineDebugger = i.dump_dir; | |||
| i.paramDetails = i.lineage_files; | |||
| }); | |||
| this.currentFolder = res.data.name ? res.data.name : '--'; | |||
| @@ -363,7 +418,83 @@ export default { | |||
| }, | |||
| }); | |||
| }, | |||
| /** | |||
| * go to Offline Debugger | |||
| * @param {Object} row select row | |||
| */ | |||
| goToOfflineDebugger(row) { | |||
| this.contextMenu.show = false; | |||
| const debuggerDir = row.dump_dir; | |||
| const params = { | |||
| session_type: 'OFFLINE', | |||
| dump_dir: debuggerDir, | |||
| }; | |||
| this.getSessionId(params).then((value) => { | |||
| if (value !== undefined) { | |||
| this.$router.push({ | |||
| path: '/offline-debugger', | |||
| query: { | |||
| dir: debuggerDir, | |||
| sessionId: value, | |||
| }, | |||
| }); | |||
| } | |||
| }); | |||
| }, | |||
| getSessionId(params) { | |||
| return RequestService.getSession(params).then( | |||
| (res) => { | |||
| if (res && res.data) { | |||
| const sessionId = res.data; | |||
| return sessionId; | |||
| } | |||
| }, | |||
| (error) => { | |||
| if (error && error.response && error.response.data && error.response.data.error_code === '5054B280') { | |||
| this.checkSessions(); | |||
| } | |||
| }, | |||
| ); | |||
| }, | |||
| deleteSession(sessionId) { | |||
| this.$confirm(this.$t('summaryManage.deleteSessionConfirm'), this.$t('public.notice'), { | |||
| confirmButtonText: this.$t('public.sure'), | |||
| cancelButtonText: this.$t('public.cancel'), | |||
| type: 'warning', | |||
| }).then(() => { | |||
| RequestService.deleteSession(sessionId).then((res) => { | |||
| this.$message({ | |||
| type: 'success', | |||
| message: this.$t('summaryManage.deleteSessionSuccess'), | |||
| }); | |||
| this.checkSessions(); | |||
| }); | |||
| }); | |||
| }, | |||
| checkSessions() { | |||
| RequestService.checkSessions().then((res) => { | |||
| if (res && res.data && res.data.train_jobs) { | |||
| const trainJobs = res.data.train_jobs; | |||
| this.debuggerDialog.trainJobs = Object.keys(trainJobs).map((val) => { | |||
| return { | |||
| relative_path: decodeURIComponent(val), | |||
| session_id: trainJobs[val], | |||
| }; | |||
| }); | |||
| this.debuggerDialog.showDialogModel = true; | |||
| } | |||
| }); | |||
| }, | |||
| viewSession(row) { | |||
| const dir = row.relative_path; | |||
| this.$router.push({ | |||
| path: '/offline-debugger', | |||
| query: { | |||
| dir, | |||
| sessionId: row.session_id, | |||
| }, | |||
| }); | |||
| }, | |||
| rightClick(row, event, type) { | |||
| const maxWidth = 175; | |||
| this.contextMenu.data = row; | |||
| @@ -380,7 +511,28 @@ export default { | |||
| if (!row) { | |||
| return; | |||
| } | |||
| if (this.contextMenu.type) { | |||
| if (this.contextMenu.type === 2) { | |||
| // open offline debugger | |||
| this.contextMenu.show = false; | |||
| const debuggerDir = row.dump_dir; | |||
| const params = { | |||
| session_type: 'OFFLINE', | |||
| dump_dir: debuggerDir, | |||
| }; | |||
| this.getSessionId(params).then((value) => { | |||
| if (value !== undefined) { | |||
| const routeUrl = this.$router.resolve({ | |||
| path: '/offline-debugger', | |||
| query: { | |||
| dir: debuggerDir, | |||
| sessionId: value, | |||
| }, | |||
| }); | |||
| window.open(routeUrl.href, '_blank'); | |||
| } | |||
| }); | |||
| } else if (this.contextMenu.type === 1) { | |||
| // open profiling | |||
| this.contextMenu.show = false; | |||
| const profilerDir = encodeURIComponent(row.profiler_dir); | |||
| const trainId = encodeURIComponent(row.train_id); | |||
| @@ -400,7 +552,7 @@ export default { | |||
| }, | |||
| }); | |||
| window.open(routeUrl.href, '_blank'); | |||
| } else { | |||
| } else { // open training dashboard | |||
| this.contextMenu.show = false; | |||
| const trainId = encodeURIComponent(row.train_id); | |||
| @@ -693,6 +845,16 @@ export default { | |||
| #cl-summary-manage .details-data-list .el-dialog__body .details-data-title { | |||
| margin-bottom: 20px; | |||
| } | |||
| #cl-summary-manage .details-data-list .sessionMsg { | |||
| color: #333; | |||
| font-weight: bold; | |||
| font-size: 16px; | |||
| margin-right: 5px; | |||
| } | |||
| #cl-summary-manage .details-data-list .session-title { | |||
| margin-bottom: 10px; | |||
| color: #333; | |||
| } | |||
| #cl-summary-manage .is-disabled.custom-btn { | |||
| background-color: #f5f5f6; | |||
| border: 1px solid #dfe1e6 !important; | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -12,15 +12,10 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Module init file.""" | |||
| from mindinsight.backend.conditionmgr.conditionmgr_api import init_module as init_query_module | |||
| """Train job register.""" | |||
| def init_module(app): | |||
| """ | |||
| Init module entry. | |||
| Args: | |||
| app (Flask): A Flask instance. | |||
| """ | |||
| init_query_module(app) | |||
| class FolderAnalyzer: | |||
| """Train job register. The subclass should implement the analyze method and return update info.""" | |||
| def analyze(self, entry, summary_base_dir, relative_path): | |||
| """Analyze file.""" | |||
| @@ -18,4 +18,6 @@ six>=1.12.0 | |||
| Werkzeug>=1.0.0 | |||
| pandas>=1.0.4 | |||
| yapf>=0.30.0 | |||
| grpcio>=1.27.3 | |||
| treelib>=1.6.1 | |||
| grpcio>=1.27.3 | |||
| XlsxWriter>=1.2.9 | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -25,13 +25,15 @@ from mindinsight.conf import settings | |||
| from mindinsight.datavisual.utils import tools | |||
| from mindinsight.debugger.proto import ms_graph_pb2 | |||
| from mindinsight.debugger.stream_handler.graph_handler import GraphHandler | |||
| from mindinsight.debugger.session_manager import SessionManager | |||
| GRAPH_PROTO_FILE = os.path.join( | |||
| os.path.dirname(__file__), '../../../utils/resource/graph_pb/lenet.pb' | |||
| ) | |||
| DEBUGGER_BASE_URL = '/v1/mindinsight/debugger' | |||
| DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expect_results') | |||
| DEBUGGER_BASE_URL = '/v1/mindinsight/debugger/sessions/0/' | |||
| DEBUGGER_TEST_BASE_DIR = os.path.dirname(__file__) | |||
| DEBUGGER_EXPECTED_RESULTS = os.path.join(DEBUGGER_TEST_BASE_DIR, 'expect_results') | |||
| def init_graph_handler(): | |||
| @@ -51,14 +53,13 @@ def init_graph_handler(): | |||
| @pytest.fixture(scope='session') | |||
| def app_client(): | |||
| """This fixture is flask server.""" | |||
| packages = ["mindinsight.backend.debugger", "mindinsight.backend.conditionmgr"] | |||
| packages = ["mindinsight.backend.debugger"] | |||
| settings.ENABLE_DEBUGGER = True | |||
| mock_obj = Mock(return_value=packages) | |||
| tools.find_app_package = mock_obj | |||
| from mindinsight.backend.application import APP | |||
| from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER | |||
| APP.response_class = Response | |||
| client = APP.test_client() | |||
| original_val = settings.ENABLE_RECOMMENDED_WATCHPOINTS | |||
| @@ -67,4 +68,4 @@ def app_client(): | |||
| yield client | |||
| finally: | |||
| settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_val | |||
| BACKEND_SERVER.stop() | |||
| SessionManager.get_instance().online_session.stop() | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Function: | |||
| Test debugger services. | |||
| Usage: | |||
| pytest tests/st/func/debugger | |||
| """ | |||
| @@ -0,0 +1,141 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| The module DbgServices provides offline debugger APIs. | |||
| """ | |||
| from unittest.mock import MagicMock | |||
| import numpy as np | |||
| import mindinsight | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| def get_version(): | |||
| """Get version.""" | |||
| return mindinsight.__version__ | |||
| class DbgServices: | |||
| """ | |||
| DbgServices. | |||
| Args: | |||
| dump_file_path (str): dir where the dump files are saved. | |||
| """ | |||
| def __init__(self, dump_file_path, verbose=True): | |||
| self._verbose = verbose | |||
| self.dump_file_path = dump_file_path | |||
| self.dbg_instance = MagicMock() | |||
| self._watchpoints = {} | |||
| self.print_mes("in Python __init__, file path is {}".format(dump_file_path)) | |||
| self.version = get_version() | |||
| self.initialized = False | |||
| self.is_sync = True | |||
| def print_mes(self, mes): | |||
| """Print message.""" | |||
| if self._verbose: | |||
| log.info(mes) | |||
| def initialize(self, net_name, is_sync_mode): | |||
| """Initialize.""" | |||
| self.print_mes(" Python Initialize dump_file_path: {}, is_sync: {}".format(net_name, is_sync_mode)) | |||
| self.initialized = True | |||
| def add_watchpoint(self, watchpoint_id, watch_condition, check_node_list, parameter_list): | |||
| """Add watchpoint.""" | |||
| self.print_mes("Add watchpoint with watchpoint id: {}".format(watchpoint_id)) | |||
| self._watchpoints[watchpoint_id] = {'watch_condition': watch_condition, | |||
| 'check_nodes': check_node_list, | |||
| 'parameter_list': parameter_list} | |||
| return 0 | |||
| def remove_watchpoint(self, watchpoint_id): | |||
| """Remove watchpoints.""" | |||
| self.print_mes("Remove watchpoint with watchpoint id: {}".format(watchpoint_id)) | |||
| return self._watchpoints.pop(watchpoint_id) | |||
| def check_watchpoints(self, iteration): | |||
| """Check watchpoints.""" | |||
| self.print_mes("Check watchpoint at iteration: {}".format(iteration)) | |||
| watch_hits = [] | |||
| for watchpoint_id, watchpoint in self._watchpoints.items(): | |||
| # add param hit info | |||
| for param in watchpoint.get('parameter_list'): | |||
| param.hit = True | |||
| param.value = 0.0 | |||
| for watch_node_name, node_info in watchpoint.get('check_nodes'): | |||
| for device_id in node_info.get('device_id'): | |||
| hit = WatchpointHit(watch_node_name, | |||
| 0, | |||
| watchpoint.get('watch_condition'), | |||
| watchpoint_id, | |||
| watchpoint.get('parameter_list'), | |||
| 0, | |||
| device_id) | |||
| watch_hits.append(hit) | |||
| return watch_hits | |||
| def read_tensors(self, info): | |||
| """Read tensor values.""" | |||
| value = np.asarray(list(range(12)), dtype=np.int32).tobytes() | |||
| info_list_inst = [] | |||
| for _ in range(info): | |||
| tensor_data = TensorData(value, len(value), 4, [2, 2, 3]) | |||
| info_list_inst.append(tensor_data) | |||
| return info_list_inst | |||
| class TensorInfo: | |||
| """Tensor Information.""" | |||
| def __init__(self, node_name, slot, iteration, device_id, is_parameter): | |||
| self.node_name = node_name | |||
| self.slot = slot | |||
| self.iteration = iteration | |||
| self.device_id = device_id | |||
| self.is_parameter = is_parameter | |||
| class TensorData: | |||
| """Tensor data structure.""" | |||
| def __init__(self, data_ptr, data_size, dtype, shape): | |||
| self.data_ptr = data_ptr | |||
| self.data_size = data_size | |||
| self.dtype = dtype | |||
| self.shape = shape | |||
| class Parameter: | |||
| """Parameter structure.""" | |||
| def __init__(self, name, disabled, value, hit=False, actual_value=0.0): | |||
| self.name = name | |||
| self.disabled = disabled | |||
| self.value = value | |||
| self.hit = hit | |||
| self.actual_value = actual_value | |||
| class WatchpointHit: | |||
| """Watchpoint hit structure.""" | |||
| def __init__(self, name, slot, condition, watchpoint_id, parameters, error_code, device_id): | |||
| self.name = name | |||
| self.slot = slot | |||
| self.condition = condition | |||
| self.watchpoint_id = watchpoint_id | |||
| self.parameters = parameters | |||
| self.error_code = error_code | |||
| self.device_id = device_id | |||
| @@ -0,0 +1,77 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Function: | |||
| Test query debugger services. | |||
| Usage: | |||
| pytest tests/st/func/debugger | |||
| """ | |||
| import os | |||
| import shutil | |||
| from unittest import mock | |||
| import pytest | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_services.debugger_server_factory import \ | |||
| DebuggerServerFactory, DebuggerServerContext | |||
| from tests.st.func.debugger.utils import build_dump_file_structure | |||
| from tests.st.func.debugger.debugger_services import mock_dbg_services | |||
| class TestDebuggerServerFactory: | |||
| """Test debugger on Ascend backend.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Setup class.""" | |||
| cls.debugger_tmp_dir, cls.dump_files_dir = build_dump_file_structure() | |||
| cls._dbg_dir = os.path.join(cls.dump_files_dir, 'Ascend/sync') | |||
| cls._dbg_server_factory = DebuggerServerFactory() | |||
| @classmethod | |||
| def teardown_class(cls): | |||
| """Run after test this class.""" | |||
| if os.path.exists(cls.debugger_tmp_dir): | |||
| shutil.rmtree(cls.debugger_tmp_dir) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_get_dbg_online_server(self): | |||
| """Get debugger online server""" | |||
| context = DebuggerServerContext(dbg_mode='online') | |||
| server_obj = self._dbg_server_factory.get_debugger_server(DebuggerCache(), context) | |||
| server_obj.start() | |||
| server_obj.stop() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @mock.patch('mindinsight.debugger.debugger_services.debugger_offline_server.import_module') | |||
| def test_get_dbg_offline_server(self, mock_import): | |||
| """Get debugger offline server""" | |||
| mock_import.return_value = mock_dbg_services | |||
| context = DebuggerServerContext(dbg_mode='offline', dbg_dir=self._dbg_dir) | |||
| server_obj = self._dbg_server_factory.get_debugger_server(DebuggerCache(), context) | |||
| server_obj.start() | |||
| server_obj.stop() | |||
| @@ -0,0 +1,15 @@ | |||
| { | |||
| "common_dump_settings": { | |||
| "dump_mode": 0, | |||
| "path": "/absolute_path", | |||
| "net_name": "Lenet", | |||
| "iteration": 0, | |||
| "input_output": 0, | |||
| "kernels": ["Default/Conv-op12"], | |||
| "support_device": [0,1,2,3,4,5,6,7] | |||
| }, | |||
| "async_dump_settings": { | |||
| "enable": true, | |||
| "op_debug_mode": 0 | |||
| } | |||
| } | |||
| @@ -0,0 +1,23 @@ | |||
| { | |||
| "version": "1.0", | |||
| "server_count": "1", | |||
| "server_list": [ | |||
| { | |||
| "server_id": "0.0.0.0", | |||
| "device": [ | |||
| { | |||
| "device_id": "0", | |||
| "device_ip": "0.0.0.1", | |||
| "rank_id": "0" | |||
| }, | |||
| { | |||
| "device_id": "1", | |||
| "device_ip": "0.0.0.2", | |||
| "rank_id": "1" | |||
| } | |||
| ], | |||
| "host_nic_ip": "reserve" | |||
| } | |||
| ], | |||
| "status": "completed" | |||
| } | |||
| @@ -0,0 +1,15 @@ | |||
| { | |||
| "common_dump_settings": { | |||
| "dump_mode": 0, | |||
| "path": "/absolute_path", | |||
| "net_name": "Lenet", | |||
| "iteration": 0, | |||
| "input_output": 0, | |||
| "kernels": ["Default/Conv-op12"], | |||
| "support_device": [0,1,2,3,4,5,6,7] | |||
| }, | |||
| "e2e_dump_settings": { | |||
| "enable": true, | |||
| "trans_flag": false | |||
| } | |||
| } | |||
| @@ -0,0 +1,23 @@ | |||
| { | |||
| "version": "1.0", | |||
| "server_count": "1", | |||
| "server_list": [ | |||
| { | |||
| "server_id": "0.0.0.0", | |||
| "device": [ | |||
| { | |||
| "device_id": "0", | |||
| "device_ip": "0.0.0.1", | |||
| "rank_id": "0" | |||
| }, | |||
| { | |||
| "device_id": "1", | |||
| "device_ip": "0.0.0.2", | |||
| "rank_id": "1" | |||
| } | |||
| ], | |||
| "host_nic_ip": "reserve" | |||
| } | |||
| ], | |||
| "status": "completed" | |||
| } | |||
| @@ -0,0 +1,15 @@ | |||
| { | |||
| "common_dump_settings": { | |||
| "dump_mode": 0, | |||
| "path": "/absolute_path", | |||
| "net_name": "Lenet", | |||
| "iteration": 0, | |||
| "input_output": 0, | |||
| "kernels": ["Default/Conv-op12"], | |||
| "support_device": [0,1,2,3,4,5,6,7] | |||
| }, | |||
| "e2e_dump_settings": { | |||
| "enable": true, | |||
| "trans_flag": false | |||
| } | |||
| } | |||
| @@ -0,0 +1,21 @@ | |||
| { | |||
| "device_target": "Ascend", | |||
| "server_list": [ | |||
| { | |||
| "server_id": "0.0.0.0", | |||
| "device": [ | |||
| { | |||
| "device_id": "0", | |||
| "device_ip": "0.0.0.1", | |||
| "rank_id": "0" | |||
| }, | |||
| { | |||
| "device_id": "1", | |||
| "device_ip": "0.0.0.2", | |||
| "rank_id": "1" | |||
| } | |||
| ], | |||
| "host_nic_ip": "reserve" | |||
| } | |||
| ] | |||
| } | |||
| @@ -1 +1 @@ | |||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} | |||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0", "graph_1"]}], "watch_points": []} | |||
| @@ -1 +1 @@ | |||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []} | |||
| {"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0", "graph_1"]}], "watch_points": []} | |||
| @@ -1,29 +1 @@ | |||
| { | |||
| "tensor_value": { | |||
| "full_name": "Default/TransData-op99:0", | |||
| "step": 1, | |||
| "dtype": "DT_FLOAT32", | |||
| "shape": [ | |||
| 2, | |||
| 3 | |||
| ], | |||
| "has_prev_step": false, | |||
| "statistics": { | |||
| "overall_max": 6.0, | |||
| "overall_min": 1.0, | |||
| "overall_avg": 3.5, | |||
| "overall_count": 6, | |||
| "overall_nan_count": 0, | |||
| "overall_neg_inf_count": 0, | |||
| "overall_pos_inf_count": 0, | |||
| "overall_zero_count": 0.0, | |||
| "overall_neg_zero_count": 0.0, | |||
| "overall_pos_zero_count": 6.0 | |||
| }, | |||
| "value": [ | |||
| 5.0, | |||
| 6.0 | |||
| ], | |||
| "name": "Default/TransData-op99:0" | |||
| } | |||
| } | |||
| {"tensor_value": {"full_name": "Default/TransData-op99:0", "step": 1, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": [5.0, 6.0], "statistics": {"overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, "overall_count": 6, "overall_nan_count": 0, "overall_neg_inf_count": 0, "overall_pos_inf_count": 0, "overall_zero_count": 0.0, "overall_neg_zero_count": 0.0, "overall_pos_zero_count": 6.0}, "name": "Default/TransData-op99:0"}} | |||
| @@ -1 +1 @@ | |||
| {"metadata": {"state": "mismatch", "step": 0, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.0.0"}}, "graph": {}, "watch_points": []} | |||
| {"metadata": {"state": "mismatch", "step": 0, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": []}], "watch_points": []} | |||
| @@ -0,0 +1,149 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test DataLoader of offline debugger.""" | |||
| import os | |||
| import shutil | |||
| import pytest | |||
| from mindinsight.debugger.stream_cache.data_loader import DataLoader | |||
| from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE | |||
| from tests.st.func.debugger.utils import build_dump_file_structure | |||
| from tests.utils.tools import compare_result_with_file, compare_result_with_binary_file | |||
| class TestDataLoader: | |||
| """Test DataLoader.""" | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Init TestDataLoader for DataLoader unittest.""" | |||
| cls.debugger_tmp_dir, cls.dump_files_dir = build_dump_file_structure() | |||
| cls.expected_results_dir = os.path.join(os.path.dirname(__file__), | |||
| 'expect_results/offline_debugger') | |||
| cls.dump_files_dir_ascend = os.path.join(cls.dump_files_dir, | |||
| 'Ascend/sync') | |||
| cls.data_loader_ascend = DataLoader(cls.dump_files_dir_ascend) | |||
| cls.data_loader_ascend.initialize() | |||
| cls.dump_files_dir_gpu = os.path.join(cls.dump_files_dir, | |||
| 'GPU/sync') | |||
| cls.data_loader_gpu = DataLoader(cls.dump_files_dir_gpu) | |||
| cls.data_loader_gpu.initialize() | |||
| cls.dump_files_dir_ascend_async = os.path.join(cls.dump_files_dir, | |||
| 'Ascend/async') | |||
| cls.data_loader_ascend_async = DataLoader(cls.dump_files_dir_ascend_async) | |||
| cls.data_loader_ascend_async.initialize() | |||
| @classmethod | |||
| def teardown_class(cls): | |||
| """Run after test this class.""" | |||
| if os.path.exists(cls.debugger_tmp_dir): | |||
| shutil.rmtree(cls.debugger_tmp_dir) | |||
| @pytest.mark.level | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_load_graphs_ascend(self): | |||
| """Test load_graphs function of offline-debugger.""" | |||
| res = self.data_loader_ascend.load_graphs() | |||
| expected_result0 = GRAPH_PROTO_FILE | |||
| res0 = res[0]['graph_protos'][0].SerializeToString() | |||
| compare_result_with_binary_file(res0, expected_result0) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_load_device_info_ascend(self): | |||
| """Test load_device_info of ascend chip for offline-debugger.""" | |||
| res = self.data_loader_ascend.load_device_info() | |||
| expected_result = os.path.join(self.expected_results_dir, 'load_device_info_ascend.json') | |||
| compare_result_with_file(res, expected_result) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_load_step_num_ascend(self): | |||
| """Test load_step_num of ascend chip for offline-debugger.""" | |||
| res = self.data_loader_ascend.load_step_number() | |||
| expected_result = {"0": 4, "1": 4} | |||
| assert res == expected_result | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_get_net_name_ascend(self): | |||
| """Test get_net_name of ascend chip for offline-debugger.""" | |||
| res = self.data_loader_ascend.get_net_name() | |||
| assert res == 'Lenet' | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_get_sync_flag(self): | |||
| """Test get_sync_flag of ascend chip for offline-debugger.""" | |||
| res = self.data_loader_ascend.get_sync_flag() | |||
| assert res | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_load_graphs_gpu(self): | |||
| """Test load_graphs function of offline-debugger.""" | |||
| res = self.data_loader_gpu.load_graphs() | |||
| expected_result0 = GRAPH_PROTO_FILE | |||
| res0 = res[0]['graph_protos'][0].SerializeToString() | |||
| compare_result_with_binary_file(res0, expected_result0) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_load_step_num_gpu(self): | |||
| """Test load_step_num of ascend chip for offline-debugger.""" | |||
| res = self.data_loader_gpu.load_step_number() | |||
| expected_result = {"0": 3, "1": 3} | |||
| assert res == expected_result | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| def test_load_step_num_ascend_async(self): | |||
| """Test load_step_num of ascend chip for offline-debugger.""" | |||
| res = self.data_loader_ascend_async.load_step_number() | |||
| expected_result = {"0": 3, "1": 3} | |||
| assert res == expected_result | |||
| @@ -84,7 +84,7 @@ class TestAscendDebugger: | |||
| def test_get_conditions(self, app_client): | |||
| """Test get conditions for ascend.""" | |||
| url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections' | |||
| url = '/v1/mindinsight/debugger/sessions/0/condition-collections' | |||
| body_data = {} | |||
| expect_file = 'get_conditions_for_ascend.json' | |||
| with self._debugger_client.get_thread_instance(): | |||
| @@ -191,7 +191,7 @@ class TestAscendDebugger: | |||
| check_state(app_client) | |||
| # prepare tensor value | |||
| url = 'tensor-history' | |||
| body_data = {'name': node_name} | |||
| body_data = {'name': node_name, 'rank_id': 0} | |||
| expect_file = 'retrieve_empty_tensor_history.json' | |||
| send_and_compare_result(app_client, url, body_data, expect_file) | |||
| # check full tensor history from poll data | |||
| @@ -229,7 +229,7 @@ class TestAscendDebugger: | |||
| get_request_result(app_client, url, body_data) | |||
| check_state(app_client) | |||
| get_request_result( | |||
| app_client=app_client, url='tensor-history', body_data={'name': node_name}) | |||
| app_client=app_client, url='tensor-history', body_data={'name': node_name, 'rank_id': 0}) | |||
| res = get_request_result( | |||
| app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get') | |||
| assert res.get('receive_tensor', {}).get('node_name') == node_name | |||
| @@ -239,30 +239,12 @@ class TestAscendDebugger: | |||
| 'name': node_name + ':0', | |||
| 'detail': 'data', | |||
| 'shape': quote('[:, :]'), | |||
| 'tolerance': 1 | |||
| } | |||
| 'tolerance': 1, | |||
| 'rank_id': 0} | |||
| expect_file = 'compare_tensors.json' | |||
| send_and_compare_result(app_client, url, body_data, expect_file, method='get') | |||
| send_terminate_cmd(app_client) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.parametrize("body_data, expect_file", [ | |||
| ({'ascend': True}, 'retrieve_node_by_bfs_ascend.json'), | |||
| ({'name': 'Default/args0', 'ascend': False}, 'retrieve_node_by_bfs.json') | |||
| ]) | |||
| def test_retrieve_bfs_node(self, app_client, body_data, expect_file): | |||
| """Test retrieve bfs node.""" | |||
| with self._debugger_client.get_thread_instance(): | |||
| check_state(app_client) | |||
| # prepare tensor values | |||
| url = 'retrieve_node_by_bfs' | |||
| send_and_compare_result(app_client, url, body_data, expect_file, method='get') | |||
| send_terminate_cmd(app_client) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @@ -441,7 +423,7 @@ class TestGPUDebugger: | |||
| def test_get_conditions(self, app_client): | |||
| """Test get conditions for gpu.""" | |||
| url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections' | |||
| url = '/v1/mindinsight/debugger/sessions/0/condition-collections' | |||
| body_data = {} | |||
| expect_file = 'get_conditions_for_gpu.json' | |||
| with self._debugger_client.get_thread_instance(): | |||
| @@ -16,8 +16,10 @@ | |||
| import json | |||
| import os | |||
| import time | |||
| from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL | |||
| import shutil | |||
| import tempfile | |||
| from mindinsight.debugger.proto import ms_graph_pb2 | |||
| from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL, GRAPH_PROTO_FILE | |||
| from tests.utils.tools import compare_result_with_file, get_url | |||
| @@ -74,10 +76,57 @@ def send_and_save_result(app_client, url, body_data, file_path, method='post'): | |||
| def delete_random_items(res): | |||
| """delete the random items in metadata.""" | |||
| if isinstance(res, dict) and res.get('metadata'): | |||
| if res['metadata'].get('ip'): | |||
| res['metadata'].pop('ip') | |||
| if res['metadata'].get('pos'): | |||
| res['metadata'].pop('pos') | |||
| if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'): | |||
| res['metadata']['debugger_version'].pop('mi') | |||
| if isinstance(res, dict): | |||
| if res.get('metadata'): | |||
| if res['metadata'].get('ip'): | |||
| res['metadata'].pop('ip') | |||
| if res['metadata'].get('pos'): | |||
| res['metadata'].pop('pos') | |||
| if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'): | |||
| res['metadata']['debugger_version'].pop('mi') | |||
| res['metadata']['debugger_version'].pop('ms') | |||
| if res.get('devices'): | |||
| for device in res.get('devices'): | |||
| if device.get('server_ip'): | |||
| device.pop('server_ip') | |||
| def build_dump_file_structure(): | |||
| """Build the dump file structure.""" | |||
| async_file_structure = { | |||
| "Ascend/async/device_0/Lenet_graph_1/1": 3, | |||
| "Ascend/async/device_1/Lenet_graph_1/1": 3 | |||
| } | |||
| sync_file_structure = { | |||
| "Ascend/sync/Lenet/device_0": 4, | |||
| "Ascend/sync/Lenet/device_1": 4, | |||
| "GPU/sync/Lenet/device_0": 3, | |||
| "GPU/sync/Lenet/device_1": 3 | |||
| } | |||
| debugger_tmp_dir = tempfile.mkdtemp(suffix='debugger_tmp') | |||
| dump_files_dir = os.path.join(debugger_tmp_dir, 'dump_files') | |||
| shutil.copytree(os.path.join(os.path.dirname(__file__), 'dump_files'), dump_files_dir) | |||
| for sub_dir, steps in async_file_structure.items(): | |||
| for step in range(1, steps + 1): | |||
| os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), str(step)), exist_ok=True) | |||
| for sub_dir, steps in sync_file_structure.items(): | |||
| for step in range(1, steps + 1): | |||
| os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), 'iteration_' + str(step)), | |||
| exist_ok=True) | |||
| graph_dir_path = os.path.join(os.path.join(dump_files_dir, sub_dir), 'graphs') | |||
| os.makedirs(graph_dir_path, exist_ok=True) | |||
| graph_path = os.path.join(graph_dir_path, 'ms_output_trace_code_graph_0.pb') | |||
| with open(GRAPH_PROTO_FILE, 'rb') as file_handler: | |||
| content = file_handler.read() | |||
| model = ms_graph_pb2.ModelProto() | |||
| model.graph.ParseFromString(content) | |||
| model_str = model.SerializeToString() | |||
| with open(graph_path, 'wb') as file_handler: | |||
| file_handler.write(model_str) | |||
| return debugger_tmp_dir, dump_files_dir | |||
| @@ -1 +1 @@ | |||
| {"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "watch_points": []} | |||
| {"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "devices": [{"rank_id": 0, "server_ip": "", "device_id": "", "graph_names": []}], "watch_points": []} | |||
| @@ -1,5 +1,80 @@ | |||
| [ | |||
| {"watchCondition": {"condition": "tensor_too_small", "value": 1.0, "params": [{"name": "abs_mean_lt", "disabled": true}, {"name": "max_lt", "value": 1.0}, {"name": "min_lt", "disabled": true}, {"name": "mean_lt", "disabled": true}]}, "id": 1, "watch_nodes_num": 0}, | |||
| {"watchCondition": {"condition": "tensor_too_small", "value": 1.0, "params": [{"name": "abs_mean_lt", "disabled": true}, {"name": "max_lt", "disabled": true}, {"name": "min_lt", "value": 1.0}, {"name": "mean_lt", "disabled": true}]}, "id": 2, "watch_nodes_num": 172}, | |||
| {"watchCondition": {"condition": "tensor_too_large", "value": 1.0, "params": [{"name": "abs_mean_gt", "disabled": true}, {"name": "max_gt", "value": 1.0}, {"name": "min_gt", "disabled": true}, {"name": "mean_gt", "disabled": true}]}, "id": 3, "watch_nodes_num": 1} | |||
| { | |||
| "watchCondition": { | |||
| "condition": "tensor_too_small", | |||
| "value": 1.0, | |||
| "params": [ | |||
| { | |||
| "name": "abs_mean_lt", | |||
| "disabled": true | |||
| }, | |||
| { | |||
| "name": "max_lt", | |||
| "value": 1.0 | |||
| }, | |||
| { | |||
| "name": "min_lt", | |||
| "disabled": true | |||
| }, | |||
| { | |||
| "name": "mean_lt", | |||
| "disabled": true | |||
| } | |||
| ] | |||
| }, | |||
| "id": 1, | |||
| "watch_nodes_num": 0 | |||
| }, | |||
| { | |||
| "watchCondition": { | |||
| "condition": "tensor_too_small", | |||
| "value": 1.0, | |||
| "params": [ | |||
| { | |||
| "name": "abs_mean_lt", | |||
| "disabled": true | |||
| }, | |||
| { | |||
| "name": "max_lt", | |||
| "disabled": true | |||
| }, | |||
| { | |||
| "name": "min_lt", | |||
| "value": 1.0 | |||
| }, | |||
| { | |||
| "name": "mean_lt", | |||
| "disabled": true | |||
| } | |||
| ] | |||
| }, | |||
| "id": 2, | |||
| "watch_nodes_num": 142 | |||
| }, | |||
| { | |||
| "watchCondition": { | |||
| "condition": "tensor_too_large", | |||
| "value": 1.0, | |||
| "params": [ | |||
| { | |||
| "name": "abs_mean_gt", | |||
| "disabled": true | |||
| }, | |||
| { | |||
| "name": "max_gt", | |||
| "value": 1.0 | |||
| }, | |||
| { | |||
| "name": "min_gt", | |||
| "disabled": true | |||
| }, | |||
| { | |||
| "name": "mean_gt", | |||
| "disabled": true | |||
| } | |||
| ] | |||
| }, | |||
| "id": 3, | |||
| "watch_nodes_num": 1 | |||
| } | |||
| ] | |||
| @@ -111,20 +111,6 @@ class TestGraphHandler: | |||
| node_name = self.graph_handler.get_node_name_by_full_name(full_name, 'kernel_graph_0') | |||
| assert node_name == expect_node_name | |||
| @pytest.mark.parametrize("node_name, ascend, expect_next", [ | |||
| (None, True, | |||
| "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0"), | |||
| (None, False, None), | |||
| ("Default/tuple_getitem[10]_0/tuple_getitem-op206", True, | |||
| "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89"), | |||
| ("Default/tuple_getitem[10]_0/tuple_getitem-op206", False, | |||
| "Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205") | |||
| ]) | |||
| def test_get_node_by_bfs_order(self, node_name, ascend, expect_next): | |||
| """Test get node by BFS order.""" | |||
| next_node = self.graph_handler.get_node_by_bfs_order(node_name, ascend) | |||
| assert next_node == expect_next | |||
| @pytest.mark.parametrize("tensor_name, expect_file", [ | |||
| ("Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0:0", "get_tensor_graph-0.json"), | |||
| ("Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89:1", "get_tensor_graph-1.json"), | |||
| @@ -40,7 +40,7 @@ class TestTensorHandler: | |||
| def test_get_tensor_value_by_name_none(self): | |||
| """Test get_tensor_value_by_name.""" | |||
| res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', True) | |||
| res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', step=0, prev=True) | |||
| assert res is None | |||
| @mock.patch.object(log, "error") | |||
| @@ -49,5 +49,5 @@ class TestTensorHandler: | |||
| """Test get_tensors_diff.""" | |||
| mock_error.return_value = None | |||
| with pytest.raises(DebuggerParamValueError) as ex: | |||
| self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}) | |||
| self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}, step=0) | |||
| assert f"Get current step and previous step for this tensor name {tensor_name} failed." in str(ex.value) | |||
| @@ -30,6 +30,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue | |||
| DebuggerParamTypeError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.stream_cache.watchpoint import Watchpoint | |||
| from mindinsight.debugger.stream_handler import MultiCardGraphHandler | |||
| from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHandler, \ | |||
| WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params | |||
| from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \ | |||
| @@ -48,7 +49,9 @@ class TestWatchpointHandler: | |||
| '../expected_results/watchpoint') | |||
| cls.graph_results_dir = os.path.join(os.path.dirname(__file__), | |||
| '../expected_results/graph') | |||
| cls.multi_graph_stream = MultiCardGraphHandler() | |||
| cls.graph_stream = init_graph_handler() | |||
| cls.multi_graph_stream.register_graph_handler(0, cls.graph_stream) | |||
| cls.conditionmgr = None | |||
| cls.handler = None | |||
| @@ -69,7 +72,7 @@ class TestWatchpointHandler: | |||
| ] | |||
| for watch_condition, watch_nodes, watch_point_id, expect_new_id in watchpoints: | |||
| watch_nodes = get_node_basic_infos(watch_nodes) | |||
| watch_point_id = self.handler.create_watchpoint(self.conditionmgr, watch_condition, watch_nodes, | |||
| watch_point_id = self.handler.create_watchpoint(self.conditionmgr, watch_condition, {0: watch_nodes}, | |||
| watch_point_id) | |||
| assert watch_point_id == expect_new_id | |||
| @@ -105,7 +108,7 @@ class TestWatchpointHandler: | |||
| file_path = os.path.join(self.results_dir, result_file) | |||
| with open(file_path, 'r') as file_handler: | |||
| contents = json.load(file_handler) | |||
| protos = self.handler.get_pending_commands(self.graph_stream) | |||
| protos = self.handler.get_pending_commands(self.multi_graph_stream) | |||
| for proto in protos: | |||
| msg_dict = json_format.MessageToDict(proto) | |||
| msg_dict['watch_nodes_num'] = len(msg_dict.pop('watchNodes', [])) | |||
| @@ -48,7 +48,8 @@ class TestTrainingControlOperator: | |||
| """Test validate leaf name.""" | |||
| args[0].return_value = 'name_scope' | |||
| with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): | |||
| self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name') | |||
| self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name', | |||
| rank_id=0) | |||
| @pytest.mark.parametrize('mode, cur_state, state', [ | |||
| ('continue', 'waiting', 'sending'), | |||
| @@ -64,3 +65,12 @@ class TestTrainingControlOperator: | |||
| """Test construct run event.""" | |||
| res = self._server._construct_run_event({'level': 'node'}) | |||
| assert res.run_cmd == RunCMD(run_level='node', node_name='') | |||
| @pytest.mark.parametrize('mode, state', [ | |||
| ('reset', 'waiting')]) | |||
| def test_control_reset_step(self, mode, state): | |||
| """Test control request, in 'reset' mode.""" | |||
| with mock.patch.object(MetadataHandler, 'max_step_num', 10), \ | |||
| mock.patch.object(MetadataHandler, 'debugger_type', 'offline'): | |||
| res = self._server.control(mode=mode, params={'steps': 9}) | |||
| assert res == {'metadata': {'enable_recheck': False, 'state': state, 'step': 9}} | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -26,7 +26,7 @@ import numpy as np | |||
| from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||
| from mindinsight.debugger.debugger_services.debugger_grpc_server import DebuggerGrpcServer | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType | |||
| from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \ | |||
| @@ -30,11 +30,11 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue | |||
| DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError | |||
| from mindinsight.debugger.common.utils import Streams | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||
| from mindinsight.debugger.debugger_server import grpc_server_base | |||
| from mindinsight.debugger.stream_operator import watchpoint_operator | |||
| from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext | |||
| from mindinsight.debugger.debugger_session import DebuggerSession as DebuggerServer | |||
| from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ | |||
| TensorHandler | |||
| from mindinsight.debugger.stream_operator import watchpoint_operator | |||
| from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history | |||
| @@ -48,12 +48,12 @@ class TestDebuggerServer: | |||
| def setup_method(self): | |||
| """Prepare debugger server object.""" | |||
| self._server = DebuggerServer() | |||
| context = DebuggerServerContext(dbg_mode='online') | |||
| self._server = DebuggerServer(context) | |||
| @mock.patch.object(signal, 'signal') | |||
| @mock.patch.object(Thread, 'join') | |||
| @mock.patch.object(Thread, 'start') | |||
| @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server') | |||
| @mock.patch.object(grpc, 'server') | |||
| def test_stop_server(self, *args): | |||
| """Test stop debugger server.""" | |||
| @@ -62,7 +62,6 @@ class TestDebuggerServer: | |||
| self._server.start() | |||
| self._server._stop_handler(MagicMock(), MagicMock()) | |||
| assert self._server.back_server is not None | |||
| assert self._server.grpc_server_manager == mock_grpc_server_manager | |||
| @mock.patch.object(DebuggerCache, 'get_data') | |||
| def test_poll_data(self, *args): | |||
| @@ -186,7 +185,6 @@ class TestDebuggerServer: | |||
| self._server.create_watchpoint({'watch_condition': {'id': 'inf'}}) | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @mock.patch.object(MetadataHandler, 'backend', 'GPU') | |||
| @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) | |||
| @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') | |||
| @mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock()) | |||
| @@ -194,6 +192,7 @@ class TestDebuggerServer: | |||
| def test_create_watchpoint(self, *args): | |||
| """Test create watchpoint.""" | |||
| args[0].return_value = 1 | |||
| self._server.cache_store.get_stream_handler((Streams.METADATA)).backend = 'GPU' | |||
| res = self._server.create_watchpoint( | |||
| {'watch_condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, | |||
| 'watch_nodes': ['watch_node_name']}) | |||
| @@ -68,6 +68,13 @@ def compare_result_with_file(result, expected_file_path): | |||
| assert result == expected_results | |||
| def compare_result_with_binary_file(result, expected_file_path): | |||
| """Compare result with binary file which contain the expected results.""" | |||
| with open(expected_file_path, 'rb') as file: | |||
| expected_results = file.read() | |||
| assert result == expected_results | |||
| def deal_float_for_dict(res: dict, expected_res: dict, decimal_num): | |||
| """ | |||
| Deal float rounded to specified decimals in dict. | |||