Browse Source

add debugger module

tags/v1.0.0
yelihua 5 years ago
parent
commit
50e1400505
39 changed files with 7244 additions and 235 deletions
  1. +26
    -0
      mindinsight/backend/debugger/__init__.py
  2. +300
    -0
      mindinsight/backend/debugger/debugger_api.py
  3. +6
    -0
      mindinsight/conf/defaults.py
  4. +2
    -116
      mindinsight/datavisual/data_transform/tensor_container.py
  5. +9
    -105
      mindinsight/datavisual/processors/tensor_processor.py
  6. +19
    -0
      mindinsight/datavisual/utils/tools.py
  7. +20
    -0
      mindinsight/debugger/__init__.py
  8. +14
    -0
      mindinsight/debugger/common/__init__.py
  9. +14
    -0
      mindinsight/debugger/common/exceptions/__init__.py
  10. +56
    -0
      mindinsight/debugger/common/exceptions/error_code.py
  11. +117
    -0
      mindinsight/debugger/common/exceptions/exceptions.py
  12. +20
    -0
      mindinsight/debugger/common/log.py
  13. +168
    -0
      mindinsight/debugger/common/utils.py
  14. +154
    -0
      mindinsight/debugger/debugger_cache.py
  15. +309
    -0
      mindinsight/debugger/debugger_grpc_server.py
  16. +752
    -0
      mindinsight/debugger/debugger_server.py
  17. +113
    -0
      mindinsight/debugger/proto/debug_grpc.proto
  18. +683
    -0
      mindinsight/debugger/proto/debug_grpc_pb2.py
  19. +193
    -0
      mindinsight/debugger/proto/debug_grpc_pb2_grpc.py
  20. +322
    -0
      mindinsight/debugger/proto/ms_graph.proto
  21. +1395
    -0
      mindinsight/debugger/proto/ms_graph_pb2.py
  22. +14
    -0
      mindinsight/debugger/stream_cache/__init__.py
  23. +289
    -0
      mindinsight/debugger/stream_cache/debugger_graph.py
  24. +61
    -0
      mindinsight/debugger/stream_cache/node.py
  25. +233
    -0
      mindinsight/debugger/stream_cache/tensor.py
  26. +300
    -0
      mindinsight/debugger/stream_cache/watchpoint.py
  27. +23
    -0
      mindinsight/debugger/stream_handler/__init__.py
  28. +34
    -0
      mindinsight/debugger/stream_handler/base_handler.py
  29. +159
    -0
      mindinsight/debugger/stream_handler/event_handler.py
  30. +314
    -0
      mindinsight/debugger/stream_handler/graph_handler.py
  31. +131
    -0
      mindinsight/debugger/stream_handler/metadata_handler.py
  32. +298
    -0
      mindinsight/debugger/stream_handler/tensor_handler.py
  33. +333
    -0
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  34. +45
    -3
      mindinsight/scripts/start.py
  35. +5
    -0
      mindinsight/utils/constant.py
  36. +298
    -0
      mindinsight/utils/tensor.py
  37. +2
    -1
      requirements.txt
  38. +8
    -4
      tests/ut/datavisual/data_transform/test_tensor_container.py
  39. +5
    -6
      tests/ut/datavisual/processors/test_tensor_processor.py

+ 26
- 0
mindinsight/backend/debugger/__init__.py View File

@@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Module init file."""
from mindinsight.backend.debugger.debugger_api import init_module as init_query_module


def init_module(app):
"""
Init module entry.

Args:
app (Flask): A Flask instance.
"""
init_query_module(app)

+ 300
- 0
mindinsight/backend/debugger/debugger_api.py View File

@@ -0,0 +1,300 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Debugger restful api."""
import json

from flask import Blueprint, jsonify, request

from mindinsight.conf import settings
from mindinsight.debugger.debugger_server import DebuggerServer
from mindinsight.utils.exceptions import ParamValueError

BLUEPRINT = Blueprint("debugger", __name__,
url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX)


def _initialize_debugger_server():
"""Initialize a debugger server instance."""
port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else None
enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
server = None
if port and enable_debugger:
server = DebuggerServer(port)
return server


def _read_post_request(post_request):
"""
Extract the body of post request.

Args:
post_request (object): The post request.

Returns:
dict, the deserialized body of request.
"""
body = post_request.stream.read()
try:
body = json.loads(body if body else "{}")
except Exception:
raise ParamValueError("Json data parse failed.")
return body


def _wrap_reply(func, *args, **kwargs):
"""Serialize reply."""
reply = func(*args, **kwargs)
return jsonify(reply)


@BLUEPRINT.route("/debugger/poll_data", methods=["GET"])
def poll_data():
"""
Wait for data to be updated on UI.

Get data from server and display the change on UI.

Returns:
str, the updated data.

Examples:
>>> Get http://xxxx/v1/mindinsight/debugger/poll_data?pos=xx
"""
pos = request.args.get('pos')

reply = _wrap_reply(BACKEND_SERVER.poll_data, pos)

return reply


@BLUEPRINT.route("/debugger/search", methods=["GET"])
def search():
"""
Search nodes in specified watchpoint.

Returns:
str, the required data.

Examples:
>>> Get http://xxxx/v1/mindinsight/debugger/retrive?mode=all
"""
name = request.args.get('name')
watch_point_id = int(request.args.get('watch_point_id', 0))
reply = _wrap_reply(BACKEND_SERVER.search, name, watch_point_id)

return reply


@BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"])
def retrieve_node_by_bfs():
"""
Search node by bfs.

Returns:
str, the required data.

Examples:
>>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true
"""
name = request.args.get('name')
ascend = request.args.get('ascend', 'false')
ascend = ascend == 'true'
reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, ascend)

return reply


@BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"])
def tensor_comparisons():
"""
Get tensor comparisons.

Returns:
str, the required data.

Examples:
>>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons?
name=node_name&detail=data&shape=[0, 0, :, :]&tolerance=0.5
"""
name = request.args.get('name')
detail = request.args.get('detail', 'data')
shape = request.args.get('shape')
tolerance = request.args.get('tolerance', '0')
reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance)

return reply


@BLUEPRINT.route("/debugger/retrieve", methods=["POST"])
def retrieve():
"""
Retrieve data according to mode and params.

Returns:
str, the required data.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/retrieve
"""
body = _read_post_request(request)
mode = body.get('mode')
params = body.get('params')
reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params)
return reply


@BLUEPRINT.route("/debugger/retrieve_tensor_history", methods=["POST"])
def retrieve_tensor_history():
"""
Retrieve data according to mode and params.

Returns:
str, the required data.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/retrieve_tensor_history
"""
body = _read_post_request(request)
name = body.get('name')
reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name)
return reply


@BLUEPRINT.route("/debugger/tensors", methods=["GET"])
def retrieve_tensor_value():
"""
Retrieve tensor value according to name and shape.

Returns:
str, the required data.

Examples:
>>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=node_name&detail=data&shape=[1,1,:,:]
"""
name = request.args.get('name')
detail = request.args.get('detail')
shape = request.args.get('shape')
reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape)
return reply


@BLUEPRINT.route("/debugger/create_watchpoint", methods=["POST"])
def create_watchpoint():
"""
Create watchpoint.

Returns:
str, watchpoint id.

Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/create_watchpoint
"""
body = _read_post_request(request)

condition = body.get('condition')
watch_nodes = body.get('watch_nodes')
watch_point_id = body.get('watch_point_id')
reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, condition, watch_nodes, watch_point_id)
return reply


@BLUEPRINT.route("/debugger/update_watchpoint", methods=["POST"])
def update_watchpoint():
"""
Update watchpoint.

Returns:
str, reply message.

Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/update_watchpoint
"""
body = _read_post_request(request)

watch_point_id = body.get('watch_point_id')
watch_nodes = body.get('watch_nodes')
mode = body.get('mode')
name = body.get('name')
reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, watch_point_id, watch_nodes, mode, name)

return reply


@BLUEPRINT.route("/debugger/delete_watchpoint", methods=["POST"])
def delete_watchpoint():
"""
delete watchpoint.

Returns:
str, reply message.

Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/delete_watchpoint
"""
body = _read_post_request(request)

watch_point_id = body.get('watch_point_id')

reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id)

return reply


@BLUEPRINT.route("/debugger/control", methods=["POST"])
def control():
"""
Control request.

Returns:
str, reply message.

Raises:
MindInsightException: If method fails to be called.
ParamValueError: If parsing json data search_condition fails.

Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/control
"""
params = _read_post_request(request)
reply = _wrap_reply(BACKEND_SERVER.control, params)

return reply


BACKEND_SERVER = _initialize_debugger_server()


def init_module(app):
"""
Init module entry.

Args:
app (Flask): The application obj.
"""
app.register_blueprint(BLUEPRINT)
if BACKEND_SERVER:
BACKEND_SERVER.start()

+ 6
- 0
mindinsight/conf/defaults.py View File

@@ -26,6 +26,12 @@ WORKSPACE = os.path.join(os.environ['HOME'], 'mindinsight')
PORT = 8080
URL_PATH_PREFIX = ''

####################################
# Debugger default settings.
####################################
DEBUGGER_PORT = '50051'
ENABLE_DEBUGGER = False

####################################
# Datavisual default settings.
####################################


+ 2
- 116
mindinsight/datavisual/data_transform/tensor_container.py View File

@@ -15,128 +15,14 @@
"""Tensor data container."""
import numpy as np

from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket
from mindinsight.datavisual.utils.utils import calc_histogram_bins
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.tensor import TensorUtils

F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max
MAX_TENSOR_COUNT = 10000000


class Statistics:
"""Statistics data class.

Args:
max_value (float): max value of tensor data.
min_value (float): min value of tensor data.
avg_value (float): avg value of tensor data.
count (int): total count of tensor data.
nan_count (int): count of NAN.
neg_inf_count (int): count of negative INF.
pos_inf_count (int): count of positive INF.
"""

def __init__(self, max_value=0, min_value=0, avg_value=0,
count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0):
self._max = max_value
self._min = min_value
self._avg = avg_value
self._count = count
self._nan_count = nan_count
self._neg_inf_count = neg_inf_count
self._pos_inf_count = pos_inf_count

@property
def max(self):
"""Get max value of tensor."""
return self._max

@property
def min(self):
"""Get min value of tensor."""
return self._min

@property
def avg(self):
"""Get avg value of tensor."""
return self._avg

@property
def count(self):
"""Get total count of tensor."""
return self._count

@property
def nan_count(self):
"""Get count of NAN."""
return self._nan_count

@property
def neg_inf_count(self):
"""Get count of negative INF."""
return self._neg_inf_count

@property
def pos_inf_count(self):
"""Get count of positive INF."""
return self._pos_inf_count


def get_statistics_from_tensor(tensors):
"""
Calculates statistics data of tensor.

Args:
tensors (numpy.ndarray): An numpy.ndarray of tensor data.

Returns:
an instance of Statistics.
"""
ma_value = np.ma.masked_invalid(tensors)
total, valid = tensors.size, ma_value.count()
invalids = []
for isfn in np.isnan, np.isposinf, np.isneginf:
if total - valid > sum(invalids):
count = np.count_nonzero(isfn(tensors))
invalids.append(count)
else:
invalids.append(0)

nan_count, pos_inf_count, neg_inf_count = invalids
if not valid:
logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape)
statistics = Statistics(max_value=0,
min_value=0,
avg_value=0,
count=total,
nan_count=nan_count,
neg_inf_count=neg_inf_count,
pos_inf_count=pos_inf_count)
return statistics

# BUG: max of a masked array with dtype np.float16 returns inf
# See numpy issue#15077
if issubclass(tensors.dtype.type, np.floating):
tensor_min = ma_value.min(fill_value=np.PINF)
tensor_max = ma_value.max(fill_value=np.NINF)
if tensor_min < F32_MIN or tensor_max > F32_MAX:
logger.warning('Values(%f, %f) are too large, you may encounter some undefined '
'behaviours hereafter.', tensor_min, tensor_max)
else:
tensor_min = ma_value.min()
tensor_max = ma_value.max()
tensor_sum = ma_value.sum(dtype=np.float64)
statistics = Statistics(max_value=tensor_max,
min_value=tensor_min,
avg_value=tensor_sum / valid,
count=total,
nan_count=nan_count,
neg_inf_count=neg_inf_count,
pos_inf_count=pos_inf_count)
return statistics


def calc_original_buckets(np_value, stats):
"""
Calculate buckets from tensor data.
@@ -188,7 +74,7 @@ class TensorContainer:
self._dims = tuple(tensor_message.dims)
self._data_type = tensor_message.data_type
self._np_array = self.get_ndarray(tensor_message.float_data)
self._stats = get_statistics_from_tensor(self._np_array)
self._stats = TensorUtils.get_statistics_from_tensor(self._np_array)
original_buckets = calc_original_buckets(self._np_array, self._stats)
self._count = sum(bucket.count for bucket in original_buckets)
self._max = self._stats.max


+ 9
- 105
mindinsight/datavisual/processors/tensor_processor.py View File

@@ -19,97 +19,16 @@ import numpy as np

from mindinsight.datavisual.utils.tools import to_int
from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError
from mindinsight.utils.tensor import TensorUtils
from mindinsight.conf.constants import MAX_TENSOR_RESPONSE_DATA_SIZE
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.common.exceptions import StepTensorDataNotInCacheError, TensorNotExistError
from mindinsight.datavisual.common.exceptions import ResponseDataExceedMaxValueError
from mindinsight.datavisual.data_transform.tensor_container import TensorContainer, get_statistics_from_tensor
from mindinsight.datavisual.data_transform.tensor_container import TensorContainer
from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2


def convert_array_from_str(dims, limit=0):
"""
Convert string of dims data to array.

Args:
dims (str): Specify dims of tensor.
limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation.

Returns:
list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None].

Raises:
ParamValueError, If flexible dimensions exceed limit value.
"""
dims = dims.strip().lstrip('[').rstrip(']')
dims_list = []
count = 0
for dim in dims.split(','):
dim = dim.strip()
if dim == ':':
dims_list.append(None)
count += 1
else:
dims_list.append(to_int(dim, "dim"))
if limit and count > limit:
raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}"
.format(limit, count))
return dims_list


def get_specific_dims_data(ndarray, dims, tensor_dims):
"""
Get specific dims data.

Args:
ndarray (numpy.ndarray): An ndarray of numpy.
dims (list): A list of specific dims.
tensor_dims (list): A list of tensor dims.

Returns:
numpy.ndarray, an ndarray of specific dims tensor data.

Raises:
ParamValueError, If the length of param dims is not equal to the length of tensor dims or
the index of param dims out of range.
"""
if len(dims) != len(tensor_dims):
raise ParamValueError("The length of param dims: {}, is not equal to the "
"length of tensor dims: {}.".format(len(dims), len(tensor_dims)))
indices = []
for k, d in enumerate(dims):
if d is not None:
if d >= tensor_dims[k]:
raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k]))
indices.append(d)
else:
indices.append(slice(0, tensor_dims[k]))
return ndarray[tuple(indices)]


def get_statistics_dict(stats):
"""
Get statistics dict according to statistics value.

Args:
stats (Statistics): An instance of Statistics.

Returns:
dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'.
"""
statistics = {
"max": float(stats.max),
"min": float(stats.min),
"avg": float(stats.avg),
"count": stats.count,
"nan_count": stats.nan_count,
"neg_inf_count": stats.neg_inf_count,
"pos_inf_count": stats.pos_inf_count
}
return statistics


class TensorProcessor(BaseProcessor):
"""Tensor Processor."""
def get_tensors(self, train_ids, tags, step, dims, detail):
@@ -130,22 +49,7 @@ class TensorProcessor(BaseProcessor):
UrlDecodeError, If unquote train id error with strict mode.
"""
Validation.check_param_empty(train_id=train_ids, tag=tags)
if dims is not None:
if not isinstance(dims, str):
raise ParamValueError('The type of dims must be str, but got {}.'.format(type(dims)))
dims = dims.strip()
if not (dims.startswith('[') and dims.endswith(']')):
raise ParamValueError('The value: {} of dims must be '
'start with `[` and end with `]`.'.format(dims))
for dim in dims[1:-1].split(','):
dim = dim.strip()
if dim == ":":
continue
if dim.startswith('-'):
dim = dim[1:]
if not dim.isdigit():
raise ParamValueError('The value: {} of dims in the square brackets '
'must be int or `:`.'.format(dims))
TensorUtils.validate_dims_format(dims)

for index, train_id in enumerate(train_ids):
try:
@@ -248,7 +152,7 @@ class TensorProcessor(BaseProcessor):
"data_type": anf_ir_pb2.DataType.Name(value.data_type)
}
if detail and detail == 'stats':
stats = get_statistics_dict(value.stats)
stats = TensorUtils.get_statistics_dict(value.stats)
value_dict.update({"statistics": stats})

