Browse Source

add offline debugger feature

pull/1270/head
jiangshuqiang 4 years ago
parent
commit
8bb3851e49
79 changed files with 3951 additions and 950 deletions
  1. +0
    -61
      mindinsight/backend/conditionmgr/conditionmgr_api.py
  2. +2
    -0
      mindinsight/backend/config/gunicorn_conf.py
  3. +8
    -1
      mindinsight/backend/data_manager/__init__.py
  4. +137
    -95
      mindinsight/backend/debugger/debugger_api.py
  5. +19
    -2
      mindinsight/datavisual/data_transform/data_manager.py
  6. +4
    -2
      mindinsight/datavisual/data_transform/graph/msgraph.py
  7. +46
    -8
      mindinsight/datavisual/data_transform/summary_watcher.py
  8. +3
    -2
      mindinsight/datavisual/processors/train_task_manager.py
  9. +17
    -1
      mindinsight/debugger/common/exceptions/error_code.py
  10. +56
    -1
      mindinsight/debugger/common/exceptions/exceptions.py
  11. +29
    -2
      mindinsight/debugger/common/utils.py
  12. +26
    -15
      mindinsight/debugger/conditionmgr/recommender.py
  13. +10
    -10
      mindinsight/debugger/debugger_cache.py
  14. +41
    -0
      mindinsight/debugger/debugger_folder_analyzer.py
  15. +15
    -0
      mindinsight/debugger/debugger_services/__init__.py
  16. +19
    -25
      mindinsight/debugger/debugger_services/debugger_grpc_server.py
  17. +613
    -0
      mindinsight/debugger/debugger_services/debugger_offline_server.py
  18. +58
    -0
      mindinsight/debugger/debugger_services/debugger_online_server.py
  19. +58
    -0
      mindinsight/debugger/debugger_services/debugger_server_base.py
  20. +92
    -0
      mindinsight/debugger/debugger_services/debugger_server_factory.py
  21. +85
    -103
      mindinsight/debugger/debugger_session.py
  22. +3
    -0
      mindinsight/debugger/proto/debug_grpc.proto
  23. +37
    -18
      mindinsight/debugger/proto/debug_grpc_pb2.py
  24. +15
    -22
      mindinsight/debugger/proto/debug_grpc_pb2_grpc.py
  25. +3
    -0
      mindinsight/debugger/proto/ms_graph.proto
  26. +44
    -39
      mindinsight/debugger/proto/ms_graph_pb2.py
  27. +172
    -0
      mindinsight/debugger/session_manager.py
  28. +210
    -0
      mindinsight/debugger/stream_cache/data_loader.py
  29. +21
    -8
      mindinsight/debugger/stream_cache/tensor.py
  30. +33
    -19
      mindinsight/debugger/stream_cache/watchpoint.py
  31. +7
    -6
      mindinsight/debugger/stream_handler/__init__.py
  32. +198
    -0
      mindinsight/debugger/stream_handler/device_handler.py
  33. +53
    -43
      mindinsight/debugger/stream_handler/graph_handler.py
  34. +33
    -17
      mindinsight/debugger/stream_handler/metadata_handler.py
  35. +70
    -23
      mindinsight/debugger/stream_handler/tensor_handler.py
  36. +87
    -19
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  37. +30
    -17
      mindinsight/debugger/stream_operator/tensor_detail_info.py
  38. +43
    -7
      mindinsight/debugger/stream_operator/training_control_operator.py
  39. +19
    -19
      mindinsight/debugger/stream_operator/watchpoint_operator.py
  40. +1
    -1
      mindinsight/ui/src/app.vue
  41. +8
    -4
      mindinsight/ui/src/components/debugger-tensor.vue
  42. +17
    -5
      mindinsight/ui/src/locales/en-us.json
  43. +17
    -5
      mindinsight/ui/src/locales/zh-cn.json
  44. +436
    -163
      mindinsight/ui/src/mixins/debugger-mixin.vue
  45. +4
    -0
      mindinsight/ui/src/router.js
  46. +8
    -1
      mindinsight/ui/src/services/fetcher.js
  47. +52
    -41
      mindinsight/ui/src/services/request-service.js
  48. +129
    -24
      mindinsight/ui/src/views/debugger/debugger.vue
  49. +166
    -4
      mindinsight/ui/src/views/train-manage/summary-manage.vue
  50. +6
    -11
      mindinsight/utils/folder_analyzer.py
  51. +3
    -1
      requirements.txt
  52. +7
    -6
      tests/st/func/debugger/conftest.py
  53. +20
    -0
      tests/st/func/debugger/debugger_services/__init__.py
  54. +141
    -0
      tests/st/func/debugger/debugger_services/mock_dbg_services.py
  55. +77
    -0
      tests/st/func/debugger/debugger_services/test_debugger_services.py
  56. +15
    -0
      tests/st/func/debugger/dump_files/Ascend/async/.metadata/data_dump.json
  57. +23
    -0
      tests/st/func/debugger/dump_files/Ascend/async/.metadata/hccl.json
  58. +15
    -0
      tests/st/func/debugger/dump_files/Ascend/sync/.metadata/data_dump.json
  59. +23
    -0
      tests/st/func/debugger/dump_files/Ascend/sync/.metadata/hccl.json
  60. +15
    -0
      tests/st/func/debugger/dump_files/GPU/sync/.metadata/data_dump.json
  61. +21
    -0
      tests/st/func/debugger/expect_results/offline_debugger/load_device_info_ascend.json
  62. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/multi_next_node.json
  63. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json
  64. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/retrieve_all.json
  65. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json
  66. +1
    -29
      tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json
  67. +1
    -1
      tests/st/func/debugger/expect_results/restful_results/version_mismatch.json
  68. +149
    -0
      tests/st/func/debugger/test_data_loader.py
  69. +6
    -24
      tests/st/func/debugger/test_restful_api.py
  70. +58
    -9
      tests/st/func/debugger/utils.py
  71. +1
    -1
      tests/ut/debugger/expected_results/debugger_server/retrieve_all.json
  72. +78
    -3
      tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json
  73. +0
    -14
      tests/ut/debugger/stream_handler/test_graph_handler.py
  74. +2
    -2
      tests/ut/debugger/stream_handler/test_tensor_handler.py
  75. +5
    -2
      tests/ut/debugger/stream_handler/test_watchpoint_handler.py
  76. +11
    -1
      tests/ut/debugger/stream_operator/test_training_control_operator.py
  77. +2
    -2
      tests/ut/debugger/test_debugger_grpc_server.py
  78. +6
    -7
      tests/ut/debugger/test_debugger_server.py
  79. +7
    -0
      tests/utils/tools.py

+ 0
- 61
mindinsight/backend/conditionmgr/conditionmgr_api.py View File

@@ -1,61 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Conditionmgr restful api."""
import json

from flask import Blueprint, request

from mindinsight.conf import settings
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.exceptions import ParamMissError
from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER, _wrap_reply

BLUEPRINT = Blueprint("conditionmgr", __name__,
url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX)


@BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/condition-collections", methods=["GET"])
def get_condition_collections(train_id):
"""get condition collections"""
reply = _wrap_reply(BACKEND_SERVER.get_condition_collections, train_id)
return reply


@BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/set-recommended-watch-points", methods=["POST"])
def set_recommended_watch_points(train_id):
"""set recommended watch points."""
body = request.stream.read()
try:
body = json.loads(body if body else "{}")
except json.JSONDecodeError:
raise ParamValueError("Json data parse failed.")

request_body = body.get('requestBody')
if request_body is None:
raise ParamMissError('requestBody')

set_recommended = request_body.get('set_recommended')
reply = _wrap_reply(BACKEND_SERVER.set_recommended_watch_points, set_recommended, train_id)
return reply


def init_module(app):
"""
Init module entry.

Args:
app (Flask): The application obj.
"""
app.register_blueprint(BLUEPRINT)

+ 2
- 0
mindinsight/backend/config/gunicorn_conf.py View File

@@ -26,6 +26,7 @@ import psutil
import gunicorn import gunicorn


from mindinsight.utils.computing_resource_mgr import terminate from mindinsight.utils.computing_resource_mgr import terminate
from mindinsight.debugger.session_manager import SessionManager




gunicorn.SERVER_SOFTWARE = 'unknown' gunicorn.SERVER_SOFTWARE = 'unknown'
@@ -110,4 +111,5 @@ def worker_int(worker):
global LISTEN_PROCESS global LISTEN_PROCESS
if LISTEN_PROCESS is not None: if LISTEN_PROCESS is not None:
LISTEN_PROCESS.terminate() LISTEN_PROCESS.terminate()
SessionManager.get_instance().exit()
worker.log.info("Worker int processed.") worker.log.info("Worker int processed.")

+ 8
- 1
mindinsight/backend/data_manager/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -19,6 +19,11 @@ from mindinsight.conf import settings
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
from mindinsight.debugger.debugger_folder_analyzer import DebuggerFolderAnalyzer

ANALYZERS = {
"debugger_folder_analyzer": DebuggerFolderAnalyzer()
}




def init_module(app): def init_module(app):
@@ -31,6 +36,8 @@ def init_module(app):
""" """
# Just to suppress pylint warning about unused arg. # Just to suppress pylint warning about unused arg.
logger.debug("App: %s", type(app)) logger.debug("App: %s", type(app))
for analyzer in ANALYZERS.values():
DATA_MANAGER.register_folder_analyzer(analyzer)
DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater()) DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater())
# Let gunicorn load other modules first. # Let gunicorn load other modules first.
time.sleep(1) time.sleep(1)


+ 137
- 95
mindinsight/backend/debugger/debugger_api.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -19,22 +19,13 @@ from urllib.parse import unquote
from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request


from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.debugger.debugger_server import DebuggerServer
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.debugger.session_manager import SessionManager
from mindinsight.utils.exceptions import ParamMissError, ParamValueError


BLUEPRINT = Blueprint("debugger", __name__, BLUEPRINT = Blueprint("debugger", __name__,
url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX)




def _initialize_debugger_server():
"""Initialize a debugger server instance."""
enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
server = None
if enable_debugger:
server = DebuggerServer()
return server


def _unquote_param(param): def _unquote_param(param):
""" """
Decode parameter value. Decode parameter value.
@@ -77,8 +68,8 @@ def _wrap_reply(func, *args, **kwargs):
return jsonify(reply) return jsonify(reply)




@BLUEPRINT.route("/debugger/poll-data", methods=["GET"])
def poll_data():
@BLUEPRINT.route("/debugger/sessions/<session_id>/poll-data", methods=["GET"])
def poll_data(session_id):
""" """
Wait for data to be updated on UI. Wait for data to be updated on UI.


@@ -88,17 +79,17 @@ def poll_data():
str, the updated data. str, the updated data.


Examples: Examples:
>>> Get http://xxxx/v1/mindinsight/debugger/poll-data?pos=xx
>>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/poll-data?pos=xx
""" """
pos = request.args.get('pos') pos = request.args.get('pos')


reply = _wrap_reply(BACKEND_SERVER.poll_data, pos)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).poll_data, pos)


return reply return reply




@BLUEPRINT.route("/debugger/search", methods=["GET"])
def search():
@BLUEPRINT.route("/debugger/sessions/<session_id>/search", methods=["GET"])
def search(session_id):
""" """
Search nodes in specified watchpoint. Search nodes in specified watchpoint.


@@ -106,42 +97,25 @@ def search():
str, the required data. str, the required data.


Examples: Examples:
>>> Get http://xxxx/v1/mindinsight/debugger/search?name=mock_name&watch_point_id=1
>>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search?name=mock_name&watch_point_id=1
""" """
name = request.args.get('name') name = request.args.get('name')
graph_name = request.args.get('graph_name') graph_name = request.args.get('graph_name')
watch_point_id = int(request.args.get('watch_point_id', 0)) watch_point_id = int(request.args.get('watch_point_id', 0))
node_category = request.args.get('node_category') node_category = request.args.get('node_category')
reply = _wrap_reply(BACKEND_SERVER.search, {'name': name,
'graph_name': graph_name,
'watch_point_id': watch_point_id,
'node_category': node_category})
rank_id = int(request.args.get('rank_id', 0))
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search,
{'name': name,
'graph_name': graph_name,
'watch_point_id': watch_point_id,
'node_category': node_category,
'rand_id': rank_id})


return reply return reply




@BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"])
def retrieve_node_by_bfs():
"""
Search node by bfs.

Returns:
str, the required data.

Examples:
>>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true
"""
name = request.args.get('name')
graph_name = request.args.get('graph_name')
ascend = request.args.get('ascend', 'false')
ascend = ascend == 'true'
reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, graph_name, ascend)

return reply


@BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"])
def tensor_comparisons():
@BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-comparisons", methods=["GET"])
def tensor_comparisons(session_id):
""" """
Get tensor comparisons. Get tensor comparisons.


@@ -149,19 +123,21 @@ def tensor_comparisons():
str, the required data. str, the required data.


Examples: Examples:
>>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons
>>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-comparisons
""" """
name = request.args.get('name') name = request.args.get('name')
detail = request.args.get('detail', 'data') detail = request.args.get('detail', 'data')
shape = _unquote_param(request.args.get('shape')) shape = _unquote_param(request.args.get('shape'))
tolerance = request.args.get('tolerance', '0') tolerance = request.args.get('tolerance', '0')
reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance)
rank_id = int(request.args.get('rank_id', 0))
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).tensor_comparisons, name, shape, detail,
tolerance, rank_id)


return reply return reply




@BLUEPRINT.route("/debugger/retrieve", methods=["POST"])
def retrieve():
@BLUEPRINT.route("/debugger/sessions/<session_id>/retrieve", methods=["POST"])
def retrieve(session_id):
""" """
Retrieve data according to mode and params. Retrieve data according to mode and params.


@@ -169,17 +145,17 @@ def retrieve():
str, the required data. str, the required data.


Examples: Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/retrieve
>>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/retrieve
""" """
body = _read_post_request(request) body = _read_post_request(request)
mode = body.get('mode') mode = body.get('mode')
params = body.get('params') params = body.get('params')
reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve, mode, params)
return reply return reply




@BLUEPRINT.route("/debugger/tensor-history", methods=["POST"])
def retrieve_tensor_history():
@BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-history", methods=["POST"])
def retrieve_tensor_history(session_id):
""" """
Retrieve data according to mode and params. Retrieve data according to mode and params.


@@ -187,17 +163,19 @@ def retrieve_tensor_history():
str, the required data. str, the required data.


Examples: Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/tensor-history
>>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-history
""" """
body = _read_post_request(request) body = _read_post_request(request)
name = body.get('name') name = body.get('name')
graph_name = body.get('graph_name') graph_name = body.get('graph_name')
reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name, graph_name)
rank_id = body.get('rank_id')
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_history, name, graph_name,
rank_id)
return reply return reply




@BLUEPRINT.route("/debugger/tensors", methods=["GET"])
def retrieve_tensor_value():
@BLUEPRINT.route("/debugger/sessions/<session_id>/tensors", methods=["GET"])
def retrieve_tensor_value(session_id):
""" """
Retrieve tensor value according to name and shape. Retrieve tensor value according to name and shape.


@@ -205,20 +183,22 @@ def retrieve_tensor_value():
str, the required data. str, the required data.


Examples: Examples:
>>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=tensor_name&detail=data&shape=[1,1,:,:]
>>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensors?name=tensor_name&detail=data&shape=[1,1,:,:]
""" """
name = request.args.get('name') name = request.args.get('name')
detail = request.args.get('detail') detail = request.args.get('detail')
shape = _unquote_param(request.args.get('shape')) shape = _unquote_param(request.args.get('shape'))
graph_name = request.args.get('graph_name') graph_name = request.args.get('graph_name')
prev = bool(request.args.get('prev') == 'true') prev = bool(request.args.get('prev') == 'true')
rank_id = int(request.args.get('rank_id', 0))


reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape, graph_name, prev)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_value, name, detail,
shape, graph_name, prev, rank_id)
return reply return reply




@BLUEPRINT.route("/debugger/create-watchpoint", methods=["POST"])
def create_watchpoint():
@BLUEPRINT.route("/debugger/sessions/<session_id>/create-watchpoint", methods=["POST"])
def create_watchpoint(session_id):
""" """
Create watchpoint. Create watchpoint.


@@ -229,16 +209,16 @@ def create_watchpoint():
MindInsightException: If method fails to be called. MindInsightException: If method fails to be called.


Examples: Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/create-watchpoint
>>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/create-watchpoint
""" """
params = _read_post_request(request) params = _read_post_request(request)
params['watch_condition'] = params.pop('condition', None) params['watch_condition'] = params.pop('condition', None)
reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, params)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).create_watchpoint, params)
return reply return reply




@BLUEPRINT.route("/debugger/update-watchpoint", methods=["POST"])
def update_watchpoint():
@BLUEPRINT.route("/debugger/sessions/<session_id>/update-watchpoint", methods=["POST"])
def update_watchpoint(session_id):
""" """
Update watchpoint. Update watchpoint.


@@ -249,17 +229,17 @@ def update_watchpoint():
MindInsightException: If method fails to be called. MindInsightException: If method fails to be called.


Examples: Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/update-watchpoint
>>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/update-watchpoint
""" """
params = _read_post_request(request) params = _read_post_request(request)
reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, params)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).update_watchpoint, params)
return reply return reply




@BLUEPRINT.route("/debugger/delete-watchpoint", methods=["POST"])
def delete_watchpoint():
@BLUEPRINT.route("/debugger/sessions/<session_id>/delete-watchpoint", methods=["POST"])
def delete_watchpoint(session_id):
""" """
delete watchpoint.
Delete watchpoint.


Returns: Returns:
str, reply message. str, reply message.
@@ -268,19 +248,19 @@ def delete_watchpoint():
MindInsightException: If method fails to be called. MindInsightException: If method fails to be called.


Examples: Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/delete-watchpoint
>>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/delete-watchpoint
""" """
body = _read_post_request(request) body = _read_post_request(request)


watch_point_id = body.get('watch_point_id') watch_point_id = body.get('watch_point_id')


reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).delete_watchpoint, watch_point_id)


return reply return reply




@BLUEPRINT.route("/debugger/control", methods=["POST"])
def control():
@BLUEPRINT.route("/debugger/sessions/<session_id>/control", methods=["POST"])
def control(session_id):
""" """
Control request. Control request.


@@ -291,16 +271,16 @@ def control():
MindInsightException: If method fails to be called. MindInsightException: If method fails to be called.


Examples: Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/control
>>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/control
""" """
params = _read_post_request(request) params = _read_post_request(request)
reply = _wrap_reply(BACKEND_SERVER.control, params)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).control, params)


return reply return reply




@BLUEPRINT.route("/debugger/recheck", methods=["POST"])
def recheck():
@BLUEPRINT.route("/debugger/sessions/<session_id>/recheck", methods=["POST"])
def recheck(session_id):
""" """
Recheck request. Recheck request.


@@ -311,15 +291,15 @@ def recheck():
MindInsightException: If method fails to be called. MindInsightException: If method fails to be called.


Examples: Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/recheck
>>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/recheck
""" """
reply = _wrap_reply(BACKEND_SERVER.recheck)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).recheck)


return reply return reply




@BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"])
def retrieve_tensor_graph():
@BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-graphs", methods=["GET"])
def retrieve_tensor_graph(session_id):
""" """
Retrieve tensor value according to name and shape. Retrieve tensor value according to name and shape.


@@ -327,16 +307,18 @@ def retrieve_tensor_graph():
str, the required data. str, the required data.


Examples: Examples:
>>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name
>>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-graphs?tensor_name=xxx&graph_name=xxx
""" """
tensor_name = request.args.get('tensor_name') tensor_name = request.args.get('tensor_name')
graph_name = request.args.get('graph_name') graph_name = request.args.get('graph_name')
reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_graph, tensor_name, graph_name)
rank_id = int(request.args.get('rank_id', 0))
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_graph, tensor_name,
graph_name, rank_id)
return reply return reply




@BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"])
def retrieve_tensor_hits():
@BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-hits", methods=["GET"])
def retrieve_tensor_hits(session_id):
""" """
Retrieve tensor value according to name and shape. Retrieve tensor value according to name and shape.


@@ -344,16 +326,18 @@ def retrieve_tensor_hits():
str, the required data. str, the required data.


Examples: Examples:
>>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name
>>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-hits?tensor_name=xxx&graph_name=xxx
""" """
tensor_name = request.args.get('tensor_name') tensor_name = request.args.get('tensor_name')
graph_name = request.args.get('graph_name') graph_name = request.args.get('graph_name')
reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_hits, tensor_name, graph_name)
rank_id = int(request.args.get('rank_id', 0))
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_hits, tensor_name,
graph_name, rank_id)
return reply return reply




@BLUEPRINT.route("/debugger/search-watchpoint-hits", methods=["POST"])
def search_watchpoint_hits():
@BLUEPRINT.route("/debugger/sessions/<session_id>/search-watchpoint-hits", methods=["POST"])
def search_watchpoint_hits(session_id):
""" """
Search watchpoint hits by group condition. Search watchpoint hits by group condition.


@@ -361,15 +345,75 @@ def search_watchpoint_hits():
str, the required data. str, the required data.


Examples: Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/search-watchpoint-hits
>>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search-watchpoint-hits
""" """
body = _read_post_request(request) body = _read_post_request(request)
group_condition = body.get('group_condition') group_condition = body.get('group_condition')
reply = _wrap_reply(BACKEND_SERVER.search_watchpoint_hits, group_condition)
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search_watchpoint_hits, group_condition)
return reply


@BLUEPRINT.route("/debugger/sessions/<session_id>/condition-collections", methods=["GET"])
def get_condition_collections(session_id):
"""Get condition collections."""
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).get_condition_collections)
return reply


@BLUEPRINT.route("/debugger/sessions/<session_id>/set-recommended-watch-points", methods=["POST"])
def set_recommended_watch_points(session_id):
"""Set recommended watch points."""
body = _read_post_request(request)
request_body = body.get('requestBody')
if request_body is None:
raise ParamMissError('requestBody')

set_recommended = request_body.get('set_recommended')
reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).set_recommended_watch_points,
set_recommended)
return reply return reply




BACKEND_SERVER = _initialize_debugger_server()
@BLUEPRINT.route("/debugger/sessions", methods=["POST"])
def creat_session():
"""
Get session id if session exist, else create a session.