values.append({
@@ -295,14 +199,14 @@ class TensorProcessor(BaseProcessor):
"""
values = []
step_in_cache = False
dims = convert_array_from_str(dims, limit=2)
dims = TensorUtils.convert_array_from_str_dims(dims, limit=2)
for tensor in tensors:
# This value is an instance of TensorContainer
value = tensor.value
if step != tensor.step:
continue
step_in_cache = True
res_data = get_specific_dims_data(value.ndarray, dims, list(value.dims))
res_data = TensorUtils.get_specific_dims_data(value.ndarray, dims, list(value.dims))
flatten_data = res_data.flatten().tolist()
if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE:
raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}."
@@ -328,7 +232,7 @@ class TensorProcessor(BaseProcessor):
transfer_data[index] = float(data)
return transfer_data

stats = get_statistics_from_tensor(res_data)
stats = TensorUtils.get_statistics_from_tensor(res_data)
if stats.nan_count + stats.neg_inf_count + stats.pos_inf_count > 0:
tensor_data = transfer(res_data)
else:
@@ -340,7 +244,7 @@ class TensorProcessor(BaseProcessor):
"dims": value.dims,
"data_type": anf_ir_pb2.DataType.Name(value.data_type),
"data": tensor_data,
"statistics": get_statistics_dict(stats)
"statistics": TensorUtils.get_statistics_dict(stats)
}
})
break
@@ -389,7 +293,7 @@ class TensorProcessor(BaseProcessor):
"dims": value.dims,
"data_type": anf_ir_pb2.DataType.Name(value.data_type),
"histogram_buckets": buckets,
"statistics": get_statistics_dict(value.stats)
"statistics": TensorUtils.get_statistics_dict(value.stats)
}
})



+ 19
- 0
mindinsight/datavisual/utils/tools.py View File

@@ -80,6 +80,25 @@ def to_int(param, param_name):
return param


def to_float(param, param_name):
"""
Transfer param to float type.

Args:
param (Any): A param transformed.
param_name (str): Param name.

Returns:
float, value after transformed.

"""
try:
param = float(param)
except ValueError:
raise exceptions.ParamTypeError(param_name, 'Float')
return param


def str_to_bool(param, param_name):
"""
Check param and transform it to bool.


+ 20
- 0
mindinsight/debugger/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Debugger Module Introduction.

This module provides Python APIs to retrieve the debugger info and control the training process.
The APIs can help users to understand the training process and find the bugs in training script.
"""

+ 14
- 0
mindinsight/debugger/common/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 14
- 0
mindinsight/debugger/common/exceptions/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 56
- 0
mindinsight/debugger/common/exceptions/error_code.py View File

@@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Debugger error code and messages."""
from enum import Enum, unique
from mindinsight.utils.constant import DebuggerErrors as DebuggerErrorCodes


_PARAM_ERROR_MASK = 0b00001 << 7
_DEBUGGER_GRAPH_ERROR = 0b00010 << 7
_DEBUGGER_RUNNING_ERROR = 0b00011 << 7


@unique
class DebuggerErrors(DebuggerErrorCodes):
"""Debugger error codes."""
PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK
PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK

NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR
GRAPH_NOT_EXIST_ERROR = 1 | _DEBUGGER_GRAPH_ERROR

CREATE_WATCHPOINT_ERROR = 0 | _DEBUGGER_RUNNING_ERROR
UPDATE_WATCHPOINT_ERROR = 1 | _DEBUGGER_RUNNING_ERROR
DELETE_WATCHPOINT_ERROR = 2 | _DEBUGGER_RUNNING_ERROR
CONTINUE_ERROR = 3 | _DEBUGGER_RUNNING_ERROR
PAUSE_ERROR = 4 | _DEBUGGER_RUNNING_ERROR
COMPARE_TENSOR_ERROR = 5 | _DEBUGGER_RUNNING_ERROR


@unique
class DebuggerErrorMsg(Enum):
"""Debugger error messages."""
PARAM_TYPE_ERROR = "TypeError. {}"
PARAM_VALUE_ERROR = "ValueError. {}"
PARAM_MISSING_ERROR = "MissingError. {}"
UNEXPECTED_EXCEPTION_ERROR = "Unexpected exception. {}"

GRAPH_NOT_EXIST_ERROR = "The graph does not exist."

CREATE_WATCHPOINT_ERROR = "Create watchpoint failed. {}"
UPDATE_WATCHPOINT_ERROR = "Update watchpoint failed. {}"
DELETE_WATCHPOINT_ERROR = "Delete watchpoint failed. {}"
CONTINUE_ERROR = "Continue debugging failed. {}"
PAUSE_ERROR = "Pause debugging failed. {}"

+ 117
- 0
mindinsight/debugger/common/exceptions/exceptions.py View File

@@ -0,0 +1,117 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Definition of error code and relative messages in debugger module."""
from mindinsight.utils.exceptions import MindInsightException
from mindinsight.debugger.common.exceptions.error_code import DebuggerErrors, DebuggerErrorMsg


class DebuggerParamTypeError(MindInsightException):
"""The parameter type error in debugger module."""

def __init__(self, msg):
super(DebuggerParamTypeError, self).__init__(
error=DebuggerErrors.PARAM_TYPE_ERROR,
message=DebuggerErrorMsg.PARAM_TYPE_ERROR.value.format(msg)
)


class DebuggerParamValueError(MindInsightException):
"""The parameter value error in debugger module."""

def __init__(self, msg):
super(DebuggerParamValueError, self).__init__(
error=DebuggerErrors.PARAM_VALUE_ERROR,
message=DebuggerErrorMsg.PARAM_VALUE_ERROR.value.format(msg)
)


class DebuggerCreateWatchPointError(MindInsightException):
"""The error about creating watch point."""

def __init__(self, msg):
super(DebuggerCreateWatchPointError, self).__init__(
error=DebuggerErrors.CREATE_WATCHPOINT_ERROR,
message=DebuggerErrorMsg.CREATE_WATCHPOINT_ERROR.value.format(msg)
)


class DebuggerUpdateWatchPointError(MindInsightException):
"""The error about updating watch point."""

def __init__(self, msg):
super(DebuggerUpdateWatchPointError, self).__init__(
error=DebuggerErrors.UPDATE_WATCHPOINT_ERROR,
message=DebuggerErrorMsg.UPDATE_WATCHPOINT_ERROR.value.format(msg)
)


class DebuggerDeleteWatchPointError(MindInsightException):
"""The error about deleting watch point."""

def __init__(self, msg):
super(DebuggerDeleteWatchPointError, self).__init__(
error=DebuggerErrors.DELETE_WATCHPOINT_ERROR,
message=DebuggerErrorMsg.DELETE_WATCHPOINT_ERROR.value.format(msg)
)


class DebuggerCompareTensorError(MindInsightException):
"""The error about comparing tensors."""

def __init__(self, msg):
super(DebuggerCompareTensorError, self).__init__(
error=DebuggerErrors.COMPARE_TENSOR_ERROR,
message=DebuggerErrorMsg.COMPARE_TENSOR_ERROR.value.format(msg)
)


class DebuggerContinueError(MindInsightException):
"""The error about continuing debugging."""
def __init__(self, msg):
super(DebuggerContinueError, self).__init__(
error=DebuggerErrors.CONTINUE_ERROR,
message=DebuggerErrorMsg.CONTINUE_ERROR.value.format(msg)
)


class DebuggerPauseError(MindInsightException):
"""The error about pausing debugging."""
def __init__(self, msg):
super(DebuggerPauseError, self).__init__(
error=DebuggerErrors.PAUSE_ERROR,
message=DebuggerErrorMsg.PAUSE_ERROR.value.format(msg)
)


class DebuggerNodeNotInGraphError(MindInsightException):
"""The node is not in the graph."""
def __init__(self, node_name, node_type=None):
if node_type is not None:
err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}, type: {node_type}."
else:
err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}."
super(DebuggerNodeNotInGraphError, self).__init__(
error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR,
message=err_msg
)


class DebuggerGraphNotExistError(MindInsightException):
"""The graph does not exist."""
def __init__(self):
super(DebuggerGraphNotExistError, self).__init__(
error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR,
message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value
)

+ 20
- 0
mindinsight/debugger/common/log.py View File

@@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Import mindinsight unified log module."""
from mindinsight.utils.log import setup_logger

LOG_NAME = "debugger"
LOG_MODULE = "debugger"
logger = setup_logger(sub_module=LOG_MODULE, log_name=LOG_NAME)

+ 168
- 0
mindinsight/debugger/common/utils.py View File

@@ -0,0 +1,168 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the utils."""
import enum
from collections import namedtuple

import numpy as np

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum

# translate the MindSpore type to numpy type.
NUMPY_TYPE_MAP = {
'DT_BOOL': np.bool,

'DT_INT8': np.int8,
'DT_INT16': np.int16,
'DT_INT32': np.int32,
'DT_INT64': np.int64,

'DT_UINT8': np.uint8,
'DT_UINT16': np.uint16,
'DT_UINT32': np.uint32,
'DT_UINT64': np.uint64,

'DT_FLOAT16': np.float16,
'DT_FLOAT32': np.float32,
'DT_FLOAT64': np.float64,

'DT_STRING': np.str
}


@enum.unique
class ReplyStates(enum.Enum):
"""Define the status of reply."""
SUCCESS = 0
FAILED = -1


@enum.unique
class ServerStatus(enum.Enum):
"""The status of debugger server."""
PENDING = 'pending' # no client session has been connected
RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph
WAITING = 'waiting' # the client session is ready
RUNNING = 'running' # the client session is running a script


@enum.unique
class Streams(enum.Enum):
"""Define the enable streams to be deal with."""

COMMAND = "command"
DATA = "data"
METADATA = "metadata"
GRAPH = 'node'
TENSOR = 'tensor'
WATCHPOINT = 'watchpoint'
WATCHPOINT_HIT = 'watchpoint_hit'


NodeBasicInfo = namedtuple('node_basic_info', ['name', 'full_name', 'type'])


def get_ack_reply(state=0):
"""The the ack EventReply."""
reply = EventReply()
state_mapping = {
0: EventReply.Status.OK,
1: EventReply.Status.FAILED,
2: EventReply.Status.PENDING
}
reply.status = state_mapping[state]

return reply


def wrap_reply_response(error_code=None, error_message=None):
"""
Wrap reply response.

Args:
error_code (str): Error code. Default: None.
error_message (str): Error message. Default: None.

Returns:
str, serialized response.
"""
if error_code is None:
reply = {'state': ReplyStates.SUCCESS.value}
else:
reply = {
'state': ReplyStates.FAILED.value,
'error_code': error_code,
'error_message': error_message
}

return reply


def create_view_event_from_tensor_history(tensor_history):
"""
Create view event reply according to tensor names.

Args:
tensor_history (list[dict]): The list of tensor history. Each element has keys:
`name`, `node_type`.

Returns:
EventReply, the event reply with view cmd.
"""
view_event = get_ack_reply()
for tensor_info in tensor_history:
node_type = tensor_info.get('node_type')
if node_type == NodeTypeEnum.CONST.value:
continue
truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value
tensor_name = tensor_info.get('full_name', '')
# create view command
ms_tensor = view_event.view_cmd.tensors.add()
ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1)
ms_tensor.truncate = truncate_tag
ms_tensor.iter = 'prev' if tensor_info.get('iter') else ''

return view_event


def is_scope_type(node_type):
"""Judge whether the type is scope type."""
scope_types = [NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.AGGREGATION_SCOPE.value]
return node_type in scope_types


def str_to_slice_or_int(input_str):
"""
Translate param from string to slice or int.

Args:
input_str (str): The string to be translated.

Returns:
Union[int, slice], the transformed param.
"""
try:
if ':' in input_str:
ret = slice(*map(lambda x: int(x.strip()) if x.strip() else None, input_str.split(':')))
else:
ret = int(input_str)
except ValueError as err:
log.error("Failed to create slice from %s", input_str)
log.exception(err)
raise DebuggerParamValueError("Invalid shape.")
return ret

+ 154
- 0
mindinsight/debugger/debugger_cache.py View File

@@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implement the debugger data cache manager."""
import sys

from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.common.utils import Streams
from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \
TensorHandler, WatchpointHandler, WatchpointHitHandler

STREAM_HANDLER_MAP = {
Streams.COMMAND.value: EventHandler,
Streams.DATA.value: EventHandler,
Streams.METADATA.value: MetadataHandler,
Streams.GRAPH.value: GraphHandler,
Streams.TENSOR.value: TensorHandler,
Streams.WATCHPOINT.value: WatchpointHandler,
Streams.WATCHPOINT_HIT.value: WatchpointHitHandler
}


class DebuggerCache:
"""The debugger data cache manager."""

def __init__(self):
self._stream_handler = {}

def initialize(self):
"""Initialize the stream handlers."""
self._stream_handler = {}
for stream in Streams:
mode = stream.value
stream_handler = STREAM_HANDLER_MAP.get(mode)
self._stream_handler[mode] = stream_handler()

def clean(self):
"""Clean cache for all stream."""
for _, stream_handler in self._stream_handler.items():
stream_handler.clean()

def get_stream_handler(self, mode):
"""
Get the stream handler object.

Args:
mode (Streams): The type of stream handler.

Returns:
StreamHandlerBase, the stream handler object.
"""
return self._stream_handler.get(mode.value)

def _get(self, mode, pos):
"""
Get updated data or command from cache.

Args:
mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`.
pos (int): The index of info.

Returns:
object, the pos-th message about `mode` type of info.
"""
stream_handler = self.get_stream_handler(mode)

return stream_handler.get(pos)

def _put(self, mode, value):
"""
Set updated data or command from cache.

Args:
mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`.
value (object): The info to be record in cache.
"""
stream_handler = self.get_stream_handler(mode)

return stream_handler.put(value)

def get_command(self, pos):
"""
Get the pos-th command in command stream.

Args:
pos (int): The index of command.

Returns:
int, the position of next message.
EventReply, the command object.
"""
content = self._get(Streams.COMMAND, pos)
next_pos = content.get('metadata').get('pos')
reply = content.get('cmd')
return next_pos, reply

def put_command(self, cmd):
"""
Set command to command stream.

Args:
cmd (EventReply): The command EventReply.
"""
log.debug("Set command %s", cmd)
return self._put(Streams.COMMAND, {'cmd': cmd})

def has_command(self, pos):
"""Judge if the number of command is no less than `pos`."""
event = self.get_stream_handler(Streams.COMMAND).has_pos(pos)

return event

def clean_command(self):
"""Clean command queue."""
self.get_stream_handler(Streams.COMMAND).clean()
log.debug("Clean command.")

def clean_data(self):
"""Clean command queue."""
self.get_stream_handler(Streams.DATA).clean()
log.debug("Clean data queue.")

def get_data(self, pos):
"""
Get updated data from data stream.

Args:
pos (int): The index of data.

Returns:
object, updated data_value.
"""
return self._get(Streams.DATA, pos)

def put_data(self, value):
"""
Set updated data to data stream.

Args:
value (dict): The updated data.
"""
log.debug("Set <%d> bytes data", sys.getsizeof(value))
return self._put(Streams.DATA, value)

+ 309
- 0
mindinsight/debugger/debugger_grpc_server.py View File

@@ -0,0 +1,309 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implement the debugger grpc server."""
from functools import wraps

from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
create_view_event_from_tensor_history, Streams
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto


def debugger_wrap(func):
"""Wrapper for catch exception."""

@wraps(func)
def record_log(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as err:
log.exception(err)
raise err

return record_log


class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
"""The grpc server used to interactive with grpc client."""

def __init__(self, cache_store):
"""
Initialize.

Args:
cache_store (DebuggerCache): Debugger cache store.
"""
cache_store.initialize()
self._cache_store = cache_store
self._pos = None
self._status = None
self._view_event = None
self._view_round = None
self._continue_steps = None
self.init()

def init(self):
"""Init debugger grpc server."""
self._pos = '0'
self._status = ServerStatus.PENDING
self._view_event = None
self._view_round = True
self._continue_steps = 0
self._cache_store.clean()

@debugger_wrap
def WaitCMD(self, request, context):
"""Wait for a command in DebuggerCache."""
# check if graph have already received.
log.info("Received WaitCMD at %s-th step.", request.cur_step)
if self._status == ServerStatus.PENDING:
log.warning("No graph received before WaitCMD.")
reply = get_ack_reply(1)
return reply
# send graph if has not been sent before
self._pre_process(request)
# deal with old command
reply = self._deal_with_old_command()
if reply:
log.info("Reply to WaitCMD with old command: %s", reply)
return reply
# send view cmd
if self._view_round and self._view_event:
self._view_round = False
reply = self._view_event
log.debug("Send ViewCMD.")
# continue multiple steps training
elif self._continue_steps != 0:
reply = get_ack_reply()
reply.run_cmd.run_steps = 1
reply.run_cmd.run_level = 'step'
self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
log.debug("Send RunCMD. Clean watchpoint hit.")
# wait for command
else:
reply = self._wait_for_next_command()

if reply is None:
reply = get_ack_reply(1)
log.warning("Failed to get command event.")
else:
log.info("Reply to WaitCMD: %s", reply)
return reply

def _pre_process(self, request):
"""Send graph and metadata when WaitCMD first called."""
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
if self._status == ServerStatus.RECEIVE_GRAPH:
self._status = ServerStatus.WAITING
metadata_stream.state = 'waiting'
metadata = metadata_stream.get()
self._cache_store.clean_command()
res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
res.update(metadata)
self._cache_store.put_data(res)
log.info("Put graph into data queue.")

if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node:
# clean tensor cache and DataQueue at the beginning of each step
self._update_metadata(metadata_stream, request)

def _update_metadata(self, metadata_stream, metadata_proto):
"""Update metadata."""
# reset view round and clean cache data
self._view_round = True
if metadata_stream.step < metadata_proto.cur_step:
self._cache_store.clean_data()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(
metadata_proto.cur_step)
# put new metadata into cache
metadata_stream.put(metadata_proto)
cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name(
metadata_proto.cur_node) if metadata_proto.cur_node else ''
metadata_stream.node_name = cur_node
metadata = metadata_stream.get()
self._cache_store.put_data(metadata)
log.info("Put new metadata into data queue.")

def _deal_with_old_command(self):
"""Deal with old command."""
event = None
while self._cache_store.has_command(self._pos) and event is None:
event = self._get_next_command()
log.debug("Deal with old %s-th command:\n%s.", self._pos, event)

return event

def _wait_for_next_command(self):
"""
Wait for next command.

Returns:
EventReply, the command event.
"""
log.info("Start to wait for command.")
self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting'
self._cache_store.put_data({'metadata': {'state': 'waiting'}})
event = None
while event is None and self._status == ServerStatus.WAITING:
log.debug("Wait for %s-th command", self._pos)
event = self._get_next_command()
return event

def _get_next_command(self):
"""Get next command."""
self._pos, event = self._cache_store.get_command(self._pos)
log.debug("Received event :%s", event)
if event is None:
return event
if isinstance(event, dict) and event.get('reset'):
self._set_view_event(event)
event = None
elif event.HasField('run_cmd'):
event = self._deal_with_run_cmd(event)
elif event.HasField('view_cmd'):
self._view_round = False
elif event.HasField('exit'):
self._cache_store.clean()
log.info("Clean cache for exit cmd.")

return event

def _deal_with_run_cmd(self, event):
"""Deal with run cmd."""
run_cmd = event.run_cmd
# receive step command
if run_cmd.run_level == 'step':
# receive pause cmd
if run_cmd.run_steps == 0:
log.debug("Pause training and wait for next command.")
self._continue_steps = 0
return None
# receive step cmd
self._continue_steps = run_cmd.run_steps - 1
event.run_cmd.run_steps = 1
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
log.debug("Receive RunCMD. Clean watchpoint hit cache.")

return event

def _set_view_event(self, event):
"""Create view event for view cmd."""
# the first tensor in view cmd is always the output
node_name = event.get('node_name')
tensor_history = event.get('tensor_history')
if not node_name or not tensor_history:
self._view_event = None
log.info("Reset view command to None.")
else:
# create view event and set
self._view_event = create_view_event_from_tensor_history(tensor_history)
log.info("Reset view command to %s.", node_name)

@debugger_wrap
def SendMetadata(self, request, context):
"""Send metadata into DebuggerCache."""
log.info("Received Metadata.")
if self._status != ServerStatus.PENDING:
log.info("Re-initialize cache store when new session comes.")
self.init()

client_ip = context.peer().split(':', 1)[-1]
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
metadata_stream.put(request)
metadata_stream.client_ip = client_ip
metadata = metadata_stream.get()
# put metadata into data queue
self._cache_store.put_data(metadata)
log.info("Put new metadata to DataQueue.")
reply = get_ack_reply()
log.info("Send the reply to %s.", client_ip)
return reply

@debugger_wrap
def SendGraph(self, request_iterator, context):
"""Send graph into DebuggerCache."""
log.info("Received graph.")
serial_graph = b""
for chunk in request_iterator:
serial_graph += chunk.buffer
graph = GraphProto.FromString(serial_graph)
log.debug("Deserialize the graph. Receive %s nodes", len(graph.node))
self._cache_store.get_stream_handler(Streams.GRAPH).put(graph)
self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
self._status = ServerStatus.RECEIVE_GRAPH
reply = get_ack_reply()
log.info("Send the reply for graph.")
return reply

@debugger_wrap
def SendTensors(self, request_iterator, context):
"""Send tensors into DebuggerCache."""
log.info("Received tensor.")
tensor_construct = []
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
tensor_names = []
step = metadata_stream.step
for tensor in request_iterator:
tensor_construct.append(tensor)
if tensor.finished:
tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
tensor_construct = []
tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
continue
# send back tensor finished flag when all waiting tensor has value.
tensor_history = tensor_stream.get_tensor_history(tensor_names)
self._add_node_name_for_tensor_history(tensor_history)
metadata = metadata_stream.get()
tensor_history.update(metadata)
self._cache_store.put_data({}) # reply to the listening request
self._cache_store.put_data(tensor_history)
log.info("Send updated tensor history to data queue.")
reply = get_ack_reply()
return reply

def _add_node_name_for_tensor_history(self, tensor_history):
"""Add node name for tensor history."""
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
for tensor_info in tensor_history.get('tensor_history'):
if tensor_info:
full_name, slot = tensor_info.get('full_name', '').rsplit(':', 1)
node_name = graph_stream.get_node_name_by_full_name(full_name)
tensor_info['name'] = node_name + ':' + slot

@debugger_wrap
def SendWatchpointHits(self, request_iterator, context):
"""Send watchpoint hits info DebuggerCache."""
log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps)
self._continue_steps = 0
self._view_event = None
watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
for watchpoint_hit_proto in request_iterator:
watchpoint_hit = {
'tensor_proto': watchpoint_hit_proto.tensor,
'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
'node_name': graph_stream.get_node_name_by_full_name(
watchpoint_hit_proto.tensor.node_name)
}
watchpoint_hit_stream.put(watchpoint_hit)
watchpoint_hits_info = watchpoint_hit_stream.get()
self._cache_store.put_data(watchpoint_hits_info)
log.info("Send the watchpoint hits to DataQueue.\nSend the reply.")
reply = get_ack_reply()
return reply

+ 752
- 0
mindinsight/debugger/debugger_server.py View File

@@ -0,0 +1,752 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implement the debugger server."""
import signal
from concurrent import futures
from threading import Thread

import grpc

from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.utils.tools import to_float
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \
DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo, \
str_to_slice_or_int
from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
from mindinsight.utils.exceptions import MindInsightException


class DebuggerServer:
"""The server manager of debugger."""

def __init__(self, grpc_port=None):
self.grpc_port = grpc_port
self.cache_store = DebuggerCache()
self.grpc_server = DebuggerGrpcServer(self.cache_store)
self.grpc_server_manager = None
self.back_server = None
self._watch_point_id = 0

def start(self):
"""Start server."""
grpc_port = self.grpc_port if self.grpc_port else "50051"
host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
hostname = "{}:{}".format(host, grpc_port)
# initialize a grpc server
grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager)
grpc_server_manager.add_insecure_port(hostname)
grpc_server_manager.start()
my_server_thread = Thread(target=grpc_server_manager.wait_for_termination)
# start grpc server
my_server_thread.start()
self.back_server = my_server_thread
self.grpc_server_manager = grpc_server_manager
# register stop server handler
signal.signal(signal.SIGINT, self._stop_handler)
log.info("Start grpc server %s", hostname)

def _stop_handler(self, signum, frame):
"""Register stop server handler."""
self.stop()
log.debug("Deal with stop signal: %s, %s", signum, frame)

def stop(self):
"""Stop debugger server."""
self.grpc_server_manager.stop(grace=None)
self.back_server.join()
log.info("Stop debugger server.")

def poll_data(self, pos):
"""
Get the pos-th data from DebuggerCache.

Args:
pos (int): The index of data.

Returns:
dict, the data to be updated.
"""
if not isinstance(pos, str):
log.error("Pos should be string. Received: %s", pos)
raise DebuggerParamValueError("Pos should be string.")

reply = self.cache_store.get_data(pos)

return reply

def search(self, name, watch_point_id):
"""Search for single node in graph."""
log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id)
graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name)
self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes(
graph, watch_point_id)
return graph

def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
"""
Get tensor comparisons data for given name, detail, shape and tolerance.

Args:
name (str): The name of tensor for ui.
detail (str): Specify which data to query. Current available value is 'data' which means
concrete tensor data. Histogram or unique count can be supported in the future.
shape (str): Specify concrete dimensions of shape.
tolerance (str): Specify tolerance of difference between current step tensor and previous
step tensor. Default value is 0.

Raises:
DebuggerParamValueError, If node type is not parameter or value of detail is not support.
DebuggerCompareTensorError, If MindSpore is not in waiting state.
Returns:
dict, the retrieved data.
"""
if self.cache_store.get_stream_handler(
Streams.METADATA).state != ServerStatus.WAITING.value:
log.error("Failed to compare tensors as the MindSpore is not in waiting state.")
raise DebuggerCompareTensorError(
"Failed to compare tensors as the MindSpore is not in waiting state."
)
self.validate_tensor_param(name, detail)
parsed_shape = self.parse_shape(shape)
node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
tolerance = to_float(tolerance, 'tolerance')
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
if detail == 'data':
if node_type == NodeTypeEnum.PARAMETER.value:
reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance)
else:
raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type))
else:
raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail))
return reply

def retrieve(self, mode, filter_condition=None):
"""
Retrieve data according to mode and params.

Args:
mode (str): The type of info message.
filter_condition (dict): The filter condition.

Returns:
dict, the retrieved data.
"""
log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode,
filter_condition)
# validate watchpoint_id

mode_mapping = {
'all': self._retrieve_all,
'node': self._retrieve_node,
'watchpoint': self._retrieve_watchpoint,
'watchpoint_hit': self._retrieve_watchpoint_hit
}
# validate param <mode>
if mode not in mode_mapping.keys():
log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
"'watchpoint_hit', 'tensor'], but got %s.", mode_mapping)
raise DebuggerParamTypeError("Invalid mode.")
filter_condition = {} if filter_condition is None else filter_condition
self._watch_point_id = filter_condition.get('watch_point_id', 0)
reply = mode_mapping[mode](filter_condition)

return reply

def _retrieve_all(self, filter_condition=None):
"""Retrieve metadata, root graph and watchpoint list."""
if filter_condition:
log.error("No filter condition required for retrieve all request.")
raise DebuggerParamTypeError("filter_condition should be empty.")
result = {}
self.cache_store.clean_data()
log.info("Clean data queue cache when retrieve all request.")
self.cache_store.put_command({'reset': True})
for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]:
sub_res = self.cache_store.get_stream_handler(stream).get()
result.update(sub_res)

return result

def _retrieve_node(self, filter_condition):
"""
Retrieve node info.

Args:
filter_condition (dict): Filter condition.

- name (str): The name of single node.

- watch_point_id (int): The id of watchpoint.

- single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node.

Returns:
dict, the node info.
"""
log.info("Retrieve node %s.", filter_condition)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
# validate parameters
node_name = filter_condition.get('name')
if not node_name:
node_type = NodeTypeEnum.NAME_SCOPE.value
else:
node_type = graph_stream.get_node_type(node_name)
filter_condition['node_type'] = node_type
filter_condition['single_node'] = bool(filter_condition.get('single_node'))
# get graph for scope node
if is_scope_type(node_type):
reply = self._get_nodes_info(filter_condition)
# get tensor history for leaf node
else:
reply = self._get_tensor_history(node_name)
if filter_condition.get('single_node'):
graph = self._get_nodes_info(filter_condition)
reply.update(graph)
return reply

def _get_nodes_info(self, filter_condition):
"""
Get nodes info.

Args:
filter_condition (dict): The filter condition.

- name (str): The node name.

- single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node.

- watch_point_id (int): The id of watchpoint.

Returns:
dict, reply with graph.
"""
# get graph
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
reply = graph_stream.get(filter_condition)
graph = reply.get('graph')
# add watched label
self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes(
graph, self._watch_point_id)
return reply

def retrieve_tensor_history(self, node_name):
"""
Retrieve tensor history for leaf node.

Args:
node_name (str): The name of leaf node.

Returns:
dict, the tensor history and metadata.
"""
log.info("Retrieve tensor history for node: %s.", node_name)
self._validate_leaf_name(node_name)
res = self._get_tensor_history(node_name)
return res

def _validate_leaf_name(self, node_name):
"""Validate if the node is a leaf node."""
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
node_type = graph_stream.get_node_type(node_name)
if is_scope_type(node_type):
log.error("Scope type node has no tensor history.")
raise DebuggerParamValueError("Invalid leaf node name.")

def _get_tensor_history(self, node_name):
"""
Get tensor history for single node.

Args:
node_name (str): The name of leaf node.

Returns:
dict, the tensor history and metadata.
"""
# get basic tensor history
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
tensor_history = graph_stream.get_tensor_history(node_name)
# set the view event
self.cache_store.put_command(
{'reset': True,
'node_name': node_name,
'tensor_history': tensor_history.get('tensor_history')})
# add tensor value for tensor history
self._add_tensor_value_for_tensor_history(tensor_history)
# add hit label for tensor history
watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
watchpoint_hit_stream.update_tensor_history(tensor_history)
# add metadata
metadata = self.cache_store.get_stream_handler(Streams.METADATA).get()
tensor_history.update(metadata)
return tensor_history

def _add_tensor_value_for_tensor_history(self, tensor_history):
"""
Add tensor value for_tensor_history and send ViewCMD if tensor value missed.

Args:
tensor_history (list[dict]): A list of tensor info, including name and type.

Returns:
dict, the tensor info.
"""
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
missed_tensors = tensor_stream.update_tensor_history(tensor_history)
if missed_tensors:
view_cmd = create_view_event_from_tensor_history(missed_tensors)
self.cache_store.put_command(view_cmd)
log.debug("Send view cmd.")

def retrieve_tensor_value(self, name, detail, shape):
"""Retrieve the tensor value."""
log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape)
self.validate_tensor_param(name, detail)
parsed_shape = self.parse_shape(shape)
node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
reply = self.cache_store.get_stream_handler(Streams.TENSOR).get(
{'name': tensor_name,
'node_type': node_type,
'shape': parsed_shape}
)
reply['tensor_value']['name'] = name

return reply

def _get_tensor_name_and_type_by_ui_name(self, name):
"""
Get inner tensor name and type by UI name.

Args:
name (str): Node name shown in UI.

Returns:
str, full name of tensor.
str, node type of tensor.
"""
node_name, slot = name.rsplit(':', 1)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
node_type = graph_stream.get_node_type(node_name)
full_name = graph_stream.get_full_name(node_name)
tensor_name = full_name + ':' + slot
return node_type, tensor_name

@staticmethod
def validate_tensor_param(name, detail):
"""Validate params for retrieve tensor request."""
# validate name
if not isinstance(name, str) or ':' not in name:
log.error("Invalid tensor name. Received: %s", name)
raise DebuggerParamValueError("Invalid tensor name.")
# validate data
if detail != 'data':
log.error("Invalid detail value. Received: %s", detail)
raise DebuggerParamValueError("Invalid detail value.")

@staticmethod
def parse_shape(shape):
"""Parse shape."""
if shape is None:
return shape
if not (isinstance(shape, str) and shape.startswith('[') and shape.endswith(']')):
log.error("Invalid shape. Received: %s", shape)
raise DebuggerParamValueError("Invalid shape.")
shape = shape.strip('[]')
if shape.count(':') > 2:
log.error("Invalid shape. At most two dimensions are specified.")
raise DebuggerParamValueError("Invalid shape.")
parsed_shape = tuple(
str_to_slice_or_int(dim) for dim in shape.split(',')) if shape else tuple()
log.info("Parsed shape: %s from %s", parsed_shape, shape)
return parsed_shape

def _retrieve_watchpoint(self, filter_condition):
"""
Retrieve watchpoint.

Args:
filter_condition (dict): Filter condition.

- watch_point_id (int): The id of watchoint. If not given, return all watchpoints.

- name (str): The name of single node.

- single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node.

Returns:
dict, watch point list or relative graph.
"""
watchpoint_id = filter_condition.get('watch_point_id')
if watchpoint_id is None:
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
log.debug("Get condition of watchpoints.")
else:
reply = self._retrieve_node(filter_condition)
log.debug("Get graph of %d-th watchpoint.", watchpoint_id)

return reply

def _retrieve_watchpoint_hit(self, filter_condition):
"""
Retrieve watchpoint hit.

Args:
filter_condition (dict): Filter condition.

- name (str): The name of single node.

- single_node (bool): If False, return the sub-layer of single node. If True, return
the node list from root node to single node.

Returns:
dict, watch point list or relative graph.
"""
node_name = filter_condition.get('name')
# get watchpoint hit list
if node_name is None:
reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get()
return reply

self._validate_leaf_name(node_name)
# get tensor history
reply = self._get_tensor_history(node_name)
log.debug("Get tensor history for watchpoint hit node.")
# get single graph
if filter_condition.get('single_node'):
graph = self._get_nodes_info(filter_condition)
reply.update(graph)
log.debug("Get tensor history for watchpoint hit node.")

return reply

def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None):
"""
Create watchpoint.

Args:
watch_condition (dict): The watch condition.

- condition (str): Accept `INF` or `NAN`.

- param (list[float]): Not defined yet.
watch_nodes (list[str]): The list of node names.
watch_point_id (int): The id of watchpoint.

Returns:
dict, the id of new watchpoint.
"""
log.info("Received create watchpoint request. WatchCondition: %s", watch_condition)
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.state != ServerStatus.WAITING.value:
log.error("Failed to create watchpoint as the MindSpore is not in waiting state.")
raise DebuggerCreateWatchPointError(
"Failed to create watchpoint as the MindSpore is not in waiting state."
)
if metadata_stream.backend == 'GPU' and watch_condition.get('condition') == 'OVERFLOW':
log.error("GPU doesn't support OVERFLOW watch condition.")
raise DebuggerParamValueError("GPU doesn't support OVERFLOW watch condition.")

watch_nodes = self._get_node_basic_infos(watch_nodes)
watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint(
watch_condition, watch_nodes, watch_point_id)
log.info("Create watchpoint %d", watch_point_id)
return {'id': watch_point_id}

def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None):
"""
Update watchpoint.

Args:
watch_point_id (int): The id of watchpoint.
watch_nodes (list[str]): The list of node names.
mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
1 for add nodes to watch nodes.
name (str): The search name. Default: None.

Returns:
dict, empty response.
"""
if self.cache_store.get_stream_handler(
Streams.METADATA).state != ServerStatus.WAITING.value:
log.error("Failed to update watchpoint as the MindSpore is not in waiting state.")
raise DebuggerUpdateWatchPointError(
"Failed to update watchpoint as the MindSpore is not in waiting state."
)
# validate
if not watch_nodes or not watch_point_id:
log.error("Invalid parameter for update watchpoint.")
raise DebuggerParamValueError("Invalid parameter for update watchpoint.")
# update watch node
if name is not None:
watch_nodes = self._get_watch_nodes_by_search(watch_nodes)
elif mode == 1:
watch_nodes = self._get_node_basic_infos(watch_nodes)

self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint(
watch_point_id, watch_nodes, mode)
log.info("Update watchpoint with id: %d", watch_point_id)
return {}

def _get_watch_nodes_by_search(self, watch_nodes):
"""Get watched leaf nodes by search name."""
watched_leaf_nodes = []
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
for search_name in watch_nodes:
search_nodes = graph_stream.get_searched_node_list()
search_node_names = [
NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
for node in search_nodes
if node.name.startswith(search_name)]
watched_leaf_nodes.extend(search_node_names)

log.debug("Update nodes: %s", watched_leaf_nodes)

return watched_leaf_nodes

def delete_watchpoint(self, watch_point_id):
"""
Delete watchpoint.

Args:
watch_point_id (int): The id of watchpoint.

Returns:
dict, empty response.
"""
if self.cache_store.get_stream_handler(
Streams.METADATA).state != ServerStatus.WAITING.value:
log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.")
raise DebuggerDeleteWatchPointError(
"Failed to delete watchpoint as the MindSpore is not in waiting state."
)
self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint(
watch_point_id)
log.info("Delete watchpoint with id: %d", watch_point_id)
return {}

def _get_node_basic_infos(self, node_names):
"""Get node info according to node names."""
if not node_names:
return []
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
node_infos = []
for node_name in node_names:
node_type = graph_stream.get_node_type(node_name)
# optimizer later
if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value:
sub_nodes = graph_stream.get_nodes(node_name)
sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
for node in sub_nodes]
node_infos.extend(sub_infos)
continue
full_name = graph_stream.get_full_name(node_name)
node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type))
return node_infos

def control(self, params=None):
"""
Control the training process.

Args:
params (dict): The control params.

- mode (str): Acceptable control command, including `continue`,
`pause` and `terminate`.

- level (str): The control granularity, `node` level or `step` level.
Default: `step`.

- steps (int): Specify the steps that training should run.
Used when `level` is `step`.

- name (str): Specify the name of the node. Used when `level` is `node`.

Returns:
dict, the response.
"""
log.info("Receive control request: %s.", params)
mode = params.get('mode')
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if mode == 'continue':
reply = self._continue(metadata_stream, params)
elif mode in ['pause', 'terminate']:
mode_mapping = {
'pause': self._pause,
'terminate': self._terminate
}
reply = mode_mapping.get(mode)(metadata_stream)
else:
log.error("Invalid control mode %s", mode)
raise DebuggerParamValueError("Invalid control mode.")

return reply

def _continue(self, metadata_stream, params):
"""
Send RunCMD to MindSpore.

Args:
metadata_stream (MetadataHandler): The metadata_handler
params (dict): The control params.
"""
if metadata_stream.state != ServerStatus.WAITING.value:
log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state)
raise DebuggerContinueError(
"MindSpore is not ready to run or is running currently."
)
metadata_stream.state = ServerStatus.RUNNING.value
current_state = ServerStatus.RUNNING.value
try:
event = self._construct_run_event(params)
self._send_watchpoints()
self.cache_store.put_command(event)
except MindInsightException as err:
log.error("Failed to send run event.")
log.exception(err)
current_state = ServerStatus.WAITING.value
metadata_stream.state = current_state
raise DebuggerContinueError("Failed to send run command.")
else:
log.debug("Send the RunCMD to command queue.")

return {'metadata': {'state': current_state}}

def _validate_node_type(self, node_name):
"""Check the node type in node control."""
if not node_name:
return
node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name)
unsupported_types = [item.value for item in list(NodeTypeEnum)]
if node_type in unsupported_types:
log.error("Invalid node type. %s", node_name)
raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for "
"continue to command.")