Returns:
str, session id.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/get-session
"""
body = _read_post_request(request)
summary_dir = body.get('dump_dir')
session_type = body.get('session_type')
reply = _wrap_reply(SessionManager.get_instance().creat_session, session_type, summary_dir)
return reply


@BLUEPRINT.route("/debugger/sessions", methods=["GET"])
def get_sessions():
"""
Check the cuurent active sessions.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/check-sessions
"""
reply = _wrap_reply(SessionManager.get_instance().get_sessions)
return reply


@BLUEPRINT.route("/debugger/sessions/<session_id>/delete", methods=["POST"])
def delete_session(session_id):
"""
Delete session by session id.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/xxx/delete-session
"""
reply = _wrap_reply(SessionManager.get_instance().delete_session, session_id)
return reply




def init_module(app): def init_module(app):
@@ -380,5 +424,3 @@ def init_module(app):
app (Flask): The application obj. app (Flask): The application obj.
""" """
app.register_blueprint(BLUEPRINT) app.register_blueprint(BLUEPRINT)
if BACKEND_SERVER:
BACKEND_SERVER.start()

+ 19
- 2
mindinsight/datavisual/data_transform/data_manager.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -112,6 +112,11 @@ class _BasicTrainJob:
"""Get the lineage files count in the summary dir.""" """Get the lineage files count in the summary dir."""
return self._entry['lineage_files'] return self._entry['lineage_files']


@property
def dump_dir(self):
"""Get the dump file path in the summary dir."""
return self._entry.get('dump_dir', None)



class CachedTrainJob: class CachedTrainJob:
""" """
@@ -369,6 +374,10 @@ class _BaseCacheManager:
class _BriefCacheManager(_BaseCacheManager): class _BriefCacheManager(_BaseCacheManager):
"""A cache manager that holds all disk train jobs on disk.""" """A cache manager that holds all disk train jobs on disk."""


def __init__(self, summary_base_dir):
super(_BriefCacheManager, self).__init__(summary_base_dir)
self._summary_watcher = SummaryWatcher()

def cache_train_job(self, train_id): def cache_train_job(self, train_id):
""" """
Cache given train job. Cache given train job.
@@ -386,7 +395,7 @@ class _BriefCacheManager(_BaseCacheManager):
def update_cache(self, executor): def update_cache(self, executor):
"""Update cache.""" """Update cache."""
logger.info('Start to update BriefCacheManager.') logger.info('Start to update BriefCacheManager.')
summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir)
summaries_info = self._summary_watcher.list_summary_directories(self._summary_base_dir)


basic_train_jobs = [] basic_train_jobs = []
for info in summaries_info: for info in summaries_info:
@@ -425,6 +434,10 @@ class _BriefCacheManager(_BaseCacheManager):


return new_cache_items return new_cache_items


def register_folder_analyzer(self, analyzer):
"""Register folder analyzer."""
self._summary_watcher.register_folder_analyzer(analyzer)

@property @property
def cache_items(self): def cache_items(self):
"""Get cache items.""" """Get cache items."""
@@ -1028,6 +1041,10 @@ class DataManager:
"""Register brief cache item updater for brief cache manager.""" """Register brief cache item updater for brief cache manager."""
self._brief_cache.register_cache_item_updater(updater) self._brief_cache.register_cache_item_updater(updater)


def register_folder_analyzer(self, analyzer):
"""Register folder analyzer."""
self._brief_cache.register_folder_analyzer(analyzer)

def get_brief_cache(self): def get_brief_cache(self):
"""Get brief cache.""" """Get brief cache."""
return self._brief_cache return self._brief_cache


+ 4
- 2
mindinsight/datavisual/data_transform/graph/msgraph.py View File

@@ -254,22 +254,24 @@ class MSGraph(Graph):


return searched_list return searched_list


def search_leaf_nodes_by_pattern(self, pattern):
def search_leaf_nodes_by_pattern(self, pattern, scope_pattern=False):
""" """
Search leaf node by a given pattern. Search leaf node by a given pattern.


Args: Args:
pattern (Union[str, None]): The pattern of the node to search, pattern (Union[str, None]): The pattern of the node to search,
if None, return all node names. if None, return all node names.
scope_pattern (bool): If true, return the children nodes of the scope. Default: False.


Returns: Returns:
list[Node], a list of nodes. list[Node], a list of nodes.
""" """
is_match = lambda x, y: x.lower().startswith(y) if scope_pattern else y in x.lower()
if pattern is not None: if pattern is not None:
pattern = pattern.lower() pattern = pattern.lower()
searched_nodes = [ searched_nodes = [
node for name, node in self._leaf_nodes.items() node for name, node in self._leaf_nodes.items()
if pattern in name.lower()
if is_match(name, pattern)
] ]
else: else:
searched_nodes = [node for node in self._leaf_nodes.values()] searched_nodes = [node for node in self._leaf_nodes.values()]


+ 46
- 8
mindinsight/datavisual/data_transform/summary_watcher.py View File

@@ -29,6 +29,7 @@ from mindinsight.utils.exceptions import FileSystemPermissionError
LINEAGE_SUMMARY_SUFFIX = '_lineage' LINEAGE_SUMMARY_SUFFIX = '_lineage'
EXPLAIN_SUMMARY_SUFFIX = '_explain' EXPLAIN_SUMMARY_SUFFIX = '_explain'
DUMP_FILE_PREFIX = 'dump_'
class SummaryWatcher: class SummaryWatcher:
@@ -45,6 +46,13 @@ class SummaryWatcher:
# to avoid long-time blocking # to avoid long-time blocking
MAX_SCAN_COUNT = 20000 MAX_SCAN_COUNT = 20000
def __init__(self):
self._analyzers = []
def register_folder_analyzer(self, analyzer):
"""Register folder analyzer."""
self._analyzers.append(analyzer)
def list_summary_directories(self, summary_base_dir, overall=True, list_explain=False): def list_summary_directories(self, summary_base_dir, overall=True, list_explain=False):
""" """
List summary directories within base directory. List summary directories within base directory.
@@ -104,7 +112,7 @@ class SummaryWatcher:
elif entry.is_dir(): elif entry.is_dir():
self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry, list_explain) self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry, list_explain)
entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) entry_path = os.path.realpath(os.path.join(summary_base_dir, entry.name))
self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry.name, counter, list_explain)
self._scan_subdir_entries(summary_dict, summary_base_dir, entry_path, entry, counter, list_explain)
directories = [] directories = []
for key, value in summary_dict.items(): for key, value in summary_dict.items():
@@ -119,7 +127,7 @@ class SummaryWatcher:
return directories return directories
def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry_name, counter, list_explain):
def _scan_subdir_entries(self, summary_dict, summary_base_dir, entry_path, entry, counter, list_explain):
""" """
Scan subdir entries. Scan subdir entries.
@@ -134,7 +142,7 @@ class SummaryWatcher:
try: try:
subdir_entries = os.scandir(entry_path) subdir_entries = os.scandir(entry_path)
except PermissionError: except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry_name)
logger.warning('Path of %s under summary base directory is not accessible.', entry.name)
return return
# sort in ascending order according to modification time. # sort in ascending order according to modification time.
@@ -149,11 +157,14 @@ class SummaryWatcher:
logger.info('Stop further scanning due to overall is False and ' logger.info('Stop further scanning due to overall is False and '
'number of scanned files exceeds upper limit.') 'number of scanned files exceeds upper limit.')
break break
subdir_relative_path = os.path.join('.', entry_name)
subdir_relative_path = os.path.join('.', entry.name)
if subdir_entry.is_symlink(): if subdir_entry.is_symlink():
pass pass
self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry, list_explain) self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry, list_explain)
relative_path = './'
self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict)
def _is_valid_summary_directory(self, summary_base_dir, relative_path): def _is_valid_summary_directory(self, summary_base_dir, relative_path):
""" """
Check if the given summary directory is valid. Check if the given summary directory is valid.
@@ -198,13 +209,11 @@ class SummaryWatcher:
list_explain (bool): Indicates whether to list only the mindexplain folder. list_explain (bool): Indicates whether to list only the mindexplain folder.
""" """
try: try:
stat = entry.stat()
ctime, mtime = self._get_stat_time(entry)
except FileNotFoundError: except FileNotFoundError:
logger.warning('File %s not found', entry.name) logger.warning('File %s not found', entry.name)
return return
ctime = datetime.datetime.fromtimestamp(stat.st_ctime).astimezone()
mtime = datetime.datetime.fromtimestamp(stat.st_mtime).astimezone()
if entry.is_file(): if entry.is_file():
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
@@ -238,7 +247,10 @@ class SummaryWatcher:
summary_dict[relative_path]['explain_files'] += 1 summary_dict[relative_path]['explain_files'] += 1
else: else:
summary_dict[relative_path]['summary_files'] += 1 summary_dict[relative_path]['summary_files'] += 1
self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict)
elif entry.is_dir(): elif entry.is_dir():
self._check_by_analyzers(entry, summary_base_dir, relative_path, summary_dict)
if list_explain: if list_explain:
return return
@@ -261,6 +273,28 @@ class SummaryWatcher:
else: else:
summary_dict[relative_path] = _new_entry(ctime, mtime, profiler) summary_dict[relative_path] = _new_entry(ctime, mtime, profiler)
def _check_by_analyzers(self, entry, summary_base_dir, relative_path, summary_dict):
"""Check by all analyzers."""
try:
ctime, mtime = self._get_stat_time(entry)
except FileNotFoundError:
logger.warning('File %s not found', entry.name)
return
for analyzer in self._analyzers:
register_info = analyzer.analyze(entry, summary_base_dir, relative_path)
if register_info:
if relative_path not in summary_dict:
summary_dict[relative_path] = _new_entry(ctime, mtime)
summary_dict[relative_path].update(register_info)
def _get_stat_time(self, entry):
"""Get ctime and mtime."""
stat = entry.stat()
ctime = datetime.datetime.fromtimestamp(stat.st_ctime).astimezone()
mtime = datetime.datetime.fromtimestamp(stat.st_mtime).astimezone()
return ctime, mtime
def _find_profiler_dir(self, entry, summary_base_dir, relative_path): def _find_profiler_dir(self, entry, summary_base_dir, relative_path):
"""Find profiler dir by the given relative path.""" """Find profiler dir by the given relative path."""
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name) profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
@@ -342,6 +376,9 @@ class SummaryWatcher:
if self._is_valid_profiler_directory(full_path)[0] or \ if self._is_valid_profiler_directory(full_path)[0] or \
self._is_valid_cluster_profiler_directory(full_path)[0]: self._is_valid_cluster_profiler_directory(full_path)[0]:
return True return True
if os.path.exists(os.path.join(summary_directory, os.path.join(entry.name, ".metadata"))):
return True
return False return False
def _is_valid_profiler_directory(self, directory): def _is_valid_profiler_directory(self, directory):
@@ -515,7 +552,8 @@ def _new_entry(ctime, mtime, profiler=None):
'lineage_files': 0, 'lineage_files': 0,
'explain_files': 0, 'explain_files': 0,
'graph_files': 0, 'graph_files': 0,
'profiler': profiler
'profiler': profiler,
'dump_dir': None
} }


+ 3
- 2
mindinsight/datavisual/processors/train_task_manager.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -150,7 +150,8 @@ class TrainTaskManager(BaseProcessor):
profiler_type=basic_info.profiler_type, profiler_type=basic_info.profiler_type,
summary_files=basic_info.summary_files, summary_files=basic_info.summary_files,
graph_files=basic_info.graph_files, graph_files=basic_info.graph_files,
lineage_files=basic_info.lineage_files
lineage_files=basic_info.lineage_files,
dump_dir=basic_info.dump_dir
) )


if train_job.cache_status != CacheStatus.NOT_IN_CACHE: if train_job.cache_status != CacheStatus.NOT_IN_CACHE:


+ 17
- 1
mindinsight/debugger/common/exceptions/error_code.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -20,6 +20,8 @@ from mindinsight.utils.constant import DebuggerErrors as DebuggerErrorCodes
_PARAM_ERROR_MASK = 0b00001 << 7 _PARAM_ERROR_MASK = 0b00001 << 7
_DEBUGGER_GRAPH_ERROR = 0b00010 << 7 _DEBUGGER_GRAPH_ERROR = 0b00010 << 7
_DEBUGGER_RUNNING_ERROR = 0b00011 << 7 _DEBUGGER_RUNNING_ERROR = 0b00011 << 7
_DEBUGGER_SERVER_ERROR = 0b00100 << 7
_DEBUGGER_SESSION_ERROR = 0b00101 << 7




@unique @unique
@@ -44,6 +46,13 @@ class DebuggerErrors(DebuggerErrorCodes):
TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR TENSOR_HIT_ERROR = 8 | _DEBUGGER_RUNNING_ERROR
SET_RECOMMEND_WATCHPOINT_ERROR = 9 | _DEBUGGER_RUNNING_ERROR SET_RECOMMEND_WATCHPOINT_ERROR = 9 | _DEBUGGER_RUNNING_ERROR


DEBUGGER_SERVER_RUNNING_ERROR = 0 | _DEBUGGER_SERVER_ERROR
DEVICE_ID_UNREGISTERED = 1 | _DEBUGGER_SERVER_ERROR
MODULE_NOT_FOUND_ERROR = 2 | _DEBUGGER_SERVER_ERROR

DEBUGGER_SESSION_OVER_BOUND_ERROR = 0 | _DEBUGGER_SESSION_ERROR
DEBUGGER_SESSION_NOT_FOUND_ERROR = 1 | _DEBUGGER_SESSION_ERROR



@unique @unique
class DebuggerErrorMsg(Enum): class DebuggerErrorMsg(Enum):
@@ -63,3 +72,10 @@ class DebuggerErrorMsg(Enum):
TENSOR_GRAPH_ERROR = "Get tensor graphs failed." TENSOR_GRAPH_ERROR = "Get tensor graphs failed."
TENSOR_HIT_ERROR = "Get tensor hits failed." TENSOR_HIT_ERROR = "Get tensor hits failed."
SET_RECOMMEND_WATCHPOINT_ERROR = "Set Recommend Watchpoints failed." SET_RECOMMEND_WATCHPOINT_ERROR = "Set Recommend Watchpoints failed."

DEBUGGER_SERVER_RUNNING_ERROR = "Debugger server running error. {}"
DEVICE_ID_UNREGISTERED = "Device id unregistered. Device id: {}"
MODULE_NOT_FOUND_ERROR = "{} module not found."

DEBUGGER_SESSION_OVER_BOUND_ERROR = "The amount of sessions is over limitation."
DEBUGGER_SESSION_NOT_FOUND_ERROR = "Session {} not found."

+ 56
- 1
mindinsight/debugger/common/exceptions/exceptions.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -190,3 +190,58 @@ class DebuggerConditionUnavailableError(MindInsightException):
message=DebuggerErrorMsg.DEBUGGER_CONDITION_UNAVAILABLE_ERROR.value.format(msg), message=DebuggerErrorMsg.DEBUGGER_CONDITION_UNAVAILABLE_ERROR.value.format(msg),
http_code=400 http_code=400
) )


class DebuggerServerRunningError(MindInsightException):
"""The condition unavailable error in debugger module."""

def __init__(self, msg):
super(DebuggerServerRunningError, self).__init__(
error=DebuggerErrors.DEBUGGER_SERVER_RUNNING_ERROR,
message=DebuggerErrorMsg.DEBUGGER_SERVER_RUNNING_ERROR.value.format(msg),
http_code=500
)


class DeviceIdUnregistered(MindInsightException):
"""The condition unavailable error in debugger module."""

def __init__(self, msg):
super(DeviceIdUnregistered, self).__init__(
error=DebuggerErrors.DEVICE_ID_UNREGISTERED,
message=DebuggerErrorMsg.DEVICE_ID_UNREGISTERED.value.format(msg),
http_code=400
)


class DebuggerModuleNotFoundError(MindInsightException):
"""The condition unavailable error in debugger module."""

def __init__(self, msg):
super(DebuggerModuleNotFoundError, self).__init__(
error=DebuggerErrors.MODULE_NOT_FOUND_ERROR,
message=DebuggerErrorMsg.MODULE_NOT_FOUND_ERROR.value.format(msg),
http_code=500
)


class DebuggerSessionNumOverBoundError(MindInsightException):
"""The condition unavailable error in debugger module."""

def __init__(self):
super(DebuggerSessionNumOverBoundError, self).__init__(
error=DebuggerErrors.DEBUGGER_SESSION_OVER_BOUND_ERROR,
message=DebuggerErrorMsg.DEBUGGER_SESSION_OVER_BOUND_ERROR.value,
http_code=400
)


class DebuggerSessionNotFoundError(MindInsightException):
"""The condition unavailable error in debugger module."""

def __init__(self, msg):
super(DebuggerSessionNotFoundError, self).__init__(
error=DebuggerErrors.DEBUGGER_SESSION_NOT_FOUND_ERROR,
message=DebuggerErrorMsg.DEBUGGER_SESSION_NOT_FOUND_ERROR.value.format(msg),
http_code=400
)

+ 29
- 2
mindinsight/debugger/common/utils.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -38,9 +38,12 @@ NUMPY_TYPE_MAP = {
'DT_FLOAT32': np.float32, 'DT_FLOAT32': np.float32,
'DT_FLOAT64': np.float64, 'DT_FLOAT64': np.float64,


'DT_STRING': np.str
'DT_STRING': np.str,
'DT_TYPE': np.str
} }


MS_VERSION = '1.0.x'



@enum.unique @enum.unique
class ReplyStates(enum.Enum): class ReplyStates(enum.Enum):
@@ -71,6 +74,7 @@ class Streams(enum.Enum):
TENSOR = 'tensor' TENSOR = 'tensor'
WATCHPOINT = 'watchpoint' WATCHPOINT = 'watchpoint'
WATCHPOINT_HIT = 'watchpoint_hit' WATCHPOINT_HIT = 'watchpoint_hit'
DEVICE = 'device'




class RunLevel(enum.Enum): class RunLevel(enum.Enum):
@@ -152,3 +156,26 @@ def is_scope_type(node_type):
def is_cst_type(node_type): def is_cst_type(node_type):
"""Judge whether the type is const type.""" """Judge whether the type is const type."""
return node_type == NodeTypeEnum.CONST.value return node_type == NodeTypeEnum.CONST.value


def version_match(ms_version, mi_version):
"""Judge if the version of Mindinsight and Mindspore is matched."""
if not ms_version:
ms_version = MS_VERSION
mi_major, mi_minor = mi_version.split('.')[:2]
ms_major, ms_minor = ms_version.split('.')[:2]
return mi_major == ms_major and mi_minor == ms_minor


@enum.unique
class DebuggerServerMode(enum.Enum):
"""Debugger Server Mode."""
ONLINE = 'online'
OFFLINE = 'offline'


class DumpSettings(enum.Enum):
"""Dump settings."""
E2E_DUMP_SETTINGS = 'e2e_dump_settings'
COMMON_DUMP_SETTINGS = 'common_dump_settings'
ASYNC_DUMP_SETTINGS = 'async_dump_settings'

+ 26
- 15
mindinsight/debugger/conditionmgr/recommender.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -64,13 +64,13 @@ class _ConditionParameterValue:
return self.parameter.name return self.parameter.name




def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_context):
def recommend_watchpoints(condition_mgr: ConditionMgr, multi_card_graph_stream, condition_context):
""" """
Recommend watchpoints. Recommend watchpoints.


Args: Args:
condition_mgr (ConditionMgr): Condition manager instance. condition_mgr (ConditionMgr): Condition manager instance.
graph_stream (GraphHandler): Graph handler instance.
multi_card_graph_stream (GraphHandler): Multi card graph handler instance.
condition_context (ConditionContext): Context for condition. condition_context (ConditionContext): Context for condition.


Returns: Returns:
@@ -78,7 +78,7 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c
""" """
watch_points = [] watch_points = []


if not graph_stream.graph:
if not multi_card_graph_stream.has_graph:
logger.warning("Given graph is None.") logger.warning("Given graph is None.")
return watch_points return watch_points


@@ -86,7 +86,7 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c
return watch_points return watch_points


# add weight watch points # add weight watch points
merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, graph_stream)
merged_info = get_basic_node_info(TargetTypeEnum.WEIGHT.value, multi_card_graph_stream)
_recommend_weight_initialization(merged_info, condition_mgr, watch_points, condition_context) _recommend_weight_initialization(merged_info, condition_mgr, watch_points, condition_context)
_recommend_weight_change_too_large(merged_info, condition_mgr, watch_points, condition_context) _recommend_weight_change_too_large(merged_info, condition_mgr, watch_points, condition_context)


@@ -97,25 +97,27 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c
_recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context) _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, watch_points, condition_context)


# add gradient watch points # add gradient watch points
merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, graph_stream)
merged_info = get_basic_node_info(TargetTypeEnum.GRADIENT.value, multi_card_graph_stream)
_recommend_gradient_vanishing(merged_info, condition_mgr, watch_points, condition_context) _recommend_gradient_vanishing(merged_info, condition_mgr, watch_points, condition_context)


# add tensor watch points # add tensor watch points
merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, graph_stream)
merged_info = get_basic_node_info(TargetTypeEnum.TENSOR.value, multi_card_graph_stream)
_recommend_operator_overflow(merged_info, condition_mgr, watch_points, condition_context) _recommend_operator_overflow(merged_info, condition_mgr, watch_points, condition_context)
_recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context) _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context)
_recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context) _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context)


# add activation watch points # add activation watch points
merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.TANH.value)
merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream,
ActivationFuncEnum.TANH.value)
_recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
ActivationFuncEnum.TANH.value) ActivationFuncEnum.TANH.value)


merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.SIGMOID.value)
merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream,
ActivationFuncEnum.SIGMOID.value)
_recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
ActivationFuncEnum.SIGMOID.value) ActivationFuncEnum.SIGMOID.value)


merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream,
merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, multi_card_graph_stream,
[ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value]) [ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value])
_recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
ActivationFuncEnum.RELU.value) ActivationFuncEnum.RELU.value)
@@ -318,12 +320,21 @@ def _recommend_activation_range(basic_info_nodes, condition_mgr, watch_points, c
watch_points.append(activation_range_watchpoint) watch_points.append(activation_range_watchpoint)




def get_basic_node_info(node_category, graph_stream, activation_func=None):
def get_basic_node_info(node_category, multi_card_graph_stream, activation_func=None):
"""Get node merged info.""" """Get node merged info."""
basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func)
merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph)
merged_info = _add_graph_name(merged_info, graph_stream)
return merged_info
nodes_for_devices = {}
has_node = False
for rank_id, graph_stream in multi_card_graph_stream.graph_handlers.items():
basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func)
merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph)
merged_info = _add_graph_name(merged_info, graph_stream)
nodes_for_devices[rank_id] = merged_info
has_node = has_node or merged_info

if has_node:
return nodes_for_devices

return {}




def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None): def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None):


+ 10
- 10
mindinsight/debugger/debugger_cache.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -17,17 +17,19 @@ import sys


from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import Streams from mindinsight.debugger.common.utils import Streams
from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \
TensorHandler, WatchpointHandler, WatchpointHitHandler
from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, MultiCardGraphHandler, \
MultiCardTensorHandler, WatchpointHandler, MultiCardWatchpointHitHandler
from mindinsight.debugger.stream_handler.device_handler import DeviceHandler


STREAM_HANDLER_MAP = { STREAM_HANDLER_MAP = {
Streams.COMMAND.value: EventHandler, Streams.COMMAND.value: EventHandler,
Streams.DATA.value: EventHandler, Streams.DATA.value: EventHandler,
Streams.METADATA.value: MetadataHandler, Streams.METADATA.value: MetadataHandler,
Streams.GRAPH.value: GraphHandler,
Streams.TENSOR.value: TensorHandler,
Streams.GRAPH.value: MultiCardGraphHandler,
Streams.TENSOR.value: MultiCardTensorHandler,
Streams.WATCHPOINT.value: WatchpointHandler, Streams.WATCHPOINT.value: WatchpointHandler,
Streams.WATCHPOINT_HIT.value: WatchpointHitHandler
Streams.WATCHPOINT_HIT.value: MultiCardWatchpointHitHandler,
Streams.DEVICE.value: DeviceHandler
} }




@@ -40,10 +42,8 @@ class DebuggerCache:
def initialize(self): def initialize(self):
"""Initialize the stream handlers.""" """Initialize the stream handlers."""
self._stream_handler = {} self._stream_handler = {}
for stream in Streams:
mode = stream.value
stream_handler = STREAM_HANDLER_MAP.get(mode)
self._stream_handler[mode] = stream_handler()
for mode, stream_class in STREAM_HANDLER_MAP.items():
self._stream_handler[mode] = stream_class()


def clean(self): def clean(self):
"""Clean cache for all stream.""" """Clean cache for all stream."""


+ 41
- 0
mindinsight/debugger/debugger_folder_analyzer.py View File

@@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Debugger train job register."""

import os

from mindinsight.utils.folder_analyzer import FolderAnalyzer
from mindinsight.datavisual.common.log import logger

class DebuggerFolderAnalyzer(FolderAnalyzer):
"""Debugger train job register."""
def analyze(self, entry, summary_base_dir, relative_path):
"""Check dir by debugger register."""
update_info = {}
if entry.is_dir():
sub_relative_path = os.path.join(relative_path, entry.name)
entry_path = os.path.join(summary_base_dir, sub_relative_path)
try:
subdir_entries = os.scandir(entry_path)
except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry.name)
return update_info
subdir_entries = [subdir_entry for subdir_entry in subdir_entries if not subdir_entry.is_symlink()]
subdir_entries = sorted(subdir_entries, key=lambda x: x.stat().st_mtime)
for subdir_entry in subdir_entries:
if subdir_entry.is_dir() and subdir_entry.name.startswith(".metadata"):
update_info = {'dump_dir': sub_relative_path}
return update_info
return update_info

+ 15
- 0
mindinsight/debugger/debugger_services/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Debugger server module."""

mindinsight/debugger/debugger_grpc_server.py → mindinsight/debugger/debugger_services/debugger_grpc_server.py View File

@@ -19,7 +19,7 @@ from functools import wraps
import mindinsight import mindinsight
from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
Streams, RunLevel
Streams, RunLevel, version_match
from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum, ParamNameEnum from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum, ParamNameEnum
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto
@@ -117,9 +117,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
# clean cache data at the beginning of new step or node has been changed. # clean cache data at the beginning of new step or node has been changed.
if is_new_step or is_new_node: if is_new_step or is_new_node:
self._cache_store.clean_data() self._cache_store.clean_data()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).clean_tensors(
request.cur_step)
if is_new_step: if is_new_step:
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean()
# receive graph at the beginning of the training # receive graph at the beginning of the training
if self._status == ServerStatus.RECEIVE_GRAPH: if self._status == ServerStatus.RECEIVE_GRAPH:
self._send_graph_flag(metadata_stream) self._send_graph_flag(metadata_stream)
@@ -141,7 +142,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._status = ServerStatus.WAITING self._status = ServerStatus.WAITING
metadata_stream.state = ServerStatus.WAITING.value metadata_stream.state = ServerStatus.WAITING.value
metadata = metadata_stream.get() metadata = metadata_stream.get()
res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get()
res.update(metadata) res.update(metadata)
self._cache_store.put_data(res) self._cache_store.put_data(res)
log.debug("Put graph into data queue.") log.debug("Put graph into data queue.")
@@ -157,7 +158,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
# put new metadata into cache # put new metadata into cache
metadata_stream.put(metadata_proto) metadata_stream.put(metadata_proto)
# update current node name and graph name # update current node name and graph name
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)
full_name = metadata_proto.cur_node full_name = metadata_proto.cur_node
graph_name = graph_stream.get_graph_id_by_full_name( graph_name = graph_stream.get_graph_id_by_full_name(
full_name) if full_name else metadata_stream.graph_name full_name) if full_name else metadata_stream.graph_name
@@ -182,7 +183,8 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):


def _send_watchpoint_hit_flag(self): def _send_watchpoint_hit_flag(self):
"""Send Watchpoint hit flag.""" """Send Watchpoint hit flag."""
watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(
0)
if not self._received_hit: if not self._received_hit:
return return
watchpoint_hits = self._received_hit watchpoint_hits = self._received_hit
@@ -344,7 +346,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
run_cmd.node_name = '' run_cmd.node_name = ''
# clean watchpoint hit cache # clean watchpoint hit cache
if run_cmd.run_level == RunLevel.RECHECK.value: if run_cmd.run_level == RunLevel.RECHECK.value:
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get_hit_handler_by_rank_id(0).clean()
log.debug("Receive RunCMD. Clean watchpoint hit cache.") log.debug("Receive RunCMD. Clean watchpoint hit cache.")
# update metadata state from sending to running # update metadata state from sending to running
metadata_stream.state = ServerStatus.RUNNING.value metadata_stream.state = ServerStatus.RUNNING.value
@@ -365,8 +367,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.info("The training from %s has finished.", client_ip) log.info("The training from %s has finished.", client_ip)
else: else:
ms_version = request.ms_version ms_version = request.ms_version
if not ms_version:
ms_version = '1.0.x'
if version_match(ms_version, mindinsight.__version__) is False: if version_match(ms_version, mindinsight.__version__) is False:
log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s", log.info("Version is mismatched, mindspore is: %s, mindinsight is: %s",
ms_version, mindinsight.__version__) ms_version, mindinsight.__version__)
@@ -403,8 +403,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
graph = GraphProto.FromString(serial_graph) graph = GraphProto.FromString(serial_graph)
log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node)) log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node))
graph_dict = {graph.name: graph} graph_dict = {graph.name: graph}
self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict)
self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals(
graph.const_vals)
self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name
self._record_parameter_names() self._record_parameter_names()
self._status = ServerStatus.RECEIVE_GRAPH self._status = ServerStatus.RECEIVE_GRAPH
@@ -429,10 +430,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name, log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name,
len(sub_graph.node)) len(sub_graph.node))
serial_graph = b"" serial_graph = b""
self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(
self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals(
sub_graph.const_vals) sub_graph.const_vals)


self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).put(graph_dict)
self._record_parameter_names() self._record_parameter_names()
self._status = ServerStatus.RECEIVE_GRAPH self._status = ServerStatus.RECEIVE_GRAPH
log.debug("Send the reply for graph.") log.debug("Send the reply for graph.")
@@ -440,9 +441,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):


def _record_parameter_names(self): def _record_parameter_names(self):
"""Record parameter full names in tensor handler.""" """Record parameter full names in tensor handler."""
parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).search_in_graph(
pattern={'node_category': TargetTypeEnum.PARAMETER.value})
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)\
.search_in_graph(pattern={'node_category': TargetTypeEnum.PARAMETER.value})
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0)
for node in parameter_nodes: for node in parameter_nodes:
tensor_name = [node.full_name + ':0'] tensor_name = [node.full_name + ':0']
tensor_stream.record_parameter_names(tensor_name) tensor_stream.record_parameter_names(tensor_name)
@@ -452,7 +453,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
"""Send tensors into DebuggerCache.""" """Send tensors into DebuggerCache."""
log.info("Received tensor.") log.info("Received tensor.")
tensor_contents = [] tensor_contents = []
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(0)
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
step = metadata_stream.step step = metadata_stream.step
for tensor in request_iterator: for tensor in request_iterator:
@@ -482,7 +483,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
# save the watchpoint_hits data # save the watchpoint_hits data
watchpoint_hits = [] watchpoint_hits = []
watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0)
for watchpoint_hit_proto in request_iterator: for watchpoint_hit_proto in request_iterator:
node_full_name = watchpoint_hit_proto.tensor.node_name node_full_name = watchpoint_hit_proto.tensor.node_name
graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) graph_name = graph_stream.get_graph_id_by_full_name(node_full_name)
@@ -517,10 +518,3 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._received_hit = watchpoint_hits self._received_hit = watchpoint_hits
reply = get_ack_reply() reply = get_ack_reply()
return reply return reply


def version_match(mi_version, ms_version):
"""Judge if the version of Mindinsight and Mindspore is matched"""
mi_major, mi_minor = mi_version.split('.')[:2]
ms_major, ms_minor = ms_version.split('.')[:2]
return mi_major == ms_major and mi_minor == ms_minor

+ 613
- 0
mindinsight/debugger/debugger_services/debugger_offline_server.py View File

@@ -0,0 +1,613 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Debugger Offline server."""
import copy
from collections import defaultdict
from importlib import import_module
from threading import Event
from multiprocessing import Process, Manager

import mindinsight
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.debugger.common.exceptions.exceptions import DebuggerModuleNotFoundError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import Streams, ServerStatus, version_match, DebuggerServerMode, get_ack_reply, \
RunLevel
from mindinsight.debugger.conditionmgr.condition import ParamNameEnum
from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase, debugger_server_wrap
from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto
from mindinsight.debugger.stream_cache.data_loader import DataLoader
from mindinsight.utils.exceptions import MindInsightException


class DebuggerOfflineServer(DebuggerServerBase):
"""Debugger Offline Server."""
_MAX_TRY_EXCEPT_COUNT = 500

def __init__(self, cache_store, context):
super(DebuggerOfflineServer, self).__init__(cache_store, context)
self._offline_server_manager = DebuggerOfflineManager(cache_store, context.dbg_dir)
self._running = Event()
self._running.clear()

def run(self):
"""Start the debugger offline server."""
log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir)
self._offline_server_manager.initialize()
self._running.set()
log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir)
try_count = 0
while self._running.is_set() and try_count < self._MAX_TRY_EXCEPT_COUNT:
try:
self._offline_server_manager.wait_for_termination()
if not self._offline_server_manager.is_runnable():
break
except MindInsightException as err:
log.exception(err)
log.warning("Error happens during listening on user commands. Restart listening again.")
finally:
try_count += 1
# protect server from too much failure commands.
if try_count == self._MAX_TRY_EXCEPT_COUNT:
self._cache_store.clean()
metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
self._cache_store.put_data(metadata)
log.warning("Exception exceed %d times, stop server.", try_count)

def stop(self):
"""Stop offline debugger server."""
log.debug("Start to wait for thread started.")
self._running.wait()
log.info("Start to stop offline debugger server.")
self._running.clear()
self._offline_server_manager.stop()
self.join()


class DebuggerOfflineManager:
"""Debugger offline manager which is used to handle user commands."""

def __init__(self, cache_store, dbg_dir):
cache_store.initialize()
self._cache_store = cache_store
self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA)

self._dbg_dir = dbg_dir
self._dbg_services_module = self._get_dbg_service_module()
self._dbg_service = None

self._command_listener = CommandListener(cache_store)
self._data_loader = DataLoader(dbg_dir)
self._is_running_flag = False
self._old_run_cmd = {}

def stop(self):
"""Stop server."""
self._is_running_flag = False
self._command_listener.stop()
self._cache_store.clean()
event = get_ack_reply()
event.exit = True
self._cache_store.put_command(event)
log.info("Stop debugger offline manager.")

def is_runnable(self):
"""Check if the offline manager is runnable."""
state = self._metadata_stream.state
flag = self._is_running_flag and state not in [ServerStatus.MISMATCH.value, ServerStatus.PENDING.value]
if not flag:
log.debug("The offline manager is not runnable, is_running_flag: %s, metadata state: %s",
self._is_running_flag, state)
return flag

@staticmethod
def _get_dbg_service_module():
"""Get dbg service module from MindSpore."""
try:
dbg_services_module = import_module('mindspore.offline_debug.dbg_services')
except ModuleNotFoundError as err:
log.error("Failed to find module dbg_services. %s", err)
raise DebuggerModuleNotFoundError("dbg_services")
return dbg_services_module

@debugger_server_wrap
def initialize(self):
"""Start to load offline debugger data."""
self._data_loader.initialize()
is_sync = self._data_loader.get_sync_flag()
net_name = self._data_loader.get_net_name()
net_dir = self._data_loader.get_net_dir()
self._dbg_service = self._dbg_services_module.DbgServices(net_dir)
self._dbg_service.initialize(net_name=net_name, is_sync_mode=is_sync)
self._cache_store.clean()
self._command_listener.start()
self._is_running_flag = True
self._check_version()
if self._metadata_stream.state == ServerStatus.MISMATCH.value:
log.info("The MindSpore and MindInsight version are mismatched. Failed to initialize offline server.")
return
self._load_metadata()
self._load_graphs()
log.info("Success initialize offline server for %s", self._dbg_dir)

def _check_version(self):
"""Check version."""
ms_version = self._dbg_services_module.get_version()
mi_version = mindinsight.__version__
self._metadata_stream.debugger_version = {'ms': ms_version, 'mi': mi_version}
if version_match(ms_version, mi_version) is False:
log.info("Version is mismatched, dbg_services is: %s, mindinsight is: %s",
ms_version, mi_version)
self._metadata_stream.state = ServerStatus.MISMATCH.value
metadata = self._metadata_stream.get(['state', 'debugger_version'])
self._cache_store.put_data(metadata)

def _load_metadata(self):
"""Load metadata."""
self._metadata_stream.debugger_type = DebuggerServerMode.OFFLINE.value
device_info = self._data_loader.load_device_info()
# The backend referred to the running environment on which the offline debugger
# data was generated.
# Currently supported options: `GPU`, `Ascend`
backend = device_info.get('device_target', 'Ascend')
self._metadata_stream.backend = backend
device_stream = self._cache_store.get_stream_handler(Streams.DEVICE)
device_stream.put(device_info.get('server_list'))
rank_id = 0
rank_0_info = device_stream.get(rank_id)['devices'][0]
self._metadata_stream.client_ip = rank_0_info.get('server_id')
# get step number per device. dict(device_id, step_num), may be increased with time goes by
step_num_per_device = self._data_loader.load_step_number()
device_stream.add_step_num_info(step_num_per_device)
self._metadata_stream.max_step_num = max(step_num_per_device.values())

def _load_graphs(self):
"""Load graphs."""
# the format of graphs is a list of {'device_id': int, 'graph_protos': [GraphProto]}}
graphs = self._data_loader.load_graphs()
device_stream = self._cache_store.get_stream_handler(Streams.DEVICE)
graph_per_rank = {}
for graph in graphs:
device_id = int(graph.get('device_id'))
rank_id = device_stream.get_rank_id_by_device_id(device_id)
graph_per_rank[rank_id] = {}
tensor_stream_per_rank = self._cache_store.get_stream_handler(Streams.TENSOR).\
get_tensor_handler_by_rank_id(rank_id, create_if_not_exit=True)
for graph_proto in graph.get('graph_protos'):
graph_per_rank[rank_id][graph_proto.name] = graph_proto
tensor_stream_per_rank.put_const_vals(graph_proto.const_vals)
# the graph_per_rank is format like: Dict[<rank_id>, Dict[<graph_name>, <GraphProto>]]
self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_per_rank)
device_stream.add_graph_name_info(graph_per_rank)
self._metadata_stream.state = ServerStatus.RECEIVE_GRAPH.value

@debugger_server_wrap
def wait_for_termination(self):
"""Begin to listen on command event."""
log.info("Begin to listen for user commands.")
self._send_graph()
while self.is_runnable():
if not self._command_listener.has_new_command() and self._old_run_cmd:
self._deal_with_old_run_cmd()
continue
cmd = self._command_listener.get_next_command()
self.deal_with_cmd(cmd)

def _send_graph(self):
"""Put graph and metadata info into data queue."""
if not self.is_runnable():
return
self._metadata_stream.state = ServerStatus.WAITING.value
metadata = self._metadata_stream.get()
res = self._cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(0).get()
res.update(metadata)
self._cache_store.put_data(res)

def _deal_with_old_run_cmd(self):
"""Deal with old run command."""
left_step_count = self._old_run_cmd.get('left_step_count')
if left_step_count:
self._execute_one_step()
# if old_run_cmd is not cleared due to hit.
if self._old_run_cmd:
self._old_run_cmd['left_step_count'] = left_step_count - 1 if left_step_count > 0 else -1
if not self._old_run_cmd.get('left_step_count'):
self._old_run_cmd.clear()

def deal_with_cmd(self, cmd):
"""Deal with command."""
if cmd is None:
return
if isinstance(cmd, dict):
self._deal_with_view_cmd(cmd)
elif isinstance(cmd, EventReply):
self._on_event(cmd)

def _on_event(self, event):
"""
Deal with different command event.

Args:
event (EventReply): Command Event.
"""
if event.HasField('run_cmd'):
self._deal_with_run_cmd(event)
elif event.HasField('exit'):
self._cache_store.clean()
self._update_state(ServerStatus.PENDING)
log.debug("Clean cache for exit cmd.")
else:
self._deal_with_set_cmd(event)
log.debug("Deal with set cmd.")

def _deal_with_view_cmd(self, event):
"""
Deal with view cmd.

Args:
event (dict): View command params.

- view_cmd (EventReply): EventReply with view command.
- node_name (str): The center node name for view command.
- tensor_name (str): The center tensor name for view command.
- graph_name (str): The graph name of center node.
- rank_id (int): The device id of the tensor.
"""
view_cmd = event.pop('view_cmd', None).view_cmd
node_info = event
log.debug("Receive view cmd for node: %s.", event)
if not (view_cmd and node_info):
log.info("Invalid view command. Ignore it.")
return
# read tensor value by dbg_service
rank_id = node_info.get('rank_id', 0)
device_id = self._cache_store.get_stream_handler(Streams.DEVICE).get_device_id_by_rank_id(rank_id)
cur_step = self._metadata_stream.step
tensor_protos = view_cmd.tensors
root_graph_id = self.get_root_graph_id()
tensor_infos = [
self._dbg_services_module.TensorInfo(
node_name=tensor_proto.node_name,
slot=int(tensor_proto.slot),
iteration=cur_step - 1 if tensor_proto.iter == 'prev' else cur_step,
device_id=device_id,
is_parameter=tensor_proto.truncate,
root_graph_id=root_graph_id
) for tensor_proto in tensor_protos]
res = self._dbg_service.read_tensors(tensor_infos)
# put tensor into cache
for tensor_proto, tensor_data in zip(tensor_protos, res):
log.debug("Tensor name: %s:%s, tensor type: %s, tensor size: %s", tensor_proto.node_name, tensor_proto.slot,
tensor_data.dtype, tensor_data.data_size)
tensor_proto.tensor_content = tensor_data.data_ptr
tensor_proto.ClearField('dims')
tensor_proto.dims.extend(tensor_data.shape)
tensor_proto.data_type = tensor_data.dtype
self._put_tensor_value_into_cache(cur_step, node_info, rank_id, tensor_protos)
log.info("Put tensor value into cache.")

def get_root_graph_id(self):
"""Get root graph id."""
is_sync = self._data_loader.get_sync_flag()
graph_id = 0 if is_sync else 1
return graph_id

def _put_tensor_value_into_cache(self, cur_step, node_info, rank_id, tensor_protos):
"""Put tensor value into tensor cache."""

tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR). \
get_tensor_handler_by_rank_id(rank_id)
update_data_flag = False
for tensor_proto in tensor_protos:
if not tensor_proto.tensor_content:
log.warning("Tensor %s:%s is empty.",
tensor_proto.node_name, tensor_proto.slot)
try:
has_update = tensor_stream.put({
'step': cur_step,
'tensor_proto': tensor_proto,
'tensor_contents': [tensor_proto.tensor_content]
})
except ValueError as err:
log.warning("Failed to put %s:%s into cache. Ignore it. %s",
tensor_proto.node_name, tensor_proto.slot, str(err))
continue
if has_update:
update_data_flag = True
if update_data_flag:
# send message to frontend
metadata = self._metadata_stream.get(['step', 'state'])
ret = {'receive_tensor': node_info.copy()}
ret.update(metadata)
self._cache_store.put_data(ret)

def _deal_with_run_cmd(self, event):
"""Deal with run cmd."""
run_cmd = event.run_cmd
parsed_run_cmd = self._get_parsed_run_cmd(run_cmd)
if parsed_run_cmd.run_steps > 0:
self._execute_one_step()
elif run_cmd.run_level == RunLevel.RECHECK.value:
log.info("Deal with recheck command.")
self._check_watchpoint(self._metadata_stream.step)

def _execute_one_step(self):
"""Execute on step."""
new_step = self._metadata_stream.step + 1
if new_step > self._metadata_stream.max_step_num:
self._old_run_cmd.clear()
log.info("The server is already at the last step. %s", self._metadata_stream.max_step_num)
return
log.info("Go to next step: %s.", new_step)
self._check_watchpoint(new_step)
self._metadata_stream.step = new_step
self._cache_store.get_stream_handler(Streams.TENSOR).set_step(new_step)
self._cache_store.put_data(self._metadata_stream.get('step'))

def _get_parsed_run_cmd(self, run_cmd):
"""Get parsed run command."""
if run_cmd.run_level == RunLevel.STEP.value:
# receive pause cmd
if not run_cmd.run_steps:
log.debug("Pause training and wait for next command.")
self._old_run_cmd.clear()
# update metadata state from sending to waiting
self._update_state(ServerStatus.WAITING)
return run_cmd
# receive step cmd
left_steps = run_cmd.run_steps - 1
run_cmd.run_steps = 1
if left_steps:
self._old_run_cmd['left_step_count'] = left_steps if left_steps > 0 else -1
elif run_cmd.node_name:
self._old_run_cmd['node_name'] = run_cmd.node_name
run_cmd.node_name = ''
return run_cmd

def _check_watchpoint(self, step):
"""Save watchpoint hits into cache."""
self._update_state(ServerStatus.RUNNING)
# Clean watchpoint_hits in cache
multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
multi_card_hit_streams.clean()
hits = Manager().list()
check_watchpoints_process = Process(target=self._check_watchpoint_work, args=(hits, step,))
check_watchpoints_process.start()
check_watchpoints_process.join()
log.info("finish check watchpoint of %s", step)
if hits:
log.info("Received WatchpointHits. Left run cmd %s change to empty.", self._old_run_cmd)
self._old_run_cmd.clear()
self._update_state(ServerStatus.WAITING)
self._save_watchpoint_hits(hits)

def _save_watchpoint_hits(self, hits):
"""Save watchpoint hits."""
multi_card_hit_streams = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
multi_card_graph_streams = self._cache_store.get_stream_handler(Streams.GRAPH)
device_stream = self._cache_store.get_stream_handler(Streams.DEVICE)
watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)

watchpoint_hits = defaultdict(list)
for hit in hits:
log.info("Received hit\n: "
"name:%s, slot:%s, condition:%s, "
"watchpoint_id:%s"
"error_code:%s, device_id:%s",
hit['name'], hit['slot'], hit['condition'],
hit['watchpoint_id'], hit['error_code'], hit['device_id'])
rank_id = device_stream.get_rank_id_by_device_id(hit['device_id'])
watchpoint_hit = {}
self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit)
if not watchpoint_hit:
continue
self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit)
watchpoint_hit['error_code'] = hit['error_code']
watchpoint_hits[rank_id].append(watchpoint_hit)
# save hit info into cache
multi_card_hit_streams.put(watchpoint_hits)
self._cache_store.put_data({'receive_watchpoint_hits': True})
log.debug("Send the watchpoint hits to DataQueue.")

@staticmethod
def _add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit):
"""Add hit node info."""
graph_stream = multi_card_graph_streams.get_graph_handler_by_rank_id(rank_id)
node_full_name = hit['name']
graph_name = graph_stream.get_graph_id_by_full_name(node_full_name)
if not graph_name:
log.warning("Cannot find node %s in graph. Skip it.", node_full_name)
return
ui_node_name = graph_stream.get_node_name_by_full_name(node_full_name, graph_name)
log.debug("Receive watch point hit: %s:%s", node_full_name, hit['slot'])
if not ui_node_name:
log.info("Not support to show %s on graph.", node_full_name)
return
watchpoint_hit.update({
'tensor_proto': TensorProto(node_name=node_full_name, slot=str(hit['slot'])),
'node_name': ui_node_name,
'graph_name': graph_name
})

@staticmethod
def _add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit):
"""Add watchpoint hit info."""
watchpoint = copy.deepcopy(watchpoint_stream.get_watchpoint_by_id(hit['watchpoint_id']))
hit_params = {}
# get hit actual value
for param in hit['parameters']:
if param['name'] not in (ParamNameEnum.RTOL.value, ParamNameEnum.RANGE_START_INCLUSIVE.value,
ParamNameEnum.RANGE_END_INCLUSIVE.value) \
and hit['error_code'] == 0:
hit_params[param['name']] = param['actual_value']
# update actual value into watchpoint
watchpoint_condition_params = watchpoint.condition['params']
for i, param in enumerate(watchpoint_condition_params):
name = param['name']
if name in hit_params.keys():
watchpoint_condition_params[i]['actual_value'] = hit_params[name]
else:
watchpoint_condition_params[i]['actual_value'] = None

watchpoint_hit['watchpoint'] = watchpoint

def _deal_with_set_cmd(self, event):
"""
Deal with set cmd.

Args:
event (EventReply): User command event including set_cmd.
"""
set_cmd = event.set_cmd
set_cmd_id = set_cmd.id
delete = set_cmd.delete
if not delete:
log.info("Add watchpoint by using dbg_server.")
watch_condition = set_cmd.watch_condition
param_list = []
for param in watch_condition.params:
param_list.append(
self._dbg_services_module.Parameter(param.name, param.disabled, param.value))
watch_nodes = set_cmd.watch_nodes
check_nodes = self._get_check_nodes(watch_nodes)
log.debug("Watchpoint %s, condition: %s, watch nodes: %s",
set_cmd_id, watch_condition.condition, check_nodes)
self._dbg_service.add_watchpoint(set_cmd_id, watch_condition.condition, check_nodes, param_list)
else:
log.info("Remove watchpoint by using dbg_server.")
self._dbg_service.remove_watchpoint(set_cmd_id)

def _get_check_nodes(self, watch_nodes):
"""Get check nodes format"""
check_nodes = {}
device_stream = self._cache_store.get_stream_handler(Streams.DEVICE)
root_graph_id = self.get_root_graph_id()
for watch_node in watch_nodes:
node_name = watch_node.node_name
rank_id = watch_node.rank_id
device_id = device_stream.get_device_id_by_rank_id(rank_id)
if node_name not in check_nodes:
is_parameter = bool(watch_node.node_type == NodeTypeEnum.PARAMETER.value)
check_nodes[node_name] = {
"device_id": [device_id],
"is_parameter": is_parameter,
"root_graph_id": [root_graph_id]
}
else:
check_nodes[node_name]["device_id"].append(device_id)
return check_nodes

def _update_state(self, server_status):
"""
Update state in metadata stream.

Args:
server_status (ServerStatus): The enum value in ServerStatus.
"""
if self._metadata_stream.state != server_status.value:
self._metadata_stream.state = server_status.value
self._cache_store.put_data(self._metadata_stream.get())

def _check_watchpoint_work(self, hits, step):
"""The check WatchPoint function work in another process."""
log.info("Start checking WatchPointHit process.")
res = self._dbg_service.check_watchpoints(step)
for watchpoint_hit in res:
hit_dict = convert_watchpointhit(watchpoint_hit)
hits.append(hit_dict)
log.info("Checking WatchPointHit process is finished.")


class CommandListener:
"""Event listener."""

def __init__(self, cache_store):
self._cache_store = cache_store
self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA)
# the next position of command queue to be queried
self._pos = '0'
self._is_waiting = Event()

def start(self):
"""Start event listener."""
self._pos = '0'
self._is_waiting.set()

def stop(self):
"""Stop event listener."""
# stop waiting for new user commands but can still get old commands.
self._is_waiting.clear()

def has_new_command(self):
"""Check if there is new command in command queue."""
return self._cache_store.has_command(self._pos)

def get_next_command(self):
"""Get next command."""
event = None
while event is None and self.has_new_command():
self._pos, event = self._cache_store.get_command(self._pos)
log.debug("Deal with old %s-th command:\n%s.", self._pos, event)
if event is None:
event = self._wait_for_next_command()
return event

def _wait_for_next_command(self):
"""
Wait for next command.

Returns:
EventReply, the command event.
"""
if not self._is_waiting.is_set():
self._metadata_stream.state = ServerStatus.PENDING.value
return None
log.info("Start to wait for command.")
if self._metadata_stream.state != ServerStatus.WAITING.value:
self._metadata_stream.state = ServerStatus.WAITING.value
self._cache_store.put_data(self._metadata_stream.get())
log.debug("Wait for %s-th command", self._pos)
event = None
while event is None and self._is_waiting.is_set():
self._pos, event = self._cache_store.get_command(self._pos)
return event


def convert_watchpointhit(watchpointhit):
"""Convert watchpointhit object to dict."""
parameters = watchpointhit.parameters
param_list = []
for param in parameters:
param_dict = convert_param(param)
param_list.append(param_dict)
watchpointhit_dict = {'condition': watchpointhit.condition,
'device_id': watchpointhit.device_id,
'error_code': watchpointhit.error_code,
'name': watchpointhit.name,
'parameters': param_list,
'slot': watchpointhit.slot,
'watchpoint_id': watchpointhit.watchpoint_id}
return watchpointhit_dict


def convert_param(param):
"""Convert parameter object to dict"""
param_dict = {'actual_value': param.actual_value,
'disabled': param.disabled,
'hit': param.hit,
'name': param.name,
'value': param.value}
return param_dict

+ 58
- 0
mindinsight/debugger/debugger_services/debugger_online_server.py View File

@@ -0,0 +1,58 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Debugger Online server."""
from concurrent import futures

import grpc
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.conf import settings
from mindinsight.debugger.debugger_services.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.debugger_services.debugger_server_base import DebuggerServerBase
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base


def get_debugger_hostname():
"""Get hostname for online debugger server."""
grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051
host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
hostname = "{}:{}".format(host, grpc_port)
return hostname


class DebuggerOnlineServer(DebuggerServerBase):
"""Debugger Online Server."""

def __init__(self, cache_store, context):
super(DebuggerOnlineServer, self).__init__(cache_store, context)
self._grpc_server_manager = self.get_grpc_server_manager()

def run(self):
self._grpc_server_manager.start()
log.info("Start grpc server %s", self._context.hostname)
self._grpc_server_manager.wait_for_termination()

def get_grpc_server_manager(self):
"""Get grpc server instance according to hostname."""
if self._context.hostname is None:
self._context.hostname = get_debugger_hostname()
grpc_server = DebuggerGrpcServer(self._cache_store, None)
grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
grpc_server_base.add_EventListenerServicer_to_server(grpc_server, grpc_server_manager)
grpc_server_manager.add_insecure_port(self._context.hostname)
return grpc_server_manager

def stop(self):
self._grpc_server_manager.stop(grace=None)
self.join()

+ 58
- 0
mindinsight/debugger/debugger_services/debugger_server_base.py View File

@@ -0,0 +1,58 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DebuggerServerBase."""
import threading
from abc import abstractmethod
from functools import wraps

from mindinsight.debugger.common.exceptions.exceptions import DebuggerServerRunningError
from mindinsight.debugger.common.log import LOGGER as log


def debugger_server_wrap(func):
"""Wrapper for catch exception."""
@wraps(func)
def record_log(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as err:
log.exception(err)
raise DebuggerServerRunningError(str(err))
return record_log


class DebuggerServerBase(threading.Thread):
"""
Debugger Server Base.

Args:
cache_store (DebuggerCacheStore): Cache store for debugger server.
context (DebuggerServerContext): Context for initialize debugger server.
"""

def __init__(self, cache_store, context):
super(DebuggerServerBase, self).__init__()
self._cache_store = cache_store
self._context = context

@abstractmethod
@debugger_server_wrap
def run(self):
"""Function that should be called when thread started."""

@abstractmethod
@debugger_server_wrap
def stop(self):
"""Stop debugger server."""

+ 92
- 0
mindinsight/debugger/debugger_services/debugger_server_factory.py View File

@@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Debugger server factory."""
import threading

from mindinsight.debugger.common.utils import DebuggerServerMode
from mindinsight.debugger.debugger_services.debugger_offline_server import DebuggerOfflineServer
from mindinsight.debugger.debugger_services.debugger_online_server import DebuggerOnlineServer


class DebuggerServerFactory:
"""Create debugger server according to debugger mode."""

_lock = threading.Lock()
_instance = None

def __init__(self):
self._server_map = {
DebuggerServerMode.ONLINE.value: DebuggerOnlineServer,
DebuggerServerMode.OFFLINE.value: DebuggerOfflineServer
}

def __new__(cls, *args, **kwargs):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance

def get_debugger_server(self, cache_store, context):
"""
Get debugger server according to debugger_context and cache_store.

Args:
cache_store (DebuggerCacheStore): Cache store for debugger server.
context (DebuggerServerContext): Context for initialize debugger server.

Returns:
DebuggerServerBase, Debugger server object.
"""
dbg_server = None
dbg_server_class = self._server_map.get(context.dbg_mode)
if dbg_server_class:
dbg_server = dbg_server_class(cache_store, context)
return dbg_server


class DebuggerServerContext:
"""
Debugger server context.

Args:
dbg_mode (str): The debugger mode. Optional: `online` or `offline`.
train_job (str): The relative directory of debugger dump data for one training.
Used only when dbg_mode is `offline.`
dbg_dir (str): The base directory of debugger dump data for one training.
Used only when dbg_mode is `offline.`
hostname (str): The hostname used for online debugger server.
Used only when dbg_mode is `online.`
"""
def __init__(self, dbg_mode, train_job=None, dbg_dir=None, hostname=None):
self._dbg_mode = dbg_mode
self._train_job = train_job
self._dbg_dir = dbg_dir
self.hostname = hostname

@property
def dbg_mode(self):
"""Property of debugger mode."""
return self._dbg_mode

@property
def dbg_dir(self):
"""Property of debugger mode."""
return self._dbg_dir

@property
def train_job(self):
"""The property of train job."""
return self._train_job

mindinsight/debugger/debugger_server.py → mindinsight/debugger/debugger_session.py View File

@@ -13,17 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Implement the debugger server.""" """Implement the debugger server."""
import signal
from concurrent import futures
from functools import wraps from functools import wraps
from threading import Thread


import grpc

from mindinsight.debugger.conditionmgr.condition import ConditionContext
from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.utils.tools import to_float from mindinsight.datavisual.utils.tools import to_float
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
@@ -32,9 +23,11 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue
from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import ServerStatus, \ from mindinsight.debugger.common.utils import ServerStatus, \
create_view_event_from_tensor_basic_info, Streams create_view_event_from_tensor_basic_info, Streams
from mindinsight.debugger.conditionmgr.condition import ConditionContext
from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
from mindinsight.debugger.conditionmgr.recommender import recommend_watchpoints
from mindinsight.debugger.debugger_cache import DebuggerCache from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerFactory
from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo
from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator
from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator
@@ -57,25 +50,29 @@ def try_except(func):
return send_latest_metadata return send_latest_metadata




class DebuggerServer:
class DebuggerSession:
"""The server manager of debugger.""" """The server manager of debugger."""


def __init__(self):
def __init__(self, context):
self.condition_mgr = ConditionMgr() self.condition_mgr = ConditionMgr()
self.cache_store = DebuggerCache() self.cache_store = DebuggerCache()
self.grpc_server = DebuggerGrpcServer(self.cache_store, self.condition_mgr)
self.grpc_server_manager = None
self.back_server = None
self.context = context
self.back_server = DebuggerServerFactory().get_debugger_server(self.cache_store, context)

@property
def train_job(self):
"""The property of train job."""
return self.context.train_job


def get_condition_collections(self, train_id):
def get_condition_collections(self, train_id=""):
"""Get default condition_collections""" """Get default condition_collections"""
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step)
log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
return self.condition_mgr.get_all_collections(condition_context) return self.condition_mgr.get_all_collections(condition_context)


def set_recommended_watch_points(self, set_recommended, train_id):
"""set recommended watch points."""
def set_recommended_watch_points(self, set_recommended, train_id=""):
"""Set recommended watch points."""
if not isinstance(set_recommended, bool): if not isinstance(set_recommended, bool):
log.error("Bool param should be given for set_recommended") log.error("Bool param should be given for set_recommended")
raise DebuggerParamValueError("Bool param should be given.") raise DebuggerParamValueError("Bool param should be given.")
@@ -97,38 +94,28 @@ class DebuggerServer:
def _add_recommended_watchpoints(self, condition_context): def _add_recommended_watchpoints(self, condition_context):
"""Add predefined watchpoints.""" """Add predefined watchpoints."""
log.debug("Add predefined watchpoints.") log.debug("Add predefined watchpoints.")
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
watchpoints = recommend_watchpoints(self.condition_mgr, graph_stream, condition_context)
multi_card_graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
watchpoints = recommend_watchpoints(self.condition_mgr, multi_card_graph_stream, condition_context)
watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT) watch_point_stream_handler = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
device_stream = self.cache_store.get_stream_handler(Streams.DEVICE)
watch_points_ids = [] watch_points_ids = []
for watchpoint in watchpoints: for watchpoint in watchpoints:
watch_points_id = watch_point_stream_handler.create_watchpoint( watch_points_id = watch_point_stream_handler.create_watchpoint(
watch_condition=watchpoint.get_watch_condition_dict(), watch_condition=watchpoint.get_watch_condition_dict(),
watch_nodes=watchpoint.watch_nodes, watch_nodes=watchpoint.watch_nodes,
name=watchpoint.name, name=watchpoint.name,
condition_mgr=self.condition_mgr
condition_mgr=self.condition_mgr,
device_amount=device_stream.device_amount
) )
watch_points_ids.append(watch_points_id) watch_points_ids.append(watch_points_id)
return watch_points_ids return watch_points_ids


def start(self): def start(self):
"""Start server.""" """Start server."""
grpc_port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else 50051
host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
hostname = "{}:{}".format(host, grpc_port)
# initialize a grpc server
grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager)
grpc_server_manager.add_insecure_port(hostname)
grpc_server_manager.start()
my_server_thread = Thread(target=grpc_server_manager.wait_for_termination)
# start grpc server
my_server_thread.start()
self.back_server = my_server_thread
self.grpc_server_manager = grpc_server_manager
self.back_server.start()
# register stop server handler # register stop server handler
signal.signal(signal.SIGINT, self._stop_handler)
log.info("Start grpc server %s", hostname)
#signal.signal(signal.SIGINT, self._stop_handler)
log.info("Start debugger backend server.")


def _stop_handler(self, signum, frame): def _stop_handler(self, signum, frame):
"""Register stop server handler.""" """Register stop server handler."""
@@ -139,8 +126,7 @@ class DebuggerServer:
"""Stop debugger server.""" """Stop debugger server."""
log.info("Send terminate info to client.") log.info("Send terminate info to client.")
self.control({'mode': 'terminate'}) self.control({'mode': 'terminate'})
self.grpc_server_manager.stop(grace=None)
self.back_server.join()
self.back_server.stop()
log.info("Stop debugger server.") log.info("Stop debugger server.")


def poll_data(self, pos): def poll_data(self, pos):
@@ -172,6 +158,7 @@ class DebuggerServer:
- graph_name (str): The graph name. - graph_name (str): The graph name.
- watch_point_id (int): The id of watchpoint. Default: 0. - watch_point_id (int): The id of watchpoint. Default: 0.
- node_category (str): The node_category. Default: None - node_category (str): The node_category. Default: None
- rank_id (int): The id of rank. Default: 0.


Returns: Returns:
dict, the searched nodes. dict, the searched nodes.
@@ -179,19 +166,20 @@ class DebuggerServer:
log.info("receive search request with filter_condition: %s", filter_condition) log.info("receive search request with filter_condition: %s", filter_condition)
# validate watchpoint id # validate watchpoint id
watch_point_id = filter_condition.pop('watch_point_id', 0) watch_point_id = filter_condition.pop('watch_point_id', 0)
rank_id = filter_condition.pop('rank_id', 0)
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watch_point_id) watchpoint_stream.validate_watchpoint_id(watch_point_id)
# validate and update graph name # validate and update graph name
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name'))
filter_condition['graph_name'] = graph_name filter_condition['graph_name'] = graph_name
# get searched graph # get searched graph
graph = graph_stream.search_nodes(filter_condition) graph = graph_stream.search_nodes(filter_condition)
# add watched label to graph # add watched label to graph
watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name)
watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name, rank_id)
return graph return graph


def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
def tensor_comparisons(self, name, shape, detail='data', tolerance='0', rank_id=0):
""" """
Get tensor comparisons data for given name, detail, shape and tolerance. Get tensor comparisons data for given name, detail, shape and tolerance.


@@ -202,6 +190,7 @@ class DebuggerServer:
shape (str): Specify concrete dimensions of shape. shape (str): Specify concrete dimensions of shape.
tolerance (str): Specify tolerance of difference between current step tensor and previous tolerance (str): Specify tolerance of difference between current step tensor and previous
step tensor. Default value is 0. step tensor. Default value is 0.
rank_id (int): The id of rank. Default: 0.


Raises: Raises:
DebuggerParamValueError, If node type is not parameter or value of detail is not support. DebuggerParamValueError, If node type is not parameter or value of detail is not support.
@@ -220,9 +209,10 @@ class DebuggerServer:
parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
tolerance = to_float(tolerance, 'tolerance') tolerance = to_float(tolerance, 'tolerance')
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step
if node_type == NodeTypeEnum.PARAMETER.value: if node_type == NodeTypeEnum.PARAMETER.value:
reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance)
reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance, cur_step)
else: else:
raise DebuggerParamValueError( raise DebuggerParamValueError(
"The node type must be parameter, but got {}.".format(node_type)) "The node type must be parameter, but got {}.".format(node_type))
@@ -270,10 +260,18 @@ class DebuggerServer:
self.cache_store.clean_data() self.cache_store.clean_data()
log.info("Clean data queue cache when retrieve all request.") log.info("Clean data queue cache when retrieve all request.")
result = {} result = {}
for stream in [Streams.METADATA, Streams.GRAPH]:
for stream in [Streams.METADATA, Streams.GRAPH, Streams.DEVICE]:
sub_res = self.cache_store.get_stream_handler(stream).get() sub_res = self.cache_store.get_stream_handler(stream).get()
result.update(sub_res) result.update(sub_res)


devices = result['devices']
if not devices:
graph = result['graph']
metadata = result['metadata']
device = {'rank_id': 0, 'server_ip': metadata.get('ip', 'localhost'),
'device_id': metadata.get('device_name', ''),
'graph_names': graph.get('graph_names', [])}
devices.append(device)
sub_res = self._hide_parameters_for_ui() sub_res = self._hide_parameters_for_ui()
result.update(sub_res) result.update(sub_res)


@@ -298,7 +296,8 @@ class DebuggerServer:
log.debug("Retrieve node %s.", filter_condition) log.debug("Retrieve node %s.", filter_condition)
# validate node name # validate node name
node_name = filter_condition.get('name') node_name = filter_condition.get('name')
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
rank_id = filter_condition.get('rank_id', 0)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name')) graph_name = graph_stream.validate_graph_name(filter_condition.get('graph_name'))
if node_name: if node_name:
# validate node name # validate node name
@@ -325,24 +324,27 @@ class DebuggerServer:
dict, reply with graph. dict, reply with graph.
""" """
# validate watch_point_id # validate watch_point_id
rank_id = filter_condition.get('rank_id', 0)
watch_point_id = filter_condition.get('watch_point_id', 0) watch_point_id = filter_condition.get('watch_point_id', 0)
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watch_point_id) watchpoint_stream.validate_watchpoint_id(watch_point_id)
# get graph # get graph
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
reply = graph_stream.get(filter_condition) reply = graph_stream.get(filter_condition)
graph = reply.get('graph') graph = reply.get('graph')
# add watched label to graph # add watched label to graph
watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name'))
watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, filter_condition.get('graph_name'),
rank_id)
return reply return reply


def retrieve_tensor_history(self, node_name, graph_name=None):
def retrieve_tensor_history(self, node_name, graph_name=None, rank_id=0):
""" """
Retrieve tensor history for leaf node. Retrieve tensor history for leaf node.


Args: Args:
node_name (str): The name of leaf node. node_name (str): The name of leaf node.
graph_name (str): The graph name. Default: None. graph_name (str): The graph name. Default: None.
rank_id (int): The id of rank. Default: 0.


Returns: Returns:
dict, the tensor history and metadata. dict, the tensor history and metadata.
@@ -352,34 +354,34 @@ class DebuggerServer:
if metadata_stream.state == ServerStatus.PENDING.value: if metadata_stream.state == ServerStatus.PENDING.value:
log.info("The backend is in pending status.") log.info("The backend is in pending status.")
return metadata_stream.get(['state', 'step']) return metadata_stream.get(['state', 'step'])
res = self._get_tensor_history(node_name, graph_name)
res = self._get_tensor_history(node_name, graph_name, rank_id)
return res return res


def _get_tensor_history(self, node_name, graph_name=None):
def _get_tensor_history(self, node_name, graph_name=None, rank_id=0):
""" """
Get tensor history for single node. Get tensor history for single node.


Args: Args:
node_name (str): The name of leaf node. node_name (str): The name of leaf node.
graph_name (str): The graph name. Default: None. graph_name (str): The graph name. Default: None.
rank_id (int): The id of rank. Default: 0.


Returns: Returns:
dict, the tensor history and metadata. dict, the tensor history and metadata.
""" """
# get basic tensor history # get basic tensor history
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
tensor_history = graph_stream.get_tensor_history(node_name, graph_name) tensor_history = graph_stream.get_tensor_history(node_name, graph_name)
# add tensor value for tensor history # add tensor value for tensor history
self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name)
self._add_tensor_value_for_tensor_history(tensor_history, node_name, graph_name, rank_id)
# add hit label for tensor history # add hit label for tensor history
watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
watchpoint_hit_stream.update_tensor_history(tensor_history)
self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).update_tensor_history(tensor_history, rank_id)
# add metadata # add metadata
metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(['step']) metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(['step'])
tensor_history.update(metadata) tensor_history.update(metadata)
return tensor_history return tensor_history


def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name):
def _add_tensor_value_for_tensor_history(self, tensor_history, node_name, graph_name, rank_id):
""" """
Add tensor value for_tensor_history and send ViewCMD if tensor value missed. Add tensor value for_tensor_history and send ViewCMD if tensor value missed.


@@ -387,48 +389,53 @@ class DebuggerServer:
tensor_history (list[dict]): A list of tensor info, including name and type. tensor_history (list[dict]): A list of tensor info, including name and type.
node_name (str): The UI node name. node_name (str): The UI node name.
graph_name (str): The graph name. Default: None. graph_name (str): The graph name. Default: None.
rank_id (int): The id of rank. Default: 0.


Returns: Returns:
dict, the tensor info. dict, the tensor info.
""" """
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
missed_tensors = tensor_stream.update_tensor_history(tensor_history)
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step
missed_tensors = tensor_stream.update_tensor_history(tensor_history, cur_step)
if missed_tensors: if missed_tensors:
view_cmd = create_view_event_from_tensor_basic_info(missed_tensors) view_cmd = create_view_event_from_tensor_basic_info(missed_tensors)
self.cache_store.put_command({'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name})
self.cache_store.put_command(
{'view_cmd': view_cmd, 'node_name': node_name, 'graph_name': graph_name, 'rank_id': rank_id})
log.debug("Send view cmd.") log.debug("Send view cmd.")


def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False):
def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False, rank_id=0):
"""Retrieve the tensor value.""" """Retrieve the tensor value."""
log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape)
self.validate_tensor_param(name, detail) self.validate_tensor_param(name, detail)
# Limit to query max two dimensions for tensor in table view. # Limit to query max two dimensions for tensor in table view.
parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name)
node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name, graph_name, rank_id)
reply = self.cache_store.get_stream_handler(Streams.TENSOR).get( reply = self.cache_store.get_stream_handler(Streams.TENSOR).get(
{'name': tensor_name, {'name': tensor_name,
'node_type': node_type, 'node_type': node_type,
'shape': parsed_shape, 'shape': parsed_shape,
'prev': prev}
'prev': prev},
rank_id
) )
reply['tensor_value']['name'] = name reply['tensor_value']['name'] = name


return reply return reply


def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None):
def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None, rank_id=0):
""" """
Get inner tensor name and type by UI name. Get inner tensor name and type by UI name.


Args: Args:
name (str): Node name shown in UI. name (str): Node name shown in UI.
graph_name (Union[str, None]): The graph name, default is: None. graph_name (Union[str, None]): The graph name, default is: None.
rank_id (int): The id of rank. Default: 0.


Returns: Returns:
str, full name of tensor. str, full name of tensor.
str, node type of tensor. str, node type of tensor.
""" """
node_name, slot = name.rsplit(':', 1) node_name, slot = name.rsplit(':', 1)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(node_name) graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(node_name)
node_type = graph_stream.get_node_type(node_name, graph_name) node_type = graph_stream.get_node_type(node_name, graph_name)
full_name = graph_stream.get_full_name(node_name, graph_name) full_name = graph_stream.get_full_name(node_name, graph_name)
@@ -483,6 +490,7 @@ class DebuggerServer:
- offset (int): The offset of current page. - offset (int): The offset of current page.
- node_name (str): The retrieved node name. - node_name (str): The retrieved node name.
- graph_name (str): The retrieved graph name. - graph_name (str): The retrieved graph name.
- rank_id (int): The rank id.


Returns: Returns:
dict, watch point list or relative graph. dict, watch point list or relative graph.
@@ -496,7 +504,13 @@ class DebuggerServer:
log.info("The backend is in pending status.") log.info("The backend is in pending status.")
return metadata_stream.get() return metadata_stream.get()


reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).group_by(group_condition)
rank_id = group_condition.pop('rank_id', 0)
reply = {}
multi_watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
if multi_watchpoint_hit_stream.check_rank_id(rank_id):
watchpoint_hit_stream = multi_watchpoint_hit_stream.get_hit_handler_by_rank_id(rank_id)
reply = watchpoint_hit_stream.group_by(group_condition)

reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable() reply['outdated'] = self.cache_store.get_stream_handler(Streams.WATCHPOINT).is_recheckable()
return reply return reply


@@ -591,40 +605,6 @@ class DebuggerServer:
training_controller.validate_mode(mode) training_controller.validate_mode(mode)
return training_controller.control(mode, params) return training_controller.control(mode, params)


def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False):
"""
Get the graph of the next node according to node_name.

Args:
node_name (str): The name of current chosen leaf node.
graph_name (str): The graph name.
ascend (bool): If True, traverse the input nodes;
If False, traverse the output nodes. Default is True.

Returns:
dict, the next node information.
"""
log.info("Retrieve node <%s> by bfs, `ascend` is :%s",
node_name, ascend)
reply = {}
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
graph_name = graph_stream.validate_graph_name(graph_name)
next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
# no next node
if next_node_name is None:
return reply
# add graph and tensor history for next node
filter_condition = {
'name': next_node_name,
'graph_name': graph_name,
'single_node': True
}
search_graph = self._get_nodes_info(filter_condition)
reply = {'name': next_node_name}
reply.update(search_graph)

return reply

@try_except @try_except
def recheck(self): def recheck(self):
""" """
@@ -635,13 +615,14 @@ class DebuggerServer:
""" """
return TrainingControlOperator(self.cache_store).recheck() return TrainingControlOperator(self.cache_store).recheck()


def retrieve_tensor_graph(self, tensor_name, graph_name):
def retrieve_tensor_graph(self, tensor_name, graph_name, rank_id=0):
""" """
Retrieve tensor graph. Retrieve tensor graph.


Args: Args:
tensor_name (str): The tensor name from UI. tensor_name (str): The tensor name from UI.
graph_name (str): The graph name. graph_name (str): The graph name.
rank_id (int): The id of rank. Default: 0.


Returns: Returns:
dict, tensor graph object. dict, tensor graph object.
@@ -650,16 +631,17 @@ class DebuggerServer:
log.error("Failed to get tensor graph the MindSpore is not in waiting state.") log.error("Failed to get tensor graph the MindSpore is not in waiting state.")
raise DebuggerTensorGraphError raise DebuggerTensorGraphError
log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name) log.info("Retrieve tensor graph for %s from %s", tensor_name, graph_name)
tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name)
tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(tensor_name, graph_name, rank_id)
return tensor_graph_ops return tensor_graph_ops


def retrieve_tensor_hits(self, tensor_name, graph_name):
def retrieve_tensor_hits(self, tensor_name, graph_name, rank_id=0):
""" """
Retrieve tensor hit information. Retrieve tensor hit information.