def _construct_run_event(self, params):
"""
Construct run cmd from input control params.

Args:
params (dict): The control params.

- level (str): The control granularity, `node` level or `step` level.
Default: `step`.

- steps (int): Specify the steps that training should run.
Used when `level` is `step`.

- full_name (str): Specify the name of the node. Used when `level` is `node`.

Returns:
EventReply, control event with run command.
"""
level = params.get('level', 'step')
event = get_ack_reply()
if level == 'step':
steps = params.get('steps')
if not steps:
steps = 1
run_cmd = RunCMD(run_level='step', run_steps=steps)
elif level == 'node':
self._validate_node_type(params.get('name'))
name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(
params['name'])
if not name:
name = ''
run_cmd = RunCMD(run_level='node', node_name=name)
else:
log.error("Invalid Value. `level` should be `step` or `node`. Got %s", level)
raise DebuggerParamValueError("level` should be `step` or `node`")

event.run_cmd.CopyFrom(run_cmd)
log.debug("Construct run event. %s", event)
return event

def _send_watchpoints(self):
"""Set watchpoints."""
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoints = watchpoint_stream.get(filter_condition=True).get('watch_points')
if watchpoints:
for watchpoint in watchpoints:
event = get_ack_reply()
event.set_cmd.CopyFrom(watchpoint)
self.cache_store.put_command(event)
watchpoint_stream.sync_set_cmd()
log.debug("Send SetCMD to MindSpore. %s", event)

def _pause(self, metadata_stream):
"""
Pause the training.

Args:
metadata_stream (MetadataHandler): The metadata stream handler.
"""
if metadata_stream.state != ServerStatus.RUNNING.value:
log.error("The MindSpore is not running.")
raise DebuggerPauseError("The MindSpore is not running.")
metadata_stream.state = 'waiting'
event = get_ack_reply()
event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
self.cache_store.put_command(event)
log.debug("Send the Pause command")
return {'metadata': {'state': 'waiting'}}

def _terminate(self, metadata_stream):
"""
Terminate the training.

Args:
metadata_stream (MetadataHandler): The metadata stream handler.
"""
metadata_stream.state = 'pending'
event = get_ack_reply()
event.exit = True
self.cache_store.put_command(event)
log.debug("Send the ExitCMD.")
return {'metadata': {'state': 'pending'}}

def retrieve_node_by_bfs(self, node_name, ascend=False):
"""Get the graph and tensor history of the next node name according to node_name."""
log.info("Retrieve node <%s> by bfs, `ascend` is :%s",
node_name, ascend)
reply = {}
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
# no next node
if next_node_name is None:
return reply
# add graph and tensor history for next node
filter_condition = {
'name': next_node_name,
'single_node': True
}
search_graph = self._get_nodes_info(filter_condition)
tensor_history = self._get_tensor_history(next_node_name)
reply = {'name': next_node_name}
reply.update(search_graph)
reply.update(tensor_history)

return reply

+ 113
- 0
mindinsight/debugger/proto/debug_grpc.proto View File

@@ -0,0 +1,113 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = "proto3";

package debugger;

import "mindinsight/debugger/proto/ms_graph.proto";


service EventListener {
rpc WaitCMD (Metadata) returns (EventReply) {};
rpc SendMetadata (Metadata) returns (EventReply) {};
rpc SendGraph (stream Chunk) returns (EventReply) {};
rpc SendTensors (stream TensorProto) returns (EventReply) {};
rpc SendWatchpointHits (stream WatchpointHit) returns (EventReply) {};
}

message Metadata {
string device_name = 1;
int32 cur_step = 2;
// define the backend is 'GPU' or 'Ascend'
string backend = 3;
// the full name of current node
string cur_node = 4;
// check if training is done.
bool training_done = 5;
}

message Chunk {
bytes buffer = 1;
}
message EventReply {
enum Status {
OK = 0;
FAILED = 1;
PENDING = 2;
}

Status status = 1;

oneof cmd {
bool exit = 2;
RunCMD run_cmd = 3;
SetCMD set_cmd = 4;
ViewCMD view_cmd = 5;
}
}

message RunCMD {
// running level. 'step' or 'node'
string run_level = 1;

oneof cmd {
int32 run_steps = 2;

// the full name of next node
string node_name = 3;
}
}

message SetCMD {
repeated WatchNode watch_nodes = 1;
WatchCondition watch_condition = 2;
bool delete = 3;
int32 id = 4;
}

message ViewCMD {
repeated TensorProto tensors = 1;
}

message WatchCondition {
enum Condition {
nan = 0;
inf = 1;
overflow = 2;
max_gt = 3;
max_lt = 4;
min_gt = 5;
min_lt = 6;
max_min_gt = 7;
max_min_lt = 8;
mean_gt = 9;
mean_lt = 10;
}
Condition condition = 1;
float value = 2; // for between condition, there will be two values
}

message WatchNode {
string node_name = 1;
string node_type = 2;
}

message WatchpointHit {
TensorProto tensor = 1;
WatchCondition watch_condition = 2;
int32 id = 3;
}

+ 683
- 0
mindinsight/debugger/proto/debug_grpc_pb2.py View File

@@ -0,0 +1,683 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: mindinsight/debugger/proto/debug_grpc.proto

from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()


from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2


DESCRIPTOR = _descriptor.FileDescriptor(
name='mindinsight/debugger/proto/debug_grpc.proto',
package='debugger',
syntax='proto3',
serialized_options=None,
serialized_pb=b'\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"k\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\"\x17\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\"\xec\x01\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\xee\x01\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\"\x95\x01\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x07\n\x03inf\x10\x01\x12\x0c\n\x08overflow\x10\x02\x12\n\n\x06max_gt\x10\x03\x12\n\n\x06max_lt\x10\x04\x12\n\n\x06min_gt\x10\x05\x12\n\n\x06min_lt\x10\x06\x12\x0e\n\nmax_min_gt\x10\x07\x12\x0e\n\nmax_min_lt\x10\x08\x12\x0b\n\x07mean_gt\x10\t\x12\x0b\n\x07mean_lt\x10\n\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"u\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x32\xc3\x02\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3'
,
dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,])



_EVENTREPLY_STATUS = _descriptor.EnumDescriptor(
name='Status',
full_name='debugger.EventReply.Status',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='OK', index=0, number=0,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FAILED', index=1, number=1,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PENDING', index=2, number=2,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
serialized_start=423,
serialized_end=464,
)
_sym_db.RegisterEnumDescriptor(_EVENTREPLY_STATUS)

_WATCHCONDITION_CONDITION = _descriptor.EnumDescriptor(
name='Condition',
full_name='debugger.WatchCondition.Condition',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='nan', index=0, number=0,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='inf', index=1, number=1,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='overflow', index=2, number=2,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='max_gt', index=3, number=3,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='max_lt', index=4, number=4,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='min_gt', index=5, number=5,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='min_lt', index=6, number=6,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='max_min_gt', index=7, number=7,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='max_min_lt', index=8, number=8,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='mean_gt', index=9, number=9,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='mean_lt', index=10, number=10,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
serialized_start=824,
serialized_end=973,
)
_sym_db.RegisterEnumDescriptor(_WATCHCONDITION_CONDITION)


_METADATA = _descriptor.Descriptor(
name='Metadata',
full_name='debugger.Metadata',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='device_name', full_name='debugger.Metadata.device_name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='cur_step', full_name='debugger.Metadata.cur_step', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='backend', full_name='debugger.Metadata.backend', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='cur_node', full_name='debugger.Metadata.cur_node', index=3,
number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='training_done', full_name='debugger.Metadata.training_done', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=100,
serialized_end=207,
)


_CHUNK = _descriptor.Descriptor(
name='Chunk',
full_name='debugger.Chunk',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='buffer', full_name='debugger.Chunk.buffer', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=209,
serialized_end=232,
)


_EVENTREPLY = _descriptor.Descriptor(
name='EventReply',
full_name='debugger.EventReply',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='status', full_name='debugger.EventReply.status', index=0,
number=1, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='exit', full_name='debugger.EventReply.exit', index=1,
number=2, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='run_cmd', full_name='debugger.EventReply.run_cmd', index=2,
number=3, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='set_cmd', full_name='debugger.EventReply.set_cmd', index=3,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='view_cmd', full_name='debugger.EventReply.view_cmd', index=4,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
_EVENTREPLY_STATUS,
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='cmd', full_name='debugger.EventReply.cmd',
index=0, containing_type=None, fields=[]),
],
serialized_start=235,
serialized_end=471,
)


_RUNCMD = _descriptor.Descriptor(
name='RunCMD',
full_name='debugger.RunCMD',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='run_level', full_name='debugger.RunCMD.run_level', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='run_steps', full_name='debugger.RunCMD.run_steps', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='node_name', full_name='debugger.RunCMD.node_name', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='cmd', full_name='debugger.RunCMD.cmd',
index=0, containing_type=None, fields=[]),
],
serialized_start=473,
serialized_end=549,
)


_SETCMD = _descriptor.Descriptor(
name='SetCMD',
full_name='debugger.SetCMD',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='watch_nodes', full_name='debugger.SetCMD.watch_nodes', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='watch_condition', full_name='debugger.SetCMD.watch_condition', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='delete', full_name='debugger.SetCMD.delete', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='id', full_name='debugger.SetCMD.id', index=3,
number=4, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=552,
serialized_end=681,
)


_VIEWCMD = _descriptor.Descriptor(
name='ViewCMD',
full_name='debugger.ViewCMD',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensors', full_name='debugger.ViewCMD.tensors', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=683,
serialized_end=732,
)


_WATCHCONDITION = _descriptor.Descriptor(
name='WatchCondition',
full_name='debugger.WatchCondition',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='condition', full_name='debugger.WatchCondition.condition', index=0,
number=1, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='value', full_name='debugger.WatchCondition.value', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
_WATCHCONDITION_CONDITION,
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=735,
serialized_end=973,
)


_WATCHNODE = _descriptor.Descriptor(
name='WatchNode',
full_name='debugger.WatchNode',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='node_name', full_name='debugger.WatchNode.node_name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='node_type', full_name='debugger.WatchNode.node_type', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=975,
serialized_end=1024,
)


_WATCHPOINTHIT = _descriptor.Descriptor(
name='WatchpointHit',
full_name='debugger.WatchpointHit',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='debugger.WatchpointHit.tensor', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='watch_condition', full_name='debugger.WatchpointHit.watch_condition', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='id', full_name='debugger.WatchpointHit.id', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=1026,
serialized_end=1143,
)

_EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS
_EVENTREPLY.fields_by_name['run_cmd'].message_type = _RUNCMD
_EVENTREPLY.fields_by_name['set_cmd'].message_type = _SETCMD
_EVENTREPLY.fields_by_name['view_cmd'].message_type = _VIEWCMD
_EVENTREPLY_STATUS.containing_type = _EVENTREPLY
_EVENTREPLY.oneofs_by_name['cmd'].fields.append(
_EVENTREPLY.fields_by_name['exit'])
_EVENTREPLY.fields_by_name['exit'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd']
_EVENTREPLY.oneofs_by_name['cmd'].fields.append(
_EVENTREPLY.fields_by_name['run_cmd'])
_EVENTREPLY.fields_by_name['run_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd']
_EVENTREPLY.oneofs_by_name['cmd'].fields.append(
_EVENTREPLY.fields_by_name['set_cmd'])
_EVENTREPLY.fields_by_name['set_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd']
_EVENTREPLY.oneofs_by_name['cmd'].fields.append(
_EVENTREPLY.fields_by_name['view_cmd'])
_EVENTREPLY.fields_by_name['view_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd']
_RUNCMD.oneofs_by_name['cmd'].fields.append(
_RUNCMD.fields_by_name['run_steps'])
_RUNCMD.fields_by_name['run_steps'].containing_oneof = _RUNCMD.oneofs_by_name['cmd']
_RUNCMD.oneofs_by_name['cmd'].fields.append(
_RUNCMD.fields_by_name['node_name'])
_RUNCMD.fields_by_name['node_name'].containing_oneof = _RUNCMD.oneofs_by_name['cmd']
_SETCMD.fields_by_name['watch_nodes'].message_type = _WATCHNODE
_SETCMD.fields_by_name['watch_condition'].message_type = _WATCHCONDITION
_VIEWCMD.fields_by_name['tensors'].message_type = mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO
_WATCHCONDITION.fields_by_name['condition'].enum_type = _WATCHCONDITION_CONDITION
_WATCHCONDITION_CONDITION.containing_type = _WATCHCONDITION
_WATCHPOINTHIT.fields_by_name['tensor'].message_type = mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO
_WATCHPOINTHIT.fields_by_name['watch_condition'].message_type = _WATCHCONDITION
DESCRIPTOR.message_types_by_name['Metadata'] = _METADATA
DESCRIPTOR.message_types_by_name['Chunk'] = _CHUNK
DESCRIPTOR.message_types_by_name['EventReply'] = _EVENTREPLY
DESCRIPTOR.message_types_by_name['RunCMD'] = _RUNCMD
DESCRIPTOR.message_types_by_name['SetCMD'] = _SETCMD
DESCRIPTOR.message_types_by_name['ViewCMD'] = _VIEWCMD
DESCRIPTOR.message_types_by_name['WatchCondition'] = _WATCHCONDITION
DESCRIPTOR.message_types_by_name['WatchNode'] = _WATCHNODE
DESCRIPTOR.message_types_by_name['WatchpointHit'] = _WATCHPOINTHIT
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

Metadata = _reflection.GeneratedProtocolMessageType('Metadata', (_message.Message,), {
'DESCRIPTOR' : _METADATA,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.Metadata)
})
_sym_db.RegisterMessage(Metadata)

Chunk = _reflection.GeneratedProtocolMessageType('Chunk', (_message.Message,), {
'DESCRIPTOR' : _CHUNK,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.Chunk)
})
_sym_db.RegisterMessage(Chunk)

EventReply = _reflection.GeneratedProtocolMessageType('EventReply', (_message.Message,), {
'DESCRIPTOR' : _EVENTREPLY,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.EventReply)
})
_sym_db.RegisterMessage(EventReply)

RunCMD = _reflection.GeneratedProtocolMessageType('RunCMD', (_message.Message,), {
'DESCRIPTOR' : _RUNCMD,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.RunCMD)
})
_sym_db.RegisterMessage(RunCMD)

SetCMD = _reflection.GeneratedProtocolMessageType('SetCMD', (_message.Message,), {
'DESCRIPTOR' : _SETCMD,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.SetCMD)
})
_sym_db.RegisterMessage(SetCMD)

ViewCMD = _reflection.GeneratedProtocolMessageType('ViewCMD', (_message.Message,), {
'DESCRIPTOR' : _VIEWCMD,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.ViewCMD)
})
_sym_db.RegisterMessage(ViewCMD)

WatchCondition = _reflection.GeneratedProtocolMessageType('WatchCondition', (_message.Message,), {
'DESCRIPTOR' : _WATCHCONDITION,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.WatchCondition)
})
_sym_db.RegisterMessage(WatchCondition)

WatchNode = _reflection.GeneratedProtocolMessageType('WatchNode', (_message.Message,), {
'DESCRIPTOR' : _WATCHNODE,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.WatchNode)
})
_sym_db.RegisterMessage(WatchNode)

WatchpointHit = _reflection.GeneratedProtocolMessageType('WatchpointHit', (_message.Message,), {
'DESCRIPTOR' : _WATCHPOINTHIT,
'__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2'
# @@protoc_insertion_point(class_scope:debugger.WatchpointHit)
})
_sym_db.RegisterMessage(WatchpointHit)



_EVENTLISTENER = _descriptor.ServiceDescriptor(
name='EventListener',
full_name='debugger.EventListener',
file=DESCRIPTOR,
index=0,
serialized_options=None,
serialized_start=1146,
serialized_end=1469,
methods=[
_descriptor.MethodDescriptor(
name='WaitCMD',
full_name='debugger.EventListener.WaitCMD',
index=0,
containing_service=None,
input_type=_METADATA,
output_type=_EVENTREPLY,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='SendMetadata',
full_name='debugger.EventListener.SendMetadata',
index=1,
containing_service=None,
input_type=_METADATA,
output_type=_EVENTREPLY,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='SendGraph',
full_name='debugger.EventListener.SendGraph',
index=2,
containing_service=None,
input_type=_CHUNK,
output_type=_EVENTREPLY,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='SendTensors',
full_name='debugger.EventListener.SendTensors',
index=3,
containing_service=None,
input_type=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO,
output_type=_EVENTREPLY,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='SendWatchpointHits',
full_name='debugger.EventListener.SendWatchpointHits',
index=4,
containing_service=None,
input_type=_WATCHPOINTHIT,
output_type=_EVENTREPLY,
serialized_options=None,
),
])
_sym_db.RegisterServiceDescriptor(_EVENTLISTENER)

DESCRIPTOR.services_by_name['EventListener'] = _EVENTLISTENER

# @@protoc_insertion_point(module_scope)

+ 193
- 0
mindinsight/debugger/proto/debug_grpc_pb2_grpc.py View File

@@ -0,0 +1,193 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
import grpc

from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2
from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2


class EventListenerStub(object):
"""Missing associated documentation comment in .proto file"""

def __init__(self, channel):
"""Constructor.

Args:
channel: A grpc.Channel.
"""
self.WaitCMD = channel.unary_unary(
'/debugger.EventListener/WaitCMD',
request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString,
response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
)
self.SendMetadata = channel.unary_unary(
'/debugger.EventListener/SendMetadata',
request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString,
response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
)
self.SendGraph = channel.stream_unary(
'/debugger.EventListener/SendGraph',
request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString,
response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
)
self.SendTensors = channel.stream_unary(
'/debugger.EventListener/SendTensors',
request_serializer=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString,
response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
)
self.SendWatchpointHits = channel.stream_unary(
'/debugger.EventListener/SendWatchpointHits',
request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString,
response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
)


class EventListenerServicer(object):
"""Missing associated documentation comment in .proto file"""

def WaitCMD(self, request, context):
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SendMetadata(self, request, context):
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SendGraph(self, request_iterator, context):
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SendTensors(self, request_iterator, context):
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SendWatchpointHits(self, request_iterator, context):
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_EventListenerServicer_to_server(servicer, server):
rpc_method_handlers = {
'WaitCMD': grpc.unary_unary_rpc_method_handler(
servicer.WaitCMD,
request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.FromString,
response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString,
),
'SendMetadata': grpc.unary_unary_rpc_method_handler(
servicer.SendMetadata,
request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.FromString,
response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString,
),
'SendGraph': grpc.stream_unary_rpc_method_handler(
servicer.SendGraph,
request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.FromString,
response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString,
),
'SendTensors': grpc.stream_unary_rpc_method_handler(
servicer.SendTensors,
request_deserializer=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.FromString,
response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString,
),
'SendWatchpointHits': grpc.stream_unary_rpc_method_handler(
servicer.SendWatchpointHits,
request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.FromString,
response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'debugger.EventListener', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
class EventListener(object):
"""Missing associated documentation comment in .proto file"""

@staticmethod
def WaitCMD(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/debugger.EventListener/WaitCMD',
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials,
call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def SendMetadata(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/debugger.EventListener/SendMetadata',
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials,
call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def SendGraph(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendGraph',
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials,
call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def SendTensors(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendTensors',
mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials,
call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def SendWatchpointHits(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendWatchpointHits',
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString,
mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString,
options, channel_credentials,
call_credentials, compression, wait_for_ready, timeout, metadata)

+ 322
- 0
mindinsight/debugger/proto/ms_graph.proto View File

@@ -0,0 +1,322 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = "proto2";

package debugger;

// Versioning
enum Version {
// unknown version
UNKNOWWN_VERSION = 0;
// Initial version (IR VERSION 1), published on Sep 23, 2019
IR_VERSION = 0x0000000000000001;
}

// Data type definition
enum DataType {
DT_UNDEFINED = 0;
// Basic types.
DT_BOOL = 1; // bool
DT_INT8 = 2; // int8_t
DT_INT16 = 3; // int16_t
DT_INT32 = 4; // int32_t
DT_INT64 = 5; // int64_t
DT_UINT8 = 6; // uint8_t
DT_UINT16 = 7; // uint16_t
DT_UINT32 = 8; // uint32_t
DT_UINT64 = 9; // uint64_t
DT_FLOAT16 = 10; // float 16
DT_FLOAT32 = 11; // float 32
DT_FLOAT64 = 12; // float 64
DT_STRING = 13; // string
DT_TENSOR = 14; // tensor
DT_GRAPH = 15; // graph
// list type
DT_BOOLS = 16; // list of bool
DT_INTS8 = 17; // list of int8_t
DT_INTS16 = 18; // list of int16_t
DT_INTS32 = 19; // list of int32_t
DT_INTS64 = 20; // list of int64_t
DT_UINTS8 = 21; // list of uint8_t
DT_UINTS16 = 22; // list of uint16_t
DT_UINTS32 = 23; // list of uint32_t
DT_UINTS64 = 24; // list of uint64_t
DT_FLOATS16 = 25; // list of float16
DT_FLOATS32 = 26; // list of float32
DT_FLOATS64 = 27; // list of float64
DT_STRINGS = 28; // list of string
DT_TENSORS = 29; // list of tensor
DT_GRAPHS = 30; // list of graph
DT_TUPLE = 31; // tuple
DT_LIST = 32; // list
DT_DICT = 33; // dictionary
// other types
DT_NONE = 34; // None
DT_SYM_INST = 35; // Symbolic Key Instance
// type related type
DT_BASE_INT = 36; // type generic int
DT_BASE_UINT = 37; // type generate unsigned int
DT_BASE_FLOAT = 38; // type generate float
DT_TYPE = 39; // type type
DT_ANYTHING = 40; // type anything
DT_REFKEY = 41; // type refkey
DT_REF = 42; // type ref
}

// Value definition for attribute value or parameter default value
message ValueProto {
// data type of value
optional DataType dtype = 1; // discriminator that indicates which field below is in use

// Exactly ONE of the following fields must be present for this version of the IR
optional bool bool_val = 2; // bool
optional int64 int_val = 3; // int
optional uint64 uint_val = 4; // uint
optional float float_val = 5; // float
optional double double_val = 6; // double
optional string str_val = 7; // string
optional TensorProto tensor_val = 8; // tensor value
optional GraphProto graph = 9; // graph
repeated bool bool_vals = 10; // list of bool
repeated int64 int_vals = 11; // list of int
repeated uint64 uint_vals = 12; // list of uint
repeated float float_vals = 13; // list of float
repeated double double_vals = 14; // list of double
repeated string str_vals = 15; // list of string
repeated TensorProto tensor_vals = 16; // list of tensor value
repeated GraphProto graphs = 17; // list of graph
// tuple or list
repeated ValueProto values = 18; // tuple, list of value
// dictionary
repeated NamedValueProto dict_val = 19; // dictionary info
// filed for type type
optional TypeProto type_val = 20; // type type info
}

message AttributeProto {
optional string name = 1; // attribute name
optional ValueProto value = 2; // attribute value
}

message NamedValueProto {
optional string key = 1; // attribute name
optional ValueProto value = 2; // attribute value
}

// Defines a tensor shape.
message TensorShapeProto {
// One dimension of the tensor.
message Dimension {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension).
optional int64 size = 1;

// Optional name of the tensor dimension.
optional string name = 2;
};
repeated Dimension dim = 1;
}

// Types for graph input(parameter) and output
message TypeProto {

message Tensor {
// This field MUST have a valid DataType value except DT_TENSOR
optional DataType elem_type = 1;
optional TensorShapeProto shape = 2; // for scalar, this field is not set
}

// tuple type
message Sequence {
// The type and optional shape of elements of the tuple.
repeated TypeProto elem_types = 1;
};

// data type
optional DataType data_type = 1;
oneof value {
// The type of a tensor.
Tensor tensor_type = 2;
// The type of a tuple.
Sequence sequence_type = 3;
}
}

// Defines information on graph parameters, including the name, the type, and
// the default value of parameter if exists.
message ParameterProto {
optional string name = 1; // parameter name
optional TypeProto type = 2; // parameter type
optional ValueProto default_val = 3; // default value of parameter if exists
}

// Defines graph output information
message OutputProto {
optional string name = 1; // output node name
optional TypeProto type = 2; // output node type
}

// Define node input information
message InputProto {
enum EdgeType {
DATA_EDGE = 0; // data edge
CONTROL_EDGE = 1; // control edge
}

optional string name = 1;
optional EdgeType type = 2;
}

// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated InputProto input = 1; // namespace Value
optional string name = 2; // namespace Value

// The symbolic identifier of the Operator to execute.
optional string op_type = 3; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
optional string scope = 4; // namespace Domain

// Additional named attributes.
repeated AttributeProto attribute = 5;
// Optional type info of this node
optional TypeProto output_type = 6;
// other fields for debug
optional uint64 output_i = 7;

// full name with scope
optional string full_name = 8;
}

// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto.
message ModelProto {
// ir version
optional int64 ir_version = 1;

// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
optional string domain = 2;

// The version of the graph encoded. See Version enum below.
optional int64 model_version = 3;

// The parameterized graph that is evaluated to execute the model.
optional GraphProto graph = 4;

// metadata info of opeartors
optional OperatorSetProto metadata_operators = 5;
};

message OperatorProto {
optional string name = 1; // used as key, must be distinct
optional bytes config = 2; // operator config info
optional bytes obj_info = 3; // operator related object info, e.g. content of operator binary or name
};

message OperatorSetProto {
repeated OperatorProto operators = 1;
optional string summary = 2; // summary info of operators, e.g. file position of operators file
}

// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;

// The name of the graph.
optional string name = 2; // namespace Graph

// The parameters(inputs) and outputs of the graph.
repeated ParameterProto parameters = 3;
repeated OutputProto outputs = 4;
// Constants used in this graph
repeated NamedValueProto const_vals = 5;
}

// Tensors
//
// A serialized tensor value.
message TensorProto {
// The node name of the tensor.
optional string node_name = 1;

// The slot of the tensor in its node.
optional string slot = 2;

// The serialized tensor content.
optional bytes tensor_content = 3;

// The shape of the tensor.
repeated int64 dims = 4;

// The data type of the tensor.
// This field MUST have a valid DataType value except DT_TENSOR
optional DataType data_type = 5;

// If the tensor content transferring is finished.
optional bool finished = 6;

// The iteration of the tensor. Supported: "prev" or leave empty.
optional string iter = 7;

// If the tensor name should be truncated.
optional bool truncate = 8;
}

+ 1395
- 0
mindinsight/debugger/proto/ms_graph_pb2.py
File diff suppressed because it is too large
View File


+ 14
- 0
mindinsight/debugger/stream_cache/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 289
- 0
mindinsight/debugger/stream_cache/debugger_graph.py View File

@@ -0,0 +1,289 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""This file is used to define the basic graph."""
from collections import deque

from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph
from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
from mindinsight.debugger.common.exceptions.exceptions import \
DebuggerNodeNotInGraphError, DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log
from .node import NodeTree


class DebuggerGraph(MSGraph):
"""The `DebuggerGraph` object provides interfaces to describe a debugger graph."""
def __init__(self):
super(DebuggerGraph, self).__init__()
self._node_tree = None

def get_node_name_by_full_name(self, full_name):
"""Get node name by full names."""
inner_name = self._full_name_map_name.get(full_name, '')
if not inner_name:
log.warning("Node %s does not find the relative inner node name.", full_name)

return inner_name

def get_full_name_by_node_name(self, node_name):
"""Get full name by node name for leaf nodes."""
node = self._normal_node_map.get(node_name)
if not node:
log.warning("Node %s is not leaf node.", node_name)

return node.full_name if node else ''

def get_nodes(self, searched_node_list):
"""
Search node names by a given pattern.

Args:
searched_node_list (list[Node]): A list of leaf nodes that
matches the given search pattern.

Returns:
A list of dict including the searched nodes.
[{
"name": "Default",
"type": "name_scope",
"nodes": [{
"name": "Default/Conv2D1",
"type": "name_scope",
"nodes": [{
...
}]
}]
},
{
"name": "Gradients",
"type": "name_scope",
"nodes": [{
"name": "Gradients/Default",
"type": "name_scope",
"nodes": [{
...
}]
}]
"""
# save the node in the NodeTree
self._node_tree = NodeTree()
for node in searched_node_list:
self._build_node_tree(node.name, node.type)

# get the searched nodes in the NodeTree and reorganize them
searched_list = []
self._traverse_node_tree(self._node_tree, searched_list)

return searched_list

def search_nodes_by_pattern(self, pattern):
"""
Search node names by a given pattern.

Args:
pattern (Union[str, None]): The pattern of the node to search,
if None, return all node names.

Returns:
list[(str, str)], a list of tuple (node name, node type).
"""
if pattern is not None:
pattern = pattern.lower()
searched_nodes = [
node for name, node in self._leaf_nodes.items()
if pattern in name.lower()
]
else:
searched_nodes = [node for name, node in self._leaf_nodes.items()]
return searched_nodes

def _build_node_tree(self, node_name, node_type):
"""Build node tree."""
scope_names = node_name.split('/')
cur_node = self._node_tree
for scope_name in scope_names[:-1]:
sub_node = cur_node.get(scope_name)
if not sub_node:
sub_node = cur_node.add(scope_name)
cur_node = sub_node
cur_node.add(scope_names[-1], node_type)

def _traverse_node_tree(self, cur_node, search_node_list):
"""Traverse the watch nodes and update the total watched node list."""
if not cur_node.get_children():
return
for _, sub_node in cur_node.get_children():
sub_nodes = []
self._traverse_node_tree(sub_node, sub_nodes)
sub_node_dict = {
'name': sub_node.node_name,
'type': sub_node.node_type,
'nodes': sub_nodes
}
search_node_list.append(sub_node_dict)

def get_node_type(self, node_name):
"""
Get the type of the node.

Args:
node_name (str): The full name of the node with its scope.

Returns:
A string, leaf or name_scope.
"""
if node_name and not self.exist_node(name=node_name):
raise DebuggerNodeNotInGraphError(node_name=node_name)

node = self._leaf_nodes.get(node_name)
if node is not None:
node_type = node.type
else:
node_type = NodeTypeEnum.NAME_SCOPE.value

return node_type

def get_tensor_history(self, node_name, depth=0):
"""
Get the tensor history of a specified node.

Args:
node_name (str): The debug name of the node.
depth (int): The number of layers the user wants to trace. Default is 0.

Returns:
list, a list of the traced tensors' name and node type,
arranged in order from leaf node to root node.
int, the number of output tensors.
"""
node = self._leaf_nodes.get(node_name)
tensor_history = self._get_tensor_infos_of_node(node)
cur_outputs_nums = len(tensor_history)
cur_depth = 0
trace_list = deque([(node, cur_depth)])
while trace_list:
cur_node, cur_depth = trace_list.popleft()
tensors_info = self._get_input_tensors_of_node(cur_node)
if tensors_info:
tensor_history.extend(tensors_info)
if cur_depth < depth:
for name in cur_node.input.keys():
trace_list.append((self._leaf_nodes[name], cur_depth + 1))

return tensor_history, cur_outputs_nums

@staticmethod
def _get_tensor_infos_of_node(cur_node, slot=None):
"""Get tensors info of specified node."""
tensors_info = []
if slot is None:
slots = range(cur_node.output_nums)
elif slot >= 0:
slots = [slot]
else:
log.info("Skip get tensor info for %s:%s.", cur_node.name, slot)
return tensors_info
for num in slots:
tensor_info = {
'name': cur_node.name + ':' + str(num),
'full_name': cur_node.full_name + ':' + str(num),
'node_type': cur_node.type
}
tensors_info.append(tensor_info)

return tensors_info

def _get_input_tensors_of_node(self, cur_node):
"""Get input tensors of node."""
tensors_info = []
for name in cur_node.input.keys():
node = self._leaf_nodes.get(name)
tensor_info = self._get_tensor_infos_of_node(node)
tensors_info.extend(tensor_info)

return tensors_info

def get_bfs_order(self):
"""
Traverse the graph in order of breath-first search.

Returns:
list, including the leaf nodes arranged in BFS order.
"""
root = self.get_default_root()
log.info('Randomly choose node %s as root to do BFS.', root.name)

bfs_order = []
self.get_bfs_graph(root.name, bfs_order)
length = len(self._leaf_nodes.keys())
# Find rest un-traversed nodes
for node_name, _ in self._leaf_nodes.items():
if node_name not in bfs_order:
self.get_bfs_graph(node_name, bfs_order)

if len(bfs_order) != length:
log.error("The length of bfs and leaf nodes are not equal.")
msg = "Not all nodes are traversed!"
raise DebuggerParamValueError(msg)

return bfs_order

def get_bfs_graph(self, node_name, bfs_order):
"""
Traverse the graph in order of breath-first search.

Returns:
list, including the leaf nodes arranged in BFS order.
"""
temp_list = deque()
temp_list.append(node_name)
while temp_list:
node_name = temp_list.popleft()
node = self._leaf_nodes.get(node_name)

if not node:
log.warning('Cannot find node %s in graph. Ignored.', node_name)
continue

bfs_order.append(node_name)
if node.input:
for name in node.input.keys():
if name not in temp_list and name not in bfs_order:
temp_list.append(name)
if node.output:
for name in node.output.keys():
if name not in temp_list and name not in bfs_order:
temp_list.append(name)

def get_default_root(self):
"""
Get a node as default root for BFS in graph. Using the
leaf node with the smallest node id as the default root.

Returns:
str, the name of the default root.
"""
default_root = None
for _, item in self._leaf_nodes.items():
if item.node_id == '1':
default_root = item
break

if default_root is None:
log.error("Abnormal graph. Invalid node for BFS.")
msg = 'Abnormal graph. Invalid node for BFS.'
raise DebuggerParamValueError(msg)

return default_root

+ 61
- 0
mindinsight/debugger/stream_cache/node.py View File

@@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
This file is used to define the node of graph and associated base types.
"""
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log


class NodeTree:
"""A class for building a node tree."""
def __init__(self, node_name='', node_type=None):
self.node_name = node_name
self._node_type = node_type
self._children = {}

@property
def node_type(self):
"""The property of node type."""
return self._node_type

@node_type.setter
def node_type(self, value):
"""Set the node type."""
self._node_type = value

def add(self, name, node_type=None):
"""Add sub node."""
sub_name = '/'.join([self.node_name, name]) if self.node_name else name
sub_node = NodeTree(sub_name, node_type)
self._children[name] = sub_node
return sub_node

def get(self, sub_name):
"""Get sub node."""
return self._children.get(sub_name)

def get_children(self):
"""Get all childrens."""
for name_scope, sub_node in self._children.items():
yield name_scope, sub_node

def remove(self, sub_name):
"""Remove sub node."""
try:
self._children.pop(sub_name)
except KeyError as err:
log.error("Failed to find node %s. %s", sub_name, err)
raise DebuggerParamValueError("Failed to find node {}".format(sub_name))

+ 233
- 0
mindinsight/debugger/stream_cache/tensor.py View File

@@ -0,0 +1,233 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The definition of tensor stream."""
from abc import abstractmethod, ABC

import numpy as np

from mindinsight.utils.tensor import TensorUtils
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.common.utils import NUMPY_TYPE_MAP
from mindinsight.debugger.proto.ms_graph_pb2 import DataType


class BaseTensor(ABC):
"""Tensor data structure."""

def __init__(self, step=0):
self._step = step

@property
@abstractmethod
def name(self):
"""The property of tensor name."""

@property
@abstractmethod
def dtype(self):
"""The property of tensor dtype."""

@property
@abstractmethod
def shape(self):
"""The property of tensor shape."""

@property
@abstractmethod
def value(self):
"""The property of tensor shape."""

@abstractmethod
def get_tensor_value_by_shape(self, shape=None):
"""Get tensor value by shape."""

def _to_dict(self):
"""Get tensor info in dict format."""
res = {
'full_name': self.name,
'step': self._step,
'dtype': self.dtype,
'shape': self.shape,
'has_prev_step': False
}
return res

def get_basic_info(self):
"""Return basic info about tensor info."""
if not self.shape:
value = self.value
else:
value = 'click to view'
res = self._to_dict()
res['value'] = value
return res

def get_full_info(self, shape=None):
"""Get tensor info with value."""
res = self._to_dict()
value_info = self.get_tensor_serializable_value_by_shape(shape)
res.update(value_info)
return res


class OpTensor(BaseTensor):
"""Tensor data structure for operator Node."""
max_number_data_show_on_ui = 100000

def __init__(self, tensor_proto, step=0):
# the type of tensor_proto is TensorProto
super(OpTensor, self).__init__(step)
self._tensor_proto = tensor_proto
self._value = self.generate_value(tensor_proto)

@property
def name(self):
"""The property of tensor name."""
node_name = self._tensor_proto.node_name
slot = self._tensor_proto.slot
return ':'.join([node_name, slot])

@property
def dtype(self):
"""The property of tensor dtype."""
tensor_type = DataType.Name(self._tensor_proto.data_type)

return tensor_type

@property
def shape(self):
"""The property of tensor shape."""
return list(self._tensor_proto.dims)

@property
def value(self):
"""The property of tensor value."""
tensor_value = None
if self._value is not None:
tensor_value = self._value.tolist()

return tensor_value

@property
def numpy_value(self):
"""The property of tensor value in numpy type."""
return self._value

def generate_value(self, tensor_proto):
"""Generate tensor value from proto."""
tensor_value = None
if tensor_proto.tensor_content:
tensor_value = tensor_proto.tensor_content
np_type = NUMPY_TYPE_MAP.get(self.dtype)
tensor_value = np.frombuffer(tensor_value, dtype=np_type)
tensor_value = tensor_value.reshape(self.shape)
return tensor_value

def get_tensor_serializable_value_by_shape(self, shape=None):
"""
Get tensor value info by shape.

Args:
shape (tuple): The specified range of tensor value.

Returns:
dict, the specified tensor value and value statistics.
"""
tensor_value = self.get_tensor_value_by_shape(shape)
res = {}
if isinstance(tensor_value, np.ndarray):
statistics = TensorUtils.get_statistics_from_tensor(tensor_value)
res['statistics'] = TensorUtils.get_statistics_dict(statistics)
res['value'] = tensor_value.tolist()
return res
return res

def get_tensor_value_by_shape(self, shape=None):
"""
Get tensor value by shape.

Args:
shape (tuple): The specified shape.

Returns:
Union[None, str, numpy.ndarray], the sub-tensor.
"""
if self._value is None:
log.warning("%s has no value yet.", self.name)
return None
if shape is None or not isinstance(shape, tuple):
log.info("Get the whole tensor value with shape is %s", shape)
return self._value
if len(shape) != len(self.shape):
log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape)
raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
try:
value = self._value[shape]
except IndexError as err:
log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape)
log.exception(err)
raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
if isinstance(value, np.ndarray):
if value.size > self.max_number_data_show_on_ui:
value = "Too large to show."
log.info("The tensor size is %s, which is too large to show on UI.")
else:
value = np.asarray(value)
return value

class ConstTensor(BaseTensor):
"""Tensor data structure for Const Node."""

def __init__(self, const_proto):
# the type of const_proto is NamedValueProto
super(ConstTensor, self).__init__()
self._const_proto = const_proto

def set_step(self, step):
"""Set step value."""
self._step = step

@property
def name(self):
"""The property of tensor name."""
return self._const_proto.key + ':0'

@property
def dtype(self):
"""The property of tensor dtype."""
return DataType.Name(self._const_proto.value.dtype)

@property
def shape(self):
"""The property of tensor shape."""
return []

@property
def value(self):
"""The property of tensor shape."""
fields = self._const_proto.value.ListFields()
if len(fields) != 2:
log.warning("Unexpected const proto <%s>.\n Please check offline.", self._const_proto)
for field_name, field_value in fields:
if field_name != 'dtype':
return field_value
return None

def get_tensor_value_by_shape(self, shape=None):
"""Get tensor info with value."""
if shape is not None:
log.warning("Invalid shape for const value.")
return self.value

+ 300
- 0
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -0,0 +1,300 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the watchpoint stream."""
from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition

WATCHPOINT_CONDITION_MAPPING = {
'INF': WatchCondition.Condition.inf,
'NAN': WatchCondition.Condition.nan,
'OVERFLOW': WatchCondition.Condition.overflow,
'MAX_GT': WatchCondition.Condition.max_gt,
'MAX_LT': WatchCondition.Condition.max_lt,
'MIN_GT': WatchCondition.Condition.min_gt,
'MIN_LT': WatchCondition.Condition.min_lt,
'MAX_MIN_GT': WatchCondition.Condition.max_min_gt,
'MAX_MIN_LT': WatchCondition.Condition.max_min_lt,
'MEAN_GT': WatchCondition.Condition.mean_gt,
'MEAN_LT': WatchCondition.Condition.mean_lt
}


class WatchNodeTree:
"""The WatchNode Node Structure."""
NOT_WATCH = 0 # the scope node and the nodes below are not watched
PARTIAL_WATCH = 1 # at least one node under the scope node is not watched
TOTAL_WATCH = 2 # the scope node and the nodes below are all watched

def __init__(self, node_name='', node_type=None, full_name='', watch_status=1):
self._node_name = node_name
self._full_name = full_name
self._node_type = self._translate_node_type(node_type)
self._watch_status = watch_status
self._children = {}

@property
def node_name(self):
"""The property of node name."""
return self._node_name

@property
def full_name(self):
"""The property of node name."""
return self._full_name

@property
def node_type(self):
"""The property of node type."""
return self._node_type

@node_type.setter
def node_type(self, value):
"""Set the node type."""
self._node_type = self._translate_node_type(value)

@property
def watch_status(self):
"""The property of watch status about current node."""
return self._watch_status

def enable_watch_status(self):
"""The property of watch status about current node."""
self._watch_status = WatchNodeTree.TOTAL_WATCH

@staticmethod
def _translate_node_type(node_type):
"""Translate node type to watch node type."""
if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value or \
node_type == NodeTypeEnum.AGGREGATION_SCOPE.value:
return 'scope'
return 'leaf'

def get(self, sub_name):
"""Get sub node."""
return self._children.get(sub_name)

def get_children(self):
"""Get all childrens."""
for name_scope, sub_watch_node in self._children.items():
yield name_scope, sub_watch_node

def add_node(self, node_name, node_type, full_name=''):
"""
Add watch node to watch node tree.

Args:
node_name (str): The node name.
node_type (str): The node type.
full_name (str): The full name of node.
"""
log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name)
scope_names = node_name.split('/', 1)
if len(scope_names) == 1:
if not self.get(node_name):
self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH)
else:
self.get(node_name).enable_watch_status()
return

scope_name, sub_names = scope_names
sub_tree = self.get(scope_name)
if not sub_tree:
sub_tree = self.add(scope_name, watch_status=1)
sub_tree.add_node(sub_names, node_type, full_name)

def add(self, name, node_type=None, full_name='', watch_status=1):
"""Add sub WatchPointTree."""
sub_name = '/'.join([self._node_name, name]) if self._node_name else name
sub_tree = WatchNodeTree(sub_name, node_type, full_name, watch_status)
self._children[name] = sub_tree

return sub_tree

def remove_node(self, node_name):
"""Remove sub node from current tree."""
log.debug("Remove %s", node_name)
scope_names = node_name.split('/', 1)
sub_tree_name = scope_names[0]
sub_tree = self._children.get(sub_tree_name)
if not sub_tree:
log.error("Failed to find node %s in WatchNodeTree.", sub_tree_name)
raise DebuggerParamValueError("Failed to find node {}".format(sub_tree_name))

if len(scope_names) > 1:
sub_tree.remove_node(scope_names[1])

if sub_tree.watch_status == WatchNodeTree.NOT_WATCH or len(scope_names) == 1:
self._children.pop(sub_tree_name)

self._watch_status = WatchNodeTree.PARTIAL_WATCH if self._children else \
WatchNodeTree.NOT_WATCH


class Watchpoint:
"""
The class of watchpoint stream.

Args:
watchpoint_id (int): The id of Watchpoint.
watch_condition (dict): The condition of Watchpoint.

- condition (str): Accept `INF` or `NAN`.

- param (list[float]): Not defined yet.
"""

def __init__(self, watchpoint_id, watch_condition):
self._id = watchpoint_id
self._condition = watch_condition
self._watch_node = WatchNodeTree()

@property
def watchpoint_id(self):
"""The property of watchpoint id."""
return self._id

@property
def nodes(self):
"""The property of watch nodes."""
return self._watch_node

@property
def condition(self):
"""The property of watch condition."""
return self._condition

def copy_nodes_from(self, other_watchpoint):
"""
Copy nodes from other watchpoint.
Args:
other_watchpoint (Watchpoint): Other watchpoint.
"""
self._watch_node = other_watchpoint.nodes

def add_nodes(self, nodes):
"""Add node into watchcpoint."""
if not nodes:
log.warning("Add empty nodes.")
return

if not isinstance(nodes, list):
nodes = [nodes]
for node in nodes:
self._watch_node.add_node(node.name, node.type, node.full_name)

def remove_nodes(self, nodes):
"""Remove nodes from watchpoint."""
if not nodes:
return
if not isinstance(nodes, list):
nodes = [nodes]
for node in nodes:
node_name = node.split(':')[0]
self._watch_node.remove_node(node_name)

def get_node_status(self, node_name, node_type, full_name):
"""Judge if the node is in watch nodes."""
scope_names = node_name.split('/')
cur_node = self._watch_node
status = 1
for scope_name in scope_names:
cur_node = cur_node.get(scope_name)
if cur_node is None:
status = WatchNodeTree.NOT_WATCH
break
if cur_node.watch_status == WatchNodeTree.TOTAL_WATCH:
status = WatchNodeTree.TOTAL_WATCH
break
if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name:
self._watch_node.add_node(node_name, node_type, full_name)

return status

def get_watch_node(self, cur_watch_node, watch_node_list):
"""
Traverse the watch nodes and add total watched node list to `watch_node_list`.

Args:
cur_watch_node (WatchNodeTree): The current watch node.
watch_node_list (list[WatchNodeTree]): The list of total watched node.
"""
if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH:
watch_node_list.append(cur_watch_node)
return
for _, watch_node in cur_watch_node.get_children():
self.get_watch_node(watch_node, watch_node_list)

def get_set_cmd(self):
"""Return the watchpoint in proto format."""
# get watch nodes.
watch_nodes = []
self.get_watch_node(self._watch_node, watch_nodes)
# construct SetCMD
set_cmd = SetCMD()
set_cmd.id = self._id
set_cmd.delete = False
set_cmd.watch_condition.condition = WATCHPOINT_CONDITION_MAPPING.get(
self._condition.get('condition'))
if self._condition.get('param'):
# at most one param is provided
set_cmd.watch_condition.value = self._condition.get('param')
for watch_node in watch_nodes:
event_node = set_cmd.watch_nodes.add()
event_node.node_name = watch_node.full_name
event_node.node_type = watch_node.node_type

return set_cmd

def get_watch_condition_info(self):
"""Get watch condition info."""
watchpoint_info = {
'id': self._id,
'watch_condition': self._condition
}
return watchpoint_info


class WatchpointHit:
"""The watchpoint hit structure."""

def __init__(self, tensor_proto, watchpoint, node_name):
self._node_name = node_name
self._full_name = tensor_proto.node_name
self._slot = tensor_proto.slot
self._watchpoint = watchpoint

@property
def tensor_full_name(self):
"""The property of tensor_name."""
tensor_name = ':'.join([self._full_name, self._slot])
return tensor_name

@property
def tensor_name(self):
"""The property of tensor_name."""
return ':'.join([self._node_name, self._slot])

@property
def watchpoint(self):
"""The property of watchpoint."""
watchpoint = self._watchpoint.get_watch_condition_info()
return watchpoint

def __eq__(self, other):
"""Define the equal condition."""
flag = self.tensor_full_name == other.tensor_full_name and self.watchpoint == other.watchpoint
return flag

+ 23
- 0
mindinsight/debugger/stream_handler/__init__.py View File

@@ -0,0 +1,23 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Import the streams handlers."""
from .event_handler import EventHandler
from .metadata_handler import MetadataHandler
from .graph_handler import GraphHandler
from .tensor_handler import TensorHandler
from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler

__all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler',
'WatchpointHandler', 'WatchpointHitHandler']

+ 34
- 0
mindinsight/debugger/stream_handler/base_handler.py View File

@@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the stream handler base."""
from abc import abstractmethod


class StreamHandlerBase:
"""The stream handler base."""

@abstractmethod
def put(self, value):
"""Abstract method of set data."""
return NotImplementedError

@abstractmethod
def get(self, filter_condition):
"""Abstract method of get data."""
return NotImplementedError

def clean(self):
"""Clean cache."""
self.__init__()

+ 159
- 0
mindinsight/debugger/stream_handler/event_handler.py View File

@@ -0,0 +1,159 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the message handler."""
import uuid
from queue import Queue, Empty
from threading import Lock

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase


class EventHandler(StreamHandlerBase):
"""Message Handler."""

max_limit = 1000 # the max number of items in cache

def __init__(self):
self._prev_flag = str(uuid.uuid4())
self._cur_flag = str(uuid.uuid4())
self._next_idx = 0
self._event_cache = [None] * self.max_limit
self._pending_requests = {}
self._lock = Lock()

@property
def next_pos(self):
"""The next pos to be updated in cache."""
return ':'.join([self._cur_flag, str(self._next_idx)])

def has_pos(self, pos):
"""Get the event according to pos."""
cur_flag, cur_idx = self._parse_pos(pos)
event = self._event_cache[cur_idx]
if event is not None:
if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \
(cur_flag == self._prev_flag and cur_idx >= self._next_idx):
return event

return None

def clean(self):
"""Clean event cache."""
with self._lock:
self._prev_flag = str(uuid.uuid4())
self._cur_flag = str(uuid.uuid4())
self._next_idx = 0
self._event_cache = [None] * self.max_limit
value = {'metadata': {'pos': '0'}}
self.clean_pending_requests(value)
log.debug("Clean event cache.")

def put(self, value):
"""
Put value into event_cache.

Args:
value (dict): The event to be put into cache.
"""
if not isinstance(value, dict):
log.error("Dict type required when put event message.")
raise DebuggerParamValueError("Dict type required when put event message.")

with self._lock:
log.debug("Put the %d-th message into queue. \n %d requests is waiting.",
self._next_idx, len(self._pending_requests))
cur_pos = self._next_idx
# update next pos
self._next_idx += 1
if self._next_idx >= self.max_limit:
self._next_idx = 0
self._prev_flag = self._cur_flag
self._cur_flag = str(uuid.uuid4())
# set next pos
if not value.get('metadata'):
value['metadata'] = {}
value['metadata']['pos'] = self.next_pos
self._event_cache[cur_pos] = value
# feed the value for pending requests
self.clean_pending_requests(value)

def clean_pending_requests(self, value):
"""Clean pending requests."""
for _, request in self._pending_requests.items():
request.put(value)
self._pending_requests = {}

def get(self, filter_condition=None):
"""
Get the pos-th value from event_cache according to filter_condition.

Args:
filter_condition (str): The index of event in cache. Default: None.

Returns:
object, the pos-th event.
"""
flag, idx = self._parse_pos(filter_condition)
cur_id = str(uuid.uuid4())
with self._lock:
# reset the pos after the cache is re-initialized.
if not flag or flag not in [self._cur_flag, self._prev_flag]:
idx = 0
# get event from cache immediately
if idx != self._next_idx and self._event_cache[idx]:
return self._event_cache[idx]
# wait for the event
cur_queue = Queue(maxsize=1)
self._pending_requests[cur_id] = cur_queue
# block until event has been received
event = self._wait_for_event(cur_id, cur_queue, filter_condition)

return event

def _parse_pos(self, pos):
"""Get next pos according to input position."""
elements = pos.split(':')
try:
idx = int(elements[-1])
except ValueError:
log.error("Invalid index. The index in pos should be digit but get pos:%s", pos)
raise DebuggerParamValueError("Invalid pos.")

if idx < 0 or idx >= self.max_limit:
log.error("Invalid index. The index in pos should between [0, %d)", self.max_limit)
raise DebuggerParamValueError(f"Invalid pos. {idx}")
flag = elements[0] if len(elements) == 2 else ''

return flag, idx

def _wait_for_event(self, cur_id, cur_queue, pos):
"""Wait for the pos-th event."""
try:
# set the timeout to 25 seconds which is less the the timeout limit from UI
event = cur_queue.get(timeout=25)
except Empty:
event = None

if event is None:
with self._lock:
if self._pending_requests.get(cur_id):
self._pending_requests.pop(cur_id)
log.debug("Clean timeout request. Left pending requests: %d",
len(self._pending_requests))
event = {'metadata': {'pos': pos}}

return event

+ 314
- 0
mindinsight/debugger/stream_handler/graph_handler.py View File

@@ -0,0 +1,314 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the graph stream handler."""
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerNodeNotInGraphError, DebuggerGraphNotExistError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.stream_cache.debugger_graph import DebuggerGraph
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase


class GraphHandler(StreamHandlerBase):
"""Metadata Handler."""

def __init__(self):
self._graph_proto = None
self._graph = None
self._searched_node_list = []
self.bfs_order = []

@property
def graph(self):
"""The property of graph."""
return self._graph_proto

def put(self, value):
"""
Put value into graph cache. Called by grpc server.

Args:
value (GraphProto): The Graph proto message.
"""
self._graph_proto = value
log.info("Put graph into cache.")

# build graph
graph = DebuggerGraph()
graph.build_graph(value)
self._graph = graph
self.bfs_order = self._graph.get_bfs_order()

def get(self, filter_condition=None):
"""
Get the graph of specific node.

Args:
filter_condition (dict):

- name (str): The full debug node name.

- single_node (bool): If True, return the graph from root
to the specific node; else, return the sublayer of the
graph. Default: False.

Returns:
dict, the metadata.
"""
try:
self._graph_exists()
except DebuggerGraphNotExistError:
log.warning('The graph is empty. To view a graph, '
'please start the training script first.')
return {'graph': {}}

if filter_condition is None:
filter_condition = {}
single_node = filter_condition.get('single_node', False)
name = filter_condition.get('name')

graph = {}
if single_node is True:
nodes = self.get_single_node(name)
else:
nodes = self.list_nodes(name)
graph.update(nodes)

return {'graph': graph}

def get_tensor_history(self, node_name, depth=0):
"""
Get the tensor history of a specified node.

Args:
node_name (str): The debug name of the node.
depth (int): The number of layers the user
wants to trace. Default is 0.

Returns:
dict, basic tensor history, only including tensor name and tensor type and node type.
"""
self._graph_exists()
if not self._graph.exist_node(node_name):
raise DebuggerNodeNotInGraphError(node_name)

tensor_history, cur_outputs_nums = self._graph.get_tensor_history(
node_name, depth
)
# add the tensor type for tensor history
self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output')
self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input')
log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name)
return {'tensor_history': tensor_history}

@staticmethod
def _update_tensor_history(tensor_history, tensor_type):
"""
Add tensor source type for tensor history.

Args:
tensor_history (list[dict]): Tensor history from Graph stream. Each element has two
keys: `node_type` and `name`. `node_type` refers to the type of the node which
the tensor come from. `name` refers to the tensor name.
tensor_type (str): The source type of the tensor. `input` or `output`.
"""
for single_tensor_info in tensor_history:
single_tensor_info['type'] = tensor_type

def search_nodes(self, pattern):
"""
Search nodes by given pattern.

Args:
pattern (Union[str, None]): The pattern of the node to search,
if None, return all node names.

Returns:
dict, the searched node.
"""
self._graph_exists()
self._searched_node_list = self._graph.search_nodes_by_pattern(pattern)
nodes = self._graph.get_nodes(self._searched_node_list)

return {'nodes': nodes}

def get_node_names(self, pattern=None):
"""Get graph nodes according to pattern."""
return self._graph.search_nodes_by_pattern(pattern)

def get_searched_node_list(self):
"""Get searched node list."""
return self._searched_node_list

def get_node_type(self, node_name):
"""
Get the type of the specified node.

Args:
node_name (str): The debug name of the node.

Returns:
A string of the node type, name_scope or leaf.
"""
self._graph_exists()
node_type = self._graph.get_node_type(node_name)

return node_type

def get_full_name(self, node_name):
"""Get full name according to ui node name."""
full_name = self._graph.get_full_name_by_node_name(node_name) if node_name else ''
return full_name

def get_node_name_by_full_name(self, full_name):
"""Get UI node name by full name."""
if self._graph:
node_name = self._graph.get_node_name_by_full_name(full_name)
else:
node_name = ''
log.info("No graph received yet.")
return node_name

def list_nodes(self, scope):
"""
Get the nodes of every layer in graph.

Args:
scope (str): The name of a scope.

Returns:
TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': [<Node object>]}.
example:
{
"nodes" : [
{
"attr" :
{
"index" : "i: 0\n"
},
"input" : {},
"name" : "input_tensor",
"output" :
{
"Default/TensorAdd-op17" :
{
"edge_type" : "data",
"scope" : "name_scope",
"shape" : [1, 16, 128, 128]
}
},
"output_i" : -1,
"proxy_input" : {},
"proxy_output" : {},
"independent_layout" : False,
"subnode_count" : 0,
"type" : "Data"
}
]
}
"""
if scope and not self._graph.exist_node(scope):
raise DebuggerNodeNotInGraphError(node_name=scope)

nodes = self._graph.list_node_by_scope(scope=scope)
return {'nodes': nodes}

def get_node_by_bfs_order(self, node_name=None, ascend=True):
"""
Traverse the graph in order of breath-first search by given node.

Args:
node_name (str): The node name which will be regarded
as the start node in graph.
ascend (bool): If True, traverse the input nodes;
If False, traverse the output nodes. Default is True.

Returns:
dict, including the searched node and its tensor value.
"""
self._graph_exists()
bfs_order = self.bfs_order
length = len(bfs_order)

if not bfs_order:
log.error('Cannot get the BFS order of the graph!')
msg = 'Cannot get the BFS order of the graph!'
raise DebuggerParamValueError(msg)

if node_name is None:
if ascend is False:
next_node = bfs_order[-1]
else:
next_node = bfs_order[0]
else:
try:
index = bfs_order.index(node_name)
log.debug("The index of the node in BFS list is: %d", index)
except ValueError as err:
log.error('Cannot find the node: %s. Please check '
'the node name: %s', node_name, err)
msg = f'Cannot find the node: {node_name}. ' \
f'Please check the node name {err}.'
raise DebuggerParamValueError(msg)

next_node = self.get_next_node_in_bfs(index, length, ascend)

return next_node

def get_next_node_in_bfs(self, index, length, ascend):
"""
Get the next node in bfs order.

Args:
index (int): The current index.
length (int): The number of all leaf nodes.
ascend (bool): Whether get the node in ascend order or not.

Returns:
Union[None, dict], the next node object in dict type or None.
"""
next_node = None
if 0 <= index < length:
if ascend is True and index < length - 1:
next_node = self.bfs_order[index + 1]
elif ascend is False and index > 0:
next_node = self.bfs_order[index - 1]

return next_node

def get_single_node(self, name):
"""
Search node, and return every layer nodes until this node.

Args:
name (str): The name of node.

Returns:
dict, every layer nodes until this node.
"""
nodes = self._graph.search_single_node(name)

return nodes

def _graph_exists(self):
"""
Check if the graph has been loaded in the debugger cache.

Raises:
DebuggerGraphNotExistError: If the graph does not exist.
"""
if self._graph is None:
log.error('The graph does not exist. Please start the '
'training script and try again.')
raise DebuggerGraphNotExistError

+ 131
- 0
mindinsight/debugger/stream_handler/metadata_handler.py View File

@@ -0,0 +1,131 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the metadata stream handler."""
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.common.utils import ServerStatus
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase


class MetadataHandler(StreamHandlerBase):
"""Metadata Handler."""

def __init__(self):
self._state = ServerStatus.PENDING
self._device_name = ""
self._step = 0
self._client_ip = ""
self._cur_node_name = ""
self._cur_full_name = ""
self._backend = ""

@property
def device_name(self):
"""The property of device name."""
return self._device_name

@property
def step(self):
"""The property of current step."""
return self._step

@property
def node_name(self):
"""The property of current node name."""
return self._cur_node_name

@node_name.setter
def node_name(self, node_name):
"""The property of current node name."""
self._cur_node_name = node_name

@property
def full_name(self):
"""The property of current node name."""
return self._cur_full_name

@property
def backend(self):
"""The property of current backend."""
return self._backend

@property
def state(self):
"""The property of state."""
return self._state.value

@state.setter
def state(self, value):
"""
Set the property of state.

Args:
value (str): The new state.
"""
self._state = ServerStatus(value)

@property
def client_ip(self):
"""The property of client ip."""
return self._client_ip

@client_ip.setter
def client_ip(self, value):
"""
Set the property of client ip.

Args:
value (str): The new ip.
"""
self._client_ip = str(value)

def put(self, value):
"""
Put value into metadata cache. Called by grpc server.

Args:
value (MetadataProto): The Metadata proto message.
"""
self._device_name = value.device_name.split(':')[0]
self._step = value.cur_step
self._cur_full_name = value.cur_node
self._backend = value.backend if value.backend else "Ascend"
log.debug("Put metadata into cache at the %d-th step.", self._step)

def get(self, filter_condition=None):
"""
Get updated value. Called by main server.

Args:
filter_condition (str): The filter property.

Returns:
dict, the metadata.
"""
metadata = {}
if filter_condition is None:
metadata = {
'state': self.state,
'step': self.step,
'device_name': self.device_name,
'pos': '0',
'ip': self.client_ip,
'node_name': self.node_name,
'backend': self.backend
}
else:
metadata[filter_condition] = getattr(self, filter_condition) if \
hasattr(self, filter_condition) else ''

return {'metadata': metadata}

+ 298
- 0
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -0,0 +1,298 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the tensor stream handler."""
import numpy as np

from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.proto.ms_graph_pb2 import DataType
from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
from mindinsight.utils.tensor import TensorUtils


class TensorHandler(StreamHandlerBase):
"""Metadata Handler."""

def __init__(self):
self._const_vals = {}
self._tensors = {}
self._cur_step = 0

def put(self, value):
"""
Put value into tensor cache. Called by grpc server.

Args:
value (dict): The Tensor proto message.

- step (int): The current step of tensor.

- tensor_protos (list[TensorProto]): The tensor proto.
"""
tensor_protos = value.get('tensor_protos')
merged_tensor = self._get_merged_tensor(tensor_protos)
step = value.get('step', 0)
if merged_tensor.iter and step > 0:
log.debug("Received previous tensor.")
step -= 1
tensor = OpTensor(merged_tensor, step)
self._put_tensor_into_cache(tensor, step)
log.debug("Put tensor %s of step: %d, into cache", tensor.name, step)

@staticmethod
def _get_merged_tensor(tensor_protos):
"""
Merged list of parsed tensor value into one.

Args:
tensor_protos (list[TensorProto]): List of tensor proto.

Returns:
TensorProto, merged tensor proto.
"""
merged_tensor = tensor_protos[-1]
if len(tensor_protos) > 1:
tensor_value = bytes()
for tensor_proto in tensor_protos:
if not tensor_proto.tensor_content:
log.warning("Doesn't find tensor value for %s:%s",
tensor_proto.node_name, tensor_proto.slot)
break
tensor_value += tensor_proto.tensor_content
merged_tensor.tensor_content = tensor_value
log.debug("Merge multi tensor values into one.")
return merged_tensor

def _put_tensor_into_cache(self, tensor, step):
"""
Put tensor into cache.

Args:
tensor (OpTensor): The tensor value.
"""
cache_tensor = self._tensors.get(tensor.name)
if cache_tensor is None:
cache_tensor = {}
self._tensors[tensor.name] = cache_tensor
cache_tensor[step] = tensor

def put_const_vals(self, const_vals):
"""
Put const value into tensor cache.

Args:
const_vals (list[NamedValueProto]): List of const values.
"""
for const_val in const_vals:
if not (const_val.value and const_val.key):
continue
if DataType.Name(const_val.value.dtype) == "DT_TENSOR":
tensor_proto = const_val.value.tensor_val
tensor_proto.node_name = const_val.key
tensor_proto.slot = '0'
const_tensor = OpTensor(tensor_proto)
else:
const_tensor = ConstTensor(const_val)
self._const_vals[const_tensor.name] = const_tensor

def get(self, filter_condition=None):
"""
Get full tensor value.

Args:
filter_condition (dict): Filter condition.

- name (str): The name of tensor.

- node_type (str): The type of the node.

Returns:
dict, the tensor_value.
"""
name = filter_condition.get('name')
node_type = filter_condition.get('node_type')
shape = filter_condition.get('shape')
tensor = self._get_tensor(name, node_type)
if not tensor:
log.error("No tensor named %s", name)
raise DebuggerParamValueError("No tensor named {}".format(name))
tensor_info = tensor.get_full_info(shape)
self._update_has_prev_step_field(tensor_info, name, node_type)
return {'tensor_value': tensor_info}

def _get_tensor(self, tensor_name, node_type=None, step=None):
"""
Get tensor according to tensor name and node_type.

Args:
tensor_name (str): Tensor name, format like `node_name:slot`.
node_type (str): Node type.
step (int): The step of tensor info. Default: None. Noe

Returns:
Union[OPTensor, ConstTensor], the tensor object.
"""
if step is None:
step = self._cur_step
tensor = self._tensors.get(tensor_name, {}).get(step)
if not tensor and node_type == NodeTypeEnum.CONST.value:
const_name = tensor_name.rsplit('/', 1)[-1]
tensor = self._const_vals.get(const_name)
self._tensors[tensor_name] = {step: tensor}

return tensor

def _get_basic_info(self, tensor_name, node_type=None):
"""Get the latest basic tensor info by tensor name."""
tensor = self._get_tensor(tensor_name, node_type)
if tensor:
return tensor.get_basic_info()

return None

def get_tensor_history(self, tensor_names):
"""Get tensor history for tensor names."""
# only used by grpc server, could be remove later
tensor_infos = []
for tensor_name in tensor_names:
tensor_info = self._get_basic_info(tensor_name)
tensor_infos.append(tensor_info)

return {'tensor_history': tensor_infos}

def update_tensor_history(self, tensor_history):
"""
Add tensor basic info in tensor_history.

Args:
tensor_history (dict): Tensor history, including a list of tensor name and type.

Returns:
list[dict], the list of tensor basic info cache.
"""
missed_tensors = []
for tensor_info in tensor_history.get('tensor_history'):
tensor_name = tensor_info.get('full_name')
node_type = tensor_info.get('node_type')
basic_info = self._get_basic_info(tensor_name, node_type)
flag = self._update_has_prev_step_field(basic_info, tensor_name, node_type)
if flag is False:
missed_tensor = tensor_info.copy()
missed_tensor['iter'] = 'prev'
missed_tensors.append(missed_tensor)
log.debug("Add previous view cmd for %s", tensor_name)
# add `has_prev_step` field to tensor basic info.
if basic_info:
tensor_info.update(basic_info)
else:
missed_tensors.append(tensor_info)
log.debug("Add view cmd for %s", tensor_name)

return missed_tensors

def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type):
"""Update has_prev_step field in tensor info."""
flag = None
if node_type == NodeTypeEnum.PARAMETER.value:
flag = self._has_prev_tensor_value(tensor_name)
if flag and tensor_info:
tensor_info['has_prev_step'] = True
return flag

def _has_prev_tensor_value(self, tensor_name):
"""
Check if the tensor has valid value of previous step.

Args:
tensor_name (str): Tensor name.

Returns:
bool, whether the tensor has valid tensor value.
"""
flag = None
# check if the tensor has previous step value.
prev_step = self._cur_step - 1
if prev_step < 0:
return flag
tensor = self._get_tensor(tensor_name, step=prev_step)
flag = bool(tensor and tensor.value)
return flag

def get_tensor_value_by_name(self, tensor_name, prev=False):
"""Get tensor value by name in numpy type."""
cur_step = self._cur_step
step = cur_step - 1 if prev else cur_step
if step < 0:
log.warning("%d step has no previous value for tensor: %s", cur_step, tensor_name)
return None
tensor = self._get_tensor(tensor_name, step=step)

return tensor

def clean_tensors(self, cur_step):
"""Clean the tensor cache."""
self._cur_step = cur_step
expired_tensor = []
for tensor_name, tensor in self._tensors.items():
expired_step = [step for step in tensor.keys() if step <= cur_step - 2]
for step in expired_step:
tensor.pop(step)
if not tensor:
expired_tensor.append(tensor_name)
for tensor_name in expired_tensor:
self._tensors.pop(tensor_name)
self._tensors = {}

def get_tensors_diff(self, tensor_name, shape, tolerance=0):
"""
Get tensor comparisons data for given name, detail, shape and tolerance.

Args:
tensor_name (str): The name of tensor for cache.
shape (tuple): Specify concrete dimensions of shape.
tolerance (str): Specify tolerance of difference between current step tensor and previous
step tensor. Default value is 0. Its is a percentage. The boundary value is equal to
max(abs(min),abs(max)) * tolerance. The function of min and max is being used to
calculate the min value and max value of the result of the current step tensor subtract
the previous step tensor. If the absolute value of result is less than or equal to
boundary value, the result will set to be zero.

Raises:
DebuggerParamValueError, If get current step node and previous step node failed.

Returns:
dict, the retrieved data.
"""
curr_tensor = self.get_tensor_value_by_name(tensor_name)
prev_tensor = self.get_tensor_value_by_name(tensor_name, prev=True)
if not (curr_tensor and prev_tensor):
log.error("Get current step and previous step for this tensor name %s failed.", tensor_name)
raise DebuggerParamValueError(f"Get current step and previous step for this tensor name "
f"{tensor_name} failed.")
curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape)
prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape)
tensor_info = curr_tensor.get_basic_info()
if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray):
diff_tensor = TensorUtils.calc_diff_between_two_tensor(curr_tensor_slice, prev_tensor_slice, tolerance)
result = np.stack([prev_tensor_slice, curr_tensor_slice, diff_tensor], axis=-1)
tensor_info['diff'] = result.tolist()
stats = TensorUtils.get_statistics_from_tensor(diff_tensor)
tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats)
del tensor_info['has_prev_step']
del tensor_info['value']
reply = {'tensor_value': tensor_info}
return reply