Args: Args:
tensor_name (str): The tensor name from UI. tensor_name (str): The tensor name from UI.
graph_name (str): The graph name. graph_name (str): The graph name.
rank_id (int): The id of rank. Default: 0.


Returns: Returns:
dict, tensor hit info. dict, tensor hit info.
@@ -668,7 +650,7 @@ class DebuggerServer:
log.error("Failed to get tensor hits as the MindSpore is not in waiting state.") log.error("Failed to get tensor hits as the MindSpore is not in waiting state.")
raise DebuggerTensorHitError raise DebuggerTensorHitError
log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name) log.info("Retrieve tensor hits for %s from %s", tensor_name, graph_name)
watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name)
watch_points = TensorDetailInfo(self.cache_store).get_tensor_watch_points(tensor_name, graph_name, rank_id)
return {'watch_points': watch_points} return {'watch_points': watch_points}


def _hide_parameters_for_ui(self): def _hide_parameters_for_ui(self):

+ 3
- 0
mindinsight/debugger/proto/debug_grpc.proto View File

@@ -122,6 +122,9 @@ message WatchCondition {
message WatchNode { message WatchNode {
string node_name = 1; string node_name = 1;
string node_type = 2; string node_type = 2;
string graph_name = 3;
int32 rank_id = 4;
int32 device_id = 5;
} }


message WatchpointHit { message WatchpointHit {


+ 37
- 18
mindinsight/debugger/proto/debug_grpc_pb2.py View File

@@ -2,8 +2,6 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: mindinsight/debugger/proto/debug_grpc.proto # source: mindinsight/debugger/proto/debug_grpc.proto


import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
@@ -21,7 +19,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='debugger', package='debugger',
syntax='proto3', syntax='proto3',
serialized_options=None, serialized_options=None,
serialized_pb=_b('\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\x81\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x95\x02\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x0c\n\x08overflow\x10\x02\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3')
serialized_pb=b'\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"\x92\x01\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\x12\x11\n\tgraph_num\x18\x06 \x01(\x05\x12\x12\n\nms_version\x18\x07 \x01(\t\")\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\"\x87\x02\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\x12\x19\n\x0fversion_matched\x18\x06 \x01(\x08H\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\x81\x04\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x32\n\x06params\x18\x04 \x03(\x0b\x32\".debugger.WatchCondition.Parameter\x1a]\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64isabled\x18\x02 \x01(\x08\x12\r\n\x05value\x18\x03 \x01(\x01\x12\x0b\n\x03hit\x18\x04 \x01(\x08\x12\x14\n\x0c\x61\x63tual_value\x18\x05 \x01(\x01\"\x95\x02\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x0c\n\x08overflow\x10\x02\x12\t\n\x05sd_gt\x10\x0b\x12\t\n\x05sd_lt\x10\x0c\x12\x1b\n\x17tensor_general_overflow\x10\r\x12\x19\n\x15tensor_initialization\x10\x0e\x12\x14\n\x10tensor_too_large\x10\x0f\x12\x14\n\x10tensor_too_small\x10\x10\x12\x13\n\x0ftensor_all_zero\x10\x11\x12\x1b\n\x17tensor_change_too_large\x10\x12\x12\x1b\n\x17tensor_change_too_small\x10\x13\x12\x16\n\x12tensor_not_changed\x10\x14\x12\x10\n\x0ctensor_range\x10\x15\"i\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\x12\x12\n\ngraph_name\x18\x03 \x01(\t\x12\x0f\n\x07rank_id\x18\x04 \x01(\x05\x12\x11\n\tdevice_id\x18\x05 \x01(\x05\"\x89\x01\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x12\x12\n\nerror_code\x18\x04 \x01(\x05\x32\x81\x03\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x12<\n\x0fSendMultiGraphs\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3'
, ,
dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,])


@@ -130,7 +128,7 @@ _METADATA = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='device_name', full_name='debugger.Metadata.device_name', index=0, name='device_name', full_name='debugger.Metadata.device_name', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@@ -144,14 +142,14 @@ _METADATA = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='backend', full_name='debugger.Metadata.backend', index=2, name='backend', full_name='debugger.Metadata.backend', index=2,
number=3, type=9, cpp_type=9, label=1, number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='cur_node', full_name='debugger.Metadata.cur_node', index=3, name='cur_node', full_name='debugger.Metadata.cur_node', index=3,
number=4, type=9, cpp_type=9, label=1, number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@@ -172,7 +170,7 @@ _METADATA = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='ms_version', full_name='debugger.Metadata.ms_version', index=6, name='ms_version', full_name='debugger.Metadata.ms_version', index=6,
number=7, type=9, cpp_type=9, label=1, number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@@ -203,7 +201,7 @@ _CHUNK = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='buffer', full_name='debugger.Chunk.buffer', index=0, name='buffer', full_name='debugger.Chunk.buffer', index=0,
number=1, type=12, cpp_type=9, label=1, number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@@ -311,7 +309,7 @@ _RUNCMD = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='run_level', full_name='debugger.RunCMD.run_level', index=0, name='run_level', full_name='debugger.RunCMD.run_level', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@@ -325,7 +323,7 @@ _RUNCMD = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='node_name', full_name='debugger.RunCMD.node_name', index=2, name='node_name', full_name='debugger.RunCMD.node_name', index=2,
number=3, type=9, cpp_type=9, label=1, number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@@ -442,7 +440,7 @@ _WATCHCONDITION_PARAMETER = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='name', full_name='debugger.WatchCondition.Parameter.name', index=0, name='name', full_name='debugger.WatchCondition.Parameter.name', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@@ -546,14 +544,35 @@ _WATCHNODE = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='node_name', full_name='debugger.WatchNode.node_name', index=0, name='node_name', full_name='debugger.WatchNode.node_name', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='node_type', full_name='debugger.WatchNode.node_type', index=1, name='node_type', full_name='debugger.WatchNode.node_type', index=1,
number=2, type=9, cpp_type=9, label=1, number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='graph_name', full_name='debugger.WatchNode.graph_name', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='rank_id', full_name='debugger.WatchNode.rank_id', index=3,
number=4, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='device_id', full_name='debugger.WatchNode.device_id', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR), serialized_options=None, file=DESCRIPTOR),
@@ -570,7 +589,7 @@ _WATCHNODE = _descriptor.Descriptor(
oneofs=[ oneofs=[
], ],
serialized_start=1335, serialized_start=1335,
serialized_end=1384,
serialized_end=1440,
) )




@@ -621,8 +640,8 @@ _WATCHPOINTHIT = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[ oneofs=[
], ],
serialized_start=1387,
serialized_end=1524,
serialized_start=1443,
serialized_end=1580,
) )


_EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS _EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS
@@ -750,8 +769,8 @@ _EVENTLISTENER = _descriptor.ServiceDescriptor(
file=DESCRIPTOR, file=DESCRIPTOR,
index=0, index=0,
serialized_options=None, serialized_options=None,
serialized_start=1527,
serialized_end=1912,
serialized_start=1583,
serialized_end=1968,
methods=[ methods=[
_descriptor.MethodDescriptor( _descriptor.MethodDescriptor(
name='WaitCMD', name='WaitCMD',


+ 15
- 22
mindinsight/debugger/proto/debug_grpc_pb2_grpc.py View File

@@ -1,5 +1,4 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc import grpc


from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2 from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2
@@ -7,7 +6,7 @@ from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_




class EventListenerStub(object): class EventListenerStub(object):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""


def __init__(self, channel): def __init__(self, channel):
"""Constructor. """Constructor.
@@ -48,40 +47,40 @@ class EventListenerStub(object):




class EventListenerServicer(object): class EventListenerServicer(object):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""


def WaitCMD(self, request, context): def WaitCMD(self, request, context):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')


def SendMetadata(self, request, context): def SendMetadata(self, request, context):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')


def SendGraph(self, request_iterator, context): def SendGraph(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')


def SendTensors(self, request_iterator, context): def SendTensors(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')


def SendWatchpointHits(self, request_iterator, context): def SendWatchpointHits(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')


def SendMultiGraphs(self, request_iterator, context): def SendMultiGraphs(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')
@@ -127,7 +126,7 @@ def add_EventListenerServicer_to_server(servicer, server):


# This class is part of an EXPERIMENTAL API. # This class is part of an EXPERIMENTAL API.
class EventListener(object): class EventListener(object):
"""Missing associated documentation comment in .proto file."""
"""Missing associated documentation comment in .proto file"""


@staticmethod @staticmethod
def WaitCMD(request, def WaitCMD(request,
@@ -135,7 +134,6 @@ class EventListener(object):
options=(), options=(),
channel_credentials=None, channel_credentials=None,
call_credentials=None, call_credentials=None,
insecure=False,
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
@@ -144,7 +142,7 @@ class EventListener(object):
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
call_credentials, compression, wait_for_ready, timeout, metadata)


@staticmethod @staticmethod
def SendMetadata(request, def SendMetadata(request,
@@ -152,7 +150,6 @@ class EventListener(object):
options=(), options=(),
channel_credentials=None, channel_credentials=None,
call_credentials=None, call_credentials=None,
insecure=False,
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
@@ -161,7 +158,7 @@ class EventListener(object):
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
call_credentials, compression, wait_for_ready, timeout, metadata)


@staticmethod @staticmethod
def SendGraph(request_iterator, def SendGraph(request_iterator,
@@ -169,7 +166,6 @@ class EventListener(object):
options=(), options=(),
channel_credentials=None, channel_credentials=None,
call_credentials=None, call_credentials=None,
insecure=False,
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
@@ -178,7 +174,7 @@ class EventListener(object):
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
call_credentials, compression, wait_for_ready, timeout, metadata)


@staticmethod @staticmethod
def SendTensors(request_iterator, def SendTensors(request_iterator,
@@ -186,7 +182,6 @@ class EventListener(object):
options=(), options=(),
channel_credentials=None, channel_credentials=None,
call_credentials=None, call_credentials=None,
insecure=False,
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
@@ -195,7 +190,7 @@ class EventListener(object):
mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
call_credentials, compression, wait_for_ready, timeout, metadata)


@staticmethod @staticmethod
def SendWatchpointHits(request_iterator, def SendWatchpointHits(request_iterator,
@@ -203,7 +198,6 @@ class EventListener(object):
options=(), options=(),
channel_credentials=None, channel_credentials=None,
call_credentials=None, call_credentials=None,
insecure=False,
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
@@ -212,7 +206,7 @@ class EventListener(object):
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
call_credentials, compression, wait_for_ready, timeout, metadata)


@staticmethod @staticmethod
def SendMultiGraphs(request_iterator, def SendMultiGraphs(request_iterator,
@@ -220,7 +214,6 @@ class EventListener(object):
options=(), options=(),
channel_credentials=None, channel_credentials=None,
call_credentials=None, call_credentials=None,
insecure=False,
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
@@ -229,4 +222,4 @@ class EventListener(object):
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
call_credentials, compression, wait_for_ready, timeout, metadata)

+ 3
- 0
mindinsight/debugger/proto/ms_graph.proto View File

@@ -229,6 +229,9 @@ message NodeProto {


// full name with scope // full name with scope
optional string full_name = 8; optional string full_name = 8;

// The corresponding source code for this node.
optional string source_address = 9;
} }


// Models // Models


+ 44
- 39
mindinsight/debugger/proto/ms_graph_pb2.py
File diff suppressed because it is too large
View File


+ 172
- 0
mindinsight/debugger/session_manager.py View File

@@ -0,0 +1,172 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implement the session manager."""
import os
import threading
from urllib.parse import unquote

import _thread

from mindinsight.conf import settings
from mindinsight.debugger.common.log import LOGGER as logger
from mindinsight.debugger.common.exceptions.exceptions import DebuggerSessionNumOverBoundError, \
DebuggerSessionNotFoundError
from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext
from mindinsight.debugger.debugger_session import DebuggerSession


class SessionManager:
"""The server manager of debugger."""

ONLINE_TYPE = "ONLINE"
MAX_SESSION_NUM = 2
ONLINE_SESSION_ID = "0"
_instance = None
_cls_lock = threading.Lock()

def __init__(self):
self.train_jobs = {}
self.sessions = {}
self.session_id = 1
self.online_session = None
self._lock = threading.Lock()
self._exiting = False
enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
if enable_debugger:
self.creat_session(self.ONLINE_TYPE)

@classmethod
def get_instance(cls):
"""Get the singleton instance."""
with cls._cls_lock:
if cls._instance is None:
cls._instance = SessionManager()
return cls._instance

def exit(self):
"""
Called when the gunicorn worker process is exiting.
"""
with self._lock:
logger.info("Start to exit sessions.")
self._exiting = True
for session in self.sessions:
session.stop()
self.online_session.stop()
logger.info("Exited.")

def get_session(self, session_id):
"""
Get session by session id or get all session info.

Args:
session_id (Union[None, str]: The id of session.

Returns:
DebuggerSession, debugger session object.
"""
with self._lock:
if session_id == self.ONLINE_SESSION_ID and self.online_session is not None:
return self.online_session

if session_id in self.sessions:
return self.sessions.get(session_id)

raise DebuggerSessionNotFoundError("{}".format(session_id))

def creat_session(self, session_type, train_job=None):
"""
Create session by the train job info.

Args:
session_type (str): The session_type.
train_job (str): The train job info.

Returns:
str, session id.
"""
with self._lock:
if self._exiting:
logger.info(
"System is exiting, will terminate the thread.")
_thread.exit()

if session_type == self.ONLINE_TYPE:
if self.online_session is None:
context = DebuggerServerContext(dbg_mode='online')
self.online_session = DebuggerSession(context)
self.online_session.start()
return self.ONLINE_SESSION_ID

if train_job in self.train_jobs:
return self.train_jobs.get(train_job)

self._check_session_num()
summary_base_dir = settings.SUMMARY_BASE_DIR
unquote_path = unquote(train_job, errors='strict')
whole_path = os.path.join(summary_base_dir, unquote_path)
normalized_path = validate_and_normalize_path(whole_path)
context = DebuggerServerContext(dbg_mode='offline', train_job=train_job, dbg_dir=normalized_path)
session = DebuggerSession(context)
session.start()
session_id = str(self.session_id)
self.sessions[session_id] = session
self.train_jobs[train_job] = session_id
self.session_id += 1
return session_id

def delete_session(self, session_id):
"""Delete session by session id."""
with self._lock:
if session_id == self.ONLINE_SESSION_ID:
self.online_session.stop()
self.online_session = None
return

if session_id not in self.sessions:
raise DebuggerSessionNotFoundError("session id {}".format(session_id))

session = self.sessions.get(session_id)
session.stop()
self.sessions.pop(session_id)
self.train_jobs.pop(session.train_job)
return

def get_sessions(self):
"""get all sessions"""
return {"train_jobs": self.train_jobs}

def _check_session_num(self):
"""Check the amount of sessions."""
if len(self.sessions) >= self.MAX_SESSION_NUM:
raise DebuggerSessionNumOverBoundError()


def validate_and_normalize_path(path):
"""Validate and normalize_path"""
if not path:
raise ValueError("The path is invalid!")

path_str = str(path)

if not path_str.startswith("/"):
raise ValueError("The path is invalid!")

try:
normalized_path = os.path.realpath(path)
except ValueError:
raise ValueError("The path is invalid!")

return normalized_path

+ 210
- 0
mindinsight/debugger/stream_cache/data_loader.py View File

@@ -0,0 +1,210 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""This file is used to define the DataLoader."""
import os
import json
from mindinsight.debugger.proto.ms_graph_pb2 import ModelProto
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.debugger.common.utils import DumpSettings


class DataLoader:
"""The DataLoader object provides interface to load graphs and device information from base_dir."""
def __init__(self, base_dir):
self._debugger_base_dir = base_dir
self._graph_protos = []
self._device_info = {}
self._step_num = {}
# flag for whether the data is from sync dump or async dump, True for sync dump, False for async dump.
self._is_sync = None
self._net_dir = ""
self._net_name = ""
self.initialize()

def initialize(self):
"""Initialize the data_mode and net_dir of DataLoader."""
dump_config_file = os.path.join(self._debugger_base_dir, os.path.join(".metadata", "data_dump.json"))
with open(dump_config_file, 'r') as load_f:
dump_config = json.load(load_f)
common_settings = dump_config.get(DumpSettings.COMMON_DUMP_SETTINGS.value)
if not common_settings:
raise ParamValueError('common_dump_settings not found in dump_config file.')
self._net_name = common_settings['net_name']
if dump_config.get(DumpSettings.E2E_DUMP_SETTINGS.value) and \
dump_config[DumpSettings.E2E_DUMP_SETTINGS.value]['enable']:
self._is_sync = True
self._net_dir = os.path.join(self._debugger_base_dir, self._net_name)
elif dump_config.get(DumpSettings.ASYNC_DUMP_SETTINGS.value) and \
dump_config[DumpSettings.ASYNC_DUMP_SETTINGS.value]['enable']:
self._is_sync = False
self._net_dir = self._debugger_base_dir
else:
raise ParamValueError('The data must be generated from sync dump or async dump.')

def load_graphs(self):
"""Load graphs from the debugger_base_dir."""
files = os.listdir(self._net_dir)
for file in files:
if not self.is_device_dir(file):
continue
device_id, device_dir = self.get_device_id_and_dir(file)
graphs_dir = os.path.join(device_dir, 'graphs')
if not os.path.exists(graphs_dir) or not os.path.isdir(graphs_dir):
log.debug("Directory '%s' not exist.", graphs_dir)
self._graph_protos.append({'device_id': device_id, 'graph_protos': []})
continue
graph_protos = get_graph_protos_from_dir(graphs_dir)
self._graph_protos.append({'device_id': device_id, 'graph_protos': graph_protos})

return self._graph_protos

def load_device_info(self):
"""Load device_info from file"""
hccl_json_file = os.path.join(self._debugger_base_dir, '.metadata/hccl.json')
if not os.path.isfile(hccl_json_file):
device = []
device_ids = self.get_all_device_id()
device_ids.sort()
for i, device_id in enumerate(device_ids):
rank_id = i
device.append({'device_id': str(device_id), 'rank_id': str(rank_id)})
device_target = 'Ascend'
self._device_info = {'device_target': device_target,
'server_list': [{'server_id': 'localhost', 'device': device}]}
else:
with open(hccl_json_file, 'r') as load_f:
load_dict = json.load(load_f)
self._device_info = {'device_target': 'Ascend', 'server_list': load_dict['server_list']}
return self._device_info

def load_step_number(self):
"""Load step number in the directory"""
files = os.listdir(self._net_dir)
for file in files:
if not self.is_device_dir(file):
continue
device_id, device_dir = self.get_device_id_and_dir(file)
max_step = 0
files_in_device = os.listdir(device_dir)
if self._is_sync:
for file_in_device in files_in_device:
abs_file_in_device = os.path.join(device_dir, file_in_device)
if os.path.isdir(abs_file_in_device) and file_in_device.startswith("iteration_"):
step_id_str = file_in_device.split('_')[-1]
max_step = update_max_step(step_id_str, max_step)
self._step_num[str(device_id)] = max_step
else:
net_graph_dir = []
for file_in_device in files_in_device:
abs_file_in_device = os.path.join(device_dir, file_in_device)
if os.path.isdir(abs_file_in_device) and file_in_device.startswith(self._net_name):
net_graph_dir.append(abs_file_in_device)
if len(net_graph_dir) > 1:
log.warning("There are more than one graph directory in device_dir: %s. "
"OfflineDebugger use data in %s.", device_dir, net_graph_dir[0])
net_graph_dir_to_use = net_graph_dir[0]
graph_id = net_graph_dir_to_use.split('_')[-1]
graph_id_dir = os.path.join(net_graph_dir_to_use, graph_id)
step_ids = os.listdir(graph_id_dir)
for step_id_str in step_ids:
max_step = update_max_step(step_id_str, max_step)
self._step_num[str(device_id)] = max_step

return self._step_num

def is_device_dir(self, file_name):
"""Judge if the file_name is a sub directory named 'device_x'."""
if not file_name.startswith("device_"):
return False
id_str = file_name.split("_")[-1]
if not id_str.isdigit():
return False
device_dir = os.path.join(self._net_dir, file_name)
if not os.path.isdir(device_dir):
return False
return True

def get_device_id_and_dir(self, file_name):
"""Get device_id and absolute directory of file_name."""
id_str = file_name.split("_")[-1]
device_id = int(id_str)
device_dir = os.path.join(self._net_dir, file_name)
return device_id, device_dir

def get_all_device_id(self):
"""Get all device_id int the debugger_base_dir"""
device_ids = []
files = os.listdir(self._net_dir)
for file in files:
if not self.is_device_dir(file):
continue
id_str = file.split("_")[-1]
device_id = int(id_str)
device_ids.append(device_id)
return device_ids

def get_net_dir(self):
"""Get graph_name directory of the data."""
return self._net_dir

def get_sync_flag(self):
"""Get the sync flag of the data."""
return self._is_sync

def get_net_name(self):
"""Get net_name of the data."""
return self._net_name


def load_graph_from_file(graph_file_name):
"""Load graph from file."""
with open(graph_file_name, 'rb') as file_handler:
model_bytes = file_handler.read()
model = ModelProto.FromString(model_bytes)
graph = model.graph

return graph


def get_graph_protos_from_dir(graphs_dir):
"""
Get graph from graph directory.

Args:
graph_dir (str): The absolute directory of graph files.

Returns:
list, list of 'GraphProto' object.
"""
files_in_graph_dir = os.listdir(graphs_dir)
graph_protos = []
pre_file_name = "ms_output_trace_code_graph_"
for file_in_device in files_in_graph_dir:
if file_in_device.startswith(pre_file_name) and file_in_device.endswith(".pb"):
abs_graph_file = os.path.join(graphs_dir, file_in_device)
graph_proto = load_graph_from_file(abs_graph_file)
graph_protos.append(graph_proto)
return graph_protos


def update_max_step(step_id_str, max_step):
"""Update max_step by compare step_id_str and max_step."""
res = max_step
if step_id_str.isdigit():
step_id = int(step_id_str)
if step_id > max_step:
res = step_id
return res

+ 21
- 8
mindinsight/debugger/stream_cache/tensor.py View File

@@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""The definition of tensor stream.""" """The definition of tensor stream."""
from abc import abstractmethod, ABC from abc import abstractmethod, ABC

import numpy as np import numpy as np


from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
@@ -149,7 +148,10 @@ class OpTensor(BaseTensor):
@property @property
def shape(self): def shape(self):
"""The property of tensor shape.""" """The property of tensor shape."""
return list(self._tensor_proto.dims)
dims = list(self._tensor_proto.dims)
if dims == [0]:
dims = []
return dims


@property @property
def value(self): def value(self):
@@ -254,12 +256,13 @@ class OpTensor(BaseTensor):
class ConstTensor(BaseTensor): class ConstTensor(BaseTensor):
"""Tensor data structure for Const Node.""" """Tensor data structure for Const Node."""
_STRING_TYPE = 'DT_STRING' _STRING_TYPE = 'DT_STRING'
_DT_TYPE = 'DT_TYPE'


def __init__(self, const_proto): def __init__(self, const_proto):
# the type of const_proto is NamedValueProto # the type of const_proto is NamedValueProto
super(ConstTensor, self).__init__() super(ConstTensor, self).__init__()
self._const_proto = const_proto self._const_proto = const_proto
self._value = self.generate_value_from_proto(const_proto)
self._value = self.generate_value_from_proto(const_proto.value)


def set_step(self, step): def set_step(self, step):
"""Set step value.""" """Set step value."""
@@ -295,16 +298,25 @@ class ConstTensor(BaseTensor):
Returns: Returns:
Union[None, str, np.ndarray], the value of the tensor. Union[None, str, np.ndarray], the value of the tensor.
""" """
fields = tensor_proto.value.ListFields()
fields = tensor_proto.ListFields()
if len(fields) != 2: if len(fields) != 2:
log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto) log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto)
tensor_value = None tensor_value = None
for field_obj, field_value in fields: for field_obj, field_value in fields:
if field_obj.name != 'dtype': if field_obj.name != 'dtype':
tensor_value = field_value
if tensor_proto.dtype == DataType.DT_TUPLE:
tensor_values = []
for field_value_element in field_value:
value_element = self.generate_value_from_proto(field_value_element)
tensor_values.append(value_element)
tensor_value = tensor_values
elif tensor_proto.dtype == DataType.DT_TYPE:
tensor_value = DataType.Name(field_value.data_type)
else:
tensor_value = field_value
break break
if tensor_value is not None and self.dtype != self._STRING_TYPE:
tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(self.dtype))
if tensor_value is not None and tensor_proto.dtype != self._STRING_TYPE:
tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(tensor_proto.dtype))
return tensor_value return tensor_value


def get_tensor_value_by_shape(self, shape=None): def get_tensor_value_by_shape(self, shape=None):
@@ -328,7 +340,8 @@ class ConstTensor(BaseTensor):
Returns: Returns:
dict, overall statistics. dict, overall statistics.
""" """
if self.empty or self.dtype == self._STRING_TYPE:
if self.empty or self.dtype == self._STRING_TYPE or self.dtype == self._DT_TYPE:
log.debug("The tensor dtype is: %s, skip getting statistics.", self.dtype)
return {} return {}
stats = TensorUtils.get_statistics_from_tensor(self.value) stats = TensorUtils.get_statistics_from_tensor(self.value)
statistics = TensorUtils.get_overall_statistic_dict(stats) statistics = TensorUtils.get_overall_statistic_dict(stats)


+ 33
- 19
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -184,7 +184,7 @@ class Watchpoint:
def __init__(self, watchpoint_id, watch_condition, name=None): def __init__(self, watchpoint_id, watch_condition, name=None):
self._id = watchpoint_id self._id = watchpoint_id
self._condition = watch_condition self._condition = watch_condition
self._watch_node = WatchNodeTree()
self._watch_node = {0: WatchNodeTree()}
self.name = name self.name = name


@property @property
@@ -214,32 +214,36 @@ class Watchpoint:
else: else:
self._watch_node = other_watchpoint.nodes self._watch_node = other_watchpoint.nodes


def add_nodes(self, nodes):
def add_nodes(self, nodes, rank_id):
"""Add node into watchpoint.""" """Add node into watchpoint."""
if not nodes: if not nodes:
log.warning("Add empty nodes.") log.warning("Add empty nodes.")
return return

if rank_id not in self._watch_node:
self._watch_node[rank_id] = WatchNodeTree()
if not isinstance(nodes, list): if not isinstance(nodes, list):
nodes = [nodes] nodes = [nodes]
for node in nodes: for node in nodes:
self._watch_node.add_node(node.name, node.type, node.full_name)
watch_node = self._watch_node.get(rank_id)
watch_node.add_node(node.name, node.type, node.full_name)


def remove_nodes(self, nodes):
def remove_nodes(self, nodes, rank_id):
"""Remove nodes from watchpoint.""" """Remove nodes from watchpoint."""
if not nodes: if not nodes:
return return
self.validate_rank_id(rank_id)
if not isinstance(nodes, list): if not isinstance(nodes, list):
nodes = [nodes] nodes = [nodes]
for node in nodes: for node in nodes:
self._watch_node.remove_node(node.name)
self._watch_node.get(rank_id).remove_node(node.name)


def get_node_status(self, node_name, node_type, full_name):
def get_node_status(self, node_name, node_type, full_name, rank_id):
"""Judge if the node is in watch nodes.""" """Judge if the node is in watch nodes."""
if is_cst_type(node_type): if is_cst_type(node_type):
return WatchNodeTree.INVALID return WatchNodeTree.INVALID
scope_names = node_name.split('/') scope_names = node_name.split('/')
cur_node = self._watch_node
self.validate_rank_id(rank_id)
cur_node = self._watch_node.get(rank_id)
status = 1 status = 1
for scope_name in scope_names: for scope_name in scope_names:
cur_node = cur_node.get(scope_name) cur_node = cur_node.get(scope_name)
@@ -250,7 +254,7 @@ class Watchpoint:
status = WatchNodeTree.TOTAL_WATCH status = WatchNodeTree.TOTAL_WATCH
break break
if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name: if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name:
self._watch_node.add_node(node_name, node_type, full_name)
self._watch_node.get(rank_id).add_node(node_name, node_type, full_name)


return status return status


@@ -278,11 +282,14 @@ class Watchpoint:
Returns: Returns:
list[NodeBasicInfo], the list of watch node basic infos. list[NodeBasicInfo], the list of watch node basic infos.
""" """
watch_nodes = []
self._get_watch_node(self._watch_node, watch_nodes)
return watch_nodes

def get_pending_cmd(self, watch_nodes):
watch_nodes_for_devices = {}
for rank_id, watch_node_tree in self._watch_node.items():
watch_nodes = []
self._get_watch_node(watch_node_tree, watch_nodes)
watch_nodes_for_devices[rank_id] = watch_nodes
return watch_nodes_for_devices

def get_pending_cmd(self, watch_nodes_for_devices):
"""Return the watchpoint in proto format.""" """Return the watchpoint in proto format."""
# construct SetCMD # construct SetCMD
condition_id = self._condition.get('id') condition_id = self._condition.get('id')
@@ -309,10 +316,12 @@ class Watchpoint:
param_proto.name = param_name param_proto.name = param_name
param_proto.disabled = True param_proto.disabled = True


for watch_node in watch_nodes:
event_node = set_cmd.watch_nodes.add()
event_node.node_name = watch_node.full_name
event_node.node_type = watch_node.type
for rank_id, watch_nodes in watch_nodes_for_devices.items():
for watch_node in watch_nodes:
event_node = set_cmd.watch_nodes.add()
event_node.node_name = watch_node.full_name
event_node.node_type = watch_node.type
event_node.rank_id = rank_id
return set_cmd return set_cmd


def get_watch_condition_info(self): def get_watch_condition_info(self):
@@ -325,6 +334,11 @@ class Watchpoint:
watchpoint_info['name'] = self.name watchpoint_info['name'] = self.name
return watchpoint_info return watchpoint_info


def validate_rank_id(self, rank_id):
if rank_id not in self._watch_node:
log.warning("Rank_id not exist")
return



class WatchpointHit: class WatchpointHit:
"""The watchpoint hit structure.""" """The watchpoint hit structure."""


+ 7
- 6
mindinsight/debugger/stream_handler/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,9 +15,10 @@
"""Import the streams handlers.""" """Import the streams handlers."""
from .event_handler import EventHandler from .event_handler import EventHandler
from .metadata_handler import MetadataHandler from .metadata_handler import MetadataHandler
from .graph_handler import GraphHandler
from .tensor_handler import TensorHandler
from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler
from .graph_handler import GraphHandler, MultiCardGraphHandler
from .tensor_handler import TensorHandler, MultiCardTensorHandler
from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler, MultiCardWatchpointHitHandler


__all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler',
'WatchpointHandler', 'WatchpointHitHandler']
__all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', 'WatchpointHitHandler',
'MultiCardGraphHandler', 'MultiCardTensorHandler',
'WatchpointHandler', 'MultiCardWatchpointHitHandler']

+ 198
- 0
mindinsight/debugger/stream_handler/device_handler.py View File

@@ -0,0 +1,198 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the Device stream handler."""
from collections import defaultdict

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, DeviceIdUnregistered, \
DebuggerParamTypeError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase


class DeviceHandler(StreamHandlerBase):
"""Metadata Handler."""

def __init__(self):
# contains all device infos, the format is like Dict[int(<device_id>, <device_info>)]
self._rank_info = defaultdict(DeviceInfo)
self._device_rank_map = {}

@property
def rank_ids(self):
"""The rank ids."""
return list(self._rank_info)

@property
def device_amount(self):
"""The rank ids."""
return len(self._rank_info)

def put(self, value):
"""
Put value into device info cache.

Args:
value (list): The list of server info. Each item is format like:
{
"server_id": str,
"device": list[<Device Info>]
},
The format of <Device Info> is like:
{
"device_id": str,
"device_ip": str,
"rank_id": str
}.
"""
if not isinstance(value, list):
log.error("Invalid input type. list object is expected.")
raise DebuggerParamTypeError("List object is expected.")
try:
self._extract_rank_info(value)
except TypeError as err:
log.exception(err)
log.error("Invalid Device info.")
raise DebuggerParamValueError("Invalid device info.")
log.debug("Put Device into cache")

def _extract_rank_info(self, value):
"""Extract rank info and save."""
for server_info in value:
server_ip = server_info.get('server_id')
for device_info in server_info.get('device', []):
rank_id = int(device_info.get('rank_id'))
if rank_id in self._rank_info:
log.error("Repeated rank info for rank_id: %d", rank_id)
raise DebuggerParamValueError("Repeated rank info.")
device_info_obj = self._rank_info[rank_id]
device_info_obj.rank_id = rank_id
device_info_obj.server_ip = server_ip
device_info_obj.device_id = int(device_info.get('device_id'))
device_info_obj.device_ip = device_info.get('device_ip')
self._device_rank_map[device_info_obj.device_id] = rank_id

def add_step_num_info(self, step_info):
"""
Add step number information for each device.

Args:
step_info (dict): Step info per device. The key is the device id, the value
is the relative step number.
"""
if not step_info:
log.warning("No step number information.")
return
if len(step_info) == 1 and not self._rank_info:
device_id = int(list(step_info)[0])
log.info("Default registered device %d as rank 0.", device_id)
self._rank_info[0].device_id = device_id
if len(step_info) > 1 and not self._rank_info:
log.error("Missing device info for multi-card training.")
raise DeviceIdUnregistered("all")

for device_id, step_num in step_info.items():
device_id = int(device_id)
rank_id = self.get_rank_id_by_device_id(device_id)
self._rank_info[rank_id].step_num = step_num

def add_graph_name_info(self, graphs):
"""
Add graph name per device.

Args:
graphs (dict): Graph infos of all rank id. Each item is format like
"""
for rank_id, graph_info in graphs.items():
graph_names = list(graph_info)
self._rank_info[rank_id].graph_names = graph_names

def get(self, filter_condition=None):
"""
Get device information according to filter_condition.

Args:
filter_condition (list): The rank id.

Returns:
dict, the device info.
"""
if filter_condition is None:
filter_condition = self.rank_ids
if not isinstance(filter_condition, list):
filter_condition = [filter_condition]
device_infos = []
for rank_id in filter_condition:
device_info = self._rank_info.get(rank_id)
if device_info is None:
log.error("Invalid rank id.")
raise DeviceIdUnregistered(rank_id)
device_infos.append(device_info.to_dict())
return {'devices': device_infos}

def get_rank_id_by_device_id(self, device_id):
"""
Get rank id by device id.

Args:
device_id (int): The device id.

Returns:
int, the rank id.
"""
rank_id = self._device_rank_map.get(device_id)
if rank_id is None:
log.error("Failed to find rank_id for device_id %s", device_id)
raise DeviceIdUnregistered(device_id)
return rank_id

def get_device_id_by_rank_id(self, rank_id):
"""
Get device id by rank id.

Args:
rank_id (int): The rank id.

Returns:
int, the device id.
"""
device_info = self._rank_info.get(rank_id)
if device_info:
return device_info.device_id
log.error("Failed to find device id according to rank_id %s", rank_id)
raise DeviceIdUnregistered(rank_id)


class DeviceInfo:
"""Device info object."""

def __init__(self):
self.rank_id = 0
self.device_id = 0
self.server_ip = ''
self.graph_names = []
self.device_ip = ''
self.step_num = 0

def to_dict(self):
"""Convert device info to dict."""
res = {
'rank_id': self.rank_id,
'server_ip': self.server_ip,
'device_id': self.device_id,
'device_ip': self.device_ip,
'graph_names': self.graph_names,
'total_step_num': self.step_num
}
return res

+ 53
- 43
mindinsight/debugger/stream_handler/graph_handler.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -24,6 +24,55 @@ from mindinsight.debugger.stream_cache.debugger_multigraph import DebuggerMultiG
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase




class MultiCardGraphHandler:
"""Multi-card Graph Handler."""

def __init__(self):
self._graph_handlers = {0: GraphHandler()}

@property
def graph_handlers(self):
"""The property of whole_graph."""
return self._graph_handlers

def get_graph_handler_by_rank_id(self, rank_id=0):
"""Get handler by rank id"""
if rank_id in self._graph_handlers:
return self._graph_handlers.get(rank_id)
log.error("There is no rank id %d.", rank_id)
raise ValueError

def put(self, value):
"""put graphs into graph_handlers"""
for rank_id, graph in value.items():
if rank_id not in self._graph_handlers:
self._graph_handlers[rank_id] = GraphHandler()
self._graph_handlers[rank_id].put(graph)

def get(self, filter_condition=None, rank_id=0):
"""Get the graph of specific node for specific device."""
if rank_id in self._graph_handlers:
return self._graph_handlers.get(rank_id).get(filter_condition)
log.error("There is no rank id %d.", rank_id)
raise ValueError

def has_graph(self):
"""check if has graph"""
res = False
for graph_handler in self._graph_handlers:
res = res or graph_handler.graph
return res

def register_graph_handler(self, rank_id, graph_handler):
"""Register graph handler."""
self._graph_handlers[rank_id] = graph_handler

def clean(self):
"""Clean cache."""
self.__init__()



class GraphHandler(StreamHandlerBase): class GraphHandler(StreamHandlerBase):
"""Metadata Handler.""" """Metadata Handler."""


@@ -68,7 +117,7 @@ class GraphHandler(StreamHandlerBase):
Put value into graph cache. Called by grpc server. Put value into graph cache. Called by grpc server.


Args: Args:
value (GraphProto): The Graph proto message.
value (dict): The Graph proto message. Each item is format like (<graph_name>, GraphProto).
""" """
log.info("Put graph into cache.") log.info("Put graph into cache.")
sorted_value_list = self._sort_graph(value) sorted_value_list = self._sort_graph(value)
@@ -430,8 +479,8 @@ class GraphHandler(StreamHandlerBase):
graph_name, node_name = self._parse_node_name(scope_name, graph_name) graph_name, node_name = self._parse_node_name(scope_name, graph_name)
graph = self._get_graph(graph_name) graph = self._get_graph(graph_name)
# to make sure fully match the scope name # to make sure fully match the scope name
node_name = node_name + '/' if not node_name.endswith('/') else node_name
nodes = graph.search_leaf_nodes_by_pattern(node_name)
node_name = node_name + '/' if node_name and not node_name.endswith('/') else node_name
nodes = graph.search_leaf_nodes_by_pattern(node_name, True)
res = [self.construct_node_basic_info(full_name=node.full_name, res = [self.construct_node_basic_info(full_name=node.full_name,
graph_name=graph_name, graph_name=graph_name,
node_name=node.name, node_name=node.name,
@@ -448,45 +497,6 @@ class GraphHandler(StreamHandlerBase):
log.debug("Get empty full name.") log.debug("Get empty full name.")
return node_name return node_name


def get_node_by_bfs_order(self, node_name=None, ascend=True):
"""
Traverse the graph in order of breath-first search by given node.

Args:
node_name (str): The name of current chosen leaf node.
ascend (bool): If True, traverse the input nodes;
If False, traverse the output nodes. Default is True.
Returns:
Union[None, dict], the next node object in dict type or None.
"""
bfs_order = self.bfs_order
length = len(bfs_order)

if not bfs_order:
log.error('Cannot get the BFS order of the graph!')
msg = 'Cannot get the BFS order of the graph!'
raise DebuggerParamValueError(msg)

if node_name is None:
if ascend is False:
next_node = None
else:
next_node = bfs_order[0]
else:
try:
index = bfs_order.index(node_name)
log.debug("The index of the node in BFS list is: %d", index)
except ValueError as err:
log.error('Cannot find the node: %s. Please check '
'the node name: %s', node_name, err)
msg = f'Cannot find the node: {node_name}. ' \
f'Please check the node name {err}.'
raise DebuggerParamValueError(msg)

next_node = self._get_next_node_in_bfs(index, length, ascend)

return next_node

def _get_next_node_in_bfs(self, index, length, ascend): def _get_next_node_in_bfs(self, index, length, ascend):
""" """
Get the next node in bfs order. Get the next node in bfs order.


+ 33
- 17
mindinsight/debugger/stream_handler/metadata_handler.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Define the metadata stream handler.""" """Define the metadata stream handler."""

from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import ServerStatus
from mindinsight.debugger.common.utils import ServerStatus, DebuggerServerMode
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase




@@ -24,28 +25,36 @@ class MetadataHandler(StreamHandlerBase):
def __init__(self): def __init__(self):
self._state = ServerStatus.PENDING self._state = ServerStatus.PENDING
self._device_name = "" self._device_name = ""
self._step = 0
self.step = 0
self._client_ip = "" self._client_ip = ""
self._cur_node_name = "" self._cur_node_name = ""
self._cur_full_name = "" self._cur_full_name = ""
self._backend = ""
self.backend = ""
self._enable_recheck = False self._enable_recheck = False
self._cur_graph_name = "" self._cur_graph_name = ""
# If recommendation_confirmed is true, it only means the user has answered yes or no to the question, # If recommendation_confirmed is true, it only means the user has answered yes or no to the question,
# it does not necessarily mean that the user will use the recommended watch points. # it does not necessarily mean that the user will use the recommended watch points.
self._recommendation_confirmed = False self._recommendation_confirmed = False
self._debugger_version = {} self._debugger_version = {}
# maximum step number among all devices
self._max_step_num = 0
self._debugger_type = DebuggerServerMode.ONLINE.value

@property
def debugger_type(self):
"""The property of debugger_type."""
return self._debugger_type

@debugger_type.setter
def debugger_type(self, debugger_type):
"""The property of debugger_type."""
self._debugger_type = debugger_type


@property @property
def device_name(self): def device_name(self):
"""The property of device name.""" """The property of device name."""
return self._device_name return self._device_name


@property
def step(self):
"""The property of current step."""
return self._step

@property @property
def node_name(self): def node_name(self):
"""The property of current node name.""" """The property of current node name."""
@@ -71,11 +80,6 @@ class MetadataHandler(StreamHandlerBase):
"""The property of current node name.""" """The property of current node name."""
return self._cur_full_name return self._cur_full_name


@property
def backend(self):
"""The property of current backend."""
return self._backend

@property @property
def state(self): def state(self):
"""The property of state.""" """The property of state."""
@@ -152,6 +156,16 @@ class MetadataHandler(StreamHandlerBase):
""" """
self._debugger_version = value self._debugger_version = value


@property
def max_step_num(self):
"""The property of max_step_num."""
return self._max_step_num

@max_step_num.setter
def max_step_num(self, max_step_num):
"""Set the property of max_step_num."""
self._max_step_num = max_step_num

def put(self, value): def put(self, value):
""" """
Put value into metadata cache. Called by grpc server. Put value into metadata cache. Called by grpc server.
@@ -160,10 +174,10 @@ class MetadataHandler(StreamHandlerBase):
value (MetadataProto): The Metadata proto message. value (MetadataProto): The Metadata proto message.
""" """
self._device_name = value.device_name.split(':')[0] self._device_name = value.device_name.split(':')[0]
self._step = value.cur_step
self.step = value.cur_step
self._cur_full_name = value.cur_node self._cur_full_name = value.cur_node
self._backend = value.backend if value.backend else "Ascend"
log.debug("Put metadata into cache at the %d-th step.", self._step)
self.backend = value.backend if value.backend else "Ascend"
log.debug("Put metadata into cache at the %d-th step.", self.step)


def get(self, filter_condition=None): def get(self, filter_condition=None):
""" """
@@ -190,6 +204,8 @@ class MetadataHandler(StreamHandlerBase):
'recommendation_confirmed': self._recommendation_confirmed, 'recommendation_confirmed': self._recommendation_confirmed,
'debugger_version': self.debugger_version 'debugger_version': self.debugger_version
} }
if self.debugger_type == 'offline':
metadata['total_step_num'] = self.max_step_num
else: else:
if not isinstance(filter_condition, list): if not isinstance(filter_condition, list):
filter_condition = [filter_condition] filter_condition = [filter_condition]


+ 70
- 23
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -28,6 +28,46 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison
TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter'])




class MultiCardTensorHandler:
"""Multi-card Tensor Handler."""
def __init__(self):
self.tensor_handlers = {0: TensorHandler()}

def set_step(self, step_id):
"""Set step id."""
for tensor_handler in self.tensor_handlers.values():
tensor_handler.cur_step = step_id

def get_tensor_handler_by_rank_id(self, rank_id=0, create_if_not_exit=False):
"""get handler by rank id"""
if rank_id in self.tensor_handlers:
return self.tensor_handlers.get(rank_id)
if create_if_not_exit:
tensor_handler = TensorHandler()
self.tensor_handlers[rank_id] = tensor_handler
return tensor_handler
log.error("There is no rank id %d in MultiCardTensorHandler.", rank_id)
raise ValueError

def put(self, value):
"""put graphs into graph_handlers"""
for rank_id, tensor in value:
if rank_id not in self.tensor_handlers:
self.tensor_handlers[rank_id] = TensorHandler()
self.tensor_handlers[rank_id].put(tensor)

def get(self, filter_condition=None, rank_id=0):
"""Get the graph of specific node for specific device."""
if rank_id in self.tensor_handlers:
return self.tensor_handlers.get(rank_id).get(filter_condition)
log.error("There is no rank id %d.", rank_id)
raise ValueError

def clean(self):
"""Clean cache."""
self.__init__()


class TensorHandler(StreamHandlerBase): class TensorHandler(StreamHandlerBase):
"""Metadata Handler.""" """Metadata Handler."""


@@ -46,6 +86,11 @@ class TensorHandler(StreamHandlerBase):
"""The property of current step.""" """The property of current step."""
return self._cur_step return self._cur_step


@cur_step.setter
def cur_step(self, step_id):
"""The property of current step."""
self._cur_step = step_id

@property @property
def prev_step(self): def prev_step(self):
"""The property of previous step.""" """The property of previous step."""
@@ -172,7 +217,7 @@ class TensorHandler(StreamHandlerBase):
log.error("No tensor named %s at the step %s", name, step) log.error("No tensor named %s at the step %s", name, step)
raise DebuggerParamValueError("No tensor named {}".format(name)) raise DebuggerParamValueError("No tensor named {}".format(name))
tensor_info = tensor.get_full_info(shape) tensor_info = tensor.get_full_info(shape)
self._update_has_prev_step_field(tensor_info, name, node_type)
self._update_has_prev_step_field(tensor_info, name, node_type, self.cur_step)
return {'tensor_value': tensor_info} return {'tensor_value': tensor_info}


def _get_tensor(self, tensor_name, node_type=None, step=None): def _get_tensor(self, tensor_name, node_type=None, step=None):
@@ -198,20 +243,21 @@ class TensorHandler(StreamHandlerBase):


return tensor return tensor


def _get_basic_info(self, tensor_name, node_type=None):
def _get_basic_info(self, tensor_name, node_type, step):
"""Get the latest basic tensor info by tensor name.""" """Get the latest basic tensor info by tensor name."""
tensor = self._get_tensor(tensor_name, node_type)
tensor = self._get_tensor(tensor_name, node_type, step)
if tensor: if tensor:
return tensor.get_basic_info() return tensor.get_basic_info()


return None return None


def update_tensor_history(self, tensor_history):
def update_tensor_history(self, tensor_history, step=None):
""" """
Add tensor basic info in tensor_history. Add tensor basic info in tensor_history.


Args: Args:
tensor_history (dict): Tensor history, including a list of tensor name and type. tensor_history (dict): Tensor history, including a list of tensor name and type.
step (int): The step of tensor info. Default: None.


Returns: Returns:
list[dict], the list of tensor basic info cache. list[dict], the list of tensor basic info cache.
@@ -220,9 +266,9 @@ class TensorHandler(StreamHandlerBase):
for tensor_info in tensor_history.get('tensor_history'): for tensor_info in tensor_history.get('tensor_history'):
tensor_name = tensor_info.get('full_name') tensor_name = tensor_info.get('full_name')
node_type = tensor_info.get('node_type') node_type = tensor_info.get('node_type')
basic_info = self._get_basic_info(tensor_name, node_type)
basic_info = self._get_basic_info(tensor_name, node_type, step)
# add `has_prev_step` field to tensor basic info. # add `has_prev_step` field to tensor basic info.
missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type)
missing_tensors_info = self._update_has_prev_step_field(basic_info, tensor_name, node_type, step)
if basic_info: if basic_info:
tensor_info.update(basic_info) tensor_info.update(basic_info)
if missing_tensors_info: if missing_tensors_info:
@@ -230,14 +276,14 @@ class TensorHandler(StreamHandlerBase):


return missed_tensors return missed_tensors


def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type):
def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type, step=None):
"""Update has_prev_step field in tensor info.""" """Update has_prev_step field in tensor info."""
missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type)
if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and self.cur_step > 0:
missing_tensors_info = self._get_missing_tensor_info(tensor_name, node_type, step)
if not missing_tensors_info and node_type == NodeTypeEnum.PARAMETER.value and step > 0:
tensor_info['has_prev_step'] = True tensor_info['has_prev_step'] = True
return missing_tensors_info return missing_tensors_info


def _get_missing_tensor_info(self, tensor_name, node_type):
def _get_missing_tensor_info(self, tensor_name, node_type, step):
""" """
Get missing tensor infos. Get missing tensor infos.


@@ -248,7 +294,6 @@ class TensorHandler(StreamHandlerBase):
Returns: Returns:
list, list of missing tensor basic information. list, list of missing tensor basic information.
""" """
step = self.cur_step
missing_tensors_info = [] missing_tensors_info = []
# check the current step value is missing # check the current step value is missing
if self._is_tensor_value_missing(tensor_name, step): if self._is_tensor_value_missing(tensor_name, step):
@@ -278,13 +323,13 @@ class TensorHandler(StreamHandlerBase):
tensor = self._get_tensor(tensor_name, step=step) tensor = self._get_tensor(tensor_name, step=step)
return bool(not tensor or tensor.empty) return bool(not tensor or tensor.empty)


def get_valid_tensor_by_name(self, tensor_name, prev=False):
def get_valid_tensor_by_name(self, tensor_name, step, prev=False):
"""Get tensor value by name in numpy type.""" """Get tensor value by name in numpy type."""
step = self.prev_step if prev else self.cur_step
if step < 0:
log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name)
target_step = step - 1 if prev else step
if target_step < 0:
log.warning("Step %d has no previous value for tensor: %s", target_step, tensor_name)
return None return None
tensor = self._get_tensor(tensor_name, step=step)
tensor = self._get_tensor(tensor_name, step=target_step)
if tensor and tensor.empty: if tensor and tensor.empty:
log.warning("%s has empty value.", tensor_name) log.warning("%s has empty value.", tensor_name)
return None return None
@@ -316,9 +361,9 @@ class TensorHandler(StreamHandlerBase):
self._tensors.pop(param) self._tensors.pop(param)
log.debug("Clean param %s in cache.", param) log.debug("Clean param %s in cache.", param)


def get_tensors_diff(self, tensor_name, shape, tolerance=0):
def get_tensors_diff(self, tensor_name, shape, tolerance=0, step=None):
""" """
Get tensor comparisons data for given name, detail, shape and tolerance.
Get tensor comparisons data for given name, detail, shape and tolerance.


Args: Args:
tensor_name (str): The name of tensor for cache. tensor_name (str): The name of tensor for cache.
@@ -329,6 +374,7 @@ class TensorHandler(StreamHandlerBase):
calculate the min value and max value of the result of the current step tensor subtract calculate the min value and max value of the result of the current step tensor subtract
the previous step tensor. If the absolute value of result is less than or equal to the previous step tensor. If the absolute value of result is less than or equal to
boundary value, the result will set to be zero. boundary value, the result will set to be zero.
step (int): The step of the tensor. Default: None.


Raises: Raises:
DebuggerParamValueError, If get current step node and previous step node failed or DebuggerParamValueError, If get current step node and previous step node failed or
@@ -337,8 +383,8 @@ class TensorHandler(StreamHandlerBase):
Returns: Returns:
dict, the retrieved data. dict, the retrieved data.
""" """
curr_tensor = self.get_valid_tensor_by_name(tensor_name)
prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True)
curr_tensor = self.get_valid_tensor_by_name(tensor_name, step=step)
prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True, step=step)
if not (curr_tensor and prev_tensor): if not (curr_tensor and prev_tensor):
log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) log.error("Get current step and previous step for this tensor name %s failed.", tensor_name)
raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " raise DebuggerParamValueError(f"Get current step and previous step for this tensor name "
@@ -386,22 +432,23 @@ class TensorHandler(StreamHandlerBase):
stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats) stats_info['statistics'] = TensorUtils.get_overall_statistic_dict(overall_stats=diff_tensor_stats)
return stats_info return stats_info


def get_tensor_info_for_tensor_graph(self, tensor_name, node_type):
def get_tensor_info_for_tensor_graph(self, tensor_name, node_type, step):
""" """
Get Tensor info for tensor graphs. Get Tensor info for tensor graphs.


Args: Args:
tensor_name (str): Tensor name, format like `node_name:slot`. tensor_name (str): Tensor name, format like `node_name:slot`.
node_type (str): Node type. node_type (str): Node type.
step (int): The step of tensor info.


Returns: Returns:
dict, tensor infos, including overall statistics, tensor shape and has_prev_step info. dict, tensor infos, including overall statistics, tensor shape and has_prev_step info.
list, list of missing tensor basic information. list, list of missing tensor basic information.
""" """
res = {} res = {}
tensor = self._get_tensor(tensor_name, node_type)
tensor = self._get_tensor(tensor_name, node_type, step)
if tensor and not tensor.empty: if tensor and not tensor.empty:
res['statistics'] = tensor.get_tensor_statistics() res['statistics'] = tensor.get_tensor_statistics()
res['shape'] = tensor.shape res['shape'] = tensor.shape
missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type)
missing_tensors = self._update_has_prev_step_field(res, tensor_name, node_type, step)
return res, missing_tensors return res, missing_tensors

+ 87
- 19
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -105,12 +105,12 @@ class WatchpointHandler(StreamHandlerBase):


return {'watch_points': reply} return {'watch_points': reply}


def get_pending_commands(self, graph_stream):
def get_pending_commands(self, multi_card_graph_stream):
""" """
Get all watchpoint in SetCMD proto format. Get all watchpoint in SetCMD proto format.


Args: Args:
graph_stream (GraphHandler): Graph handler.
multi_card_graph_stream (MultiCardGraphHandler): Multi card graph handler.


Returns: Returns:
list[SetCMD], updated watchpoint to be sent to MindSpore. list[SetCMD], updated watchpoint to be sent to MindSpore.
@@ -118,9 +118,13 @@ class WatchpointHandler(StreamHandlerBase):
newly_set_cmds = [] newly_set_cmds = []
for _, watchpoint in self._updated_watchpoints.items(): for _, watchpoint in self._updated_watchpoints.items():
# construct set command with leaf nodes # construct set command with leaf nodes
watch_nodes = watchpoint.get_watch_nodes()
leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes)
newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes))
watch_nodes_for_devices = watchpoint.get_watch_nodes()
leaf_watch_nodes_for_devices = {}
for rank_id, watch_nodes in watch_nodes_for_devices.items():
graph_stream = multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id)
leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes)
leaf_watch_nodes_for_devices[rank_id] = leaf_watch_nodes
newly_set_cmds.append(watchpoint.get_pending_cmd(leaf_watch_nodes_for_devices))
newly_set_cmds.extend(self._deleted_watchpoints) newly_set_cmds.extend(self._deleted_watchpoints)
self.sync_set_cmd(newly_set_cmds) self.sync_set_cmd(newly_set_cmds)


@@ -161,7 +165,7 @@ class WatchpointHandler(StreamHandlerBase):
""" """
return self._outdated return self._outdated


def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None):
def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None, rank_id=0):
""" """
set watch nodes for graph. set watch nodes for graph.


@@ -170,23 +174,24 @@ class WatchpointHandler(StreamHandlerBase):
graph_stream (GraphHandler): The graph handler. graph_stream (GraphHandler): The graph handler.
watch_point_id (int): The id of watchpoint. watch_point_id (int): The id of watchpoint.
graph_name (str): The graph name. graph_name (str): The graph name.
rank_id (int): The rank id.
""" """
if not (watch_point_id and graph): if not (watch_point_id and graph):
return return
log.debug("add watch flags") log.debug("add watch flags")
watchpoint = self._watchpoints.get(watch_point_id) watchpoint = self._watchpoints.get(watch_point_id)
self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name)
self._set_watch_status_recursively(graph, graph_stream, watchpoint, graph_name, rank_id)


def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None):
def _set_watch_status_recursively(self, graph, graph_stream, watchpoint, graph_name=None, rank_id=0):
"""Set watch status to graph.""" """Set watch status to graph."""
if graph.get('children'): if graph.get('children'):
self._set_watch_status_recursively( self._set_watch_status_recursively(
graph.get('children'), graph_stream, watchpoint, graph_name)
graph.get('children'), graph_stream, watchpoint, graph_name, rank_id=0)


if graph.get('nodes'): if graph.get('nodes'):
_ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name)
_ = self._set_watch_state_for_nodes(graph['nodes'], graph_stream, watchpoint, graph_name, rank_id)


def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name):
def _set_watch_state_for_nodes(self, nodes, graph_stream, watchpoint, graph_name, rank_id=0):
""" """
Set watch state for nodes. Set watch state for nodes.


@@ -204,11 +209,11 @@ class WatchpointHandler(StreamHandlerBase):
node_name = node.get('name') node_name = node.get('name')
# search result could have `nodes` in nodes object # search result could have `nodes` in nodes object
if node.get('nodes'): if node.get('nodes'):
flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name)
flag = self._set_watch_state_for_nodes(node.get('nodes'), graph_stream, watchpoint, graph_name, rank_id)
else: else:
full_name = graph_stream.get_full_name(node_name, graph_name) full_name = graph_stream.get_full_name(node_name, graph_name)
new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name]) new_node_name = node_name if graph_name is None else '/'.join([graph_name, node_name])
flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name)
flag = watchpoint.get_node_status(new_node_name, node.get('type'), full_name, rank_id)
node['watched'] = flag node['watched'] = flag
if flag == WatchNodeTree.NOT_WATCH: if flag == WatchNodeTree.NOT_WATCH:
continue continue
@@ -224,7 +229,8 @@ class WatchpointHandler(StreamHandlerBase):
state = WatchNodeTree.TOTAL_WATCH state = WatchNodeTree.TOTAL_WATCH
return state return state


def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None):
def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None,
device_amount=8):
""" """
Create watchpoint. Create watchpoint.
Args: Args:
@@ -241,9 +247,10 @@ class WatchpointHandler(StreamHandlerBase):
} }
- id (str): Id of condition. - id (str): Id of condition.
- param (list[dict]): The list of param for this condition. - param (list[dict]): The list of param for this condition.
watch_nodes (list[NodeBasicInfo]): The list of node basic info.
watch_nodes (dict[list[NodeBasicInfo]]): The list of node basic info.
watch_point_id (int): The id of watchpoint. watch_point_id (int): The id of watchpoint.
name (str): The name of watchpoint. name (str): The name of watchpoint.
device_amount (int): The amount of devices.


Returns: Returns:
int, the new id of watchpoint. int, the new id of watchpoint.
@@ -253,7 +260,9 @@ class WatchpointHandler(StreamHandlerBase):
new_id = self._latest_id + 1 new_id = self._latest_id + 1
watchpoint = Watchpoint(new_id, watch_condition, name) watchpoint = Watchpoint(new_id, watch_condition, name)
if watch_nodes: if watch_nodes:
watchpoint.add_nodes(watch_nodes)
for rank_id, watch_nodes_for_device in watch_nodes.items():
validate_rank_id(rank_id, device_amount)
watchpoint.add_nodes(watch_nodes_for_device, rank_id)
elif watch_point_id: elif watch_point_id:
self.validate_watchpoint_id(watch_point_id) self.validate_watchpoint_id(watch_point_id)
watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
@@ -261,7 +270,7 @@ class WatchpointHandler(StreamHandlerBase):
self._outdated = True self._outdated = True
return new_id return new_id


def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
def update_watchpoint(self, watch_point_id, watch_nodes, watched=False, rank_id=0):
""" """
Update watchpoint. Update watchpoint.


@@ -270,13 +279,14 @@ class WatchpointHandler(StreamHandlerBase):
watch_nodes (list[NodeBasicInfo]): The list of node basic info. watch_nodes (list[NodeBasicInfo]): The list of node basic info.
watched (bool): The update operator on nodes. If False, remove nodes from watch nodes. watched (bool): The update operator on nodes. If False, remove nodes from watch nodes.
If True, add nodes to watch nodes. Default: False. If True, add nodes to watch nodes. Default: False.
rank_id (int): The rank id.
""" """
self.validate_watchpoint_id(watch_point_id) self.validate_watchpoint_id(watch_point_id)
watchpoint = self._watchpoints.get(watch_point_id) watchpoint = self._watchpoints.get(watch_point_id)
if watched: if watched:
watchpoint.add_nodes(watch_nodes)
watchpoint.add_nodes(watch_nodes, rank_id)
else: else:
watchpoint.remove_nodes(watch_nodes)
watchpoint.remove_nodes(watch_nodes, rank_id)
self._updated_watchpoints[watch_point_id] = watchpoint self._updated_watchpoints[watch_point_id] = watchpoint
self._outdated = True self._outdated = True
log.debug("Update watchpoint %d in cache.", watch_point_id) log.debug("Update watchpoint %d in cache.", watch_point_id)
@@ -328,6 +338,58 @@ class WatchpointHandler(StreamHandlerBase):
raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))




class MultiCardWatchpointHitHandler:
"""Multi-card Watchpoint-hit Handler."""

def __init__(self):
self.watchpoint_hit_handlers = {0: WatchpointHitHandler()}

def get_hit_handler_by_rank_id(self, rank_id=0):
"""Get handler by rank id."""
if rank_id in self.watchpoint_hit_handlers:
return self.watchpoint_hit_handlers.get(rank_id)
log.error("There is no rank id %d.", rank_id)
raise ValueError

def put(self, value):
"""Put watchpoint hit into cache."""
for rank_id, tensor_hit_values in value.items():
if rank_id not in self.watchpoint_hit_handlers:
self.watchpoint_hit_handlers[rank_id] = WatchpointHitHandler()
cur_hit_handler = self.watchpoint_hit_handlers[rank_id]
for tensor_hit_value in tensor_hit_values:
cur_hit_handler.put(tensor_hit_value)

def get(self, filter_condition=None, rank_id=0):
"""Get the graph of specific node for specific device."""
if rank_id in self.watchpoint_hit_handlers:
return self.watchpoint_hit_handlers.get(rank_id).get(filter_condition)
log.error("There is no rank id %d.", rank_id)
raise ValueError

def update_tensor_history(self, tensor_history, rank_id):
"""
Add hit flag to tensor history.

Args:
tensor_history (dict): The tensor history.
rank_id (int): The rank id.
"""
if rank_id in self.watchpoint_hit_handlers:
self.watchpoint_hit_handlers[rank_id].update_tensor_history(tensor_history)
else:
for tensor_info in tensor_history.get('tensor_history'):
tensor_info['is_hit'] = False

def check_rank_id(self, rank_id):
"""check if has the rank id."""
return rank_id in self.watchpoint_hit_handlers

def clean(self):
"""Clean cache."""
self.__init__()


class WatchpointHitHandler(StreamHandlerBase): class WatchpointHitHandler(StreamHandlerBase):
"""Watchpoint hit handler.""" """Watchpoint hit handler."""


@@ -743,3 +805,9 @@ def _get_error_list(error_code):
error_list.append(error_str) error_list.append(error_str)


return error_list return error_list


def validate_rank_id(rank_id, device_amount):
"""validate rank id"""
if rank_id >= device_amount:
log.debug("The rank id %d over device amount.", rank_id)

+ 30
- 17
mindinsight/debugger/stream_operator/tensor_detail_info.py View File

@@ -23,17 +23,19 @@ class TensorDetailInfo:


def __init__(self, cache): def __init__(self, cache):
self._put_command = cache.put_command self._put_command = cache.put_command
self._tensor_stream = cache.get_stream_handler(Streams.TENSOR)
self._graph_stream = cache.get_stream_handler(Streams.GRAPH)
self._hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT)
self._metadata_stream = cache.get_stream_handler(Streams.METADATA)
self._multi_card_tensor_stream = cache.get_stream_handler(Streams.TENSOR)
self._multi_card_graph_stream = cache.get_stream_handler(Streams.GRAPH)
self._multi_card_hit_stream = cache.get_stream_handler(Streams.WATCHPOINT_HIT)


def validate_tensor_name(self, tensor_name, graph_name):
def validate_tensor_name(self, tensor_name, graph_name, rank_id):
""" """
Get the graph id of the tensor. Get the graph id of the tensor.


Args: Args:
tensor_name (str): The tensor name on UI. tensor_name (str): The tensor name on UI.
graph_name (str): The graph name. graph_name (str): The graph name.
rank_id (int): The rank id.
""" """
# validate tensor name format # validate tensor name format
if not isinstance(tensor_name, str) or ':' not in tensor_name: if not isinstance(tensor_name, str) or ':' not in tensor_name:
@@ -41,15 +43,17 @@ class TensorDetailInfo:
raise DebuggerParamValueError("Invalid tensor name.") raise DebuggerParamValueError("Invalid tensor name.")
node_name, _ = tensor_name.rsplit(':', 1) node_name, _ = tensor_name.rsplit(':', 1)
# check if the node name is in graph # check if the node name is in graph
self._graph_stream.validate_node_name(node_name=node_name, graph_name=graph_name)
self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).validate_node_name(node_name=node_name,
graph_name=graph_name)


def get_tensor_graph(self, tensor_name, graph_name):
def get_tensor_graph(self, tensor_name, graph_name, rank_id=0):
""" """
Get the graph related to specific tensor. Get the graph related to specific tensor.


Args: Args:
tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}. tensor_name (str): The ui name of tensor. Format like {node_name}:{slot}.
graph_name (str): The graph name. graph_name (str): The graph name.
rank_id (int): The rank id.


Returns: Returns:
dict, tensor graph, format is {'nodes': [Node object]}. dict, tensor graph, format is {'nodes': [Node object]}.
@@ -68,8 +72,9 @@ class TensorDetailInfo:
'slot_mapping': list[pair<slot, slot>], 'slot_mapping': list[pair<slot, slot>],
}. }.
""" """
self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name)
graph = self._graph_stream.get_tensor_graph(tensor_name, graph_name)
self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name, rank_id=rank_id)
graph = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_tensor_graph(tensor_name,
graph_name)
# add watchpoint hits info and statistics info for each tensor in tensor graph. # add watchpoint hits info and statistics info for each tensor in tensor graph.
# record missing tensor basic info # record missing tensor basic info
nodes = graph.get('graph', {}).get('nodes', []) nodes = graph.get('graph', {}).get('nodes', [])
@@ -77,13 +82,13 @@ class TensorDetailInfo:
for node in nodes: for node in nodes:
node['graph_name'] = graph_name node['graph_name'] = graph_name
for slot_info in node.get('slots', []): for slot_info in node.get('slots', []):
self._add_watchpoint_hit_info(slot_info, node, graph_name)
self._add_tensor_info(slot_info, node, missing_tensors)
self._add_watchpoint_hit_info(slot_info, node, graph_name, rank_id)
self._add_tensor_info(slot_info, node, missing_tensors, rank_id)
# query missing tensor values from client # query missing tensor values from client
self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name) self._ask_for_missing_tensor_value(missing_tensors, tensor_name, graph_name)
return graph return graph


def _add_watchpoint_hit_info(self, slot_info, node, graph_name):
def _add_watchpoint_hit_info(self, slot_info, node, graph_name, rank_id):
""" """
Add watchpoint hit info for the tensor. Add watchpoint hit info for the tensor.


@@ -93,9 +98,12 @@ class TensorDetailInfo:
graph_name (str): Graph name. graph_name (str): Graph name.
""" """
tensor_name = ':'.join([node.get('name'), slot_info.get('slot')]) tensor_name = ':'.join([node.get('name'), slot_info.get('slot')])
slot_info.update(self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name))
if self._multi_card_hit_stream.check_rank_id(rank_id=rank_id):
slot_info.update(
self._multi_card_hit_stream.get_hit_handler_by_rank_id(rank_id).get_tensor_hit_infos(tensor_name,
graph_name))


def _add_tensor_info(self, slot_info, node, missing_tensors):
def _add_tensor_info(self, slot_info, node, missing_tensors, rank_id):
""" """
Add the tensor info and query for missed tensors. Add the tensor info and query for missed tensors.


@@ -106,7 +114,8 @@ class TensorDetailInfo:
""" """
tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')]) tensor_name = ':'.join([node.get('full_name'), slot_info.get('slot')])
node_type = node.get('type') node_type = node.get('type')
tensor_info, cur_missing_tensors = self._tensor_stream.get_tensor_info_for_tensor_graph(tensor_name, node_type)
tensor_info, cur_missing_tensors = self._multi_card_tensor_stream.get_tensor_handler_by_rank_id(
rank_id).get_tensor_info_for_tensor_graph(tensor_name, node_type, self._metadata_stream.step)
slot_info.update(tensor_info) slot_info.update(tensor_info)
if cur_missing_tensors: if cur_missing_tensors:
log.debug("Get missing tensor basic infos for %s", tensor_name) log.debug("Get missing tensor basic infos for %s", tensor_name)
@@ -128,20 +137,24 @@ class TensorDetailInfo:
self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name}) self._put_command({'view_cmd': view_cmd, 'tensor_name': tensor_name, 'graph_name': graph_name})
log.debug("Send view cmd for tensor-graphs.") log.debug("Send view cmd for tensor-graphs.")


def get_tensor_watch_points(self, tensor_name, graph_name):
def get_tensor_watch_points(self, tensor_name, graph_name, rank_id=0):
""" """
Get all watchpoints that the tensor hit. Get all watchpoints that the tensor hit.


Args: Args:
tensor_name (str): Tensor name from UI. tensor_name (str): Tensor name from UI.
graph_name (str): The graph name. graph_name (str): The graph name.
rank_id (int): The rank id.


Returns: Returns:
list, watchpoint hit infos. list, watchpoint hit infos.
""" """
# validate tensor_name # validate tensor_name
self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name)
self.validate_tensor_name(tensor_name=tensor_name, graph_name=graph_name, rank_id=rank_id)
# get watchpoint info that the tensor hit # get watchpoint info that the tensor hit
tensor_hit_info = self._hit_stream.get_tensor_hit_infos(tensor_name, graph_name)
if not self._multi_card_hit_stream.check_rank_id(rank_id=rank_id):
return []
tensor_hit_info = self._multi_card_hit_stream.get_hit_handler_by_rank_id(rank_id).get_tensor_hit_infos(
tensor_name, graph_name)
watch_points = tensor_hit_info.get('watch_points', []) watch_points = tensor_hit_info.get('watch_points', [])
return watch_points return watch_points

+ 43
- 7
mindinsight/debugger/stream_operator/training_control_operator.py View File

@@ -18,7 +18,8 @@ import enum
from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \ from mindinsight.debugger.common.exceptions.exceptions import DebuggerContinueError, DebuggerParamValueError, \
DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError DebuggerPauseError, DebuggerRecheckError, DebuggerStepNumError
from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type
from mindinsight.debugger.common.utils import Streams, get_ack_reply, ServerStatus, RunLevel, is_scope_type, \
DebuggerServerMode
from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
from mindinsight.utils.exceptions import MindInsightException from mindinsight.utils.exceptions import MindInsightException


@@ -29,6 +30,7 @@ class ControlTypeEnum(enum.Enum):
CONTINUE = 'continue' # continue to run training CONTINUE = 'continue' # continue to run training
PAUSE = 'pause' # suspend training PAUSE = 'pause' # suspend training
TERMINATE = 'terminate' # terminate training TERMINATE = 'terminate' # terminate training
RESET = 'reset' # reset the step_id in offline debugger




class TrainingControlOperator: class TrainingControlOperator:
@@ -39,7 +41,7 @@ class TrainingControlOperator:
def __init__(self, cache_store): def __init__(self, cache_store):
self._cache_store = cache_store self._cache_store = cache_store
self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT)
self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH)
self._multi_card_graph_stream = cache_store.get_stream_handler(Streams.GRAPH)
self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA)


@staticmethod @staticmethod
@@ -71,6 +73,9 @@ class TrainingControlOperator:
""" """
if mode == ControlTypeEnum.CONTINUE.value: if mode == ControlTypeEnum.CONTINUE.value:
reply = self.continue_training(params) reply = self.continue_training(params)
elif mode == ControlTypeEnum.RESET.value:
step_id = params['steps']
reply = self.reset_training_step(step_id)
else: else:
mode_mapping = { mode_mapping = {
ControlTypeEnum.PAUSE.value: self.pause_training, ControlTypeEnum.PAUSE.value: self.pause_training,
@@ -150,13 +155,15 @@ class TrainingControlOperator:
if level == RunLevel.NODE.value: if level == RunLevel.NODE.value:
node_name = params.get('name') node_name = params.get('name')
graph_name = params.get('graph_name') graph_name = params.get('graph_name')
self._validate_continue_node_name(node_name, graph_name)
rank_id = params.get('rank_id', 0)
self._validate_continue_node_name(node_name, graph_name, rank_id)


def _validate_continue_node_name(self, node_name, graph_name):
def _validate_continue_node_name(self, node_name, graph_name, rank_id):
"""Validate if the node is a leaf node.""" """Validate if the node is a leaf node."""
if not node_name: if not node_name:
return return
node_type = self._graph_stream.get_node_type(node_name, graph_name)
node_type = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_node_type(node_name,
graph_name)
if is_scope_type(node_type): if is_scope_type(node_type):
log.error("Scope type node has no tensor history.") log.error("Scope type node has no tensor history.")
raise DebuggerParamValueError("Invalid leaf node name.") raise DebuggerParamValueError("Invalid leaf node name.")
@@ -188,7 +195,9 @@ class TrainingControlOperator:
name = params.get('name', '') name = params.get('name', '')
graph_name = params.get('graph_name') graph_name = params.get('graph_name')
if name: if name:
name = self._cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name, graph_name)
rank_id = params.get('rank_id', 0)
name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).get_full_name(name,
graph_name)
run_cmd = RunCMD(run_level='node', node_name=name) run_cmd = RunCMD(run_level='node', node_name=name)
else: else:
run_cmd = RunCMD(run_level='recheck') run_cmd = RunCMD(run_level='recheck')
@@ -199,7 +208,7 @@ class TrainingControlOperator:


def _send_watchpoints(self): def _send_watchpoints(self):
"""Send watchpoints to client.""" """Send watchpoints to client."""
set_commands = self._watchpoint_stream.get_pending_commands(self._graph_stream)
set_commands = self._watchpoint_stream.get_pending_commands(self._multi_card_graph_stream)
if not set_commands: if not set_commands:
return return
for set_cmd in set_commands: for set_cmd in set_commands:
@@ -274,3 +283,30 @@ class TrainingControlOperator:
else: else:
log.debug("Send the recheck to command queue.") log.debug("Send the recheck to command queue.")
return metadata_stream.get(['state', 'enable_recheck']) return metadata_stream.get(['state', 'enable_recheck'])

def reset_training_step(self, step_id):
"""
Reset the training step.

Args:
step_id (int): The target step_id.

Returns:
dict, metadata info.
"""
metadata_stream = self._metadata_stream
if metadata_stream.debugger_type == DebuggerServerMode.ONLINE.value:
log.error("'step_id' can not be changed manually in online debugger.")
return metadata_stream.get(['state', 'enable_recheck', 'step'])
if step_id > metadata_stream.max_step_num:
log.error("Invalid step_id, step_id should be less than %d.", metadata_stream.max_step_num)
raise DebuggerParamValueError("Invalid step_id.")
metadata_stream.state = ServerStatus.SENDING.value
metadata_stream.step = step_id
self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id)
self._cache_store.clean_data()
self._cache_store.clean_command()
metadata_stream.enable_recheck = False
metadata_stream.state = ServerStatus.WAITING.value
log.debug("Send the Change_training_step CMD.")
return metadata_stream.get(['state', 'enable_recheck', 'step'])

+ 19
- 19
mindinsight/debugger/stream_operator/watchpoint_operator.py View File

@@ -31,8 +31,9 @@ class WatchpointOperator:


def __init__(self, cache_store, condition_mgr): def __init__(self, cache_store, condition_mgr):
self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT)
self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH)
self._multi_card_graph_stream = cache_store.get_stream_handler(Streams.GRAPH)
self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA)
self._device_stream = cache_store.get_stream_handler(Streams.DEVICE)
self._condition_mgr = condition_mgr self._condition_mgr = condition_mgr


def create_watchpoint(self, params): def create_watchpoint(self, params):
@@ -70,11 +71,6 @@ class WatchpointOperator:
"Failed to create watchpoint as the MindSpore is not in waiting state.") "Failed to create watchpoint as the MindSpore is not in waiting state.")
self._validate_watch_condition(watch_condition) self._validate_watch_condition(watch_condition)


watch_nodes = self._get_watch_node_with_basic_info(
node_names=params.get('watch_nodes'),
search_pattern=params.get('search_pattern'),
graph_name=params.get('graph_name'))

validate_watch_condition(self._condition_mgr, watch_condition) validate_watch_condition(self._condition_mgr, watch_condition)
condition_id = watch_condition.get('id') condition_id = watch_condition.get('id')
condition = self._condition_mgr.get_condition(condition_id) condition = self._condition_mgr.get_condition(condition_id)
@@ -84,10 +80,11 @@ class WatchpointOperator:
raise DebuggerConditionUnavailableError( raise DebuggerConditionUnavailableError(
"Failed to create watchpoint as the condition is not available.") "Failed to create watchpoint as the condition is not available.")


watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy()
watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._multi_card_graph_stream)
watchpoint_stream = self._watchpoint_stream watchpoint_stream = self._watchpoint_stream
watch_point_id = watchpoint_stream.create_watchpoint(
self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id'))
watch_point_id = watchpoint_stream.create_watchpoint(self._condition_mgr, watch_condition, watch_nodes,
params.get('watch_point_id'),
self._device_stream.device_amount)
log.info("Create watchpoint %d", watch_point_id) log.info("Create watchpoint %d", watch_point_id)


metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() metadata_stream.enable_recheck = watchpoint_stream.is_recheckable()
@@ -115,6 +112,7 @@ class WatchpointOperator:
1 for add nodes to watch nodes. 1 for add nodes to watch nodes.
- search_pattern (dict): The search pattern. - search_pattern (dict): The search pattern.
- graph_name (str): The relative graph_name of the watched node. - graph_name (str): The relative graph_name of the watched node.
- rank_id (int): The rank id.


Returns: Returns:
dict, the metadata info. dict, the metadata info.
@@ -137,13 +135,14 @@ class WatchpointOperator:
watch_nodes = self._get_watch_node_with_basic_info( watch_nodes = self._get_watch_node_with_basic_info(
node_names=params.get('watch_nodes'), node_names=params.get('watch_nodes'),
search_pattern=params.get('search_pattern'), search_pattern=params.get('search_pattern'),
graph_name=params.get('graph_name'))
watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode'))
graph_name=params.get('graph_name'),
rank_id=params.get('rank_id', 0))
watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode'), params.get('rank_id', 0))
metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() metadata_stream.enable_recheck = watchpoint_stream.is_recheckable()
log.info("Update watchpoint with id: %d", watch_point_id) log.info("Update watchpoint with id: %d", watch_point_id)
return metadata_stream.get(['state', 'enable_recheck']) return metadata_stream.get(['state', 'enable_recheck'])


def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None):
def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None, rank_id=0):
""" """
Get watch node with basic info. Get watch node with basic info.


@@ -151,20 +150,21 @@ class WatchpointOperator:
node_names (list[str]): A list of node names. node_names (list[str]): A list of node names.
search_pattern (dict): Get watch node with search pattern. Default: None search_pattern (dict): Get watch node with search pattern. Default: None
graph_name (str): The relative graph_name of the watched node. Default: None. graph_name (str): The relative graph_name of the watched node. Default: None.
rank_id (int): The rank id.


Returns: Returns:
list[NodeBasicInfo], a list of node basic infos. list[NodeBasicInfo], a list of node basic infos.
""" """
if not node_names: if not node_names:
return [] return []
graph_name = self._graph_stream.validate_graph_name(graph_name)
graph_name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).validate_graph_name(graph_name)
if search_pattern is not None: if search_pattern is not None:
watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name)
watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name, rank_id)
else: else:
watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name)
watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name, rank_id=rank_id)
return watch_nodes return watch_nodes


def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name):
def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name, rank_id):
""" """
Get watched leaf nodes by search name. Get watched leaf nodes by search name.


@@ -180,7 +180,7 @@ class WatchpointOperator:
list[NodeBasicInfo], a list of node basic infos. list[NodeBasicInfo], a list of node basic infos.
""" """
search_pattern['graph_name'] = graph_name search_pattern['graph_name'] = graph_name
search_nodes = self._graph_stream.search_nodes(search_pattern)
search_nodes = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id).search_nodes(search_pattern)
watch_node_names = set() watch_node_names = set()
for name in node_names: for name in node_names:
names = self._get_watch_names_by_search(search_nodes, name) names = self._get_watch_names_by_search(search_nodes, name)
@@ -260,7 +260,7 @@ class WatchpointOperator:
log.info("Delete watchpoint with id: %s", watch_point_id) log.info("Delete watchpoint with id: %s", watch_point_id)
return metadata_stream.get(['state', 'enable_recheck']) return metadata_stream.get(['state', 'enable_recheck'])


def _get_node_basic_infos(self, node_names, graph_name=None):
def _get_node_basic_infos(self, node_names, graph_name=None, rank_id=0):
""" """
Get watch node info according to node names. Get watch node info according to node names.


@@ -273,7 +273,7 @@ class WatchpointOperator:
""" """
if not node_names: if not node_names:
return [] return []
graph_stream = self._graph_stream
graph_stream = self._multi_card_graph_stream.get_graph_handler_by_rank_id(rank_id)
node_infos = [] node_infos = []
for node_name in node_names: for node_name in node_names:
node_info = graph_stream.get_node_basic_info(node_name, graph_name) node_info = graph_stream.get_node_basic_info(node_name, graph_name)


+ 1
- 1
mindinsight/ui/src/app.vue View File

@@ -26,7 +26,7 @@ limitations under the License.
<div class="cl-center" <div class="cl-center"
:class="showWarmText ? 'cl-center-height' : ''"> :class="showWarmText ? 'cl-center-height' : ''">


<router-view></router-view>
<router-view :key="$route.fullPath"></router-view>
</div> </div>
</div> </div>
</template> </template>


+ 8
- 4
mindinsight/ui/src/components/debugger-tensor.vue View File

@@ -362,8 +362,9 @@ export default {
const params = { const params = {
tensor_name: this.curRowObj.name, tensor_name: this.curRowObj.name,
graph_name: this.curRowObj.graph_name, graph_name: this.curRowObj.graph_name,
rank_id: this.curRowObj.rank_id,
}; };
RequestService.getTensorGraphData(params).then(
RequestService.getTensorGraphData(params, this.curRowObj.sessionId).then(
(res) => { (res) => {
if (res && res.data && res.data.graph && res.data.graph.nodes && res.data.graph.nodes.length) { if (res && res.data && res.data.graph && res.data.graph.nodes && res.data.graph.nodes.length) {
this.graphShow = true; this.graphShow = true;
@@ -419,8 +420,9 @@ export default {
const params = { const params = {
tensor_name: this.curRowObj.name, tensor_name: this.curRowObj.name,
graph_name: this.curRowObj.graph_name, graph_name: this.curRowObj.graph_name,
rank_id: this.curRowObj.rank_id,
}; };
RequestService.tensorHitsData(params).then(
RequestService.tensorHitsData(params, this.curRowObj.sessionId).then(
(res) => { (res) => {
if (res && res.data && res.data.watch_points && res.data.watch_points.length) { if (res && res.data && res.data.watch_points && res.data.watch_points.length) {
this.leftDataShow = true; this.leftDataShow = true;
@@ -995,11 +997,12 @@ export default {
shape: encodeURIComponent(shape), shape: encodeURIComponent(shape),
tolerance: this.tolerance / 100, tolerance: this.tolerance / 100,
graph_name: row.graph_name, graph_name: row.graph_name,
rank_id: row.rank_id,
}; };
if (loadingFlag) { if (loadingFlag) {
this.loadingInstance = this.$loading(this.loadingOption); this.loadingInstance = this.$loading(this.loadingOption);
} }
RequestService.tensorComparisons(params).then(
RequestService.tensorComparisons(params, row.sessionId).then(
(res) => { (res) => {
if (res && res.data && res.data.tensor_value) { if (res && res.data && res.data.tensor_value) {
if (row.shape === '[]') { if (row.shape === '[]') {
@@ -1088,11 +1091,12 @@ export default {
shape: encodeURIComponent(shape), shape: encodeURIComponent(shape),
graph_name: row.graph_name, graph_name: row.graph_name,
prev: this.gridType === 'preStep' ? true : false, prev: this.gridType === 'preStep' ? true : false,
rank_id: row.rank_id,
}; };
if (loadingFlag) { if (loadingFlag) {
this.loadingInstance = this.$loading(this.loadingOption); this.loadingInstance = this.$loading(this.loadingOption);
} }
RequestService.tensors(params).then(
RequestService.tensors(params, row.sessionId).then(
(res) => { (res) => {
if (row.shape === '[]') { if (row.shape === '[]') {
this.showFilterInput = false; this.showFilterInput = false;


+ 17
- 5
mindinsight/ui/src/locales/en-us.json View File

@@ -24,7 +24,9 @@
"dataLoading": "Loading data...", "dataLoading": "Loading data...",
"notice": "Information", "notice": "Information",
"caseMode": "Not case sensitive", "caseMode": "Not case sensitive",
"all": "All"
"all": "All",
"details": "Details",
"delete": "Delete"
}, },
"symbols": { "symbols": {
"leftbracket": "(", "leftbracket": "(",
@@ -52,12 +54,14 @@
"operation": "Operation", "operation": "Operation",
"viewDashboard": "Training Dashboard", "viewDashboard": "Training Dashboard",
"viewProfiler": "Profiling", "viewProfiler": "Profiling",
"viewOfflineDebugger": "Offline Debugger",
"modelTraceback": "Model Lineage", "modelTraceback": "Model Lineage",
"dataTraceback": "Dataset Lineage", "dataTraceback": "Dataset Lineage",
"comparePlate": "Comparison Dashboard", "comparePlate": "Comparison Dashboard",
"disableProfilerTip": "Failed to view profiling because no profiler log is available.", "disableProfilerTip": "Failed to view profiling because no profiler log is available.",
"disableDashboardTip": "Failed to view training dashboard because no summary log or pb files are available.", "disableDashboardTip": "Failed to view training dashboard because no summary log or pb files are available.",
"disableParameterTip": "Failed to view parameter details because no lineage log is available.", "disableParameterTip": "Failed to view parameter details because no lineage log is available.",
"disableOfflineDebugger": "Failed to view offline debugger because no debugger log is available.",
"openNewTab": "Open Link in New Tab", "openNewTab": "Open Link in New Tab",
"paramDetails": "Parameter Details", "paramDetails": "Parameter Details",
"trainingParamDetails": "Training Parameter Details", "trainingParamDetails": "Training Parameter Details",
@@ -80,7 +84,12 @@
"tensorUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#tensor-visualization", "tensorUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#tensor-visualization",
"graphUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#computational-graph-visualization", "graphUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#computational-graph-visualization",
"dataProcessUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#dataset-graph-visualization", "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#dataset-graph-visualization",
"imageUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#image-visualization"
"imageUrl": "https://www.mindspore.cn/tutorial/training/en/master/advanced_use/dashboard.html#image-visualization",
"sessionLimit": "The number of sessions of the offline debugger exceeds the number of online sessions",
"sessionLimitNum": "At most 2 exist at the same time",
"sessionLists": "List of currently existing sessions",
"deleteSessionConfirm": "This operation will delete the current session, do you want to continue?",
"deleteSessionSuccess": "Delete session successfully!"
}, },
"modelTraceback": { "modelTraceback": {
"summaryPath": "Summary Path", "summaryPath": "Summary Path",
@@ -561,7 +570,7 @@
"terminate": "TERMINATE", "terminate": "TERMINATE",
"selectCondition": "Select a condition", "selectCondition": "Select a condition",
"inputStep": "Enter a step value", "inputStep": "Enter a step value",
"inputTip": "A positive integer less than 2147483648",
"inputTip": "A positive integer less than or equal to {total_step_num}",
"curHitNode": "Watch Point Hit List", "curHitNode": "Watch Point Hit List",
"backstageStatus": "The backend running status is ", "backstageStatus": "The backend running status is ",
"view": "View", "view": "View",
@@ -830,7 +839,9 @@
"allPositive": "he parameter value must be greater than 0.", "allPositive": "he parameter value must be greater than 0.",
"watchOverflow": "The asynchronous full overflow watching function must be enabled before the training starts." "watchOverflow": "The asynchronous full overflow watching function must be enabled before the training starts."
}, },
"paramValueTip": "Preset Value: {value}"
"paramValueTip": "Preset Value: {value}",
"logicCard": "Logic card",
"inpStepTip": "Step:0~{total_step_num}"
}, },
"explain": { "explain": {
"explain": "Model Explanation", "explain": "Model Explanation",
@@ -952,6 +963,7 @@
"5054B183": "Backend training is in progress or has ended. Please try again later", "5054B183": "Backend training is in progress or has ended. Please try again later",
"5054B184": "The operation is too fast, the backend service has been suspended.", "5054B184": "The operation is too fast, the backend service has been suspended.",
"5054B189": "Do not set the value repeatedly.", "5054B189": "Do not set the value repeatedly.",
"5054B083": "Failed to create the watchpoint. Do not use invalid rules."
"5054B083": "Failed to create the watchpoint. Do not use invalid rules.",
"5054B202": "The debugger offline server module was not found"
} }
} }

+ 17
- 5
mindinsight/ui/src/locales/zh-cn.json View File

@@ -24,7 +24,9 @@
"dataLoading": "数据加载中", "dataLoading": "数据加载中",
"notice": "提示", "notice": "提示",
"caseMode": "不区分大小写", "caseMode": "不区分大小写",
"all": "全部"
"all": "全部",
"details": "详情",
"delete": "删除"
}, },
"symbols": { "symbols": {
"leftbracket": "(", "leftbracket": "(",
@@ -52,12 +54,14 @@
"operation": "操作", "operation": "操作",
"viewDashboard": "训练看板", "viewDashboard": "训练看板",
"viewProfiler": "性能分析", "viewProfiler": "性能分析",
"viewOfflineDebugger": "离线调试器",
"modelTraceback": "模型溯源", "modelTraceback": "模型溯源",
"dataTraceback": "数据溯源", "dataTraceback": "数据溯源",
"comparePlate": "对比看板", "comparePlate": "对比看板",
"disableProfilerTip": "无profiler日志,无法查看性能分析", "disableProfilerTip": "无profiler日志,无法查看性能分析",
"disableDashboardTip": "无summary日志或pb文件,无法查看训练看板", "disableDashboardTip": "无summary日志或pb文件,无法查看训练看板",
"disableParameterTip": "无lineage日志,无法查看参数详情", "disableParameterTip": "无lineage日志,无法查看参数详情",
"disableOfflineDebugger": "无Debugger日志,无法查看离线调试器",
"openNewTab": "打开新页签", "openNewTab": "打开新页签",
"paramDetails": "参数详情", "paramDetails": "参数详情",
"trainingParamDetails": "训练参数详情", "trainingParamDetails": "训练参数详情",
@@ -80,7 +84,12 @@
"tensorUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id8", "tensorUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id8",
"graphUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id5", "graphUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id5",
"dataProcessUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id6", "dataProcessUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id6",
"imageUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id7"
"imageUrl": "https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/dashboard.html#id7",
"sessionLimit": "离线调试器的session个数超过上线",
"sessionLimitNum": "最多同时存在2个",
"sessionLists": "目前存在的session列表",
"deleteSessionConfirm": "此操作将删除当前session, 是否继续?",
"deleteSessionSuccess": "删除session成功!"
}, },
"modelTraceback": { "modelTraceback": {
"summaryPath": "训练日志路径", "summaryPath": "训练日志路径",
@@ -560,7 +569,7 @@
"terminate": "结束", "terminate": "结束",
"selectCondition": "请选择条件", "selectCondition": "请选择条件",
"inputStep": "请输入轮次值", "inputStep": "请输入轮次值",
"inputTip": "小于2147483648的正整数",
"inputTip": "小于等于{total_step_num}的正整数",
"curHitNode": "命中的监测点", "curHitNode": "命中的监测点",
"backstageStatus": "后台运行状态是", "backstageStatus": "后台运行状态是",
"view": "查看", "view": "查看",
@@ -825,7 +834,9 @@
"allPositive": "此参数值必须大于0", "allPositive": "此参数值必须大于0",
"watchOverflow": "训练开始前需开启异步全量溢出监测功能" "watchOverflow": "训练开始前需开启异步全量溢出监测功能"
}, },
"paramValueTip": "设置值为:{value}"
"paramValueTip": "设置值为:{value}",
"logicCard": "逻辑卡",
"inpStepTip": "可输入当前轮次:0~{total_step_num}"
}, },
"explain": { "explain": {
"explain": "模型解释", "explain": "模型解释",
@@ -947,6 +958,7 @@
"5054B183": "后台训练运行中,请稍后重试", "5054B183": "后台训练运行中,请稍后重试",
"5054B184": "操作过快,后台服务已暂停。", "5054B184": "操作过快,后台服务已暂停。",
"5054B189": "请勿重复设置。", "5054B189": "请勿重复设置。",
"5054B083": "监测点创建失败,请勿使用已失效规则。"
"5054B083": "监测点创建失败,请勿使用已失效规则。",
"5054B202": "未找到调试器离线服务器模块"
} }
} }

+ 436
- 163
mindinsight/ui/src/mixins/debugger-mixin.vue
File diff suppressed because it is too large
View File


+ 4
- 0
mindinsight/ui/src/router.js View File

@@ -157,6 +157,10 @@ export default new Router({
path: '/debugger', path: '/debugger',
component: () => import('./views/debugger/debugger.vue'), component: () => import('./views/debugger/debugger.vue'),
}, },
{
path: '/offline-debugger',
component: () => import('./views/debugger/debugger.vue'),
},
{ {
path: '/explain', path: '/explain',
component: () => import('./views/explain/summary-list.vue'), component: () => import('./views/explain/summary-list.vue'),


+ 8
- 1
mindinsight/ui/src/services/fetcher.js View File

@@ -62,7 +62,14 @@ axios.interceptors.response.use(
const errorData = i18n.messages[i18n.locale].error; const errorData = i18n.messages[i18n.locale].error;
const path = router.currentRoute.path; const path = router.currentRoute.path;


if (path === '/debugger') {
if (path === '/debugger' || path === '/offline-debugger') {
if (
error.response &&
error.response.data &&
error.response.data.error_code === '5054B281'
) {
router.push('/');
}
return Promise.reject(error); return Promise.reject(error);
} }
// error returned by backend // error returned by backend


+ 52
- 41
mindinsight/ui/src/services/request-service.js View File

@@ -309,55 +309,74 @@ export default {
}); });
}, },
// debugger // debugger
pollData(params) {
getSession(params) {
return axios({
method: 'post',
url: 'v1/mindinsight/debugger/sessions',
data: params,
});
},
deleteSession(sessionId) {
return axios({
method: 'post',
url: `v1/mindinsight/debugger/sessions/${sessionId}/delete`,
});
},
checkSessions() {
return axios({ return axios({
method: 'get', method: 'get',
url: 'v1/mindinsight/debugger/poll-data',
url: `v1/mindinsight/debugger/sessions`,
});
},
pollData(params, sessionId) {
return axios({
method: 'get',
url: `v1/mindinsight/debugger/sessions/${sessionId}/poll-data`,
params: params, params: params,
headers: { headers: {
ignoreError: true, ignoreError: true,
}, },
}); });
}, },
retrieve(params) {
retrieve(params, sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: 'v1/mindinsight/debugger/retrieve',
url: `v1/mindinsight/debugger/sessions/${sessionId}/retrieve`,
data: params, data: params,
}); });
}, },
createWatchpoint(params) {
createWatchpoint(params, sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: 'v1/mindinsight/debugger/create-watchpoint',
url: `v1/mindinsight/debugger/sessions/${sessionId}/create-watchpoint`,
data: params, data: params,
}); });
}, },
updateWatchpoint(params) {
updateWatchpoint(params, sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: 'v1/mindinsight/debugger/update-watchpoint',
url: `v1/mindinsight/debugger/sessions/${sessionId}/update-watchpoint`,
data: params, data: params,
}); });
}, },
deleteWatchpoint(params) {
deleteWatchpoint(params, sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: 'v1/mindinsight/debugger/delete-watchpoint',
url: `v1/mindinsight/debugger/sessions/${sessionId}/delete-watchpoint`,
data: params, data: params,
}); });
}, },
control(params) {
control(params, sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: 'v1/mindinsight/debugger/control',
url: `v1/mindinsight/debugger/sessions/${sessionId}/control`,
data: params, data: params,
}); });
}, },
search(params) {
search(params, sessionId) {
return axios({ return axios({
method: 'get', method: 'get',
url: 'v1/mindinsight/debugger/search',
url: `v1/mindinsight/debugger/sessions/${sessionId}/search`,
params: params, params: params,
}); });
}, },
@@ -368,43 +387,43 @@ export default {
params: params, params: params,
}); });
}, },
tensorComparisons(params) {
tensorComparisons(params, sessionId) {
return axios({ return axios({
method: 'get', method: 'get',
url: 'v1/mindinsight/debugger/tensor-comparisons',
url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-comparisons`,
params: params, params: params,
}); });
}, },
tensors(params) {
tensors(params, sessionId) {
return axios({ return axios({
method: 'get', method: 'get',
url: 'v1/mindinsight/debugger/tensors',
url: `v1/mindinsight/debugger/sessions/${sessionId}/tensors`,
params: params, params: params,
}); });
}, },
retrieveTensorHistory(params) {
retrieveTensorHistory(params, sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: 'v1/mindinsight/debugger/tensor-history',
url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-history`,
data: params, data: params,
}); });
}, },
queryConditions(trainId) {
queryConditions(sessionId) {
return axios({ return axios({
method: 'get', method: 'get',
url: `v1/mindinsight/conditionmgr/train-jobs/${trainId}/condition-collections`,
url: `v1/mindinsight/debugger/sessions/${sessionId}/condition-collections`,
}); });
}, },
recheckWatchPoints() {
recheckWatchPoints(sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: `v1/mindinsight/debugger/recheck`,
url: `v1/mindinsight/debugger/sessions/${sessionId}/recheck`,
}); });
}, },
searchWatchpointHits(params) {
searchWatchpointHits(params, sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: `v1/mindinsight/debugger/search-watchpoint-hits`,
url: `v1/mindinsight/debugger/sessions/${sessionId}/search-watchpoint-hits`,
data: params, data: params,
}); });
}, },
@@ -447,33 +466,25 @@ export default {
data: params, data: params,
}); });
}, },
tensorHitsData(params) {
tensorHitsData(params, sessionId) {
return axios({ return axios({
method: 'get', method: 'get',
url: 'v1/mindinsight/debugger/tensor-hits',
url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-hits`,
params: params, params: params,
}); });
}, },
getTensorGraphData(params) {
getTensorGraphData(params, sessionId) {
return axios({ return axios({
method: 'get', method: 'get',
url: 'v1/mindinsight/debugger/tensor-graphs',
url: `v1/mindinsight/debugger/sessions/${sessionId}/tensor-graphs`,
params: params, params: params,
}); });
}, },
getCpuUtilization(params) {
return axios({
method: 'post',
url: 'v1/mindinsight/profile/minddata-cpu-utilization-summary',
params: params.params,
data: params.body,
});
},
setRecommendWatchPoints(params) {
setRecommendWatchPoints(params, sessionId) {
return axios({ return axios({
method: 'post', method: 'post',
url: `v1/mindinsight/conditionmgr/train-jobs/${params.trainId}/set-recommended-watch-points`,
data: params.body,
url: `v1/mindinsight/debugger/sessions/${sessionId}/set-recommended-watch-points`,
data: params,
}); });
}, },
// memory-datail apis // memory-datail apis


+ 129
- 24
mindinsight/ui/src/views/debugger/debugger.vue View File

@@ -46,6 +46,17 @@ limitations under the License.
</div> </div>
<div class="content" <div class="content"
v-show="radio1==='tree'"> v-show="radio1==='tree'">
<div class="node-type">
<div class="label">{{ $t('debugger.logicCard') }}</div>
<el-select v-model="logicCard.value"
@change="logicCardChange"
:disabled="!trainId">
<el-option v-for="item in logicCard.options"
:key="item"
:value="item">
</el-option>
</el-select>
</div>
<div class="node-type"> <div class="node-type">
<div class="label">{{ $t('debugger.graphFile') }}</div> <div class="label">{{ $t('debugger.graphFile') }}</div>
<el-select v-model="graphFiles.value" <el-select v-model="graphFiles.value"
@@ -209,6 +220,17 @@ limitations under the License.
</div> </div>
<div class="content" <div class="content"
v-show="radio1==='hit'"> v-show="radio1==='hit'">
<div class="node-type">
<div class="label">{{ $t('debugger.logicCard') }}</div>
<el-select v-model="logicCard.value"
:disabled="!trainId"
@change="logicCardChange();searchWatchpointHits(true);">
<el-option v-for="item in logicCard.options"
:key="item"
:value="item">
</el-option>
</el-select>
</div>
<div class="hit-list-wrap"> <div class="hit-list-wrap">
<el-table class="watchpoint-table" <el-table class="watchpoint-table"
:data="watchPointHits" :data="watchPointHits"
@@ -261,7 +283,7 @@ limitations under the License.
<div class="step"> <div class="step">
<el-tooltip class="item" <el-tooltip class="item"
effect="light" effect="light"
:content="$t('debugger.inputTip')"
:content="$t('debugger.inputTip',{total_step_num:metadata.total_step_num})"
placement="top-start"> placement="top-start">
<el-input v-model="step" <el-input v-model="step"
:placeholder="$t('debugger.inputStep')" :placeholder="$t('debugger.inputStep')"
@@ -330,6 +352,25 @@ limitations under the License.
v-show="metadata.state === state.sending"> v-show="metadata.state === state.sending">
<i class="el-icon-time"></i> <i class="el-icon-time"></i>
</el-tooltip> </el-tooltip>
<i class="el-icon-edit"
v-if="trainId && !isShowInp"
:title="$t('debugger.inpStepTip',{total_step_num:metadata.total_step_num})"
@click="editStep"></i>
<el-tooltip class="item"
effect="light"
:content="$t('debugger.inputTip',{total_step_num:metadata.total_step_num})"
placement="top-start"
v-if="trainId && isShowInp">
<el-input v-model="newStep"
type="text"
@input="newStepChange"></el-input>
</el-tooltip>
<i class="el-icon-check"
v-if="trainId && isShowInp"
@click="saveStepValue"></i>
<i class="el-icon-close"
v-if="trainId && isShowInp"
@click="isShowInp=false"></i>
</div> </div>
<div class="svg-wrap" <div class="svg-wrap"
:class="{collapse: collapseTable}"> :class="{collapse: collapseTable}">
@@ -505,7 +546,7 @@ limitations under the License.
:close-on-click-modal="false" :close-on-click-modal="false"
:modal-append-to-body="false" :modal-append-to-body="false"
class="creat-watch-point-dialog" class="creat-watch-point-dialog"
width="890px">
width="930px">


<div class="conditions-container"> <div class="conditions-container">
<div class="condition-item" <div class="condition-item"
@@ -787,6 +828,11 @@ export default {
value: '', value: '',
graphs: {}, graphs: {},
}, },
logicCard: {
options: [],
value: '',
},
devices: [],
allGraphData: {}, // Graph Original input data allGraphData: {}, // Graph Original input data
firstFloorNodes: [], // ID array of the first layer node. firstFloorNodes: [], // ID array of the first layer node.
nodesCountLimit: 1500, // Maximum number of sub-nodes in a namespace. nodesCountLimit: 1500, // Maximum number of sub-nodes in a namespace.
@@ -830,7 +876,7 @@ export default {
expandKeys: [], expandKeys: [],
isHitIntoView: true, isHitIntoView: true,
searchedWord: '', searchedWord: '',
trainId: '',
trainId: this.$route.query.dir,
recommendWatchPointDialog: false, recommendWatchPointDialog: false,
hitsOutdated: false, hitsOutdated: false,
conflictFlag: false, conflictFlag: false,
@@ -859,6 +905,9 @@ export default {
}, },
loadingInstance: null, loadingInstance: null,
paramErrorMsg: '', paramErrorMsg: '',
sessionId: this.$route.query.sessionId,
isShowInp: false,
newStep: '',
}; };
}, },
components: {debuggerTensor, tree}, components: {debuggerTensor, tree},
@@ -866,6 +915,12 @@ export default {
mounted() { mounted() {
document.title = `${this.$t('debugger.debugger')}-MindInsight`; document.title = `${this.$t('debugger.debugger')}-MindInsight`;
this.nodeTypes.label = this.$t('debugger.nodeType'); this.nodeTypes.label = this.$t('debugger.nodeType');
if (this.trainId) {
document.title = `${this.trainId}-${this.$t('debugger.debugger')}-MindInsight`;
this.retrieveAll();
} else {
this.getSession();
}
}, },
watch: { watch: {
'metadata.state': { 'metadata.state': {
@@ -896,7 +951,7 @@ export default {


if (newValue === this.state.waiting) { if (newValue === this.state.waiting) {
if (this.oldState === this.state.pending || oldValue === this.state.pending) { if (this.oldState === this.state.pending || oldValue === this.state.pending) {
this.loadNode(this.node, this.resolve);
this.retrieveAll();
} else if (this.oldState === this.state.running || oldValue === this.state.running) { } else if (this.oldState === this.state.running || oldValue === this.state.running) {
this.pagination.currentPage = 1; this.pagination.currentPage = 1;
this.watchPointHits = []; this.watchPointHits = [];
@@ -914,6 +969,8 @@ export default {
this.curRowObj.type = type; this.curRowObj.type = type;
this.curRowObj.curFileName = this.graphFiles.value; this.curRowObj.curFileName = this.graphFiles.value;
this.curRowObj.step = this.metadata.step; this.curRowObj.step = this.metadata.step;
this.curRowObj.rank_id = this.logicCard.value;
this.curRowObj.sessionId = this.sessionId;
this.tensorCompareFlag = true; this.tensorCompareFlag = true;
}, },
closeTensor(tensor, graphName) { closeTensor(tensor, graphName) {
@@ -922,6 +979,19 @@ export default {
this.queryAllTreeData(tensor, true, graphName, true); this.queryAllTreeData(tensor, true, graphName, true);
} }
}, },
logicCardChange() {
this.graphFiles.options = JSON.parse(
JSON.stringify(this.devices.find((val) => val.rank_id === this.logicCard.value).graph_names),
);
if (this.graphFiles.options.length > 1) {
this.graphFiles.options.unshift(this.$t('debugger.all'));
}
this.graphFiles.value = this.graphFiles.options[0];
const device = this.devices.find((val) => val.rank_id === this.logicCard.value);
this.metadata.ip = device.server_ip;
this.metadata.device_name = device.device_id;
this.queryGraphByFile();
},
queryGraphByFile() { queryGraphByFile() {
this.searchWord = ''; this.searchWord = '';
this.nodeTypes.value = 'all'; this.nodeTypes.value = 'all';
@@ -931,12 +1001,13 @@ export default {
params: { params: {
watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0, watch_point_id: this.curWatchPointId ? this.curWatchPointId : 0,
graph_name: this.graphFiles.value, graph_name: this.graphFiles.value,
rank_id: this.logicCard.value,
}, },
}; };
if (this.graphFiles.value === this.$t('debugger.all')) { if (this.graphFiles.value === this.$t('debugger.all')) {
delete params.params.graph_name; delete params.params.graph_name;
} }
RequestService.retrieve(params).then(
RequestService.retrieve(params, this.sessionId).then(
(res) => { (res) => {
if (res.data && res.data.metadata) { if (res.data && res.data.metadata) {
this.dealMetadata(res.data.metadata); this.dealMetadata(res.data.metadata);
@@ -975,6 +1046,7 @@ export default {
d3.select('#graph svg').remove(); d3.select('#graph svg').remove();
this.selectedNode.name = ''; this.selectedNode.name = '';
this.dealGraphData(JSON.parse(JSON.stringify(graph.nodes))); this.dealGraphData(JSON.parse(JSON.stringify(graph.nodes)));
this.tableData = [];
} }
}, },
(err) => { (err) => {
@@ -1015,11 +1087,12 @@ export default {
watch_nodes: watchNodes, watch_nodes: watchNodes,
mode: type ? 1 : 0, mode: type ? 1 : 0,
graph_name: this.graphFiles.value, graph_name: this.graphFiles.value,
rank_id: this.logicCard.value,
}; };
if (this.graphFiles.value === this.$t('debugger.all')) { if (this.graphFiles.value === this.$t('debugger.all')) {
delete params.graph_name; delete params.graph_name;
} }
RequestService.updateWatchpoint(params).then(
RequestService.updateWatchpoint(params, this.sessionId).then(
(res) => { (res) => {
this.defaultCheckedArr = this.$refs.tree.getCheckedKeys(); this.defaultCheckedArr = this.$refs.tree.getCheckedKeys();
if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) { if (res && res.data && res.data.metadata && res.data.metadata.enable_recheck !== undefined) {
@@ -1049,12 +1122,16 @@ export default {
queryGraphByWatchpoint(id) { queryGraphByWatchpoint(id) {
const params = { const params = {
mode: 'watchpoint', mode: 'watchpoint',
params: {watch_point_id: id, graph_name: this.graphFiles.value},
params: {
watch_point_id: id,
graph_name: this.graphFiles.value,
rank_id: this.logicCard.value,
},
}; };
if (this.graphFiles.value === this.$t('debugger.all')) { if (this.graphFiles.value === this.$t('debugger.all')) {
delete params.params.graph_name; delete params.params.graph_name;
} }
RequestService.retrieve(params).then(
RequestService.retrieve(params, this.sessionId).then(
(res) => { (res) => {
if (res.data && res.data.graph) { if (res.data && res.data.graph) {
const graph = res.data.graph; const graph = res.data.graph;
@@ -1306,11 +1383,12 @@ export default {
level: 'node', level: 'node',
name: this.selectedNode.name.replace('_unfold', ''), name: this.selectedNode.name.replace('_unfold', ''),
graph_name: this.graphFiles.value, graph_name: this.graphFiles.value,
rank_id: this.logicCard.value,
}; };
if (this.graphFiles.value === this.$t('debugger.all')) { if (this.graphFiles.value === this.$t('debugger.all')) {
delete params.graph_name; delete params.graph_name;
} }
RequestService.control(params).then(
RequestService.control(params, this.sessionId).then(
(res) => { (res) => {
if (res && res.data) { if (res && res.data) {
} }
@@ -1387,12 +1465,13 @@ export default {
node_type: type, node_type: type,
single_node: false, single_node: false,
graph_name: this.graphFiles.value, graph_name: this.graphFiles.value,
rank_id: this.logicCard.value,
}; };
if (this.graphFiles.value === this.$t('debugger.all')) { if (this.graphFiles.value === this.$t('debugger.all')) {
delete params.params.graph_name; delete params.params.graph_name;
} }
} }
RequestService.retrieve(params)
RequestService.retrieve(params, this.sessionId)
.then( .then(
(response) => { (response) => {
if (response && response.data && response.data.graph) { if (response && response.data && response.data.graph) {
@@ -1560,7 +1639,12 @@ export default {
graphName = key.split('/')[0]; graphName = key.split('/')[0];
key = key.replace(`${graphName}/`, ''); key = key.replace(`${graphName}/`, '');
} }
const obj = {name: key, IOType: 'output', graph_name: graphName};
const obj = {
name: key,
IOType: 'output',
graph_name: graphName,
rank_id: this.logicCard.value,
};
IOInfo.push(obj); IOInfo.push(obj);
this.selectedNode.outputNum++; this.selectedNode.outputNum++;
}); });
@@ -1572,7 +1656,12 @@ export default {
graphName = key.split('/')[0]; graphName = key.split('/')[0];
key = key.replace(`${graphName}/`, ''); key = key.replace(`${graphName}/`, '');
} }
const obj = {name: key, IOType: 'input', graph_name: graphName};
const obj = {
name: key,
IOType: 'input',
graph_name: graphName,
rank_id: this.logicCard.value,
};
IOInfo.push(obj); IOInfo.push(obj);
this.selectedNode.inputNum++; this.selectedNode.inputNum++;
}); });
@@ -1606,11 +1695,7 @@ export default {
`translate(${this.graph.transform.x},` + `${this.graph.transform.y}) scale(${this.graph.transform.k})`, `translate(${this.graph.transform.x},` + `${this.graph.transform.y}) scale(${this.graph.transform.k})`,
); );


const transitionTime = Math.min(
Math.abs(screenChange.x) * 2,
Math.abs(screenChange.y) * 2,
needDelay ? 800 : 0,
);
const transitionTime = Math.min(Math.abs(screenChange.x) * 2, Math.abs(screenChange.y) * 2, needDelay ? 800 : 0);


this.graph.dom.style.transition = `${transitionTime / 1000}s`; this.graph.dom.style.transition = `${transitionTime / 1000}s`;
this.graph.dom.style['transition-timing-function'] = 'linear'; this.graph.dom.style['transition-timing-function'] = 'linear';
@@ -1829,8 +1914,8 @@ export default {
height: calc(100% - 145px); height: calc(100% - 145px);
} }
.deb-wrap .left-wrap .left .content .node-type { .deb-wrap .left-wrap .left .content .node-type {
height: 50px;
padding: 15px 15px 0 15px;
height: 40px;
padding: 10px 15px 0 15px;
} }
.deb-wrap .left-wrap .left .content .node-type .label { .deb-wrap .left-wrap .left .content .node-type .label {
display: inline-block; display: inline-block;
@@ -1855,7 +1940,7 @@ export default {
font-size: 12px; font-size: 12px;
} }
.deb-wrap .left-wrap .left .content .tree-wrap { .deb-wrap .left-wrap .left .content .tree-wrap {
height: calc(70% - 155px);
height: calc(70% - 172px);
overflow-y: auto; overflow-y: auto;
padding: 0 15px 15px; padding: 0 15px 15px;
position: relative; position: relative;
@@ -1973,12 +2058,13 @@ export default {
color: red; color: red;
} }
.deb-wrap .left-wrap .left .content .hit-list-wrap { .deb-wrap .left-wrap .left .content .hit-list-wrap {
height: 100%;
height: calc(100% - 40px);
padding: 10px; padding: 10px;
} }
.deb-wrap .left-wrap .left .content .hit-list-wrap .watchpoint-table { .deb-wrap .left-wrap .left .content .hit-list-wrap .watchpoint-table {
max-height: calc(100% - 45px); max-height: calc(100% - 45px);
overflow: auto; overflow: auto;
margin-top: 10px;
} }
.deb-wrap .left-wrap .left .content .hit-list-wrap .el-table::before { .deb-wrap .left-wrap .left .content .hit-list-wrap .el-table::before {
height: 0; height: 0;
@@ -2096,7 +2182,7 @@ export default {
/* Opera */ /* Opera */
} }
.deb-wrap .right .header { .deb-wrap .right .header {
padding: 15px;
line-height: 51px;
border-bottom: 1px solid #ebeef5; border-bottom: 1px solid #ebeef5;
position: relative; position: relative;
background: #fff; background: #fff;
@@ -2113,6 +2199,25 @@ export default {
.deb-wrap .right .header .item + .item { .deb-wrap .right .header .item + .item {
margin-left: 15px; margin-left: 15px;
} }
.deb-wrap .right .header .el-icon-edit {
margin-left: 5px;
}
.deb-wrap .right .header i {
font-size: 18px;
margin: 0 2px;
color: #00a5a7;
cursor: pointer;
}
.deb-wrap .right .header .el-icon-close {
color: #f56c6c;
}
.deb-wrap .right .header .el-input {
width: 45px;
}
.deb-wrap .right .header .el-input input {
padding: 0;
text-align: center;
}
.deb-wrap .right .header .tooltip { .deb-wrap .right .header .tooltip {
margin-left: 5px; margin-left: 5px;
cursor: pointer; cursor: pointer;
@@ -2343,13 +2448,13 @@ export default {
display: none; display: none;
} }
.deb-wrap .creat-watch-point-dialog .conditions-container .collection { .deb-wrap .creat-watch-point-dialog .conditions-container .collection {
width: 200px;
width: 210px;
} }
.deb-wrap .creat-watch-point-dialog .conditions-container .condition, .deb-wrap .creat-watch-point-dialog .conditions-container .condition,
.deb-wrap .creat-watch-point-dialog .conditions-container .param, .deb-wrap .creat-watch-point-dialog .conditions-container .param,
.deb-wrap .creat-watch-point-dialog .conditions-container .param-value { .deb-wrap .creat-watch-point-dialog .conditions-container .param-value {
margin-left: 10px; margin-left: 10px;
width: 200px;
width: 210px;
} }
.deb-wrap .creat-watch-point-dialog .conditions-container .percent-sign { .deb-wrap .creat-watch-point-dialog .conditions-container .percent-sign {
display: inline-block; display: inline-block;


+ 166
- 4
mindinsight/ui/src/views/train-manage/summary-manage.vue View File

@@ -96,6 +96,16 @@ limitations under the License.
:title="$t('summaryManage.disableProfilerTip')"> :title="$t('summaryManage.disableProfilerTip')">
{{$t('summaryManage.viewProfiler')}} {{$t('summaryManage.viewProfiler')}}
</span> </span>
<span class="menu-item operate-btn"
v-if="scope.row.viewOfflineDebugger"
@contextmenu.prevent="rightClick(scope.row, $event, 2)"
@click.stop="goToOfflineDebugger(scope.row)">
{{$t('summaryManage.viewOfflineDebugger')}} </span>
<span class="menu-item operate-btn button-disable"
v-else
:title="$t('summaryManage.disableOfflineDebugger')">
{{$t('summaryManage.viewOfflineDebugger')}}
</span>
<span class="menu-item operate-btn" <span class="menu-item operate-btn"
v-if="scope.row.paramDetails" v-if="scope.row.paramDetails"
@click.stop="showModelDialog(scope.row)"> @click.stop="showModelDialog(scope.row)">
@@ -157,6 +167,45 @@ limitations under the License.
<li @click="doRightClick()">{{$t('summaryManage.openNewTab')}}</li> <li @click="doRightClick()">{{$t('summaryManage.openNewTab')}}</li>
</ul> </ul>
</div> </div>
<el-dialog :visible.sync="debuggerDialog.showDialogModel"
width="50%"
:close-on-click-modal="false"
class="details-data-list">
<span slot="title">
<span class="sessionMsg">{{ debuggerDialog.title }}</span>
<el-tooltip placement="right"
effect="light"
popper-class="legend-tip"
:content="$t('summaryManage.sessionLimitNum')">
<i class="el-icon-info"></i>
</el-tooltip>
</span>
<div class="session-title">{{ $t('summaryManage.sessionLists') }}</div>
<el-table :data="debuggerDialog.trainJobs">
<el-table-column width="50"
type=index
:label="$t('summaryManage.sorting')">
</el-table-column>
<el-table-column min-width="300"
prop="relative_path"
:label="$t('summaryManage.summaryPath')"
show-overflow-tooltip>
</el-table-column>
<!-- operate -->
<el-table-column prop="operate"
:label="$t('summaryManage.operation')"
class-name="operate-container">
<template slot-scope="scope">
<span class="menu-item operate-btn first-btn"
@click="deleteSession(scope.row.session_id)">
{{$t('public.delete')}} </span>
<span class="menu-item operate-btn first-btn"
@click="viewSession(scope.row)">
{{$t('debugger.view')}} </span>
</template>
</el-table-column>
</el-table>
</el-dialog>
</div> </div>
</template> </template>


@@ -223,7 +272,12 @@ export default {
type: 0, type: 0,
}, },
tableDom: null, tableDom: null,
operateWidth: localStorage.getItem('milang') === 'en-us' ? 400 : 290,
operateWidth: localStorage.getItem('milang') === 'en-us' ? 550 : 400,
debuggerDialog: {
title: this.$t('summaryManage.sessionLimit'),
showDialogModel: false,
trainJobs: [],
},
}; };
}, },
computed: {}, computed: {},
@@ -286,6 +340,7 @@ export default {
i.update_time = i.update_time ? i.update_time : '--'; i.update_time = i.update_time ? i.update_time : '--';
i.viewProfiler = i.profiler_dir && i.profiler_dir.length; i.viewProfiler = i.profiler_dir && i.profiler_dir.length;
i.viewDashboard = i.summary_files || i.graph_files || i.lineage_files; i.viewDashboard = i.summary_files || i.graph_files || i.lineage_files;
i.viewOfflineDebugger = i.dump_dir;
i.paramDetails = i.lineage_files; i.paramDetails = i.lineage_files;
}); });
this.currentFolder = res.data.name ? res.data.name : '--'; this.currentFolder = res.data.name ? res.data.name : '--';
@@ -363,7 +418,83 @@ export default {
}, },
}); });
}, },

/**
* go to Offline Debugger
* @param {Object} row select row
*/
goToOfflineDebugger(row) {
this.contextMenu.show = false;
const debuggerDir = row.dump_dir;
const params = {
session_type: 'OFFLINE',
dump_dir: debuggerDir,
};
this.getSessionId(params).then((value) => {
if (value !== undefined) {
this.$router.push({
path: '/offline-debugger',
query: {
dir: debuggerDir,
sessionId: value,
},
});
}
});
},
getSessionId(params) {
return RequestService.getSession(params).then(
(res) => {
if (res && res.data) {
const sessionId = res.data;
return sessionId;
}
},
(error) => {
if (error && error.response && error.response.data && error.response.data.error_code === '5054B280') {
this.checkSessions();
}
},
);
},
deleteSession(sessionId) {
this.$confirm(this.$t('summaryManage.deleteSessionConfirm'), this.$t('public.notice'), {
confirmButtonText: this.$t('public.sure'),
cancelButtonText: this.$t('public.cancel'),
type: 'warning',
}).then(() => {
RequestService.deleteSession(sessionId).then((res) => {
this.$message({
type: 'success',
message: this.$t('summaryManage.deleteSessionSuccess'),
});
this.checkSessions();
});
});
},
checkSessions() {
RequestService.checkSessions().then((res) => {
if (res && res.data && res.data.train_jobs) {
const trainJobs = res.data.train_jobs;
this.debuggerDialog.trainJobs = Object.keys(trainJobs).map((val) => {
return {
relative_path: decodeURIComponent(val),
session_id: trainJobs[val],
};
});
this.debuggerDialog.showDialogModel = true;
}
});
},
viewSession(row) {
const dir = row.relative_path;
this.$router.push({
path: '/offline-debugger',
query: {
dir,
sessionId: row.session_id,
},
});
},
rightClick(row, event, type) { rightClick(row, event, type) {
const maxWidth = 175; const maxWidth = 175;
this.contextMenu.data = row; this.contextMenu.data = row;
@@ -380,7 +511,28 @@ export default {
if (!row) { if (!row) {
return; return;
} }
if (this.contextMenu.type) {
if (this.contextMenu.type === 2) {
// open offline debugger
this.contextMenu.show = false;
const debuggerDir = row.dump_dir;
const params = {
session_type: 'OFFLINE',
dump_dir: debuggerDir,
};
this.getSessionId(params).then((value) => {
if (value !== undefined) {
const routeUrl = this.$router.resolve({
path: '/offline-debugger',
query: {
dir: debuggerDir,
sessionId: value,
},
});
window.open(routeUrl.href, '_blank');
}
});
} else if (this.contextMenu.type === 1) {
// open profiling
this.contextMenu.show = false; this.contextMenu.show = false;
const profilerDir = encodeURIComponent(row.profiler_dir); const profilerDir = encodeURIComponent(row.profiler_dir);
const trainId = encodeURIComponent(row.train_id); const trainId = encodeURIComponent(row.train_id);
@@ -400,7 +552,7 @@ export default {
}, },
}); });
window.open(routeUrl.href, '_blank'); window.open(routeUrl.href, '_blank');
} else {
} else { // open training dashboard
this.contextMenu.show = false; this.contextMenu.show = false;
const trainId = encodeURIComponent(row.train_id); const trainId = encodeURIComponent(row.train_id);


@@ -693,6 +845,16 @@ export default {
#cl-summary-manage .details-data-list .el-dialog__body .details-data-title { #cl-summary-manage .details-data-list .el-dialog__body .details-data-title {
margin-bottom: 20px; margin-bottom: 20px;
} }
#cl-summary-manage .details-data-list .sessionMsg {
color: #333;
font-weight: bold;
font-size: 16px;
margin-right: 5px;
}
#cl-summary-manage .details-data-list .session-title {
margin-bottom: 10px;
color: #333;
}
#cl-summary-manage .is-disabled.custom-btn { #cl-summary-manage .is-disabled.custom-btn {
background-color: #f5f5f6; background-color: #f5f5f6;
border: 1px solid #dfe1e6 !important; border: 1px solid #dfe1e6 !important;


mindinsight/backend/conditionmgr/__init__.py → mindinsight/utils/folder_analyzer.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -12,15 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Module init file."""
from mindinsight.backend.conditionmgr.conditionmgr_api import init_module as init_query_module
"""Train job register."""




def init_module(app):
"""
Init module entry.

Args:
app (Flask): A Flask instance.
"""
init_query_module(app)
class FolderAnalyzer:
"""Train job register. The subclass should implement the analyze method and return update info."""
def analyze(self, entry, summary_base_dir, relative_path):
"""Analyze file."""

+ 3
- 1
requirements.txt View File

@@ -18,4 +18,6 @@ six>=1.12.0
Werkzeug>=1.0.0 Werkzeug>=1.0.0
pandas>=1.0.4 pandas>=1.0.4
yapf>=0.30.0 yapf>=0.30.0
grpcio>=1.27.3
treelib>=1.6.1
grpcio>=1.27.3
XlsxWriter>=1.2.9

+ 7
- 6
tests/st/func/debugger/conftest.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -25,13 +25,15 @@ from mindinsight.conf import settings
from mindinsight.datavisual.utils import tools from mindinsight.datavisual.utils import tools
from mindinsight.debugger.proto import ms_graph_pb2 from mindinsight.debugger.proto import ms_graph_pb2
from mindinsight.debugger.stream_handler.graph_handler import GraphHandler from mindinsight.debugger.stream_handler.graph_handler import GraphHandler
from mindinsight.debugger.session_manager import SessionManager


GRAPH_PROTO_FILE = os.path.join( GRAPH_PROTO_FILE = os.path.join(
os.path.dirname(__file__), '../../../utils/resource/graph_pb/lenet.pb' os.path.dirname(__file__), '../../../utils/resource/graph_pb/lenet.pb'
) )


DEBUGGER_BASE_URL = '/v1/mindinsight/debugger'
DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expect_results')
DEBUGGER_BASE_URL = '/v1/mindinsight/debugger/sessions/0/'
DEBUGGER_TEST_BASE_DIR = os.path.dirname(__file__)
DEBUGGER_EXPECTED_RESULTS = os.path.join(DEBUGGER_TEST_BASE_DIR, 'expect_results')




def init_graph_handler(): def init_graph_handler():
@@ -51,14 +53,13 @@ def init_graph_handler():
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
def app_client(): def app_client():
"""This fixture is flask server.""" """This fixture is flask server."""
packages = ["mindinsight.backend.debugger", "mindinsight.backend.conditionmgr"]
packages = ["mindinsight.backend.debugger"]
settings.ENABLE_DEBUGGER = True settings.ENABLE_DEBUGGER = True


mock_obj = Mock(return_value=packages) mock_obj = Mock(return_value=packages)
tools.find_app_package = mock_obj tools.find_app_package = mock_obj


from mindinsight.backend.application import APP from mindinsight.backend.application import APP
from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER
APP.response_class = Response APP.response_class = Response
client = APP.test_client() client = APP.test_client()
original_val = settings.ENABLE_RECOMMENDED_WATCHPOINTS original_val = settings.ENABLE_RECOMMENDED_WATCHPOINTS
@@ -67,4 +68,4 @@ def app_client():
yield client yield client
finally: finally:
settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_val settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_val
BACKEND_SERVER.stop()
SessionManager.get_instance().online_session.stop()

+ 20
- 0
tests/st/func/debugger/debugger_services/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test debugger services.
Usage:
pytest tests/st/func/debugger
"""

+ 141
- 0
tests/st/func/debugger/debugger_services/mock_dbg_services.py View File

@@ -0,0 +1,141 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
The module DbgServices provides offline debugger APIs.
"""
from unittest.mock import MagicMock

import numpy as np

import mindinsight
from mindinsight.debugger.common.log import LOGGER as log


def get_version():
"""Get version."""
return mindinsight.__version__


class DbgServices:
"""
DbgServices.

Args:
dump_file_path (str): dir where the dump files are saved.
"""
def __init__(self, dump_file_path, verbose=True):
self._verbose = verbose
self.dump_file_path = dump_file_path
self.dbg_instance = MagicMock()
self._watchpoints = {}
self.print_mes("in Python __init__, file path is {}".format(dump_file_path))
self.version = get_version()
self.initialized = False
self.is_sync = True

def print_mes(self, mes):
"""Print message."""
if self._verbose:
log.info(mes)

def initialize(self, net_name, is_sync_mode):
"""Initialize."""
self.print_mes(" Python Initialize dump_file_path: {}, is_sync: {}".format(net_name, is_sync_mode))
self.initialized = True

def add_watchpoint(self, watchpoint_id, watch_condition, check_node_list, parameter_list):
"""Add watchpoint."""
self.print_mes("Add watchpoint with watchpoint id: {}".format(watchpoint_id))
self._watchpoints[watchpoint_id] = {'watch_condition': watch_condition,
'check_nodes': check_node_list,
'parameter_list': parameter_list}
return 0

def remove_watchpoint(self, watchpoint_id):
"""Remove watchpoints."""
self.print_mes("Remove watchpoint with watchpoint id: {}".format(watchpoint_id))
return self._watchpoints.pop(watchpoint_id)

def check_watchpoints(self, iteration):
"""Check watchpoints."""
self.print_mes("Check watchpoint at iteration: {}".format(iteration))
watch_hits = []
for watchpoint_id, watchpoint in self._watchpoints.items():
# add param hit info
for param in watchpoint.get('parameter_list'):
param.hit = True
param.value = 0.0
for watch_node_name, node_info in watchpoint.get('check_nodes'):
for device_id in node_info.get('device_id'):
hit = WatchpointHit(watch_node_name,
0,
watchpoint.get('watch_condition'),
watchpoint_id,
watchpoint.get('parameter_list'),
0,
device_id)
watch_hits.append(hit)

return watch_hits

def read_tensors(self, info):
"""Read tensor values."""
value = np.asarray(list(range(12)), dtype=np.int32).tobytes()
info_list_inst = []
for _ in range(info):
tensor_data = TensorData(value, len(value), 4, [2, 2, 3])
info_list_inst.append(tensor_data)
return info_list_inst


class TensorInfo:
"""Tensor Information."""
def __init__(self, node_name, slot, iteration, device_id, is_parameter):
self.node_name = node_name
self.slot = slot
self.iteration = iteration
self.device_id = device_id
self.is_parameter = is_parameter


class TensorData:
"""Tensor data structure."""
def __init__(self, data_ptr, data_size, dtype, shape):
self.data_ptr = data_ptr
self.data_size = data_size
self.dtype = dtype
self.shape = shape


class Parameter:
"""Parameter structure."""
def __init__(self, name, disabled, value, hit=False, actual_value=0.0):
self.name = name
self.disabled = disabled
self.value = value
self.hit = hit
self.actual_value = actual_value


class WatchpointHit:
"""Watchpoint hit structure."""
def __init__(self, name, slot, condition, watchpoint_id, parameters, error_code, device_id):
self.name = name
self.slot = slot
self.condition = condition
self.watchpoint_id = watchpoint_id
self.parameters = parameters
self.error_code = error_code
self.device_id = device_id

+ 77
- 0
tests/st/func/debugger/debugger_services/test_debugger_services.py View File

@@ -0,0 +1,77 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test query debugger services.
Usage:
pytest tests/st/func/debugger
"""

import os
import shutil
from unittest import mock

import pytest

from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_services.debugger_server_factory import \
DebuggerServerFactory, DebuggerServerContext
from tests.st.func.debugger.utils import build_dump_file_structure
from tests.st.func.debugger.debugger_services import mock_dbg_services


class TestDebuggerServerFactory:
"""Test debugger on Ascend backend."""

@classmethod
def setup_class(cls):
"""Setup class."""
cls.debugger_tmp_dir, cls.dump_files_dir = build_dump_file_structure()
cls._dbg_dir = os.path.join(cls.dump_files_dir, 'Ascend/sync')
cls._dbg_server_factory = DebuggerServerFactory()

@classmethod
def teardown_class(cls):
"""Run after test this class."""
if os.path.exists(cls.debugger_tmp_dir):
shutil.rmtree(cls.debugger_tmp_dir)

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_get_dbg_online_server(self):
"""Get debugger online server"""
context = DebuggerServerContext(dbg_mode='online')
server_obj = self._dbg_server_factory.get_debugger_server(DebuggerCache(), context)
server_obj.start()
server_obj.stop()

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@mock.patch('mindinsight.debugger.debugger_services.debugger_offline_server.import_module')
def test_get_dbg_offline_server(self, mock_import):
"""Get debugger offline server"""
mock_import.return_value = mock_dbg_services
context = DebuggerServerContext(dbg_mode='offline', dbg_dir=self._dbg_dir)
server_obj = self._dbg_server_factory.get_debugger_server(DebuggerCache(), context)
server_obj.start()
server_obj.stop()

+ 15
- 0
tests/st/func/debugger/dump_files/Ascend/async/.metadata/data_dump.json View File

@@ -0,0 +1,15 @@
{
"common_dump_settings": {
"dump_mode": 0,
"path": "/absolute_path",
"net_name": "Lenet",
"iteration": 0,
"input_output": 0,
"kernels": ["Default/Conv-op12"],
"support_device": [0,1,2,3,4,5,6,7]
},
"async_dump_settings": {
"enable": true,
"op_debug_mode": 0
}
}

+ 23
- 0
tests/st/func/debugger/dump_files/Ascend/async/.metadata/hccl.json View File

@@ -0,0 +1,23 @@
{
"version": "1.0",
"server_count": "1",
"server_list": [
{
"server_id": "0.0.0.0",
"device": [
{
"device_id": "0",
"device_ip": "0.0.0.1",
"rank_id": "0"
},
{
"device_id": "1",
"device_ip": "0.0.0.2",
"rank_id": "1"
}
],
"host_nic_ip": "reserve"
}
],
"status": "completed"
}

+ 15
- 0
tests/st/func/debugger/dump_files/Ascend/sync/.metadata/data_dump.json View File

@@ -0,0 +1,15 @@
{
"common_dump_settings": {
"dump_mode": 0,
"path": "/absolute_path",
"net_name": "Lenet",
"iteration": 0,
"input_output": 0,
"kernels": ["Default/Conv-op12"],
"support_device": [0,1,2,3,4,5,6,7]
},
"e2e_dump_settings": {
"enable": true,
"trans_flag": false
}
}

+ 23
- 0
tests/st/func/debugger/dump_files/Ascend/sync/.metadata/hccl.json View File

@@ -0,0 +1,23 @@
{
"version": "1.0",
"server_count": "1",
"server_list": [
{
"server_id": "0.0.0.0",
"device": [
{
"device_id": "0",
"device_ip": "0.0.0.1",
"rank_id": "0"
},
{
"device_id": "1",
"device_ip": "0.0.0.2",
"rank_id": "1"
}
],
"host_nic_ip": "reserve"
}
],
"status": "completed"
}

+ 15
- 0
tests/st/func/debugger/dump_files/GPU/sync/.metadata/data_dump.json View File

@@ -0,0 +1,15 @@
{
"common_dump_settings": {
"dump_mode": 0,
"path": "/absolute_path",
"net_name": "Lenet",
"iteration": 0,
"input_output": 0,
"kernels": ["Default/Conv-op12"],
"support_device": [0,1,2,3,4,5,6,7]
},
"e2e_dump_settings": {
"enable": true,
"trans_flag": false
}
}

+ 21
- 0
tests/st/func/debugger/expect_results/offline_debugger/load_device_info_ascend.json View File

@@ -0,0 +1,21 @@
{
"device_target": "Ascend",
"server_list": [
{
"server_id": "0.0.0.0",
"device": [
{
"device_id": "0",
"device_ip": "0.0.0.1",
"rank_id": "0"
},
{
"device_id": "1",
"device_ip": "0.0.0.2",
"rank_id": "1"
}
],
"host_nic_ip": "reserve"
}
]
}

+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/multi_next_node.json View File

@@ -1 +1 @@
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []}
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0", "backend": "GPU", "enable_recheck": false, "graph_name": "graph_1", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0", "graph_1"]}], "watch_points": []}

+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/multi_retrieve_all.json View File

@@ -1 +1 @@
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.2.0"}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "watch_points": []}
{"metadata": {"state": "waiting", "step": 1, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {"graph_names": ["graph_0", "graph_1"], "nodes": [{"name": "graph_0", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}, {"name": "graph_1", "type": "name_scope", "attr": {}, "input": {}, "output": {}, "output_i": 0, "proxy_input": {}, "proxy_output": {}, "subnode_count": 2, "independent_layout": false}]}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": ["graph_0", "graph_1"]}], "watch_points": []}

+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/retrieve_all.json
File diff suppressed because it is too large
View File


+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/retrieve_next_node_on_gpu.json
File diff suppressed because it is too large
View File


+ 1
- 29
tests/st/func/debugger/expect_results/restful_results/retrieve_tensor_value.json View File

@@ -1,29 +1 @@
{
"tensor_value": {
"full_name": "Default/TransData-op99:0",
"step": 1,
"dtype": "DT_FLOAT32",
"shape": [
2,
3
],
"has_prev_step": false,
"statistics": {
"overall_max": 6.0,
"overall_min": 1.0,
"overall_avg": 3.5,
"overall_count": 6,
"overall_nan_count": 0,
"overall_neg_inf_count": 0,
"overall_pos_inf_count": 0,
"overall_zero_count": 0.0,
"overall_neg_zero_count": 0.0,
"overall_pos_zero_count": 6.0
},
"value": [
5.0,
6.0
],
"name": "Default/TransData-op99:0"
}
}
{"tensor_value": {"full_name": "Default/TransData-op99:0", "step": 1, "dtype": "DT_FLOAT32", "shape": [2, 3], "has_prev_step": false, "value": [5.0, 6.0], "statistics": {"overall_max": 6.0, "overall_min": 1.0, "overall_avg": 3.5, "overall_count": 6, "overall_nan_count": 0, "overall_neg_inf_count": 0, "overall_pos_inf_count": 0, "overall_zero_count": 0.0, "overall_neg_zero_count": 0.0, "overall_pos_zero_count": 6.0}, "name": "Default/TransData-op99:0"}}

+ 1
- 1
tests/st/func/debugger/expect_results/restful_results/version_mismatch.json View File

@@ -1 +1 @@
{"metadata": {"state": "mismatch", "step": 0, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {"ms": "1.0.0"}}, "graph": {}, "watch_points": []}
{"metadata": {"state": "mismatch", "step": 0, "device_name": "0", "node_name": "", "backend": "Ascend", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "devices": [{"rank_id": 0, "device_id": "0", "graph_names": []}], "watch_points": []}

+ 149
- 0
tests/st/func/debugger/test_data_loader.py View File

@@ -0,0 +1,149 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test DataLoader of offline debugger."""
import os
import shutil

import pytest

from mindinsight.debugger.stream_cache.data_loader import DataLoader
from tests.st.func.debugger.conftest import GRAPH_PROTO_FILE
from tests.st.func.debugger.utils import build_dump_file_structure
from tests.utils.tools import compare_result_with_file, compare_result_with_binary_file


class TestDataLoader:
"""Test DataLoader."""

@classmethod
def setup_class(cls):
"""Init TestDataLoader for DataLoader unittest."""
cls.debugger_tmp_dir, cls.dump_files_dir = build_dump_file_structure()
cls.expected_results_dir = os.path.join(os.path.dirname(__file__),
'expect_results/offline_debugger')
cls.dump_files_dir_ascend = os.path.join(cls.dump_files_dir,
'Ascend/sync')
cls.data_loader_ascend = DataLoader(cls.dump_files_dir_ascend)
cls.data_loader_ascend.initialize()
cls.dump_files_dir_gpu = os.path.join(cls.dump_files_dir,
'GPU/sync')
cls.data_loader_gpu = DataLoader(cls.dump_files_dir_gpu)
cls.data_loader_gpu.initialize()
cls.dump_files_dir_ascend_async = os.path.join(cls.dump_files_dir,
'Ascend/async')
cls.data_loader_ascend_async = DataLoader(cls.dump_files_dir_ascend_async)
cls.data_loader_ascend_async.initialize()

@classmethod
def teardown_class(cls):
"""Run after test this class."""
if os.path.exists(cls.debugger_tmp_dir):
shutil.rmtree(cls.debugger_tmp_dir)

@pytest.mark.level
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_load_graphs_ascend(self):
"""Test load_graphs function of offline-debugger."""
res = self.data_loader_ascend.load_graphs()
expected_result0 = GRAPH_PROTO_FILE
res0 = res[0]['graph_protos'][0].SerializeToString()
compare_result_with_binary_file(res0, expected_result0)

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_load_device_info_ascend(self):
"""Test load_device_info of ascend chip for offline-debugger."""
res = self.data_loader_ascend.load_device_info()
expected_result = os.path.join(self.expected_results_dir, 'load_device_info_ascend.json')
compare_result_with_file(res, expected_result)

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_load_step_num_ascend(self):
"""Test load_step_num of ascend chip for offline-debugger."""
res = self.data_loader_ascend.load_step_number()
expected_result = {"0": 4, "1": 4}
assert res == expected_result

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_get_net_name_ascend(self):
"""Test get_net_name of ascend chip for offline-debugger."""
res = self.data_loader_ascend.get_net_name()
assert res == 'Lenet'

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_get_sync_flag(self):
"""Test get_sync_flag of ascend chip for offline-debugger."""
res = self.data_loader_ascend.get_sync_flag()
assert res

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_load_graphs_gpu(self):
"""Test load_graphs function of offline-debugger."""
res = self.data_loader_gpu.load_graphs()
expected_result0 = GRAPH_PROTO_FILE
res0 = res[0]['graph_protos'][0].SerializeToString()
compare_result_with_binary_file(res0, expected_result0)

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_load_step_num_gpu(self):
"""Test load_step_num of ascend chip for offline-debugger."""
res = self.data_loader_gpu.load_step_number()
expected_result = {"0": 3, "1": 3}
assert res == expected_result

@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
def test_load_step_num_ascend_async(self):
"""Test load_step_num of ascend chip for offline-debugger."""
res = self.data_loader_ascend_async.load_step_number()
expected_result = {"0": 3, "1": 3}
assert res == expected_result

+ 6
- 24
tests/st/func/debugger/test_restful_api.py View File

@@ -84,7 +84,7 @@ class TestAscendDebugger:


def test_get_conditions(self, app_client): def test_get_conditions(self, app_client):
"""Test get conditions for ascend.""" """Test get conditions for ascend."""
url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections'
url = '/v1/mindinsight/debugger/sessions/0/condition-collections'
body_data = {} body_data = {}
expect_file = 'get_conditions_for_ascend.json' expect_file = 'get_conditions_for_ascend.json'
with self._debugger_client.get_thread_instance(): with self._debugger_client.get_thread_instance():
@@ -191,7 +191,7 @@ class TestAscendDebugger:
check_state(app_client) check_state(app_client)
# prepare tensor value # prepare tensor value
url = 'tensor-history' url = 'tensor-history'
body_data = {'name': node_name}
body_data = {'name': node_name, 'rank_id': 0}
expect_file = 'retrieve_empty_tensor_history.json' expect_file = 'retrieve_empty_tensor_history.json'
send_and_compare_result(app_client, url, body_data, expect_file) send_and_compare_result(app_client, url, body_data, expect_file)
# check full tensor history from poll data # check full tensor history from poll data
@@ -229,7 +229,7 @@ class TestAscendDebugger:
get_request_result(app_client, url, body_data) get_request_result(app_client, url, body_data)
check_state(app_client) check_state(app_client)
get_request_result( get_request_result(
app_client=app_client, url='tensor-history', body_data={'name': node_name})
app_client=app_client, url='tensor-history', body_data={'name': node_name, 'rank_id': 0})
res = get_request_result( res = get_request_result(
app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get') app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
assert res.get('receive_tensor', {}).get('node_name') == node_name assert res.get('receive_tensor', {}).get('node_name') == node_name
@@ -239,30 +239,12 @@ class TestAscendDebugger:
'name': node_name + ':0', 'name': node_name + ':0',
'detail': 'data', 'detail': 'data',
'shape': quote('[:, :]'), 'shape': quote('[:, :]'),
'tolerance': 1
}
'tolerance': 1,
'rank_id': 0}
expect_file = 'compare_tensors.json' expect_file = 'compare_tensors.json'
send_and_compare_result(app_client, url, body_data, expect_file, method='get') send_and_compare_result(app_client, url, body_data, expect_file, method='get')
send_terminate_cmd(app_client) send_terminate_cmd(app_client)


@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.parametrize("body_data, expect_file", [
({'ascend': True}, 'retrieve_node_by_bfs_ascend.json'),
({'name': 'Default/args0', 'ascend': False}, 'retrieve_node_by_bfs.json')
])
def test_retrieve_bfs_node(self, app_client, body_data, expect_file):
"""Test retrieve bfs node."""
with self._debugger_client.get_thread_instance():
check_state(app_client)
# prepare tensor values
url = 'retrieve_node_by_bfs'
send_and_compare_result(app_client, url, body_data, expect_file, method='get')
send_terminate_cmd(app_client)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.env_single @pytest.mark.env_single
@@ -441,7 +423,7 @@ class TestGPUDebugger:


def test_get_conditions(self, app_client): def test_get_conditions(self, app_client):
"""Test get conditions for gpu.""" """Test get conditions for gpu."""
url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections'
url = '/v1/mindinsight/debugger/sessions/0/condition-collections'
body_data = {} body_data = {}
expect_file = 'get_conditions_for_gpu.json' expect_file = 'get_conditions_for_gpu.json'
with self._debugger_client.get_thread_instance(): with self._debugger_client.get_thread_instance():


+ 58
- 9
tests/st/func/debugger/utils.py View File

@@ -16,8 +16,10 @@
import json import json
import os import os
import time import time

from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL
import shutil
import tempfile
from mindinsight.debugger.proto import ms_graph_pb2
from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL, GRAPH_PROTO_FILE
from tests.utils.tools import compare_result_with_file, get_url from tests.utils.tools import compare_result_with_file, get_url




@@ -74,10 +76,57 @@ def send_and_save_result(app_client, url, body_data, file_path, method='post'):


def delete_random_items(res): def delete_random_items(res):
"""delete the random items in metadata.""" """delete the random items in metadata."""
if isinstance(res, dict) and res.get('metadata'):
if res['metadata'].get('ip'):
res['metadata'].pop('ip')
if res['metadata'].get('pos'):
res['metadata'].pop('pos')
if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'):
res['metadata']['debugger_version'].pop('mi')
if isinstance(res, dict):
if res.get('metadata'):
if res['metadata'].get('ip'):
res['metadata'].pop('ip')
if res['metadata'].get('pos'):
res['metadata'].pop('pos')
if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'):
res['metadata']['debugger_version'].pop('mi')
res['metadata']['debugger_version'].pop('ms')
if res.get('devices'):
for device in res.get('devices'):
if device.get('server_ip'):
device.pop('server_ip')


def build_dump_file_structure():
"""Build the dump file structure."""
async_file_structure = {
"Ascend/async/device_0/Lenet_graph_1/1": 3,
"Ascend/async/device_1/Lenet_graph_1/1": 3
}

sync_file_structure = {
"Ascend/sync/Lenet/device_0": 4,
"Ascend/sync/Lenet/device_1": 4,
"GPU/sync/Lenet/device_0": 3,
"GPU/sync/Lenet/device_1": 3
}

debugger_tmp_dir = tempfile.mkdtemp(suffix='debugger_tmp')
dump_files_dir = os.path.join(debugger_tmp_dir, 'dump_files')
shutil.copytree(os.path.join(os.path.dirname(__file__), 'dump_files'), dump_files_dir)

for sub_dir, steps in async_file_structure.items():
for step in range(1, steps + 1):
os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), str(step)), exist_ok=True)

for sub_dir, steps in sync_file_structure.items():
for step in range(1, steps + 1):
os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), 'iteration_' + str(step)),
exist_ok=True)
graph_dir_path = os.path.join(os.path.join(dump_files_dir, sub_dir), 'graphs')
os.makedirs(graph_dir_path, exist_ok=True)
graph_path = os.path.join(graph_dir_path, 'ms_output_trace_code_graph_0.pb')
with open(GRAPH_PROTO_FILE, 'rb') as file_handler:
content = file_handler.read()

model = ms_graph_pb2.ModelProto()
model.graph.ParseFromString(content)
model_str = model.SerializeToString()
with open(graph_path, 'wb') as file_handler:
file_handler.write(model_str)

return debugger_tmp_dir, dump_files_dir

+ 1
- 1
tests/ut/debugger/expected_results/debugger_server/retrieve_all.json View File

@@ -1 +1 @@
{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "watch_points": []}
{"metadata": {"state": "waiting", "step": 0, "device_name": "", "pos": "0", "ip": "", "node_name": "", "backend": "", "enable_recheck": false, "graph_name": "", "recommendation_confirmed": false, "debugger_version": {}}, "graph": {}, "devices": [{"rank_id": 0, "server_ip": "", "device_id": "", "graph_names": []}], "watch_points": []}

+ 78
- 3
tests/ut/debugger/expected_results/watchpoint/watchpoint_handler_get_0.json View File

@@ -1,5 +1,80 @@
[ [
{"watchCondition": {"condition": "tensor_too_small", "value": 1.0, "params": [{"name": "abs_mean_lt", "disabled": true}, {"name": "max_lt", "value": 1.0}, {"name": "min_lt", "disabled": true}, {"name": "mean_lt", "disabled": true}]}, "id": 1, "watch_nodes_num": 0},
{"watchCondition": {"condition": "tensor_too_small", "value": 1.0, "params": [{"name": "abs_mean_lt", "disabled": true}, {"name": "max_lt", "disabled": true}, {"name": "min_lt", "value": 1.0}, {"name": "mean_lt", "disabled": true}]}, "id": 2, "watch_nodes_num": 172},
{"watchCondition": {"condition": "tensor_too_large", "value": 1.0, "params": [{"name": "abs_mean_gt", "disabled": true}, {"name": "max_gt", "value": 1.0}, {"name": "min_gt", "disabled": true}, {"name": "mean_gt", "disabled": true}]}, "id": 3, "watch_nodes_num": 1}
{
"watchCondition": {
"condition": "tensor_too_small",
"value": 1.0,
"params": [
{
"name": "abs_mean_lt",
"disabled": true
},
{
"name": "max_lt",
"value": 1.0
},
{
"name": "min_lt",
"disabled": true
},
{
"name": "mean_lt",
"disabled": true
}
]
},
"id": 1,
"watch_nodes_num": 0
},
{
"watchCondition": {
"condition": "tensor_too_small",
"value": 1.0,
"params": [
{
"name": "abs_mean_lt",
"disabled": true
},
{
"name": "max_lt",
"disabled": true
},
{
"name": "min_lt",
"value": 1.0
},
{
"name": "mean_lt",
"disabled": true
}
]
},
"id": 2,
"watch_nodes_num": 142
},
{
"watchCondition": {
"condition": "tensor_too_large",
"value": 1.0,
"params": [
{
"name": "abs_mean_gt",
"disabled": true
},
{
"name": "max_gt",
"value": 1.0
},
{
"name": "min_gt",
"disabled": true
},
{
"name": "mean_gt",
"disabled": true
}
]
},
"id": 3,
"watch_nodes_num": 1
}
] ]

+ 0
- 14
tests/ut/debugger/stream_handler/test_graph_handler.py View File

@@ -111,20 +111,6 @@ class TestGraphHandler:
node_name = self.graph_handler.get_node_name_by_full_name(full_name, 'kernel_graph_0') node_name = self.graph_handler.get_node_name_by_full_name(full_name, 'kernel_graph_0')
assert node_name == expect_node_name assert node_name == expect_node_name


@pytest.mark.parametrize("node_name, ascend, expect_next", [
(None, True,
"Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0"),
(None, False, None),
("Default/tuple_getitem[10]_0/tuple_getitem-op206", True,
"Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89"),
("Default/tuple_getitem[10]_0/tuple_getitem-op206", False,
"Default/network-WithLossCell/_backbone-LeNet5/max_pool2d-MaxPool2d/Cast-op205")
])
def test_get_node_by_bfs_order(self, node_name, ascend, expect_next):
"""Test get node by BFS order."""
next_node = self.graph_handler.get_node_by_bfs_order(node_name, ascend)
assert next_node == expect_next

@pytest.mark.parametrize("tensor_name, expect_file", [ @pytest.mark.parametrize("tensor_name, expect_file", [
("Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0:0", "get_tensor_graph-0.json"), ("Default/network-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/OneHot-op0:0", "get_tensor_graph-0.json"),
("Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89:1", "get_tensor_graph-1.json"), ("Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/ReLUV2-op89:1", "get_tensor_graph-1.json"),


+ 2
- 2
tests/ut/debugger/stream_handler/test_tensor_handler.py View File

@@ -40,7 +40,7 @@ class TestTensorHandler:


def test_get_tensor_value_by_name_none(self): def test_get_tensor_value_by_name_none(self):
"""Test get_tensor_value_by_name.""" """Test get_tensor_value_by_name."""
res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', True)
res = self.tensor_handler.get_valid_tensor_by_name('tensor_name', step=0, prev=True)
assert res is None assert res is None


@mock.patch.object(log, "error") @mock.patch.object(log, "error")
@@ -49,5 +49,5 @@ class TestTensorHandler:
"""Test get_tensors_diff.""" """Test get_tensors_diff."""
mock_error.return_value = None mock_error.return_value = None
with pytest.raises(DebuggerParamValueError) as ex: with pytest.raises(DebuggerParamValueError) as ex:
self.tensor_handler.get_tensors_diff(tensor_name, {1, 1})
self.tensor_handler.get_tensors_diff(tensor_name, {1, 1}, step=0)
assert f"Get current step and previous step for this tensor name {tensor_name} failed." in str(ex.value) assert f"Get current step and previous step for this tensor name {tensor_name} failed." in str(ex.value)

+ 5
- 2
tests/ut/debugger/stream_handler/test_watchpoint_handler.py View File

@@ -30,6 +30,7 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue
DebuggerParamTypeError DebuggerParamTypeError
from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.stream_cache.watchpoint import Watchpoint from mindinsight.debugger.stream_cache.watchpoint import Watchpoint
from mindinsight.debugger.stream_handler import MultiCardGraphHandler
from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHandler, \ from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHandler, \
WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params WatchpointHitHandler, validate_watch_condition, validate_watch_condition_params
from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \ from tests.ut.debugger.configurations import init_graph_handler, mock_tensor_proto, \
@@ -48,7 +49,9 @@ class TestWatchpointHandler:
'../expected_results/watchpoint') '../expected_results/watchpoint')
cls.graph_results_dir = os.path.join(os.path.dirname(__file__), cls.graph_results_dir = os.path.join(os.path.dirname(__file__),
'../expected_results/graph') '../expected_results/graph')
cls.multi_graph_stream = MultiCardGraphHandler()
cls.graph_stream = init_graph_handler() cls.graph_stream = init_graph_handler()
cls.multi_graph_stream.register_graph_handler(0, cls.graph_stream)
cls.conditionmgr = None cls.conditionmgr = None
cls.handler = None cls.handler = None


@@ -69,7 +72,7 @@ class TestWatchpointHandler:
] ]
for watch_condition, watch_nodes, watch_point_id, expect_new_id in watchpoints: for watch_condition, watch_nodes, watch_point_id, expect_new_id in watchpoints:
watch_nodes = get_node_basic_infos(watch_nodes) watch_nodes = get_node_basic_infos(watch_nodes)
watch_point_id = self.handler.create_watchpoint(self.conditionmgr, watch_condition, watch_nodes,
watch_point_id = self.handler.create_watchpoint(self.conditionmgr, watch_condition, {0: watch_nodes},
watch_point_id) watch_point_id)
assert watch_point_id == expect_new_id assert watch_point_id == expect_new_id


@@ -105,7 +108,7 @@ class TestWatchpointHandler:
file_path = os.path.join(self.results_dir, result_file) file_path = os.path.join(self.results_dir, result_file)
with open(file_path, 'r') as file_handler: with open(file_path, 'r') as file_handler:
contents = json.load(file_handler) contents = json.load(file_handler)
protos = self.handler.get_pending_commands(self.graph_stream)
protos = self.handler.get_pending_commands(self.multi_graph_stream)
for proto in protos: for proto in protos:
msg_dict = json_format.MessageToDict(proto) msg_dict = json_format.MessageToDict(proto)
msg_dict['watch_nodes_num'] = len(msg_dict.pop('watchNodes', [])) msg_dict['watch_nodes_num'] = len(msg_dict.pop('watchNodes', []))


+ 11
- 1
tests/ut/debugger/stream_operator/test_training_control_operator.py View File

@@ -48,7 +48,8 @@ class TestTrainingControlOperator:
"""Test validate leaf name.""" """Test validate leaf name."""
args[0].return_value = 'name_scope' args[0].return_value = 'name_scope'
with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'): with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'):
self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name')
self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name',
rank_id=0)


@pytest.mark.parametrize('mode, cur_state, state', [ @pytest.mark.parametrize('mode, cur_state, state', [
('continue', 'waiting', 'sending'), ('continue', 'waiting', 'sending'),
@@ -64,3 +65,12 @@ class TestTrainingControlOperator:
"""Test construct run event.""" """Test construct run event."""
res = self._server._construct_run_event({'level': 'node'}) res = self._server._construct_run_event({'level': 'node'})
assert res.run_cmd == RunCMD(run_level='node', node_name='') assert res.run_cmd == RunCMD(run_level='node', node_name='')

@pytest.mark.parametrize('mode, state', [
('reset', 'waiting')])
def test_control_reset_step(self, mode, state):
"""Test control request, in 'reset' mode."""
with mock.patch.object(MetadataHandler, 'max_step_num', 10), \
mock.patch.object(MetadataHandler, 'debugger_type', 'offline'):
res = self._server.control(mode=mode, params={'steps': 9})
assert res == {'metadata': {'enable_recheck': False, 'state': state, 'step': 9}}

+ 2
- 2
tests/ut/debugger/test_debugger_grpc_server.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -26,7 +26,7 @@ import numpy as np
from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus
from mindinsight.debugger.debugger_cache import DebuggerCache from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.debugger_services.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply, SetCMD, Chunk, WatchpointHit
from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType from mindinsight.debugger.proto.ms_graph_pb2 import TensorProto, DataType
from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \ from mindinsight.debugger.stream_handler import WatchpointHitHandler, GraphHandler, \


+ 6
- 7
tests/ut/debugger/test_debugger_server.py View File

@@ -30,11 +30,11 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue
DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError DebuggerCompareTensorError, DebuggerCreateWatchPointError, DebuggerDeleteWatchPointError
from mindinsight.debugger.common.utils import Streams from mindinsight.debugger.common.utils import Streams
from mindinsight.debugger.debugger_cache import DebuggerCache from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_server import DebuggerServer
from mindinsight.debugger.debugger_server import grpc_server_base
from mindinsight.debugger.stream_operator import watchpoint_operator
from mindinsight.debugger.debugger_services.debugger_server_factory import DebuggerServerContext
from mindinsight.debugger.debugger_session import DebuggerSession as DebuggerServer
from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \ from mindinsight.debugger.stream_handler import GraphHandler, WatchpointHandler, MetadataHandler, \
TensorHandler TensorHandler
from mindinsight.debugger.stream_operator import watchpoint_operator
from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history from tests.ut.debugger.configurations import compare_debugger_result_with_file, mock_tensor_history




@@ -48,12 +48,12 @@ class TestDebuggerServer:


def setup_method(self): def setup_method(self):
"""Prepare debugger server object.""" """Prepare debugger server object."""
self._server = DebuggerServer()
context = DebuggerServerContext(dbg_mode='online')
self._server = DebuggerServer(context)


@mock.patch.object(signal, 'signal') @mock.patch.object(signal, 'signal')
@mock.patch.object(Thread, 'join') @mock.patch.object(Thread, 'join')
@mock.patch.object(Thread, 'start') @mock.patch.object(Thread, 'start')
@mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server')
@mock.patch.object(grpc, 'server') @mock.patch.object(grpc, 'server')
def test_stop_server(self, *args): def test_stop_server(self, *args):
"""Test stop debugger server.""" """Test stop debugger server."""
@@ -62,7 +62,6 @@ class TestDebuggerServer:
self._server.start() self._server.start()
self._server._stop_handler(MagicMock(), MagicMock()) self._server._stop_handler(MagicMock(), MagicMock())
assert self._server.back_server is not None assert self._server.back_server is not None
assert self._server.grpc_server_manager == mock_grpc_server_manager


@mock.patch.object(DebuggerCache, 'get_data') @mock.patch.object(DebuggerCache, 'get_data')
def test_poll_data(self, *args): def test_poll_data(self, *args):
@@ -186,7 +185,6 @@ class TestDebuggerServer:
self._server.create_watchpoint({'watch_condition': {'id': 'inf'}}) self._server.create_watchpoint({'watch_condition': {'id': 'inf'}})


@mock.patch.object(MetadataHandler, 'state', 'waiting') @mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(MetadataHandler, 'backend', 'GPU')
@mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock()) @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=MagicMock())
@mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope') @mock.patch.object(GraphHandler, 'get_node_type', return_value='aggregation_scope')
@mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock()) @mock.patch.object(watchpoint_operator, 'get_basic_node_info', return_value=MagicMock())
@@ -194,6 +192,7 @@ class TestDebuggerServer:
def test_create_watchpoint(self, *args): def test_create_watchpoint(self, *args):
"""Test create watchpoint.""" """Test create watchpoint."""
args[0].return_value = 1 args[0].return_value = 1
self._server.cache_store.get_stream_handler((Streams.METADATA)).backend = 'GPU'
res = self._server.create_watchpoint( res = self._server.create_watchpoint(
{'watch_condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, {'watch_condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
'watch_nodes': ['watch_node_name']}) 'watch_nodes': ['watch_node_name']})


+ 7
- 0
tests/utils/tools.py View File

@@ -68,6 +68,13 @@ def compare_result_with_file(result, expected_file_path):
assert result == expected_results assert result == expected_results




def compare_result_with_binary_file(result, expected_file_path):
"""Compare result with binary file which contain the expected results."""
with open(expected_file_path, 'rb') as file:
expected_results = file.read()
assert result == expected_results


def deal_float_for_dict(res: dict, expected_res: dict, decimal_num): def deal_float_for_dict(res: dict, expected_res: dict, decimal_num):
""" """
Deal float rounded to specified decimals in dict. Deal float rounded to specified decimals in dict.


Loading…
Cancel
Save