+ 333
- 0
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -0,0 +1,333 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the watchpoint stream handler."""
import numpy as np

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerParamTypeError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD
from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \
WATCHPOINT_CONDITION_MAPPING
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase


class WatchpointHandler(StreamHandlerBase):
"""watchpoint Handler."""

def __init__(self):
self._watchpoints = {}
self._deleted_watchpoints = []
self._updated_watchpoints = {}
self._latest_id = 0

def put(self, value):
"""
Put Watchpoint into watchpoint handler.

Args:
value (Watchpoint): The name of nodes that have been chosen.
"""
new_id = value.watchpoint_id
self._watchpoints[new_id] = value
self._updated_watchpoints[new_id] = value
self._latest_id = new_id
log.debug("Put watchpoint %d into cache.", new_id)

def sync_set_cmd(self):
"""Clean temp watchpoints."""
self._deleted_watchpoints = []
self._updated_watchpoints = {}

def get_watchpoint_by_id(self, watchpoint_id):
"""Get watchpoint by watchpoint id."""
watchpoint = self._watchpoints.get(watchpoint_id)
if not watchpoint:
log.error("Invalid watchpoint id %d", watchpoint_id)
raise DebuggerParamValueError("Invalid watchpoint id {}".format(watchpoint_id))

return watchpoint

def get(self, filter_condition=False):
"""
Get the watchpoints.

Args:
filter_condition (bool): If True, get all watchpoints without nodes. If False,
get updated watchpoints in SetCMD proto format. Default: False.

Returns:
dict, the watchpoints.
"""
reply = []
if not filter_condition:
# get watch condition list
for _, watchpoint in self._watchpoints.items():
watchpoint_info = watchpoint.get_watch_condition_info()
reply.append(watchpoint_info)
else:
# get updated watchpoint list
for _, watchpoint in self._updated_watchpoints.items():
set_cmd = watchpoint.get_set_cmd()
reply.append(set_cmd)
reply.extend(self._deleted_watchpoints)

log.debug("get the watch points with filter_condition:%s", filter_condition)

return {'watch_points': reply}

def set_watch_nodes(self, graph, watch_point_id):
"""
set watch nodes for graph.

Args:
graph (dict): The graph with list of nodes.
watch_point_id (int): The id of watchpoint.
"""
if not (watch_point_id and graph):
return
self._validate_watchpoint_id(watch_point_id)
log.debug("add watch flags")
watchpoint = self._watchpoints.get(watch_point_id)
self._set_watch_status_recursively(graph, watchpoint)

def _set_watch_status_recursively(self, graph, watchpoint):
"""Set watch status to graph."""
if not isinstance(graph, dict):
log.warning("The graph is not dict.")
return
if graph.get('children'):
self._set_watch_status_recursively(graph.get('children'), watchpoint)

for node in graph.get('nodes', []):
if not isinstance(node, dict):
log.warning("The node is not dict.")
return
node_name = node.get('name')
if not node_name:
continue
flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name'))
node['watched'] = flag
if node.get('nodes'):
self._set_watch_status_recursively(node, watchpoint)

def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None):
"""
Create watchpoint.
Args:
watch_condition (dict): The watch condition.

- condition (str): Accept `INF` or `NAN`.

- param (list[float]): Not defined yet.
watch_nodes (list[NodeBasicInfo]): The list of node basic info.
watch_point_id (int): The id of watchpoint.

Returns:
int, the new id of watchpoint.
"""
validate_watch_condition(watch_condition)
new_id = self._latest_id + 1
watchpoint = Watchpoint(new_id, watch_condition)
if watch_nodes:
watchpoint.add_nodes(watch_nodes)
elif watch_point_id:
self._validate_watchpoint_id(watch_point_id)
watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
self.put(watchpoint)

return new_id

def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
"""
Update watchpoint.

Args:
watch_point_id (int): The id of watchpoint.
watch_nodes (list[str]): The list of node names.
watched (bool): The update operator on nodes. If False, remove nodes from watch nodes.
If True, add nodes to watch nodes. Default: False.

Returns:
dict, empty response.
"""
self._validate_watchpoint_id(watch_point_id)
watchpoint = self._watchpoints.get(watch_point_id)
if watched:
watchpoint.add_nodes(watch_nodes)
else:
watchpoint.remove_nodes(watch_nodes)
self._updated_watchpoints[watch_point_id] = watchpoint
log.debug("Update watchpoint %d in cache.", watch_point_id)

def delete_watchpoint(self, watch_point_id):
"""
Delete watchpoint.

Args:
watch_point_id (int): The id of watchpoint.

Returns:
dict, empty response.
"""
self._validate_watchpoint_id(watch_point_id)
self._watchpoints.pop(watch_point_id)
set_cmd = SetCMD()
set_cmd.id = watch_point_id
set_cmd.delete = True
self._deleted_watchpoints.append(set_cmd)
log.debug("Delete watchpoint %d in cache.", watch_point_id)

def _validate_watchpoint_id(self, watch_point_id):
"""Validate watchpoint id."""
if watch_point_id and watch_point_id not in self._watchpoints:
log.error("Invalid watchpoint id: %d.", watch_point_id)
raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id))


class WatchpointHitHandler(StreamHandlerBase):
"""Watchpoint hit handler."""

def __init__(self):
self._hits = {}

def put(self, value):
"""
Put value into watchpoint hit cache. Called by grpc server.

Args:
value (dict): The watchpoint hit info.

- tensor_proto (TensorProto): The message about hit tensor.

- watchpoint (Watchpoint): The Watchpoint that a node hit.
"""
watchpoint_hit = WatchpointHit(
tensor_proto=value.get('tensor_proto'),
watchpoint=value.get('watchpoint'),
node_name=value.get('node_name')
)

node_name = value.get('node_name')
hit_tensors = self._hits.get(node_name)
if hit_tensors is None:
hit_tensors = []
self._hits[node_name] = hit_tensors
if watchpoint_hit not in hit_tensors:
hit_tensors.append(watchpoint_hit)

def get(self, filter_condition=None):
"""
Get watchpoint hit list.

Args:
filter_condition (str): Get the watchpoint hit according to specifiled node name.
If not given, get all watchpoint hits. Default: None.

Returns:
dict, the watchpoint hit list.
"""
if filter_condition is None:
log.debug("Get all watchpoint hit list.")
reply = self.get_watchpoint_hits()
else:
log.debug("Get the watchpoint for node: <%s>.", filter_condition)
reply = self._hits.get(filter_condition)

return reply

def get_watchpoint_hits(self):
"""Return the list of watchpoint hits."""
watch_point_hits = []
for node_name, watchpoint_hits in self._hits.items():
watch_points = [watchpoint_hit.watchpoint for watchpoint_hit in watchpoint_hits]
watch_point_hits.append({
'node_name': node_name,
'watch_points': watch_points
})

return {'watch_point_hits': watch_point_hits}

def _is_tensor_hit(self, tensor_name):
"""Check if the tensor is record in hit cache."""
node_name = tensor_name.split(':')[0]
watchpoint_hits = self.get(node_name)
if watchpoint_hits is None:
return False

for watchpoint_hit in watchpoint_hits:
if tensor_name == watchpoint_hit.tensor_name:
return True

return False

def update_tensor_history(self, tensor_history):
"""
Add hit flag to tensor history.

Args:
tensor_history (dict): The tensor history.
"""
if not self._hits:
return

# add hit tensor names to `tensor_names`
for tensor_info in tensor_history.get('tensor_history'):
tensor_name = tensor_info['full_name']
hit_flag = self._is_tensor_hit(tensor_name)
tensor_info['is_hit'] = hit_flag


def validate_watch_condition(watch_condition):
"""Validate watch condition."""
if not isinstance(watch_condition, dict):
log.error("<watch_condition> should be dict. %s received.", watch_condition)
raise DebuggerParamTypeError("<watch_condition> should be dict.")
# validate condition
condition = watch_condition.get('condition')
if condition not in WATCHPOINT_CONDITION_MAPPING.keys():
log.error("Invalid watch condition. Acceptable values are <%s>.",
str(WATCHPOINT_CONDITION_MAPPING.keys()))
raise DebuggerParamValueError("Invalid watch condition value.")
# validate param
validate_watch_condition_params(watch_condition)


def validate_watch_condition_params(watch_condition):
"""
Validate watch condition parameters.

Args:
watch_condition (dict): Watch condition.

- condition (str): Condition type. Should be in WATCHPOINT_CONDITION_MAPPING.

- param (list): Condition value. Should be given for comparison condition. The value will
be translated to np.float32.
"""
condition = watch_condition.get('condition')
param = watch_condition.get('param')
if condition in ['NAN', 'INF', 'OVERFLOW']:
if param:
log.error("No param is expected for %s condition.", condition)
raise DebuggerParamValueError("No param is expected.")
else:
if not isinstance(param, (float, int)):
log.error("Number param should be given for condition <%s>.",
condition)
raise DebuggerParamValueError("Number param should be given.")
if np.isinf(np.float32(param)):
log.error("Condition param should be float32.")
raise DebuggerParamValueError("The value of condition param should be within float32.")

+ 45
- 3
mindinsight/scripts/start.py View File

@@ -14,19 +14,28 @@
# ============================================================================
"""Start mindinsight service."""

import argparse
import os
import sys
import re
import argparse
import sys
from importlib import import_module

import psutil

from mindinsight.conf import settings
from mindinsight.utils.command import BaseCommand
from mindinsight.utils.exceptions import PortNotAvailableError
from mindinsight.utils.hook import HookUtils
from mindinsight.utils.hook import init
from mindinsight.utils.exceptions import PortNotAvailableError


def str2bool(string):
"""Convert str to bool"""
if string.lower() == 'false':
return False
if string.lower() == 'true':
return True
raise ValueError


class ConfigAction(argparse.Action):
@@ -146,6 +155,23 @@ class UrlPathPrefixAction(argparse.Action):
setattr(namespace, self.dest, prefix)


class EnableDebuggerAction(argparse.Action):
"""SSL certificate action class definition."""

def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.

Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
enable_debugger = values
setattr(namespace, self.dest, enable_debugger)


class Command(BaseCommand):
"""
Start mindinsight service.
@@ -186,6 +212,14 @@ class Command(BaseCommand):
Custom port ranging from %s to %s. Default value is %s.
""" % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.PORT))

parser.add_argument(
'--debugger_port',
type=int,
action=PortAction,
help="""
Debugger port ranging from %s to %s. Default value is %s.
""" % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.DEBUGGER_PORT))

parser.add_argument(
'--url-path-prefix',
type=str,
@@ -197,6 +231,14 @@ class Command(BaseCommand):
dot or double dots. Default value is ''.
""")

parser.add_argument(
'--enable_debugger',
type=str2bool,
action=EnableDebuggerAction,
default=False,
help="""
Enable debugger or not.
Dfault is False.""")
for hook in HookUtils.instance().hooks():
hook.register_startup_arguments(parser)



+ 5
- 0
mindinsight/utils/constant.py View File

@@ -33,6 +33,7 @@ class MindInsightModules(Enum):
SCRIPTCONVERTER = 7
WIZARD = 9
OPTIMIZER = 10
DEBUGGER = 11


class GeneralErrors(Enum):
@@ -56,6 +57,10 @@ class LineageMgrErrors(Enum):
"""Enum definition for lineage errors."""


class DebuggerErrors(Enum):
"""Enum definition for debugger errors."""


class DataVisualErrors(Enum):
"""Enum definition for datavisual errors."""
RESTFUL_API_NOT_EXIST = 1


+ 298
- 0
mindinsight/utils/tensor.py View File

@@ -0,0 +1,298 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tensor utils."""

import numpy as np

from mindinsight.datavisual.utils.tools import to_int
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.exceptions import ParamTypeError
from mindinsight.utils.log import utils_logger as logger

F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max


class Statistics:
"""Statistics data class.

Args:
max_value (float): max value of tensor data.
min_value (float): min value of tensor data.
avg_value (float): avg value of tensor data.
count (int): total count of tensor data.
nan_count (int): count of NAN.
neg_inf_count (int): count of negative INF.
pos_inf_count (int): count of positive INF.
"""

def __init__(self, max_value=0, min_value=0, avg_value=0,
count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0):
self._max = max_value
self._min = min_value
self._avg = avg_value
self._count = count
self._nan_count = nan_count
self._neg_inf_count = neg_inf_count
self._pos_inf_count = pos_inf_count

@property
def max(self):
"""Get max value of tensor."""
return self._max

@property
def min(self):
"""Get min value of tensor."""
return self._min

@property
def avg(self):
"""Get avg value of tensor."""
return self._avg

@property
def count(self):
"""Get total count of tensor."""
return self._count

@property
def nan_count(self):
"""Get count of NAN."""
return self._nan_count

@property
def neg_inf_count(self):
"""Get count of negative INF."""
return self._neg_inf_count

@property
def pos_inf_count(self):
"""Get count of positive INF."""
return self._pos_inf_count


class TensorUtils:
"""Tensor Utils class."""

@staticmethod
def validate_dims_format(dims):
"""
Validate correct of format of dimension parameter.

Args:
dims (str): Dims of tensor. Its format is something like this "[0, 0, :, :]".

Raises:
ParamValueError: If format of dims is not correct.
"""
if dims is not None:
if not isinstance(dims, str):
raise ParamTypeError(dims, str)
dims = dims.strip()
if not (dims.startswith('[') and dims.endswith(']')):
raise ParamValueError('The value: {} of dims must be '
'start with `[` and end with `]`.'.format(dims))
for dim in dims[1:-1].split(','):
dim = dim.strip()
if dim == ":":
continue
if dim.startswith('-'):
dim = dim[1:]
if not dim.isdigit():
raise ParamValueError('The value: {} of dims in the square brackets '
'must be int or `:`.'.format(dims))

@staticmethod
def convert_array_from_str_dims(dims, limit=0):
"""
Convert string of dims data to array.

Args:
dims (str): Specify dims of tensor.
limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation.

Returns:
list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None].

Raises:
ParamValueError, If flexible dimensions exceed limit value.
"""
dims = dims.strip().lstrip('[').rstrip(']')
dims_list = []
count = 0
for dim in dims.split(','):
dim = dim.strip()
if dim == ':':
dims_list.append(None)
count += 1
else:
dims_list.append(to_int(dim, "dim"))
if limit and count > limit:
raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}"
.format(limit, count))
return dims_list

@staticmethod
def get_specific_dims_data(ndarray, dims, tensor_dims):
"""
Get specific dims data.

Args:
ndarray (numpy.ndarray): An ndarray of numpy.
dims (list): A list of specific dims.
tensor_dims (list): A list of tensor dims.

Returns:
numpy.ndarray, an ndarray of specific dims tensor data.

Raises:
ParamValueError, If the length of param dims is not equal to the length of tensor dims or
the index of param dims out of range.
"""
if len(dims) != len(tensor_dims):
raise ParamValueError("The length of param dims: {}, is not equal to the "
"length of tensor dims: {}.".format(len(dims), len(tensor_dims)))
indices = []
for k, d in enumerate(dims):
if d is not None:
if d >= tensor_dims[k]:
raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k]))
indices.append(d)
else:
indices.append(slice(0, tensor_dims[k]))
result = ndarray[tuple(indices)]
# Make sure the return type is numpy.ndarray.
if not isinstance(result, np.ndarray):
result = np.array(result)
return result

@staticmethod
def get_statistics_from_tensor(tensors):
"""
Calculates statistics data of tensor.

Args:
tensors (numpy.ndarray): An numpy.ndarray of tensor data.

Returns:
an instance of Statistics.
"""
ma_value = np.ma.masked_invalid(tensors)
total, valid = tensors.size, ma_value.count()
invalids = []
for isfn in np.isnan, np.isposinf, np.isneginf:
if total - valid > sum(invalids):
count = np.count_nonzero(isfn(tensors))
invalids.append(count)
else:
invalids.append(0)

nan_count, pos_inf_count, neg_inf_count = invalids
if not valid:
logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape)
statistics = Statistics(max_value=0,
min_value=0,
avg_value=0,
count=total,
nan_count=nan_count,
neg_inf_count=neg_inf_count,
pos_inf_count=pos_inf_count)
return statistics

# BUG: max of a masked array with dtype np.float16 returns inf
# See numpy issue#15077
if issubclass(tensors.dtype.type, np.floating):
tensor_min = ma_value.min(fill_value=np.PINF)
tensor_max = ma_value.max(fill_value=np.NINF)
if tensor_min < F32_MIN or tensor_max > F32_MAX:
logger.warning('Values(%f, %f) are too large, you may encounter some undefined '
'behaviours hereafter.', tensor_min, tensor_max)
else:
tensor_min = ma_value.min()
tensor_max = ma_value.max()
tensor_sum = ma_value.sum(dtype=np.float64)
statistics = Statistics(max_value=tensor_max,
min_value=tensor_min,
avg_value=tensor_sum / valid,
count=total,
nan_count=nan_count,
neg_inf_count=neg_inf_count,
pos_inf_count=pos_inf_count)
return statistics

@staticmethod
def get_statistics_dict(stats):
"""
Get statistics dict according to statistics value.

Args:
stats (Statistics): An instance of Statistics.

Returns:
dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'.
"""
statistics = {
"max": float(stats.max),
"min": float(stats.min),
"avg": float(stats.avg),
"count": stats.count,
"nan_count": stats.nan_count,
"neg_inf_count": stats.neg_inf_count,
"pos_inf_count": stats.pos_inf_count
}
return statistics

@staticmethod
def calc_diff_between_two_tensor(first_tensor, second_tensor, tolerance):
"""
Calculate the difference between the first tensor and the second tensor.

Args:
first_tensor (numpy.ndarray): Specify the first tensor.
second_tensor (numpy.ndarray): Specify the second tensor.
tolerance (float): The tolerance of difference between the first tensor and the second tensor.
Its is a percentage. The boundary value is equal to max(abs(min),abs(max)) * tolerance.
The function of min and max is being used to calculate the min value and max value of
the result of the first tensor subtract the second tensor. If the absolute value of
result is less than or equal to boundary value, the result will set to be zero.

Returns:
tuple[numpy.ndarray, OverallDiffMetric], numpy.ndarray indicates the value of the first tensor
subtract the second tensor and set the value to be zero when its less than or equal to tolerance.

Raises:
ParamTypeError: If the type of these two tensors is not the numpy.ndarray.
ParamValueError: If the shape or dtype is not the same of these two tensors.
"""
if not isinstance(first_tensor, np.ndarray):
raise ParamTypeError('first_tensor', np.ndarray)

if not isinstance(second_tensor, np.ndarray):
raise ParamTypeError('second_tensor', np.ndarray)

if first_tensor.shape != second_tensor.shape:
raise ParamValueError("the shape: {} of first tensor is not equal to shape: {} of second tensor."
.format(first_tensor.shape, second_tensor.shape))

if first_tensor.dtype != second_tensor.dtype:
raise ParamValueError("the dtype: {} of first tensor is not equal to dtype: {} of second tensor."
.format(first_tensor.dtype, second_tensor.dtype))

diff_tensor = np.subtract(first_tensor, second_tensor)
stats = TensorUtils.get_statistics_from_tensor(diff_tensor)
boundary_value = max(abs(stats.max), abs(stats.min)) * tolerance
is_close = np.isclose(first_tensor, second_tensor, atol=boundary_value, rtol=0)
result = np.multiply(diff_tensor, ~is_close)
return result

+ 2
- 1
requirements.txt View File

@@ -16,4 +16,5 @@ Werkzeug>=1.0.0
tabulate>=0.8.6
pandas>=1.0.4
yapf>=0.30.0
treelib>=1.6.1
treelib>=1.6.1
grpcio>=1.29.0

+ 8
- 4
tests/ut/datavisual/data_transform/test_tensor_container.py View File

@@ -14,9 +14,11 @@
# ============================================================================
"""Test tensor container."""
import unittest.mock as mock

import numpy as np

from mindinsight.datavisual.data_transform import tensor_container as tensor
from mindinsight.utils.tensor import TensorUtils


class TestTensorContainer:
@@ -34,8 +36,9 @@ class TestTensorContainer:

def test_get_statistics_from_tensor(self):
"""Tests get statistics from tensor."""
ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape([2, 2, 2])
statistics = tensor.get_statistics_from_tensor(ndarray)
ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape(
[2, 2, 2])
statistics = TensorUtils.get_statistics_from_tensor(ndarray)
assert (statistics.max, statistics.min, statistics.avg, statistics.count,
statistics.nan_count, statistics.neg_inf_count, statistics.pos_inf_count) == \
(5, 1, 3, 8,
@@ -43,8 +46,9 @@ class TestTensorContainer:

def test_calc_original_buckets(self):
"""Tests calculate original buckets."""
ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape([2, 2, 2])
statistics = tensor.get_statistics_from_tensor(ndarray)
ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape(
[2, 2, 2])
statistics = TensorUtils.get_statistics_from_tensor(ndarray)
buckets = tensor.calc_original_buckets(ndarray, statistics)

assert (buckets[0].left, buckets[0].width, buckets[0].count) == (1, 2, 2)


+ 5
- 6
tests/ut/datavisual/processors/test_tensor_processor.py View File

@@ -29,10 +29,8 @@ from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.exceptions import TensorNotExistError
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.tensor_container import calc_original_buckets
from mindinsight.datavisual.data_transform.tensor_container import get_statistics_from_tensor
from mindinsight.datavisual.processors.tensor_processor import TensorProcessor
from mindinsight.datavisual.processors.tensor_processor import get_specific_dims_data
from mindinsight.datavisual.processors.tensor_processor import get_statistics_dict
from mindinsight.utils.tensor import TensorUtils
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.exceptions import ParamMissError
@@ -187,7 +185,7 @@ class TestTensorProcessor:
dims = expected_values.get('value').get("dims")
expected_data = np.array(expected_values.get('value').get("float_data")).reshape(dims)
recv_tensor = np.array(recv_values.get('value').get("data"))
expected_tensor = get_specific_dims_data(expected_data, [0, 0, None, None], dims)
expected_tensor = TensorUtils.get_specific_dims_data(expected_data, [0, 0, None, None], dims)
assert np.sum(np.isclose(recv_tensor, expected_tensor, rtol=1e-6) == 0) == 0

@pytest.mark.usefixtures('load_tensor_record')
@@ -204,7 +202,8 @@ class TestTensorProcessor:
assert recv_values.get('wall_time') == expected_values.get('wall_time')
assert recv_values.get('step') == expected_values.get('step')
expected_data = expected_values.get('value').get("float_data")
expected_statistic = get_statistics_dict(get_statistics_from_tensor(expected_data))
expected_statistic_instance = TensorUtils.get_statistics_from_tensor(expected_data)
expected_statistic = TensorUtils.get_statistics_dict(expected_statistic_instance)
recv_statistic = recv_values.get('value').get("statistics")
assert recv_statistic.get("max") - expected_statistic.get("max") < 1e-6
assert recv_statistic.get("min") - expected_statistic.get("min") < 1e-6
@@ -225,7 +224,7 @@ class TestTensorProcessor:
assert recv_values.get('wall_time') == expected_values.get('wall_time')
assert recv_values.get('step') == expected_values.get('step')
expected_data = expected_values.get('value').get("float_data")
expected_statistic = get_statistics_from_tensor(expected_data)
expected_statistic = TensorUtils.get_statistics_from_tensor(expected_data)
expected_buckets = calc_original_buckets(expected_data, expected_statistic)
recv_buckets = recv_values.get('value').get("histogram_buckets")



Loading…
Cancel
Save