| @@ -0,0 +1,26 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Module init file.""" | |||
| from mindinsight.backend.debugger.debugger_api import init_module as init_query_module | |||
| def init_module(app): | |||
| """ | |||
| Init module entry. | |||
| Args: | |||
| app (Flask): A Flask instance. | |||
| """ | |||
| init_query_module(app) | |||
| @@ -0,0 +1,300 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Debugger restful api.""" | |||
| import json | |||
| from flask import Blueprint, jsonify, request | |||
| from mindinsight.conf import settings | |||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| BLUEPRINT = Blueprint("debugger", __name__, | |||
| url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) | |||
| def _initialize_debugger_server(): | |||
| """Initialize a debugger server instance.""" | |||
| port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else None | |||
| enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | |||
| server = None | |||
| if port and enable_debugger: | |||
| server = DebuggerServer(port) | |||
| return server | |||
| def _read_post_request(post_request): | |||
| """ | |||
| Extract the body of post request. | |||
| Args: | |||
| post_request (object): The post request. | |||
| Returns: | |||
| dict, the deserialized body of request. | |||
| """ | |||
| body = post_request.stream.read() | |||
| try: | |||
| body = json.loads(body if body else "{}") | |||
| except Exception: | |||
| raise ParamValueError("Json data parse failed.") | |||
| return body | |||
| def _wrap_reply(func, *args, **kwargs): | |||
| """Serialize reply.""" | |||
| reply = func(*args, **kwargs) | |||
| return jsonify(reply) | |||
| @BLUEPRINT.route("/debugger/poll_data", methods=["GET"]) | |||
| def poll_data(): | |||
| """ | |||
| Wait for data to be updated on UI. | |||
| Get data from server and display the change on UI. | |||
| Returns: | |||
| str, the updated data. | |||
| Examples: | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/poll_data?pos=xx | |||
| """ | |||
| pos = request.args.get('pos') | |||
| reply = _wrap_reply(BACKEND_SERVER.poll_data, pos) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/search", methods=["GET"]) | |||
| def search(): | |||
| """ | |||
| Search nodes in specified watchpoint. | |||
| Returns: | |||
| str, the required data. | |||
| Examples: | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/retrive?mode=all | |||
| """ | |||
| name = request.args.get('name') | |||
| watch_point_id = int(request.args.get('watch_point_id', 0)) | |||
| reply = _wrap_reply(BACKEND_SERVER.search, name, watch_point_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"]) | |||
| def retrieve_node_by_bfs(): | |||
| """ | |||
| Search node by bfs. | |||
| Returns: | |||
| str, the required data. | |||
| Examples: | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true | |||
| """ | |||
| name = request.args.get('name') | |||
| ascend = request.args.get('ascend', 'false') | |||
| ascend = ascend == 'true' | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, ascend) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"]) | |||
| def tensor_comparisons(): | |||
| """ | |||
| Get tensor comparisons. | |||
| Returns: | |||
| str, the required data. | |||
| Examples: | |||
| >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons? | |||
| name=node_name&detail=data&shape=[0, 0, :, :]&tolerance=0.5 | |||
| """ | |||
| name = request.args.get('name') | |||
| detail = request.args.get('detail', 'data') | |||
| shape = request.args.get('shape') | |||
| tolerance = request.args.get('tolerance', '0') | |||
| reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/retrieve", methods=["POST"]) | |||
| def retrieve(): | |||
| """ | |||
| Retrieve data according to mode and params. | |||
| Returns: | |||
| str, the required data. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/retrieve | |||
| """ | |||
| body = _read_post_request(request) | |||
| mode = body.get('mode') | |||
| params = body.get('params') | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/retrieve_tensor_history", methods=["POST"]) | |||
| def retrieve_tensor_history(): | |||
| """ | |||
| Retrieve data according to mode and params. | |||
| Returns: | |||
| str, the required data. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/retrieve_tensor_history | |||
| """ | |||
| body = _read_post_request(request) | |||
| name = body.get('name') | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/tensors", methods=["GET"]) | |||
| def retrieve_tensor_value(): | |||
| """ | |||
| Retrieve tensor value according to name and shape. | |||
| Returns: | |||
| str, the required data. | |||
| Examples: | |||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=node_name&detail=data&shape=[1,1,:,:] | |||
| """ | |||
| name = request.args.get('name') | |||
| detail = request.args.get('detail') | |||
| shape = request.args.get('shape') | |||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/create_watchpoint", methods=["POST"]) | |||
| def create_watchpoint(): | |||
| """ | |||
| Create watchpoint. | |||
| Returns: | |||
| str, watchpoint id. | |||
| Raises: | |||
| MindInsightException: If method fails to be called. | |||
| ParamValueError: If parsing json data search_condition fails. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/create_watchpoint | |||
| """ | |||
| body = _read_post_request(request) | |||
| condition = body.get('condition') | |||
| watch_nodes = body.get('watch_nodes') | |||
| watch_point_id = body.get('watch_point_id') | |||
| reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, condition, watch_nodes, watch_point_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/update_watchpoint", methods=["POST"]) | |||
| def update_watchpoint(): | |||
| """ | |||
| Update watchpoint. | |||
| Returns: | |||
| str, reply message. | |||
| Raises: | |||
| MindInsightException: If method fails to be called. | |||
| ParamValueError: If parsing json data search_condition fails. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/update_watchpoint | |||
| """ | |||
| body = _read_post_request(request) | |||
| watch_point_id = body.get('watch_point_id') | |||
| watch_nodes = body.get('watch_nodes') | |||
| mode = body.get('mode') | |||
| name = body.get('name') | |||
| reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, watch_point_id, watch_nodes, mode, name) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/delete_watchpoint", methods=["POST"]) | |||
| def delete_watchpoint(): | |||
| """ | |||
| delete watchpoint. | |||
| Returns: | |||
| str, reply message. | |||
| Raises: | |||
| MindInsightException: If method fails to be called. | |||
| ParamValueError: If parsing json data search_condition fails. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/delete_watchpoint | |||
| """ | |||
| body = _read_post_request(request) | |||
| watch_point_id = body.get('watch_point_id') | |||
| reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id) | |||
| return reply | |||
| @BLUEPRINT.route("/debugger/control", methods=["POST"]) | |||
| def control(): | |||
| """ | |||
| Control request. | |||
| Returns: | |||
| str, reply message. | |||
| Raises: | |||
| MindInsightException: If method fails to be called. | |||
| ParamValueError: If parsing json data search_condition fails. | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/control | |||
| """ | |||
| params = _read_post_request(request) | |||
| reply = _wrap_reply(BACKEND_SERVER.control, params) | |||
| return reply | |||
| BACKEND_SERVER = _initialize_debugger_server() | |||
| def init_module(app): | |||
| """ | |||
| Init module entry. | |||
| Args: | |||
| app (Flask): The application obj. | |||
| """ | |||
| app.register_blueprint(BLUEPRINT) | |||
| if BACKEND_SERVER: | |||
| BACKEND_SERVER.start() | |||
| @@ -26,6 +26,12 @@ WORKSPACE = os.path.join(os.environ['HOME'], 'mindinsight') | |||
| PORT = 8080 | |||
| URL_PATH_PREFIX = '' | |||
| #################################### | |||
| # Debugger default settings. | |||
| #################################### | |||
| DEBUGGER_PORT = '50051' | |||
| ENABLE_DEBUGGER = False | |||
| #################################### | |||
| # Datavisual default settings. | |||
| #################################### | |||
| @@ -15,128 +15,14 @@ | |||
| """Tensor data container.""" | |||
| import numpy as np | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket | |||
| from mindinsight.datavisual.utils.utils import calc_histogram_bins | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.utils.tensor import TensorUtils | |||
| F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max | |||
| MAX_TENSOR_COUNT = 10000000 | |||
| class Statistics: | |||
| """Statistics data class. | |||
| Args: | |||
| max_value (float): max value of tensor data. | |||
| min_value (float): min value of tensor data. | |||
| avg_value (float): avg value of tensor data. | |||
| count (int): total count of tensor data. | |||
| nan_count (int): count of NAN. | |||
| neg_inf_count (int): count of negative INF. | |||
| pos_inf_count (int): count of positive INF. | |||
| """ | |||
| def __init__(self, max_value=0, min_value=0, avg_value=0, | |||
| count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0): | |||
| self._max = max_value | |||
| self._min = min_value | |||
| self._avg = avg_value | |||
| self._count = count | |||
| self._nan_count = nan_count | |||
| self._neg_inf_count = neg_inf_count | |||
| self._pos_inf_count = pos_inf_count | |||
| @property | |||
| def max(self): | |||
| """Get max value of tensor.""" | |||
| return self._max | |||
| @property | |||
| def min(self): | |||
| """Get min value of tensor.""" | |||
| return self._min | |||
| @property | |||
| def avg(self): | |||
| """Get avg value of tensor.""" | |||
| return self._avg | |||
| @property | |||
| def count(self): | |||
| """Get total count of tensor.""" | |||
| return self._count | |||
| @property | |||
| def nan_count(self): | |||
| """Get count of NAN.""" | |||
| return self._nan_count | |||
| @property | |||
| def neg_inf_count(self): | |||
| """Get count of negative INF.""" | |||
| return self._neg_inf_count | |||
| @property | |||
| def pos_inf_count(self): | |||
| """Get count of positive INF.""" | |||
| return self._pos_inf_count | |||
| def get_statistics_from_tensor(tensors): | |||
| """ | |||
| Calculates statistics data of tensor. | |||
| Args: | |||
| tensors (numpy.ndarray): An numpy.ndarray of tensor data. | |||
| Returns: | |||
| an instance of Statistics. | |||
| """ | |||
| ma_value = np.ma.masked_invalid(tensors) | |||
| total, valid = tensors.size, ma_value.count() | |||
| invalids = [] | |||
| for isfn in np.isnan, np.isposinf, np.isneginf: | |||
| if total - valid > sum(invalids): | |||
| count = np.count_nonzero(isfn(tensors)) | |||
| invalids.append(count) | |||
| else: | |||
| invalids.append(0) | |||
| nan_count, pos_inf_count, neg_inf_count = invalids | |||
| if not valid: | |||
| logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape) | |||
| statistics = Statistics(max_value=0, | |||
| min_value=0, | |||
| avg_value=0, | |||
| count=total, | |||
| nan_count=nan_count, | |||
| neg_inf_count=neg_inf_count, | |||
| pos_inf_count=pos_inf_count) | |||
| return statistics | |||
| # BUG: max of a masked array with dtype np.float16 returns inf | |||
| # See numpy issue#15077 | |||
| if issubclass(tensors.dtype.type, np.floating): | |||
| tensor_min = ma_value.min(fill_value=np.PINF) | |||
| tensor_max = ma_value.max(fill_value=np.NINF) | |||
| if tensor_min < F32_MIN or tensor_max > F32_MAX: | |||
| logger.warning('Values(%f, %f) are too large, you may encounter some undefined ' | |||
| 'behaviours hereafter.', tensor_min, tensor_max) | |||
| else: | |||
| tensor_min = ma_value.min() | |||
| tensor_max = ma_value.max() | |||
| tensor_sum = ma_value.sum(dtype=np.float64) | |||
| statistics = Statistics(max_value=tensor_max, | |||
| min_value=tensor_min, | |||
| avg_value=tensor_sum / valid, | |||
| count=total, | |||
| nan_count=nan_count, | |||
| neg_inf_count=neg_inf_count, | |||
| pos_inf_count=pos_inf_count) | |||
| return statistics | |||
| def calc_original_buckets(np_value, stats): | |||
| """ | |||
| Calculate buckets from tensor data. | |||
| @@ -188,7 +74,7 @@ class TensorContainer: | |||
| self._dims = tuple(tensor_message.dims) | |||
| self._data_type = tensor_message.data_type | |||
| self._np_array = self.get_ndarray(tensor_message.float_data) | |||
| self._stats = get_statistics_from_tensor(self._np_array) | |||
| self._stats = TensorUtils.get_statistics_from_tensor(self._np_array) | |||
| original_buckets = calc_original_buckets(self._np_array, self._stats) | |||
| self._count = sum(bucket.count for bucket in original_buckets) | |||
| self._max = self._stats.max | |||
| @@ -19,97 +19,16 @@ import numpy as np | |||
| from mindinsight.datavisual.utils.tools import to_int | |||
| from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError | |||
| from mindinsight.utils.tensor import TensorUtils | |||
| from mindinsight.conf.constants import MAX_TENSOR_RESPONSE_DATA_SIZE | |||
| from mindinsight.datavisual.common.validation import Validation | |||
| from mindinsight.datavisual.common.exceptions import StepTensorDataNotInCacheError, TensorNotExistError | |||
| from mindinsight.datavisual.common.exceptions import ResponseDataExceedMaxValueError | |||
| from mindinsight.datavisual.data_transform.tensor_container import TensorContainer, get_statistics_from_tensor | |||
| from mindinsight.datavisual.data_transform.tensor_container import TensorContainer | |||
| from mindinsight.datavisual.processors.base_processor import BaseProcessor | |||
| from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 | |||
| def convert_array_from_str(dims, limit=0): | |||
| """ | |||
| Convert string of dims data to array. | |||
| Args: | |||
| dims (str): Specify dims of tensor. | |||
| limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation. | |||
| Returns: | |||
| list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None]. | |||
| Raises: | |||
| ParamValueError, If flexible dimensions exceed limit value. | |||
| """ | |||
| dims = dims.strip().lstrip('[').rstrip(']') | |||
| dims_list = [] | |||
| count = 0 | |||
| for dim in dims.split(','): | |||
| dim = dim.strip() | |||
| if dim == ':': | |||
| dims_list.append(None) | |||
| count += 1 | |||
| else: | |||
| dims_list.append(to_int(dim, "dim")) | |||
| if limit and count > limit: | |||
| raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}" | |||
| .format(limit, count)) | |||
| return dims_list | |||
| def get_specific_dims_data(ndarray, dims, tensor_dims): | |||
| """ | |||
| Get specific dims data. | |||
| Args: | |||
| ndarray (numpy.ndarray): An ndarray of numpy. | |||
| dims (list): A list of specific dims. | |||
| tensor_dims (list): A list of tensor dims. | |||
| Returns: | |||
| numpy.ndarray, an ndarray of specific dims tensor data. | |||
| Raises: | |||
| ParamValueError, If the length of param dims is not equal to the length of tensor dims or | |||
| the index of param dims out of range. | |||
| """ | |||
| if len(dims) != len(tensor_dims): | |||
| raise ParamValueError("The length of param dims: {}, is not equal to the " | |||
| "length of tensor dims: {}.".format(len(dims), len(tensor_dims))) | |||
| indices = [] | |||
| for k, d in enumerate(dims): | |||
| if d is not None: | |||
| if d >= tensor_dims[k]: | |||
| raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k])) | |||
| indices.append(d) | |||
| else: | |||
| indices.append(slice(0, tensor_dims[k])) | |||
| return ndarray[tuple(indices)] | |||
| def get_statistics_dict(stats): | |||
| """ | |||
| Get statistics dict according to statistics value. | |||
| Args: | |||
| stats (Statistics): An instance of Statistics. | |||
| Returns: | |||
| dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'. | |||
| """ | |||
| statistics = { | |||
| "max": float(stats.max), | |||
| "min": float(stats.min), | |||
| "avg": float(stats.avg), | |||
| "count": stats.count, | |||
| "nan_count": stats.nan_count, | |||
| "neg_inf_count": stats.neg_inf_count, | |||
| "pos_inf_count": stats.pos_inf_count | |||
| } | |||
| return statistics | |||
| class TensorProcessor(BaseProcessor): | |||
| """Tensor Processor.""" | |||
| def get_tensors(self, train_ids, tags, step, dims, detail): | |||
| @@ -130,22 +49,7 @@ class TensorProcessor(BaseProcessor): | |||
| UrlDecodeError, If unquote train id error with strict mode. | |||
| """ | |||
| Validation.check_param_empty(train_id=train_ids, tag=tags) | |||
| if dims is not None: | |||
| if not isinstance(dims, str): | |||
| raise ParamValueError('The type of dims must be str, but got {}.'.format(type(dims))) | |||
| dims = dims.strip() | |||
| if not (dims.startswith('[') and dims.endswith(']')): | |||
| raise ParamValueError('The value: {} of dims must be ' | |||
| 'start with `[` and end with `]`.'.format(dims)) | |||
| for dim in dims[1:-1].split(','): | |||
| dim = dim.strip() | |||
| if dim == ":": | |||
| continue | |||
| if dim.startswith('-'): | |||
| dim = dim[1:] | |||
| if not dim.isdigit(): | |||
| raise ParamValueError('The value: {} of dims in the square brackets ' | |||
| 'must be int or `:`.'.format(dims)) | |||
| TensorUtils.validate_dims_format(dims) | |||
| for index, train_id in enumerate(train_ids): | |||
| try: | |||
| @@ -248,7 +152,7 @@ class TensorProcessor(BaseProcessor): | |||
| "data_type": anf_ir_pb2.DataType.Name(value.data_type) | |||
| } | |||
| if detail and detail == 'stats': | |||
| stats = get_statistics_dict(value.stats) | |||
| stats = TensorUtils.get_statistics_dict(value.stats) | |||
| value_dict.update({"statistics": stats}) | |||
| values.append({ | |||
| @@ -295,14 +199,14 @@ class TensorProcessor(BaseProcessor): | |||
| """ | |||
| values = [] | |||
| step_in_cache = False | |||
| dims = convert_array_from_str(dims, limit=2) | |||
| dims = TensorUtils.convert_array_from_str_dims(dims, limit=2) | |||
| for tensor in tensors: | |||
| # This value is an instance of TensorContainer | |||
| value = tensor.value | |||
| if step != tensor.step: | |||
| continue | |||
| step_in_cache = True | |||
| res_data = get_specific_dims_data(value.ndarray, dims, list(value.dims)) | |||
| res_data = TensorUtils.get_specific_dims_data(value.ndarray, dims, list(value.dims)) | |||
| flatten_data = res_data.flatten().tolist() | |||
| if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE: | |||
| raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}." | |||
| @@ -328,7 +232,7 @@ class TensorProcessor(BaseProcessor): | |||
| transfer_data[index] = float(data) | |||
| return transfer_data | |||
| stats = get_statistics_from_tensor(res_data) | |||
| stats = TensorUtils.get_statistics_from_tensor(res_data) | |||
| if stats.nan_count + stats.neg_inf_count + stats.pos_inf_count > 0: | |||
| tensor_data = transfer(res_data) | |||
| else: | |||
| @@ -340,7 +244,7 @@ class TensorProcessor(BaseProcessor): | |||
| "dims": value.dims, | |||
| "data_type": anf_ir_pb2.DataType.Name(value.data_type), | |||
| "data": tensor_data, | |||
| "statistics": get_statistics_dict(stats) | |||
| "statistics": TensorUtils.get_statistics_dict(stats) | |||
| } | |||
| }) | |||
| break | |||
| @@ -389,7 +293,7 @@ class TensorProcessor(BaseProcessor): | |||
| "dims": value.dims, | |||
| "data_type": anf_ir_pb2.DataType.Name(value.data_type), | |||
| "histogram_buckets": buckets, | |||
| "statistics": get_statistics_dict(value.stats) | |||
| "statistics": TensorUtils.get_statistics_dict(value.stats) | |||
| } | |||
| }) | |||
| @@ -80,6 +80,25 @@ def to_int(param, param_name): | |||
| return param | |||
| def to_float(param, param_name): | |||
| """ | |||
| Transfer param to float type. | |||
| Args: | |||
| param (Any): A param transformed. | |||
| param_name (str): Param name. | |||
| Returns: | |||
| float, value after transformed. | |||
| """ | |||
| try: | |||
| param = float(param) | |||
| except ValueError: | |||
| raise exceptions.ParamTypeError(param_name, 'Float') | |||
| return param | |||
| def str_to_bool(param, param_name): | |||
| """ | |||
| Check param and transform it to bool. | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Debugger Module Introduction. | |||
| This module provides Python APIs to retrieve the debugger info and control the training process. | |||
| The APIs can help users to understand the training process and find the bugs in training script. | |||
| """ | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,56 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Debugger error code and messages.""" | |||
| from enum import Enum, unique | |||
| from mindinsight.utils.constant import DebuggerErrors as DebuggerErrorCodes | |||
| _PARAM_ERROR_MASK = 0b00001 << 7 | |||
| _DEBUGGER_GRAPH_ERROR = 0b00010 << 7 | |||
| _DEBUGGER_RUNNING_ERROR = 0b00011 << 7 | |||
| @unique | |||
| class DebuggerErrors(DebuggerErrorCodes): | |||
| """Debugger error codes.""" | |||
| PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK | |||
| PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK | |||
| NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR | |||
| GRAPH_NOT_EXIST_ERROR = 1 | _DEBUGGER_GRAPH_ERROR | |||
| CREATE_WATCHPOINT_ERROR = 0 | _DEBUGGER_RUNNING_ERROR | |||
| UPDATE_WATCHPOINT_ERROR = 1 | _DEBUGGER_RUNNING_ERROR | |||
| DELETE_WATCHPOINT_ERROR = 2 | _DEBUGGER_RUNNING_ERROR | |||
| CONTINUE_ERROR = 3 | _DEBUGGER_RUNNING_ERROR | |||
| PAUSE_ERROR = 4 | _DEBUGGER_RUNNING_ERROR | |||
| COMPARE_TENSOR_ERROR = 5 | _DEBUGGER_RUNNING_ERROR | |||
| @unique | |||
| class DebuggerErrorMsg(Enum): | |||
| """Debugger error messages.""" | |||
| PARAM_TYPE_ERROR = "TypeError. {}" | |||
| PARAM_VALUE_ERROR = "ValueError. {}" | |||
| PARAM_MISSING_ERROR = "MissingError. {}" | |||
| UNEXPECTED_EXCEPTION_ERROR = "Unexpected exception. {}" | |||
| GRAPH_NOT_EXIST_ERROR = "The graph does not exist." | |||
| CREATE_WATCHPOINT_ERROR = "Create watchpoint failed. {}" | |||
| UPDATE_WATCHPOINT_ERROR = "Update watchpoint failed. {}" | |||
| DELETE_WATCHPOINT_ERROR = "Delete watchpoint failed. {}" | |||
| CONTINUE_ERROR = "Continue debugging failed. {}" | |||
| PAUSE_ERROR = "Pause debugging failed. {}" | |||
| @@ -0,0 +1,117 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Definition of error code and relative messages in debugger module.""" | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| from mindinsight.debugger.common.exceptions.error_code import DebuggerErrors, DebuggerErrorMsg | |||
| class DebuggerParamTypeError(MindInsightException): | |||
| """The parameter type error in debugger module.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerParamTypeError, self).__init__( | |||
| error=DebuggerErrors.PARAM_TYPE_ERROR, | |||
| message=DebuggerErrorMsg.PARAM_TYPE_ERROR.value.format(msg) | |||
| ) | |||
| class DebuggerParamValueError(MindInsightException): | |||
| """The parameter value error in debugger module.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerParamValueError, self).__init__( | |||
| error=DebuggerErrors.PARAM_VALUE_ERROR, | |||
| message=DebuggerErrorMsg.PARAM_VALUE_ERROR.value.format(msg) | |||
| ) | |||
| class DebuggerCreateWatchPointError(MindInsightException): | |||
| """The error about creating watch point.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerCreateWatchPointError, self).__init__( | |||
| error=DebuggerErrors.CREATE_WATCHPOINT_ERROR, | |||
| message=DebuggerErrorMsg.CREATE_WATCHPOINT_ERROR.value.format(msg) | |||
| ) | |||
| class DebuggerUpdateWatchPointError(MindInsightException): | |||
| """The error about updating watch point.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerUpdateWatchPointError, self).__init__( | |||
| error=DebuggerErrors.UPDATE_WATCHPOINT_ERROR, | |||
| message=DebuggerErrorMsg.UPDATE_WATCHPOINT_ERROR.value.format(msg) | |||
| ) | |||
| class DebuggerDeleteWatchPointError(MindInsightException): | |||
| """The error about deleting watch point.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerDeleteWatchPointError, self).__init__( | |||
| error=DebuggerErrors.DELETE_WATCHPOINT_ERROR, | |||
| message=DebuggerErrorMsg.DELETE_WATCHPOINT_ERROR.value.format(msg) | |||
| ) | |||
| class DebuggerCompareTensorError(MindInsightException): | |||
| """The error about comparing tensors.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerCompareTensorError, self).__init__( | |||
| error=DebuggerErrors.COMPARE_TENSOR_ERROR, | |||
| message=DebuggerErrorMsg.COMPARE_TENSOR_ERROR.value.format(msg) | |||
| ) | |||
| class DebuggerContinueError(MindInsightException): | |||
| """The error about continuing debugging.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerContinueError, self).__init__( | |||
| error=DebuggerErrors.CONTINUE_ERROR, | |||
| message=DebuggerErrorMsg.CONTINUE_ERROR.value.format(msg) | |||
| ) | |||
| class DebuggerPauseError(MindInsightException): | |||
| """The error about pausing debugging.""" | |||
| def __init__(self, msg): | |||
| super(DebuggerPauseError, self).__init__( | |||
| error=DebuggerErrors.PAUSE_ERROR, | |||
| message=DebuggerErrorMsg.PAUSE_ERROR.value.format(msg) | |||
| ) | |||
| class DebuggerNodeNotInGraphError(MindInsightException): | |||
| """The node is not in the graph.""" | |||
| def __init__(self, node_name, node_type=None): | |||
| if node_type is not None: | |||
| err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}, type: {node_type}." | |||
| else: | |||
| err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}." | |||
| super(DebuggerNodeNotInGraphError, self).__init__( | |||
| error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR, | |||
| message=err_msg | |||
| ) | |||
| class DebuggerGraphNotExistError(MindInsightException): | |||
| """The graph does not exist.""" | |||
| def __init__(self): | |||
| super(DebuggerGraphNotExistError, self).__init__( | |||
| error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR, | |||
| message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value | |||
| ) | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Import mindinsight unified log module.""" | |||
| from mindinsight.utils.log import setup_logger | |||
| LOG_NAME = "debugger" | |||
| LOG_MODULE = "debugger" | |||
| logger = setup_logger(sub_module=LOG_MODULE, log_name=LOG_NAME) | |||
| @@ -0,0 +1,168 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the utils.""" | |||
| import enum | |||
| from collections import namedtuple | |||
| import numpy as np | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply | |||
| from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum | |||
| # translate the MindSpore type to numpy type. | |||
| NUMPY_TYPE_MAP = { | |||
| 'DT_BOOL': np.bool, | |||
| 'DT_INT8': np.int8, | |||
| 'DT_INT16': np.int16, | |||
| 'DT_INT32': np.int32, | |||
| 'DT_INT64': np.int64, | |||
| 'DT_UINT8': np.uint8, | |||
| 'DT_UINT16': np.uint16, | |||
| 'DT_UINT32': np.uint32, | |||
| 'DT_UINT64': np.uint64, | |||
| 'DT_FLOAT16': np.float16, | |||
| 'DT_FLOAT32': np.float32, | |||
| 'DT_FLOAT64': np.float64, | |||
| 'DT_STRING': np.str | |||
| } | |||
| @enum.unique | |||
| class ReplyStates(enum.Enum): | |||
| """Define the status of reply.""" | |||
| SUCCESS = 0 | |||
| FAILED = -1 | |||
| @enum.unique | |||
| class ServerStatus(enum.Enum): | |||
| """The status of debugger server.""" | |||
| PENDING = 'pending' # no client session has been connected | |||
| RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph | |||
| WAITING = 'waiting' # the client session is ready | |||
| RUNNING = 'running' # the client session is running a script | |||
| @enum.unique | |||
| class Streams(enum.Enum): | |||
| """Define the enable streams to be deal with.""" | |||
| COMMAND = "command" | |||
| DATA = "data" | |||
| METADATA = "metadata" | |||
| GRAPH = 'node' | |||
| TENSOR = 'tensor' | |||
| WATCHPOINT = 'watchpoint' | |||
| WATCHPOINT_HIT = 'watchpoint_hit' | |||
| NodeBasicInfo = namedtuple('node_basic_info', ['name', 'full_name', 'type']) | |||
| def get_ack_reply(state=0): | |||
| """The the ack EventReply.""" | |||
| reply = EventReply() | |||
| state_mapping = { | |||
| 0: EventReply.Status.OK, | |||
| 1: EventReply.Status.FAILED, | |||
| 2: EventReply.Status.PENDING | |||
| } | |||
| reply.status = state_mapping[state] | |||
| return reply | |||
| def wrap_reply_response(error_code=None, error_message=None): | |||
| """ | |||
| Wrap reply response. | |||
| Args: | |||
| error_code (str): Error code. Default: None. | |||
| error_message (str): Error message. Default: None. | |||
| Returns: | |||
| str, serialized response. | |||
| """ | |||
| if error_code is None: | |||
| reply = {'state': ReplyStates.SUCCESS.value} | |||
| else: | |||
| reply = { | |||
| 'state': ReplyStates.FAILED.value, | |||
| 'error_code': error_code, | |||
| 'error_message': error_message | |||
| } | |||
| return reply | |||
| def create_view_event_from_tensor_history(tensor_history): | |||
| """ | |||
| Create view event reply according to tensor names. | |||
| Args: | |||
| tensor_history (list[dict]): The list of tensor history. Each element has keys: | |||
| `name`, `node_type`. | |||
| Returns: | |||
| EventReply, the event reply with view cmd. | |||
| """ | |||
| view_event = get_ack_reply() | |||
| for tensor_info in tensor_history: | |||
| node_type = tensor_info.get('node_type') | |||
| if node_type == NodeTypeEnum.CONST.value: | |||
| continue | |||
| truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value | |||
| tensor_name = tensor_info.get('full_name', '') | |||
| # create view command | |||
| ms_tensor = view_event.view_cmd.tensors.add() | |||
| ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1) | |||
| ms_tensor.truncate = truncate_tag | |||
| ms_tensor.iter = 'prev' if tensor_info.get('iter') else '' | |||
| return view_event | |||
| def is_scope_type(node_type): | |||
| """Judge whether the type is scope type.""" | |||
| scope_types = [NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.AGGREGATION_SCOPE.value] | |||
| return node_type in scope_types | |||
| def str_to_slice_or_int(input_str): | |||
| """ | |||
| Translate param from string to slice or int. | |||
| Args: | |||
| input_str (str): The string to be translated. | |||
| Returns: | |||
| Union[int, slice], the transformed param. | |||
| """ | |||
| try: | |||
| if ':' in input_str: | |||
| ret = slice(*map(lambda x: int(x.strip()) if x.strip() else None, input_str.split(':'))) | |||
| else: | |||
| ret = int(input_str) | |||
| except ValueError as err: | |||
| log.error("Failed to create slice from %s", input_str) | |||
| log.exception(err) | |||
| raise DebuggerParamValueError("Invalid shape.") | |||
| return ret | |||
| @@ -0,0 +1,154 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Implement the debugger data cache manager.""" | |||
| import sys | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.common.utils import Streams | |||
| from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \ | |||
| TensorHandler, WatchpointHandler, WatchpointHitHandler | |||
| STREAM_HANDLER_MAP = { | |||
| Streams.COMMAND.value: EventHandler, | |||
| Streams.DATA.value: EventHandler, | |||
| Streams.METADATA.value: MetadataHandler, | |||
| Streams.GRAPH.value: GraphHandler, | |||
| Streams.TENSOR.value: TensorHandler, | |||
| Streams.WATCHPOINT.value: WatchpointHandler, | |||
| Streams.WATCHPOINT_HIT.value: WatchpointHitHandler | |||
| } | |||
| class DebuggerCache: | |||
| """The debugger data cache manager.""" | |||
| def __init__(self): | |||
| self._stream_handler = {} | |||
| def initialize(self): | |||
| """Initialize the stream handlers.""" | |||
| self._stream_handler = {} | |||
| for stream in Streams: | |||
| mode = stream.value | |||
| stream_handler = STREAM_HANDLER_MAP.get(mode) | |||
| self._stream_handler[mode] = stream_handler() | |||
| def clean(self): | |||
| """Clean cache for all stream.""" | |||
| for _, stream_handler in self._stream_handler.items(): | |||
| stream_handler.clean() | |||
| def get_stream_handler(self, mode): | |||
| """ | |||
| Get the stream handler object. | |||
| Args: | |||
| mode (Streams): The type of stream handler. | |||
| Returns: | |||
| StreamHandlerBase, the stream handler object. | |||
| """ | |||
| return self._stream_handler.get(mode.value) | |||
| def _get(self, mode, pos): | |||
| """ | |||
| Get updated data or command from cache. | |||
| Args: | |||
| mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`. | |||
| pos (int): The index of info. | |||
| Returns: | |||
| object, the pos-th message about `mode` type of info. | |||
| """ | |||
| stream_handler = self.get_stream_handler(mode) | |||
| return stream_handler.get(pos) | |||
| def _put(self, mode, value): | |||
| """ | |||
| Set updated data or command from cache. | |||
| Args: | |||
| mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`. | |||
| value (object): The info to be record in cache. | |||
| """ | |||
| stream_handler = self.get_stream_handler(mode) | |||
| return stream_handler.put(value) | |||
| def get_command(self, pos): | |||
| """ | |||
| Get the pos-th command in command stream. | |||
| Args: | |||
| pos (int): The index of command. | |||
| Returns: | |||
| int, the position of next message. | |||
| EventReply, the command object. | |||
| """ | |||
| content = self._get(Streams.COMMAND, pos) | |||
| next_pos = content.get('metadata').get('pos') | |||
| reply = content.get('cmd') | |||
| return next_pos, reply | |||
| def put_command(self, cmd): | |||
| """ | |||
| Set command to command stream. | |||
| Args: | |||
| cmd (EventReply): The command EventReply. | |||
| """ | |||
| log.debug("Set command %s", cmd) | |||
| return self._put(Streams.COMMAND, {'cmd': cmd}) | |||
| def has_command(self, pos): | |||
| """Judge if the number of command is no less than `pos`.""" | |||
| event = self.get_stream_handler(Streams.COMMAND).has_pos(pos) | |||
| return event | |||
| def clean_command(self): | |||
| """Clean command queue.""" | |||
| self.get_stream_handler(Streams.COMMAND).clean() | |||
| log.debug("Clean command.") | |||
| def clean_data(self): | |||
| """Clean command queue.""" | |||
| self.get_stream_handler(Streams.DATA).clean() | |||
| log.debug("Clean data queue.") | |||
| def get_data(self, pos): | |||
| """ | |||
| Get updated data from data stream. | |||
| Args: | |||
| pos (int): The index of data. | |||
| Returns: | |||
| object, updated data_value. | |||
| """ | |||
| return self._get(Streams.DATA, pos) | |||
| def put_data(self, value): | |||
| """ | |||
| Set updated data to data stream. | |||
| Args: | |||
| value (dict): The updated data. | |||
| """ | |||
| log.debug("Set <%d> bytes data", sys.getsizeof(value)) | |||
| return self._put(Streams.DATA, value) | |||
| @@ -0,0 +1,309 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Implement the debugger grpc server.""" | |||
| from functools import wraps | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||
| create_view_event_from_tensor_history, Streams | |||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto | |||
| def debugger_wrap(func): | |||
| """Wrapper for catch exception.""" | |||
| @wraps(func) | |||
| def record_log(*args, **kwargs): | |||
| try: | |||
| return func(*args, **kwargs) | |||
| except Exception as err: | |||
| log.exception(err) | |||
| raise err | |||
| return record_log | |||
| class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| """The grpc server used to interactive with grpc client.""" | |||
| def __init__(self, cache_store): | |||
| """ | |||
| Initialize. | |||
| Args: | |||
| cache_store (DebuggerCache): Debugger cache store. | |||
| """ | |||
| cache_store.initialize() | |||
| self._cache_store = cache_store | |||
| self._pos = None | |||
| self._status = None | |||
| self._view_event = None | |||
| self._view_round = None | |||
| self._continue_steps = None | |||
| self.init() | |||
| def init(self): | |||
| """Init debugger grpc server.""" | |||
| self._pos = '0' | |||
| self._status = ServerStatus.PENDING | |||
| self._view_event = None | |||
| self._view_round = True | |||
| self._continue_steps = 0 | |||
| self._cache_store.clean() | |||
| @debugger_wrap | |||
| def WaitCMD(self, request, context): | |||
| """Wait for a command in DebuggerCache.""" | |||
| # check if graph have already received. | |||
| log.info("Received WaitCMD at %s-th step.", request.cur_step) | |||
| if self._status == ServerStatus.PENDING: | |||
| log.warning("No graph received before WaitCMD.") | |||
| reply = get_ack_reply(1) | |||
| return reply | |||
| # send graph if has not been sent before | |||
| self._pre_process(request) | |||
| # deal with old command | |||
| reply = self._deal_with_old_command() | |||
| if reply: | |||
| log.info("Reply to WaitCMD with old command: %s", reply) | |||
| return reply | |||
| # send view cmd | |||
| if self._view_round and self._view_event: | |||
| self._view_round = False | |||
| reply = self._view_event | |||
| log.debug("Send ViewCMD.") | |||
| # continue multiple steps training | |||
| elif self._continue_steps != 0: | |||
| reply = get_ack_reply() | |||
| reply.run_cmd.run_steps = 1 | |||
| reply.run_cmd.run_level = 'step' | |||
| self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1 | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||
| log.debug("Send RunCMD. Clean watchpoint hit.") | |||
| # wait for command | |||
| else: | |||
| reply = self._wait_for_next_command() | |||
| if reply is None: | |||
| reply = get_ack_reply(1) | |||
| log.warning("Failed to get command event.") | |||
| else: | |||
| log.info("Reply to WaitCMD: %s", reply) | |||
| return reply | |||
| def _pre_process(self, request): | |||
| """Send graph and metadata when WaitCMD first called.""" | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| if self._status == ServerStatus.RECEIVE_GRAPH: | |||
| self._status = ServerStatus.WAITING | |||
| metadata_stream.state = 'waiting' | |||
| metadata = metadata_stream.get() | |||
| self._cache_store.clean_command() | |||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get() | |||
| res.update(metadata) | |||
| self._cache_store.put_data(res) | |||
| log.info("Put graph into data queue.") | |||
| if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node: | |||
| # clean tensor cache and DataQueue at the beginning of each step | |||
| self._update_metadata(metadata_stream, request) | |||
| def _update_metadata(self, metadata_stream, metadata_proto): | |||
| """Update metadata.""" | |||
| # reset view round and clean cache data | |||
| self._view_round = True | |||
| if metadata_stream.step < metadata_proto.cur_step: | |||
| self._cache_store.clean_data() | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors( | |||
| metadata_proto.cur_step) | |||
| # put new metadata into cache | |||
| metadata_stream.put(metadata_proto) | |||
| cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name( | |||
| metadata_proto.cur_node) if metadata_proto.cur_node else '' | |||
| metadata_stream.node_name = cur_node | |||
| metadata = metadata_stream.get() | |||
| self._cache_store.put_data(metadata) | |||
| log.info("Put new metadata into data queue.") | |||
| def _deal_with_old_command(self): | |||
| """Deal with old command.""" | |||
| event = None | |||
| while self._cache_store.has_command(self._pos) and event is None: | |||
| event = self._get_next_command() | |||
| log.debug("Deal with old %s-th command:\n%s.", self._pos, event) | |||
| return event | |||
| def _wait_for_next_command(self): | |||
| """ | |||
| Wait for next command. | |||
| Returns: | |||
| EventReply, the command event. | |||
| """ | |||
| log.info("Start to wait for command.") | |||
| self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting' | |||
| self._cache_store.put_data({'metadata': {'state': 'waiting'}}) | |||
| event = None | |||
| while event is None and self._status == ServerStatus.WAITING: | |||
| log.debug("Wait for %s-th command", self._pos) | |||
| event = self._get_next_command() | |||
| return event | |||
| def _get_next_command(self): | |||
| """Get next command.""" | |||
| self._pos, event = self._cache_store.get_command(self._pos) | |||
| log.debug("Received event :%s", event) | |||
| if event is None: | |||
| return event | |||
| if isinstance(event, dict) and event.get('reset'): | |||
| self._set_view_event(event) | |||
| event = None | |||
| elif event.HasField('run_cmd'): | |||
| event = self._deal_with_run_cmd(event) | |||
| elif event.HasField('view_cmd'): | |||
| self._view_round = False | |||
| elif event.HasField('exit'): | |||
| self._cache_store.clean() | |||
| log.info("Clean cache for exit cmd.") | |||
| return event | |||
| def _deal_with_run_cmd(self, event): | |||
| """Deal with run cmd.""" | |||
| run_cmd = event.run_cmd | |||
| # receive step command | |||
| if run_cmd.run_level == 'step': | |||
| # receive pause cmd | |||
| if run_cmd.run_steps == 0: | |||
| log.debug("Pause training and wait for next command.") | |||
| self._continue_steps = 0 | |||
| return None | |||
| # receive step cmd | |||
| self._continue_steps = run_cmd.run_steps - 1 | |||
| event.run_cmd.run_steps = 1 | |||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||
| log.debug("Receive RunCMD. Clean watchpoint hit cache.") | |||
| return event | |||
| def _set_view_event(self, event): | |||
| """Create view event for view cmd.""" | |||
| # the first tensor in view cmd is always the output | |||
| node_name = event.get('node_name') | |||
| tensor_history = event.get('tensor_history') | |||
| if not node_name or not tensor_history: | |||
| self._view_event = None | |||
| log.info("Reset view command to None.") | |||
| else: | |||
| # create view event and set | |||
| self._view_event = create_view_event_from_tensor_history(tensor_history) | |||
| log.info("Reset view command to %s.", node_name) | |||
| @debugger_wrap | |||
| def SendMetadata(self, request, context): | |||
| """Send metadata into DebuggerCache.""" | |||
| log.info("Received Metadata.") | |||
| if self._status != ServerStatus.PENDING: | |||
| log.info("Re-initialize cache store when new session comes.") | |||
| self.init() | |||
| client_ip = context.peer().split(':', 1)[-1] | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| metadata_stream.put(request) | |||
| metadata_stream.client_ip = client_ip | |||
| metadata = metadata_stream.get() | |||
| # put metadata into data queue | |||
| self._cache_store.put_data(metadata) | |||
| log.info("Put new metadata to DataQueue.") | |||
| reply = get_ack_reply() | |||
| log.info("Send the reply to %s.", client_ip) | |||
| return reply | |||
| @debugger_wrap | |||
| def SendGraph(self, request_iterator, context): | |||
| """Send graph into DebuggerCache.""" | |||
| log.info("Received graph.") | |||
| serial_graph = b"" | |||
| for chunk in request_iterator: | |||
| serial_graph += chunk.buffer | |||
| graph = GraphProto.FromString(serial_graph) | |||
| log.debug("Deserialize the graph. Receive %s nodes", len(graph.node)) | |||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph) | |||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) | |||
| self._status = ServerStatus.RECEIVE_GRAPH | |||
| reply = get_ack_reply() | |||
| log.info("Send the reply for graph.") | |||
| return reply | |||
| @debugger_wrap | |||
| def SendTensors(self, request_iterator, context): | |||
| """Send tensors into DebuggerCache.""" | |||
| log.info("Received tensor.") | |||
| tensor_construct = [] | |||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| tensor_names = [] | |||
| step = metadata_stream.step | |||
| for tensor in request_iterator: | |||
| tensor_construct.append(tensor) | |||
| if tensor.finished: | |||
| tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) | |||
| tensor_construct = [] | |||
| tensor_names.append(':'.join([tensor.node_name, tensor.slot])) | |||
| continue | |||
| # send back tensor finished flag when all waiting tensor has value. | |||
| tensor_history = tensor_stream.get_tensor_history(tensor_names) | |||
| self._add_node_name_for_tensor_history(tensor_history) | |||
| metadata = metadata_stream.get() | |||
| tensor_history.update(metadata) | |||
| self._cache_store.put_data({}) # reply to the listening request | |||
| self._cache_store.put_data(tensor_history) | |||
| log.info("Send updated tensor history to data queue.") | |||
| reply = get_ack_reply() | |||
| return reply | |||
| def _add_node_name_for_tensor_history(self, tensor_history): | |||
| """Add node name for tensor history.""" | |||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||
| for tensor_info in tensor_history.get('tensor_history'): | |||
| if tensor_info: | |||
| full_name, slot = tensor_info.get('full_name', '').rsplit(':', 1) | |||
| node_name = graph_stream.get_node_name_by_full_name(full_name) | |||
| tensor_info['name'] = node_name + ':' + slot | |||
| @debugger_wrap | |||
| def SendWatchpointHits(self, request_iterator, context): | |||
| """Send watchpoint hits info DebuggerCache.""" | |||
| log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps) | |||
| self._continue_steps = 0 | |||
| self._view_event = None | |||
| watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||
| for watchpoint_hit_proto in request_iterator: | |||
| watchpoint_hit = { | |||
| 'tensor_proto': watchpoint_hit_proto.tensor, | |||
| 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id), | |||
| 'node_name': graph_stream.get_node_name_by_full_name( | |||
| watchpoint_hit_proto.tensor.node_name) | |||
| } | |||
| watchpoint_hit_stream.put(watchpoint_hit) | |||
| watchpoint_hits_info = watchpoint_hit_stream.get() | |||
| self._cache_store.put_data(watchpoint_hits_info) | |||
| log.info("Send the watchpoint hits to DataQueue.\nSend the reply.") | |||
| reply = get_ack_reply() | |||
| return reply | |||
| @@ -0,0 +1,752 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Implement the debugger server.""" | |||
| import signal | |||
| from concurrent import futures | |||
| from threading import Thread | |||
| import grpc | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.datavisual.utils.tools import to_float | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | |||
| DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||
| create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo, \ | |||
| str_to_slice_or_int | |||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||
| from mindinsight.utils.exceptions import MindInsightException | |||
| class DebuggerServer: | |||
| """The server manager of debugger.""" | |||
| def __init__(self, grpc_port=None): | |||
| self.grpc_port = grpc_port | |||
| self.cache_store = DebuggerCache() | |||
| self.grpc_server = DebuggerGrpcServer(self.cache_store) | |||
| self.grpc_server_manager = None | |||
| self.back_server = None | |||
| self._watch_point_id = 0 | |||
| def start(self): | |||
| """Start server.""" | |||
| grpc_port = self.grpc_port if self.grpc_port else "50051" | |||
| host = settings.HOST if hasattr(settings, 'HOST') else '[::]' | |||
| hostname = "{}:{}".format(host, grpc_port) | |||
| # initialize a grpc server | |||
| grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |||
| grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager) | |||
| grpc_server_manager.add_insecure_port(hostname) | |||
| grpc_server_manager.start() | |||
| my_server_thread = Thread(target=grpc_server_manager.wait_for_termination) | |||
| # start grpc server | |||
| my_server_thread.start() | |||
| self.back_server = my_server_thread | |||
| self.grpc_server_manager = grpc_server_manager | |||
| # register stop server handler | |||
| signal.signal(signal.SIGINT, self._stop_handler) | |||
| log.info("Start grpc server %s", hostname) | |||
| def _stop_handler(self, signum, frame): | |||
| """Register stop server handler.""" | |||
| self.stop() | |||
| log.debug("Deal with stop signal: %s, %s", signum, frame) | |||
| def stop(self): | |||
| """Stop debugger server.""" | |||
| self.grpc_server_manager.stop(grace=None) | |||
| self.back_server.join() | |||
| log.info("Stop debugger server.") | |||
| def poll_data(self, pos): | |||
| """ | |||
| Get the pos-th data from DebuggerCache. | |||
| Args: | |||
| pos (int): The index of data. | |||
| Returns: | |||
| dict, the data to be updated. | |||
| """ | |||
| if not isinstance(pos, str): | |||
| log.error("Pos should be string. Received: %s", pos) | |||
| raise DebuggerParamValueError("Pos should be string.") | |||
| reply = self.cache_store.get_data(pos) | |||
| return reply | |||
| def search(self, name, watch_point_id): | |||
| """Search for single node in graph.""" | |||
| log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) | |||
| graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name) | |||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( | |||
| graph, watch_point_id) | |||
| return graph | |||
| def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): | |||
| """ | |||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||
| Args: | |||
| name (str): The name of tensor for ui. | |||
| detail (str): Specify which data to query. Current available value is 'data' which means | |||
| concrete tensor data. Histogram or unique count can be supported in the future. | |||
| shape (str): Specify concrete dimensions of shape. | |||
| tolerance (str): Specify tolerance of difference between current step tensor and previous | |||
| step tensor. Default value is 0. | |||
| Raises: | |||
| DebuggerParamValueError, If node type is not parameter or value of detail is not support. | |||
| DebuggerCompareTensorError, If MindSpore is not in waiting state. | |||
| Returns: | |||
| dict, the retrieved data. | |||
| """ | |||
| if self.cache_store.get_stream_handler( | |||
| Streams.METADATA).state != ServerStatus.WAITING.value: | |||
| log.error("Failed to compare tensors as the MindSpore is not in waiting state.") | |||
| raise DebuggerCompareTensorError( | |||
| "Failed to compare tensors as the MindSpore is not in waiting state." | |||
| ) | |||
| self.validate_tensor_param(name, detail) | |||
| parsed_shape = self.parse_shape(shape) | |||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | |||
| tolerance = to_float(tolerance, 'tolerance') | |||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||
| if detail == 'data': | |||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) | |||
| else: | |||
| raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) | |||
| else: | |||
| raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail)) | |||
| return reply | |||
| def retrieve(self, mode, filter_condition=None): | |||
| """ | |||
| Retrieve data according to mode and params. | |||
| Args: | |||
| mode (str): The type of info message. | |||
| filter_condition (dict): The filter condition. | |||
| Returns: | |||
| dict, the retrieved data. | |||
| """ | |||
| log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode, | |||
| filter_condition) | |||
| # validate watchpoint_id | |||
| mode_mapping = { | |||
| 'all': self._retrieve_all, | |||
| 'node': self._retrieve_node, | |||
| 'watchpoint': self._retrieve_watchpoint, | |||
| 'watchpoint_hit': self._retrieve_watchpoint_hit | |||
| } | |||
| # validate param <mode> | |||
| if mode not in mode_mapping.keys(): | |||
| log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', " | |||
| "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping) | |||
| raise DebuggerParamTypeError("Invalid mode.") | |||
| filter_condition = {} if filter_condition is None else filter_condition | |||
| self._watch_point_id = filter_condition.get('watch_point_id', 0) | |||
| reply = mode_mapping[mode](filter_condition) | |||
| return reply | |||
| def _retrieve_all(self, filter_condition=None): | |||
| """Retrieve metadata, root graph and watchpoint list.""" | |||
| if filter_condition: | |||
| log.error("No filter condition required for retrieve all request.") | |||
| raise DebuggerParamTypeError("filter_condition should be empty.") | |||
| result = {} | |||
| self.cache_store.clean_data() | |||
| log.info("Clean data queue cache when retrieve all request.") | |||
| self.cache_store.put_command({'reset': True}) | |||
| for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]: | |||
| sub_res = self.cache_store.get_stream_handler(stream).get() | |||
| result.update(sub_res) | |||
| return result | |||
| def _retrieve_node(self, filter_condition): | |||
| """ | |||
| Retrieve node info. | |||
| Args: | |||
| filter_condition (dict): Filter condition. | |||
| - name (str): The name of single node. | |||
| - watch_point_id (int): The id of watchpoint. | |||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||
| the node list from root node to single node. | |||
| Returns: | |||
| dict, the node info. | |||
| """ | |||
| log.info("Retrieve node %s.", filter_condition) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| # validate parameters | |||
| node_name = filter_condition.get('name') | |||
| if not node_name: | |||
| node_type = NodeTypeEnum.NAME_SCOPE.value | |||
| else: | |||
| node_type = graph_stream.get_node_type(node_name) | |||
| filter_condition['node_type'] = node_type | |||
| filter_condition['single_node'] = bool(filter_condition.get('single_node')) | |||
| # get graph for scope node | |||
| if is_scope_type(node_type): | |||
| reply = self._get_nodes_info(filter_condition) | |||
| # get tensor history for leaf node | |||
| else: | |||
| reply = self._get_tensor_history(node_name) | |||
| if filter_condition.get('single_node'): | |||
| graph = self._get_nodes_info(filter_condition) | |||
| reply.update(graph) | |||
| return reply | |||
| def _get_nodes_info(self, filter_condition): | |||
| """ | |||
| Get nodes info. | |||
| Args: | |||
| filter_condition (dict): The filter condition. | |||
| - name (str): The node name. | |||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||
| the node list from root node to single node. | |||
| - watch_point_id (int): The id of watchpoint. | |||
| Returns: | |||
| dict, reply with graph. | |||
| """ | |||
| # get graph | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| reply = graph_stream.get(filter_condition) | |||
| graph = reply.get('graph') | |||
| # add watched label | |||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( | |||
| graph, self._watch_point_id) | |||
| return reply | |||
| def retrieve_tensor_history(self, node_name): | |||
| """ | |||
| Retrieve tensor history for leaf node. | |||
| Args: | |||
| node_name (str): The name of leaf node. | |||
| Returns: | |||
| dict, the tensor history and metadata. | |||
| """ | |||
| log.info("Retrieve tensor history for node: %s.", node_name) | |||
| self._validate_leaf_name(node_name) | |||
| res = self._get_tensor_history(node_name) | |||
| return res | |||
| def _validate_leaf_name(self, node_name): | |||
| """Validate if the node is a leaf node.""" | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| node_type = graph_stream.get_node_type(node_name) | |||
| if is_scope_type(node_type): | |||
| log.error("Scope type node has no tensor history.") | |||
| raise DebuggerParamValueError("Invalid leaf node name.") | |||
| def _get_tensor_history(self, node_name): | |||
| """ | |||
| Get tensor history for single node. | |||
| Args: | |||
| node_name (str): The name of leaf node. | |||
| Returns: | |||
| dict, the tensor history and metadata. | |||
| """ | |||
| # get basic tensor history | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| tensor_history = graph_stream.get_tensor_history(node_name) | |||
| # set the view event | |||
| self.cache_store.put_command( | |||
| {'reset': True, | |||
| 'node_name': node_name, | |||
| 'tensor_history': tensor_history.get('tensor_history')}) | |||
| # add tensor value for tensor history | |||
| self._add_tensor_value_for_tensor_history(tensor_history) | |||
| # add hit label for tensor history | |||
| watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| watchpoint_hit_stream.update_tensor_history(tensor_history) | |||
| # add metadata | |||
| metadata = self.cache_store.get_stream_handler(Streams.METADATA).get() | |||
| tensor_history.update(metadata) | |||
| return tensor_history | |||
| def _add_tensor_value_for_tensor_history(self, tensor_history): | |||
| """ | |||
| Add tensor value for_tensor_history and send ViewCMD if tensor value missed. | |||
| Args: | |||
| tensor_history (list[dict]): A list of tensor info, including name and type. | |||
| Returns: | |||
| dict, the tensor info. | |||
| """ | |||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||
| missed_tensors = tensor_stream.update_tensor_history(tensor_history) | |||
| if missed_tensors: | |||
| view_cmd = create_view_event_from_tensor_history(missed_tensors) | |||
| self.cache_store.put_command(view_cmd) | |||
| log.debug("Send view cmd.") | |||
| def retrieve_tensor_value(self, name, detail, shape): | |||
| """Retrieve the tensor value.""" | |||
| log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) | |||
| self.validate_tensor_param(name, detail) | |||
| parsed_shape = self.parse_shape(shape) | |||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | |||
| reply = self.cache_store.get_stream_handler(Streams.TENSOR).get( | |||
| {'name': tensor_name, | |||
| 'node_type': node_type, | |||
| 'shape': parsed_shape} | |||
| ) | |||
| reply['tensor_value']['name'] = name | |||
| return reply | |||
| def _get_tensor_name_and_type_by_ui_name(self, name): | |||
| """ | |||
| Get inner tensor name and type by UI name. | |||
| Args: | |||
| name (str): Node name shown in UI. | |||
| Returns: | |||
| str, full name of tensor. | |||
| str, node type of tensor. | |||
| """ | |||
| node_name, slot = name.rsplit(':', 1) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| node_type = graph_stream.get_node_type(node_name) | |||
| full_name = graph_stream.get_full_name(node_name) | |||
| tensor_name = full_name + ':' + slot | |||
| return node_type, tensor_name | |||
| @staticmethod | |||
| def validate_tensor_param(name, detail): | |||
| """Validate params for retrieve tensor request.""" | |||
| # validate name | |||
| if not isinstance(name, str) or ':' not in name: | |||
| log.error("Invalid tensor name. Received: %s", name) | |||
| raise DebuggerParamValueError("Invalid tensor name.") | |||
| # validate data | |||
| if detail != 'data': | |||
| log.error("Invalid detail value. Received: %s", detail) | |||
| raise DebuggerParamValueError("Invalid detail value.") | |||
| @staticmethod | |||
| def parse_shape(shape): | |||
| """Parse shape.""" | |||
| if shape is None: | |||
| return shape | |||
| if not (isinstance(shape, str) and shape.startswith('[') and shape.endswith(']')): | |||
| log.error("Invalid shape. Received: %s", shape) | |||
| raise DebuggerParamValueError("Invalid shape.") | |||
| shape = shape.strip('[]') | |||
| if shape.count(':') > 2: | |||
| log.error("Invalid shape. At most two dimensions are specified.") | |||
| raise DebuggerParamValueError("Invalid shape.") | |||
| parsed_shape = tuple( | |||
| str_to_slice_or_int(dim) for dim in shape.split(',')) if shape else tuple() | |||
| log.info("Parsed shape: %s from %s", parsed_shape, shape) | |||
| return parsed_shape | |||
| def _retrieve_watchpoint(self, filter_condition): | |||
| """ | |||
| Retrieve watchpoint. | |||
| Args: | |||
| filter_condition (dict): Filter condition. | |||
| - watch_point_id (int): The id of watchoint. If not given, return all watchpoints. | |||
| - name (str): The name of single node. | |||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||
| the node list from root node to single node. | |||
| Returns: | |||
| dict, watch point list or relative graph. | |||
| """ | |||
| watchpoint_id = filter_condition.get('watch_point_id') | |||
| if watchpoint_id is None: | |||
| reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get() | |||
| log.debug("Get condition of watchpoints.") | |||
| else: | |||
| reply = self._retrieve_node(filter_condition) | |||
| log.debug("Get graph of %d-th watchpoint.", watchpoint_id) | |||
| return reply | |||
| def _retrieve_watchpoint_hit(self, filter_condition): | |||
| """ | |||
| Retrieve watchpoint hit. | |||
| Args: | |||
| filter_condition (dict): Filter condition. | |||
| - name (str): The name of single node. | |||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||
| the node list from root node to single node. | |||
| Returns: | |||
| dict, watch point list or relative graph. | |||
| """ | |||
| node_name = filter_condition.get('name') | |||
| # get watchpoint hit list | |||
| if node_name is None: | |||
| reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() | |||
| return reply | |||
| self._validate_leaf_name(node_name) | |||
| # get tensor history | |||
| reply = self._get_tensor_history(node_name) | |||
| log.debug("Get tensor history for watchpoint hit node.") | |||
| # get single graph | |||
| if filter_condition.get('single_node'): | |||
| graph = self._get_nodes_info(filter_condition) | |||
| reply.update(graph) | |||
| log.debug("Get tensor history for watchpoint hit node.") | |||
| return reply | |||
| def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None): | |||
| """ | |||
| Create watchpoint. | |||
| Args: | |||
| watch_condition (dict): The watch condition. | |||
| - condition (str): Accept `INF` or `NAN`. | |||
| - param (list[float]): Not defined yet. | |||
| watch_nodes (list[str]): The list of node names. | |||
| watch_point_id (int): The id of watchpoint. | |||
| Returns: | |||
| dict, the id of new watchpoint. | |||
| """ | |||
| log.info("Received create watchpoint request. WatchCondition: %s", watch_condition) | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| log.error("Failed to create watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerCreateWatchPointError( | |||
| "Failed to create watchpoint as the MindSpore is not in waiting state." | |||
| ) | |||
| if metadata_stream.backend == 'GPU' and watch_condition.get('condition') == 'OVERFLOW': | |||
| log.error("GPU doesn't support OVERFLOW watch condition.") | |||
| raise DebuggerParamValueError("GPU doesn't support OVERFLOW watch condition.") | |||
| watch_nodes = self._get_node_basic_infos(watch_nodes) | |||
| watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint( | |||
| watch_condition, watch_nodes, watch_point_id) | |||
| log.info("Create watchpoint %d", watch_point_id) | |||
| return {'id': watch_point_id} | |||
| def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None): | |||
| """ | |||
| Update watchpoint. | |||
| Args: | |||
| watch_point_id (int): The id of watchpoint. | |||
| watch_nodes (list[str]): The list of node names. | |||
| mode (int): The update operator on nodes. 0 for remove nodes from watch nodes. | |||
| 1 for add nodes to watch nodes. | |||
| name (str): The search name. Default: None. | |||
| Returns: | |||
| dict, empty response. | |||
| """ | |||
| if self.cache_store.get_stream_handler( | |||
| Streams.METADATA).state != ServerStatus.WAITING.value: | |||
| log.error("Failed to update watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerUpdateWatchPointError( | |||
| "Failed to update watchpoint as the MindSpore is not in waiting state." | |||
| ) | |||
| # validate | |||
| if not watch_nodes or not watch_point_id: | |||
| log.error("Invalid parameter for update watchpoint.") | |||
| raise DebuggerParamValueError("Invalid parameter for update watchpoint.") | |||
| # update watch node | |||
| if name is not None: | |||
| watch_nodes = self._get_watch_nodes_by_search(watch_nodes) | |||
| elif mode == 1: | |||
| watch_nodes = self._get_node_basic_infos(watch_nodes) | |||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint( | |||
| watch_point_id, watch_nodes, mode) | |||
| log.info("Update watchpoint with id: %d", watch_point_id) | |||
| return {} | |||
| def _get_watch_nodes_by_search(self, watch_nodes): | |||
| """Get watched leaf nodes by search name.""" | |||
| watched_leaf_nodes = [] | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| for search_name in watch_nodes: | |||
| search_nodes = graph_stream.get_searched_node_list() | |||
| search_node_names = [ | |||
| NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | |||
| for node in search_nodes | |||
| if node.name.startswith(search_name)] | |||
| watched_leaf_nodes.extend(search_node_names) | |||
| log.debug("Update nodes: %s", watched_leaf_nodes) | |||
| return watched_leaf_nodes | |||
| def delete_watchpoint(self, watch_point_id): | |||
| """ | |||
| Delete watchpoint. | |||
| Args: | |||
| watch_point_id (int): The id of watchpoint. | |||
| Returns: | |||
| dict, empty response. | |||
| """ | |||
| if self.cache_store.get_stream_handler( | |||
| Streams.METADATA).state != ServerStatus.WAITING.value: | |||
| log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerDeleteWatchPointError( | |||
| "Failed to delete watchpoint as the MindSpore is not in waiting state." | |||
| ) | |||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint( | |||
| watch_point_id) | |||
| log.info("Delete watchpoint with id: %d", watch_point_id) | |||
| return {} | |||
| def _get_node_basic_infos(self, node_names): | |||
| """Get node info according to node names.""" | |||
| if not node_names: | |||
| return [] | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| node_infos = [] | |||
| for node_name in node_names: | |||
| node_type = graph_stream.get_node_type(node_name) | |||
| # optimizer later | |||
| if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | |||
| sub_nodes = graph_stream.get_nodes(node_name) | |||
| sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | |||
| for node in sub_nodes] | |||
| node_infos.extend(sub_infos) | |||
| continue | |||
| full_name = graph_stream.get_full_name(node_name) | |||
| node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type)) | |||
| return node_infos | |||
| def control(self, params=None): | |||
| """ | |||
| Control the training process. | |||
| Args: | |||
| params (dict): The control params. | |||
| - mode (str): Acceptable control command, including `continue`, | |||
| `pause` and `terminate`. | |||
| - level (str): The control granularity, `node` level or `step` level. | |||
| Default: `step`. | |||
| - steps (int): Specify the steps that training should run. | |||
| Used when `level` is `step`. | |||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||
| Returns: | |||
| dict, the response. | |||
| """ | |||
| log.info("Receive control request: %s.", params) | |||
| mode = params.get('mode') | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| if mode == 'continue': | |||
| reply = self._continue(metadata_stream, params) | |||
| elif mode in ['pause', 'terminate']: | |||
| mode_mapping = { | |||
| 'pause': self._pause, | |||
| 'terminate': self._terminate | |||
| } | |||
| reply = mode_mapping.get(mode)(metadata_stream) | |||
| else: | |||
| log.error("Invalid control mode %s", mode) | |||
| raise DebuggerParamValueError("Invalid control mode.") | |||
| return reply | |||
| def _continue(self, metadata_stream, params): | |||
| """ | |||
| Send RunCMD to MindSpore. | |||
| Args: | |||
| metadata_stream (MetadataHandler): The metadata_handler | |||
| params (dict): The control params. | |||
| """ | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) | |||
| raise DebuggerContinueError( | |||
| "MindSpore is not ready to run or is running currently." | |||
| ) | |||
| metadata_stream.state = ServerStatus.RUNNING.value | |||
| current_state = ServerStatus.RUNNING.value | |||
| try: | |||
| event = self._construct_run_event(params) | |||
| self._send_watchpoints() | |||
| self.cache_store.put_command(event) | |||
| except MindInsightException as err: | |||
| log.error("Failed to send run event.") | |||
| log.exception(err) | |||
| current_state = ServerStatus.WAITING.value | |||
| metadata_stream.state = current_state | |||
| raise DebuggerContinueError("Failed to send run command.") | |||
| else: | |||
| log.debug("Send the RunCMD to command queue.") | |||
| return {'metadata': {'state': current_state}} | |||
| def _validate_node_type(self, node_name): | |||
| """Check the node type in node control.""" | |||
| if not node_name: | |||
| return | |||
| node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) | |||
| unsupported_types = [item.value for item in list(NodeTypeEnum)] | |||
| if node_type in unsupported_types: | |||
| log.error("Invalid node type. %s", node_name) | |||
| raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for " | |||
| "continue to command.") | |||
| def _construct_run_event(self, params): | |||
| """ | |||
| Construct run cmd from input control params. | |||
| Args: | |||
| params (dict): The control params. | |||
| - level (str): The control granularity, `node` level or `step` level. | |||
| Default: `step`. | |||
| - steps (int): Specify the steps that training should run. | |||
| Used when `level` is `step`. | |||
| - full_name (str): Specify the name of the node. Used when `level` is `node`. | |||
| Returns: | |||
| EventReply, control event with run command. | |||
| """ | |||
| level = params.get('level', 'step') | |||
| event = get_ack_reply() | |||
| if level == 'step': | |||
| steps = params.get('steps') | |||
| if not steps: | |||
| steps = 1 | |||
| run_cmd = RunCMD(run_level='step', run_steps=steps) | |||
| elif level == 'node': | |||
| self._validate_node_type(params.get('name')) | |||
| name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name( | |||
| params['name']) | |||
| if not name: | |||
| name = '' | |||
| run_cmd = RunCMD(run_level='node', node_name=name) | |||
| else: | |||
| log.error("Invalid Value. `level` should be `step` or `node`. Got %s", level) | |||
| raise DebuggerParamValueError("level` should be `step` or `node`") | |||
| event.run_cmd.CopyFrom(run_cmd) | |||
| log.debug("Construct run event. %s", event) | |||
| return event | |||
| def _send_watchpoints(self): | |||
| """Set watchpoints.""" | |||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| watchpoints = watchpoint_stream.get(filter_condition=True).get('watch_points') | |||
| if watchpoints: | |||
| for watchpoint in watchpoints: | |||
| event = get_ack_reply() | |||
| event.set_cmd.CopyFrom(watchpoint) | |||
| self.cache_store.put_command(event) | |||
| watchpoint_stream.sync_set_cmd() | |||
| log.debug("Send SetCMD to MindSpore. %s", event) | |||
| def _pause(self, metadata_stream): | |||
| """ | |||
| Pause the training. | |||
| Args: | |||
| metadata_stream (MetadataHandler): The metadata stream handler. | |||
| """ | |||
| if metadata_stream.state != ServerStatus.RUNNING.value: | |||
| log.error("The MindSpore is not running.") | |||
| raise DebuggerPauseError("The MindSpore is not running.") | |||
| metadata_stream.state = 'waiting' | |||
| event = get_ack_reply() | |||
| event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) | |||
| self.cache_store.put_command(event) | |||
| log.debug("Send the Pause command") | |||
| return {'metadata': {'state': 'waiting'}} | |||
| def _terminate(self, metadata_stream): | |||
| """ | |||
| Terminate the training. | |||
| Args: | |||
| metadata_stream (MetadataHandler): The metadata stream handler. | |||
| """ | |||
| metadata_stream.state = 'pending' | |||
| event = get_ack_reply() | |||
| event.exit = True | |||
| self.cache_store.put_command(event) | |||
| log.debug("Send the ExitCMD.") | |||
| return {'metadata': {'state': 'pending'}} | |||
| def retrieve_node_by_bfs(self, node_name, ascend=False): | |||
| """Get the graph and tensor history of the next node name according to node_name.""" | |||
| log.info("Retrieve node <%s> by bfs, `ascend` is :%s", | |||
| node_name, ascend) | |||
| reply = {} | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) | |||
| # no next node | |||
| if next_node_name is None: | |||
| return reply | |||
| # add graph and tensor history for next node | |||
| filter_condition = { | |||
| 'name': next_node_name, | |||
| 'single_node': True | |||
| } | |||
| search_graph = self._get_nodes_info(filter_condition) | |||
| tensor_history = self._get_tensor_history(next_node_name) | |||
| reply = {'name': next_node_name} | |||
| reply.update(search_graph) | |||
| reply.update(tensor_history) | |||
| return reply | |||
| @@ -0,0 +1,113 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto3"; | |||
| package debugger; | |||
| import "mindinsight/debugger/proto/ms_graph.proto"; | |||
| service EventListener { | |||
| rpc WaitCMD (Metadata) returns (EventReply) {}; | |||
| rpc SendMetadata (Metadata) returns (EventReply) {}; | |||
| rpc SendGraph (stream Chunk) returns (EventReply) {}; | |||
| rpc SendTensors (stream TensorProto) returns (EventReply) {}; | |||
| rpc SendWatchpointHits (stream WatchpointHit) returns (EventReply) {}; | |||
| } | |||
| message Metadata { | |||
| string device_name = 1; | |||
| int32 cur_step = 2; | |||
| // define the backend is 'GPU' or 'Ascend' | |||
| string backend = 3; | |||
| // the full name of current node | |||
| string cur_node = 4; | |||
| // check if training is done. | |||
| bool training_done = 5; | |||
| } | |||
| message Chunk { | |||
| bytes buffer = 1; | |||
| } | |||
| message EventReply { | |||
| enum Status { | |||
| OK = 0; | |||
| FAILED = 1; | |||
| PENDING = 2; | |||
| } | |||
| Status status = 1; | |||
| oneof cmd { | |||
| bool exit = 2; | |||
| RunCMD run_cmd = 3; | |||
| SetCMD set_cmd = 4; | |||
| ViewCMD view_cmd = 5; | |||
| } | |||
| } | |||
| message RunCMD { | |||
| // running level. 'step' or 'node' | |||
| string run_level = 1; | |||
| oneof cmd { | |||
| int32 run_steps = 2; | |||
| // the full name of next node | |||
| string node_name = 3; | |||
| } | |||
| } | |||
| message SetCMD { | |||
| repeated WatchNode watch_nodes = 1; | |||
| WatchCondition watch_condition = 2; | |||
| bool delete = 3; | |||
| int32 id = 4; | |||
| } | |||
| message ViewCMD { | |||
| repeated TensorProto tensors = 1; | |||
| } | |||
| message WatchCondition { | |||
| enum Condition { | |||
| nan = 0; | |||
| inf = 1; | |||
| overflow = 2; | |||
| max_gt = 3; | |||
| max_lt = 4; | |||
| min_gt = 5; | |||
| min_lt = 6; | |||
| max_min_gt = 7; | |||
| max_min_lt = 8; | |||
| mean_gt = 9; | |||
| mean_lt = 10; | |||
| } | |||
| Condition condition = 1; | |||
| float value = 2; // for between condition, there will be two values | |||
| } | |||
| message WatchNode { | |||
| string node_name = 1; | |||
| string node_type = 2; | |||
| } | |||
| message WatchpointHit { | |||
| TensorProto tensor = 1; | |||
| WatchCondition watch_condition = 2; | |||
| int32 id = 3; | |||
| } | |||
| @@ -0,0 +1,683 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # Generated by the protocol buffer compiler. DO NOT EDIT! | |||
| # source: mindinsight/debugger/proto/debug_grpc.proto | |||
| from google.protobuf import descriptor as _descriptor | |||
| from google.protobuf import message as _message | |||
| from google.protobuf import reflection as _reflection | |||
| from google.protobuf import symbol_database as _symbol_database | |||
| # @@protoc_insertion_point(imports) | |||
| _sym_db = _symbol_database.Default() | |||
| from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2 | |||
| DESCRIPTOR = _descriptor.FileDescriptor( | |||
| name='mindinsight/debugger/proto/debug_grpc.proto', | |||
| package='debugger', | |||
| syntax='proto3', | |||
| serialized_options=None, | |||
| serialized_pb=b'\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"k\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\"\x17\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\"\xec\x01\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\xee\x01\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\"\x95\x01\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x07\n\x03inf\x10\x01\x12\x0c\n\x08overflow\x10\x02\x12\n\n\x06max_gt\x10\x03\x12\n\n\x06max_lt\x10\x04\x12\n\n\x06min_gt\x10\x05\x12\n\n\x06min_lt\x10\x06\x12\x0e\n\nmax_min_gt\x10\x07\x12\x0e\n\nmax_min_lt\x10\x08\x12\x0b\n\x07mean_gt\x10\t\x12\x0b\n\x07mean_lt\x10\n\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"u\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x32\xc3\x02\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3' | |||
| , | |||
| dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) | |||
| _EVENTREPLY_STATUS = _descriptor.EnumDescriptor( | |||
| name='Status', | |||
| full_name='debugger.EventReply.Status', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| values=[ | |||
| _descriptor.EnumValueDescriptor( | |||
| name='OK', index=0, number=0, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='FAILED', index=1, number=1, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='PENDING', index=2, number=2, | |||
| serialized_options=None, | |||
| type=None), | |||
| ], | |||
| containing_type=None, | |||
| serialized_options=None, | |||
| serialized_start=423, | |||
| serialized_end=464, | |||
| ) | |||
| _sym_db.RegisterEnumDescriptor(_EVENTREPLY_STATUS) | |||
| _WATCHCONDITION_CONDITION = _descriptor.EnumDescriptor( | |||
| name='Condition', | |||
| full_name='debugger.WatchCondition.Condition', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| values=[ | |||
| _descriptor.EnumValueDescriptor( | |||
| name='nan', index=0, number=0, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='inf', index=1, number=1, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='overflow', index=2, number=2, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='max_gt', index=3, number=3, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='max_lt', index=4, number=4, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='min_gt', index=5, number=5, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='min_lt', index=6, number=6, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='max_min_gt', index=7, number=7, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='max_min_lt', index=8, number=8, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='mean_gt', index=9, number=9, | |||
| serialized_options=None, | |||
| type=None), | |||
| _descriptor.EnumValueDescriptor( | |||
| name='mean_lt', index=10, number=10, | |||
| serialized_options=None, | |||
| type=None), | |||
| ], | |||
| containing_type=None, | |||
| serialized_options=None, | |||
| serialized_start=824, | |||
| serialized_end=973, | |||
| ) | |||
| _sym_db.RegisterEnumDescriptor(_WATCHCONDITION_CONDITION) | |||
| _METADATA = _descriptor.Descriptor( | |||
| name='Metadata', | |||
| full_name='debugger.Metadata', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='device_name', full_name='debugger.Metadata.device_name', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='cur_step', full_name='debugger.Metadata.cur_step', index=1, | |||
| number=2, type=5, cpp_type=1, label=1, | |||
| has_default_value=False, default_value=0, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='backend', full_name='debugger.Metadata.backend', index=2, | |||
| number=3, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='cur_node', full_name='debugger.Metadata.cur_node', index=3, | |||
| number=4, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='training_done', full_name='debugger.Metadata.training_done', index=4, | |||
| number=5, type=8, cpp_type=7, label=1, | |||
| has_default_value=False, default_value=False, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=100, | |||
| serialized_end=207, | |||
| ) | |||
| _CHUNK = _descriptor.Descriptor( | |||
| name='Chunk', | |||
| full_name='debugger.Chunk', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='buffer', full_name='debugger.Chunk.buffer', index=0, | |||
| number=1, type=12, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"", | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=209, | |||
| serialized_end=232, | |||
| ) | |||
| _EVENTREPLY = _descriptor.Descriptor( | |||
| name='EventReply', | |||
| full_name='debugger.EventReply', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='status', full_name='debugger.EventReply.status', index=0, | |||
| number=1, type=14, cpp_type=8, label=1, | |||
| has_default_value=False, default_value=0, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='exit', full_name='debugger.EventReply.exit', index=1, | |||
| number=2, type=8, cpp_type=7, label=1, | |||
| has_default_value=False, default_value=False, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='run_cmd', full_name='debugger.EventReply.run_cmd', index=2, | |||
| number=3, type=11, cpp_type=10, label=1, | |||
| has_default_value=False, default_value=None, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='set_cmd', full_name='debugger.EventReply.set_cmd', index=3, | |||
| number=4, type=11, cpp_type=10, label=1, | |||
| has_default_value=False, default_value=None, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='view_cmd', full_name='debugger.EventReply.view_cmd', index=4, | |||
| number=5, type=11, cpp_type=10, label=1, | |||
| has_default_value=False, default_value=None, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| _EVENTREPLY_STATUS, | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| _descriptor.OneofDescriptor( | |||
| name='cmd', full_name='debugger.EventReply.cmd', | |||
| index=0, containing_type=None, fields=[]), | |||
| ], | |||
| serialized_start=235, | |||
| serialized_end=471, | |||
| ) | |||
| _RUNCMD = _descriptor.Descriptor( | |||
| name='RunCMD', | |||
| full_name='debugger.RunCMD', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='run_level', full_name='debugger.RunCMD.run_level', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='run_steps', full_name='debugger.RunCMD.run_steps', index=1, | |||
| number=2, type=5, cpp_type=1, label=1, | |||
| has_default_value=False, default_value=0, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='node_name', full_name='debugger.RunCMD.node_name', index=2, | |||
| number=3, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| _descriptor.OneofDescriptor( | |||
| name='cmd', full_name='debugger.RunCMD.cmd', | |||
| index=0, containing_type=None, fields=[]), | |||
| ], | |||
| serialized_start=473, | |||
| serialized_end=549, | |||
| ) | |||
| _SETCMD = _descriptor.Descriptor( | |||
| name='SetCMD', | |||
| full_name='debugger.SetCMD', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='watch_nodes', full_name='debugger.SetCMD.watch_nodes', index=0, | |||
| number=1, type=11, cpp_type=10, label=3, | |||
| has_default_value=False, default_value=[], | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='watch_condition', full_name='debugger.SetCMD.watch_condition', index=1, | |||
| number=2, type=11, cpp_type=10, label=1, | |||
| has_default_value=False, default_value=None, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='delete', full_name='debugger.SetCMD.delete', index=2, | |||
| number=3, type=8, cpp_type=7, label=1, | |||
| has_default_value=False, default_value=False, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='id', full_name='debugger.SetCMD.id', index=3, | |||
| number=4, type=5, cpp_type=1, label=1, | |||
| has_default_value=False, default_value=0, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=552, | |||
| serialized_end=681, | |||
| ) | |||
| _VIEWCMD = _descriptor.Descriptor( | |||
| name='ViewCMD', | |||
| full_name='debugger.ViewCMD', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='tensors', full_name='debugger.ViewCMD.tensors', index=0, | |||
| number=1, type=11, cpp_type=10, label=3, | |||
| has_default_value=False, default_value=[], | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=683, | |||
| serialized_end=732, | |||
| ) | |||
| _WATCHCONDITION = _descriptor.Descriptor( | |||
| name='WatchCondition', | |||
| full_name='debugger.WatchCondition', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='condition', full_name='debugger.WatchCondition.condition', index=0, | |||
| number=1, type=14, cpp_type=8, label=1, | |||
| has_default_value=False, default_value=0, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='value', full_name='debugger.WatchCondition.value', index=1, | |||
| number=2, type=2, cpp_type=6, label=1, | |||
| has_default_value=False, default_value=float(0), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| _WATCHCONDITION_CONDITION, | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=735, | |||
| serialized_end=973, | |||
| ) | |||
| _WATCHNODE = _descriptor.Descriptor( | |||
| name='WatchNode', | |||
| full_name='debugger.WatchNode', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='node_name', full_name='debugger.WatchNode.node_name', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='node_type', full_name='debugger.WatchNode.node_type', index=1, | |||
| number=2, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=975, | |||
| serialized_end=1024, | |||
| ) | |||
| _WATCHPOINTHIT = _descriptor.Descriptor( | |||
| name='WatchpointHit', | |||
| full_name='debugger.WatchpointHit', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='tensor', full_name='debugger.WatchpointHit.tensor', index=0, | |||
| number=1, type=11, cpp_type=10, label=1, | |||
| has_default_value=False, default_value=None, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='watch_condition', full_name='debugger.WatchpointHit.watch_condition', index=1, | |||
| number=2, type=11, cpp_type=10, label=1, | |||
| has_default_value=False, default_value=None, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| _descriptor.FieldDescriptor( | |||
| name='id', full_name='debugger.WatchpointHit.id', index=2, | |||
| number=3, type=5, cpp_type=1, label=1, | |||
| has_default_value=False, default_value=0, | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=1026, | |||
| serialized_end=1143, | |||
| ) | |||
| _EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS | |||
| _EVENTREPLY.fields_by_name['run_cmd'].message_type = _RUNCMD | |||
| _EVENTREPLY.fields_by_name['set_cmd'].message_type = _SETCMD | |||
| _EVENTREPLY.fields_by_name['view_cmd'].message_type = _VIEWCMD | |||
| _EVENTREPLY_STATUS.containing_type = _EVENTREPLY | |||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||
| _EVENTREPLY.fields_by_name['exit']) | |||
| _EVENTREPLY.fields_by_name['exit'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||
| _EVENTREPLY.fields_by_name['run_cmd']) | |||
| _EVENTREPLY.fields_by_name['run_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||
| _EVENTREPLY.fields_by_name['set_cmd']) | |||
| _EVENTREPLY.fields_by_name['set_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||
| _EVENTREPLY.fields_by_name['view_cmd']) | |||
| _EVENTREPLY.fields_by_name['view_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||
| _RUNCMD.oneofs_by_name['cmd'].fields.append( | |||
| _RUNCMD.fields_by_name['run_steps']) | |||
| _RUNCMD.fields_by_name['run_steps'].containing_oneof = _RUNCMD.oneofs_by_name['cmd'] | |||
| _RUNCMD.oneofs_by_name['cmd'].fields.append( | |||
| _RUNCMD.fields_by_name['node_name']) | |||
| _RUNCMD.fields_by_name['node_name'].containing_oneof = _RUNCMD.oneofs_by_name['cmd'] | |||
| _SETCMD.fields_by_name['watch_nodes'].message_type = _WATCHNODE | |||
| _SETCMD.fields_by_name['watch_condition'].message_type = _WATCHCONDITION | |||
| _VIEWCMD.fields_by_name['tensors'].message_type = mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO | |||
| _WATCHCONDITION.fields_by_name['condition'].enum_type = _WATCHCONDITION_CONDITION | |||
| _WATCHCONDITION_CONDITION.containing_type = _WATCHCONDITION | |||
| _WATCHPOINTHIT.fields_by_name['tensor'].message_type = mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO | |||
| _WATCHPOINTHIT.fields_by_name['watch_condition'].message_type = _WATCHCONDITION | |||
| DESCRIPTOR.message_types_by_name['Metadata'] = _METADATA | |||
| DESCRIPTOR.message_types_by_name['Chunk'] = _CHUNK | |||
| DESCRIPTOR.message_types_by_name['EventReply'] = _EVENTREPLY | |||
| DESCRIPTOR.message_types_by_name['RunCMD'] = _RUNCMD | |||
| DESCRIPTOR.message_types_by_name['SetCMD'] = _SETCMD | |||
| DESCRIPTOR.message_types_by_name['ViewCMD'] = _VIEWCMD | |||
| DESCRIPTOR.message_types_by_name['WatchCondition'] = _WATCHCONDITION | |||
| DESCRIPTOR.message_types_by_name['WatchNode'] = _WATCHNODE | |||
| DESCRIPTOR.message_types_by_name['WatchpointHit'] = _WATCHPOINTHIT | |||
| _sym_db.RegisterFileDescriptor(DESCRIPTOR) | |||
| Metadata = _reflection.GeneratedProtocolMessageType('Metadata', (_message.Message,), { | |||
| 'DESCRIPTOR' : _METADATA, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.Metadata) | |||
| }) | |||
| _sym_db.RegisterMessage(Metadata) | |||
| Chunk = _reflection.GeneratedProtocolMessageType('Chunk', (_message.Message,), { | |||
| 'DESCRIPTOR' : _CHUNK, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.Chunk) | |||
| }) | |||
| _sym_db.RegisterMessage(Chunk) | |||
| EventReply = _reflection.GeneratedProtocolMessageType('EventReply', (_message.Message,), { | |||
| 'DESCRIPTOR' : _EVENTREPLY, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.EventReply) | |||
| }) | |||
| _sym_db.RegisterMessage(EventReply) | |||
| RunCMD = _reflection.GeneratedProtocolMessageType('RunCMD', (_message.Message,), { | |||
| 'DESCRIPTOR' : _RUNCMD, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.RunCMD) | |||
| }) | |||
| _sym_db.RegisterMessage(RunCMD) | |||
| SetCMD = _reflection.GeneratedProtocolMessageType('SetCMD', (_message.Message,), { | |||
| 'DESCRIPTOR' : _SETCMD, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.SetCMD) | |||
| }) | |||
| _sym_db.RegisterMessage(SetCMD) | |||
| ViewCMD = _reflection.GeneratedProtocolMessageType('ViewCMD', (_message.Message,), { | |||
| 'DESCRIPTOR' : _VIEWCMD, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.ViewCMD) | |||
| }) | |||
| _sym_db.RegisterMessage(ViewCMD) | |||
| WatchCondition = _reflection.GeneratedProtocolMessageType('WatchCondition', (_message.Message,), { | |||
| 'DESCRIPTOR' : _WATCHCONDITION, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.WatchCondition) | |||
| }) | |||
| _sym_db.RegisterMessage(WatchCondition) | |||
| WatchNode = _reflection.GeneratedProtocolMessageType('WatchNode', (_message.Message,), { | |||
| 'DESCRIPTOR' : _WATCHNODE, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.WatchNode) | |||
| }) | |||
| _sym_db.RegisterMessage(WatchNode) | |||
| WatchpointHit = _reflection.GeneratedProtocolMessageType('WatchpointHit', (_message.Message,), { | |||
| 'DESCRIPTOR' : _WATCHPOINTHIT, | |||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||
| # @@protoc_insertion_point(class_scope:debugger.WatchpointHit) | |||
| }) | |||
| _sym_db.RegisterMessage(WatchpointHit) | |||
| _EVENTLISTENER = _descriptor.ServiceDescriptor( | |||
| name='EventListener', | |||
| full_name='debugger.EventListener', | |||
| file=DESCRIPTOR, | |||
| index=0, | |||
| serialized_options=None, | |||
| serialized_start=1146, | |||
| serialized_end=1469, | |||
| methods=[ | |||
| _descriptor.MethodDescriptor( | |||
| name='WaitCMD', | |||
| full_name='debugger.EventListener.WaitCMD', | |||
| index=0, | |||
| containing_service=None, | |||
| input_type=_METADATA, | |||
| output_type=_EVENTREPLY, | |||
| serialized_options=None, | |||
| ), | |||
| _descriptor.MethodDescriptor( | |||
| name='SendMetadata', | |||
| full_name='debugger.EventListener.SendMetadata', | |||
| index=1, | |||
| containing_service=None, | |||
| input_type=_METADATA, | |||
| output_type=_EVENTREPLY, | |||
| serialized_options=None, | |||
| ), | |||
| _descriptor.MethodDescriptor( | |||
| name='SendGraph', | |||
| full_name='debugger.EventListener.SendGraph', | |||
| index=2, | |||
| containing_service=None, | |||
| input_type=_CHUNK, | |||
| output_type=_EVENTREPLY, | |||
| serialized_options=None, | |||
| ), | |||
| _descriptor.MethodDescriptor( | |||
| name='SendTensors', | |||
| full_name='debugger.EventListener.SendTensors', | |||
| index=3, | |||
| containing_service=None, | |||
| input_type=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO, | |||
| output_type=_EVENTREPLY, | |||
| serialized_options=None, | |||
| ), | |||
| _descriptor.MethodDescriptor( | |||
| name='SendWatchpointHits', | |||
| full_name='debugger.EventListener.SendWatchpointHits', | |||
| index=4, | |||
| containing_service=None, | |||
| input_type=_WATCHPOINTHIT, | |||
| output_type=_EVENTREPLY, | |||
| serialized_options=None, | |||
| ), | |||
| ]) | |||
| _sym_db.RegisterServiceDescriptor(_EVENTLISTENER) | |||
| DESCRIPTOR.services_by_name['EventListener'] = _EVENTLISTENER | |||
| # @@protoc_insertion_point(module_scope) | |||
| @@ -0,0 +1,193 @@ | |||
| # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | |||
| import grpc | |||
| from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2 | |||
| from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2 | |||
| class EventListenerStub(object): | |||
| """Missing associated documentation comment in .proto file""" | |||
| def __init__(self, channel): | |||
| """Constructor. | |||
| Args: | |||
| channel: A grpc.Channel. | |||
| """ | |||
| self.WaitCMD = channel.unary_unary( | |||
| '/debugger.EventListener/WaitCMD', | |||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| ) | |||
| self.SendMetadata = channel.unary_unary( | |||
| '/debugger.EventListener/SendMetadata', | |||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| ) | |||
| self.SendGraph = channel.stream_unary( | |||
| '/debugger.EventListener/SendGraph', | |||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | |||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| ) | |||
| self.SendTensors = channel.stream_unary( | |||
| '/debugger.EventListener/SendTensors', | |||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, | |||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| ) | |||
| self.SendWatchpointHits = channel.stream_unary( | |||
| '/debugger.EventListener/SendWatchpointHits', | |||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, | |||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| ) | |||
| class EventListenerServicer(object): | |||
| """Missing associated documentation comment in .proto file""" | |||
| def WaitCMD(self, request, context): | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendMetadata(self, request, context): | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendGraph(self, request_iterator, context): | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendTensors(self, request_iterator, context): | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def SendWatchpointHits(self, request_iterator, context): | |||
| """Missing associated documentation comment in .proto file""" | |||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||
| context.set_details('Method not implemented!') | |||
| raise NotImplementedError('Method not implemented!') | |||
| def add_EventListenerServicer_to_server(servicer, server): | |||
| rpc_method_handlers = { | |||
| 'WaitCMD': grpc.unary_unary_rpc_method_handler( | |||
| servicer.WaitCMD, | |||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.FromString, | |||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||
| ), | |||
| 'SendMetadata': grpc.unary_unary_rpc_method_handler( | |||
| servicer.SendMetadata, | |||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.FromString, | |||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||
| ), | |||
| 'SendGraph': grpc.stream_unary_rpc_method_handler( | |||
| servicer.SendGraph, | |||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.FromString, | |||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||
| ), | |||
| 'SendTensors': grpc.stream_unary_rpc_method_handler( | |||
| servicer.SendTensors, | |||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.FromString, | |||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||
| ), | |||
| 'SendWatchpointHits': grpc.stream_unary_rpc_method_handler( | |||
| servicer.SendWatchpointHits, | |||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.FromString, | |||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||
| ), | |||
| } | |||
| generic_handler = grpc.method_handlers_generic_handler( | |||
| 'debugger.EventListener', rpc_method_handlers) | |||
| server.add_generic_rpc_handlers((generic_handler,)) | |||
| # This class is part of an EXPERIMENTAL API. | |||
| class EventListener(object): | |||
| """Missing associated documentation comment in .proto file""" | |||
| @staticmethod | |||
| def WaitCMD(request, | |||
| target, | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| metadata=None): | |||
| return grpc.experimental.unary_unary(request, target, '/debugger.EventListener/WaitCMD', | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendMetadata(request, | |||
| target, | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| metadata=None): | |||
| return grpc.experimental.unary_unary(request, target, '/debugger.EventListener/SendMetadata', | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendGraph(request_iterator, | |||
| target, | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| metadata=None): | |||
| return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendGraph', | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendTensors(request_iterator, | |||
| target, | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| metadata=None): | |||
| return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendTensors', | |||
| mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @staticmethod | |||
| def SendWatchpointHits(request_iterator, | |||
| target, | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| metadata=None): | |||
| return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendWatchpointHits', | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, | |||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||
| options, channel_credentials, | |||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @@ -0,0 +1,322 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| syntax = "proto2"; | |||
| package debugger; | |||
| // Versioning | |||
| enum Version { | |||
| // unknown version | |||
| UNKNOWWN_VERSION = 0; | |||
| // Initial version (IR VERSION 1), published on Sep 23, 2019 | |||
| IR_VERSION = 0x0000000000000001; | |||
| } | |||
| // Data type definition | |||
| enum DataType { | |||
| DT_UNDEFINED = 0; | |||
| // Basic types. | |||
| DT_BOOL = 1; // bool | |||
| DT_INT8 = 2; // int8_t | |||
| DT_INT16 = 3; // int16_t | |||
| DT_INT32 = 4; // int32_t | |||
| DT_INT64 = 5; // int64_t | |||
| DT_UINT8 = 6; // uint8_t | |||
| DT_UINT16 = 7; // uint16_t | |||
| DT_UINT32 = 8; // uint32_t | |||
| DT_UINT64 = 9; // uint64_t | |||
| DT_FLOAT16 = 10; // float 16 | |||
| DT_FLOAT32 = 11; // float 32 | |||
| DT_FLOAT64 = 12; // float 64 | |||
| DT_STRING = 13; // string | |||
| DT_TENSOR = 14; // tensor | |||
| DT_GRAPH = 15; // graph | |||
| // list type | |||
| DT_BOOLS = 16; // list of bool | |||
| DT_INTS8 = 17; // list of int8_t | |||
| DT_INTS16 = 18; // list of int16_t | |||
| DT_INTS32 = 19; // list of int32_t | |||
| DT_INTS64 = 20; // list of int64_t | |||
| DT_UINTS8 = 21; // list of uint8_t | |||
| DT_UINTS16 = 22; // list of uint16_t | |||
| DT_UINTS32 = 23; // list of uint32_t | |||
| DT_UINTS64 = 24; // list of uint64_t | |||
| DT_FLOATS16 = 25; // list of float16 | |||
| DT_FLOATS32 = 26; // list of float32 | |||
| DT_FLOATS64 = 27; // list of float64 | |||
| DT_STRINGS = 28; // list of string | |||
| DT_TENSORS = 29; // list of tensor | |||
| DT_GRAPHS = 30; // list of graph | |||
| DT_TUPLE = 31; // tuple | |||
| DT_LIST = 32; // list | |||
| DT_DICT = 33; // dictionary | |||
| // other types | |||
| DT_NONE = 34; // None | |||
| DT_SYM_INST = 35; // Symbolic Key Instance | |||
| // type related type | |||
| DT_BASE_INT = 36; // type generic int | |||
| DT_BASE_UINT = 37; // type generate unsigned int | |||
| DT_BASE_FLOAT = 38; // type generate float | |||
| DT_TYPE = 39; // type type | |||
| DT_ANYTHING = 40; // type anything | |||
| DT_REFKEY = 41; // type refkey | |||
| DT_REF = 42; // type ref | |||
| } | |||
| // Value definition for attribute value or parameter default value | |||
| message ValueProto { | |||
| // data type of value | |||
| optional DataType dtype = 1; // discriminator that indicates which field below is in use | |||
| // Exactly ONE of the following fields must be present for this version of the IR | |||
| optional bool bool_val = 2; // bool | |||
| optional int64 int_val = 3; // int | |||
| optional uint64 uint_val = 4; // uint | |||
| optional float float_val = 5; // float | |||
| optional double double_val = 6; // double | |||
| optional string str_val = 7; // string | |||
| optional TensorProto tensor_val = 8; // tensor value | |||
| optional GraphProto graph = 9; // graph | |||
| repeated bool bool_vals = 10; // list of bool | |||
| repeated int64 int_vals = 11; // list of int | |||
| repeated uint64 uint_vals = 12; // list of uint | |||
| repeated float float_vals = 13; // list of float | |||
| repeated double double_vals = 14; // list of double | |||
| repeated string str_vals = 15; // list of string | |||
| repeated TensorProto tensor_vals = 16; // list of tensor value | |||
| repeated GraphProto graphs = 17; // list of graph | |||
| // tuple or list | |||
| repeated ValueProto values = 18; // tuple, list of value | |||
| // dictionary | |||
| repeated NamedValueProto dict_val = 19; // dictionary info | |||
| // filed for type type | |||
| optional TypeProto type_val = 20; // type type info | |||
| } | |||
| message AttributeProto { | |||
| optional string name = 1; // attribute name | |||
| optional ValueProto value = 2; // attribute value | |||
| } | |||
| message NamedValueProto { | |||
| optional string key = 1; // attribute name | |||
| optional ValueProto value = 2; // attribute value | |||
| } | |||
| // Defines a tensor shape. | |||
| message TensorShapeProto { | |||
| // One dimension of the tensor. | |||
| message Dimension { | |||
| // Size of the tensor in that dimension. | |||
| // This value must be >= -1, but values of -1 are reserved for "unknown" | |||
| // shapes (values of -1 mean "unknown" dimension). | |||
| optional int64 size = 1; | |||
| // Optional name of the tensor dimension. | |||
| optional string name = 2; | |||
| }; | |||
| repeated Dimension dim = 1; | |||
| } | |||
| // Types for graph input(parameter) and output | |||
| message TypeProto { | |||
| message Tensor { | |||
| // This field MUST have a valid DataType value except DT_TENSOR | |||
| optional DataType elem_type = 1; | |||
| optional TensorShapeProto shape = 2; // for scalar, this field is not set | |||
| } | |||
| // tuple type | |||
| message Sequence { | |||
| // The type and optional shape of elements of the tuple. | |||
| repeated TypeProto elem_types = 1; | |||
| }; | |||
| // data type | |||
| optional DataType data_type = 1; | |||
| oneof value { | |||
| // The type of a tensor. | |||
| Tensor tensor_type = 2; | |||
| // The type of a tuple. | |||
| Sequence sequence_type = 3; | |||
| } | |||
| } | |||
| // Defines information on graph parameters, including the name, the type, and | |||
| // the default value of parameter if exists. | |||
| message ParameterProto { | |||
| optional string name = 1; // parameter name | |||
| optional TypeProto type = 2; // parameter type | |||
| optional ValueProto default_val = 3; // default value of parameter if exists | |||
| } | |||
| // Defines graph output information | |||
| message OutputProto { | |||
| optional string name = 1; // output node name | |||
| optional TypeProto type = 2; // output node type | |||
| } | |||
| // Define node input information | |||
| message InputProto { | |||
| enum EdgeType { | |||
| DATA_EDGE = 0; // data edge | |||
| CONTROL_EDGE = 1; // control edge | |||
| } | |||
| optional string name = 1; | |||
| optional EdgeType type = 2; | |||
| } | |||
| // Nodes | |||
| // | |||
| // Computation graphs are made up of a DAG of nodes, which represent what is | |||
| // commonly called a "layer" or "pipeline stage" in machine learning frameworks. | |||
| // | |||
| // For example, it can be a node of type "Conv" that takes in an image, a filter | |||
| // tensor and a bias tensor, and produces the convolved output. | |||
| message NodeProto { | |||
| repeated InputProto input = 1; // namespace Value | |||
| optional string name = 2; // namespace Value | |||
| // The symbolic identifier of the Operator to execute. | |||
| optional string op_type = 3; // namespace Operator | |||
| // The domain of the OperatorSet that specifies the operator named by op_type. | |||
| optional string scope = 4; // namespace Domain | |||
| // Additional named attributes. | |||
| repeated AttributeProto attribute = 5; | |||
| // Optional type info of this node | |||
| optional TypeProto output_type = 6; | |||
| // other fields for debug | |||
| optional uint64 output_i = 7; | |||
| // full name with scope | |||
| optional string full_name = 8; | |||
| } | |||
| // Models | |||
| // | |||
| // ModelProto is a top-level file/container format for bundling a ML model and | |||
| // associating its computation graph with metadata. | |||
| // | |||
| // The semantics of the model are described by the associated GraphProto. | |||
| message ModelProto { | |||
| // ir version | |||
| optional int64 ir_version = 1; | |||
| // Domain name of the model. | |||
| // We use reverse domain names as name space indicators. For example: | |||
| // `com.facebook.fair` or `com.microsoft.cognitiveservices` | |||
| // | |||
| // Together with `model_version` and GraphProto.name, this forms the unique identity of | |||
| // the graph. | |||
| optional string domain = 2; | |||
| // The version of the graph encoded. See Version enum below. | |||
| optional int64 model_version = 3; | |||
| // The parameterized graph that is evaluated to execute the model. | |||
| optional GraphProto graph = 4; | |||
| // metadata info of opeartors | |||
| optional OperatorSetProto metadata_operators = 5; | |||
| }; | |||
| message OperatorProto { | |||
| optional string name = 1; // used as key, must be distinct | |||
| optional bytes config = 2; // operator config info | |||
| optional bytes obj_info = 3; // operator related object info, e.g. content of operator binary or name | |||
| }; | |||
| message OperatorSetProto { | |||
| repeated OperatorProto operators = 1; | |||
| optional string summary = 2; // summary info of operators, e.g. file position of operators file | |||
| } | |||
| // Graphs | |||
| // | |||
| // A graph defines the computational logic of a model and is comprised of a parameterized | |||
| // list of nodes that form a directed acyclic graph based on their inputs and outputs. | |||
| // This is the equivalent of the "network" or "graph" in many deep learning | |||
| // frameworks. | |||
| message GraphProto { | |||
| // The nodes in the graph, sorted topologically. | |||
| repeated NodeProto node = 1; | |||
| // The name of the graph. | |||
| optional string name = 2; // namespace Graph | |||
| // The parameters(inputs) and outputs of the graph. | |||
| repeated ParameterProto parameters = 3; | |||
| repeated OutputProto outputs = 4; | |||
| // Constants used in this graph | |||
| repeated NamedValueProto const_vals = 5; | |||
| } | |||
| // Tensors | |||
| // | |||
| // A serialized tensor value. | |||
| message TensorProto { | |||
| // The node name of the tensor. | |||
| optional string node_name = 1; | |||
| // The slot of the tensor in its node. | |||
| optional string slot = 2; | |||
| // The serialized tensor content. | |||
| optional bytes tensor_content = 3; | |||
| // The shape of the tensor. | |||
| repeated int64 dims = 4; | |||
| // The data type of the tensor. | |||
| // This field MUST have a valid DataType value except DT_TENSOR | |||
| optional DataType data_type = 5; | |||
| // If the tensor content transferring is finished. | |||
| optional bool finished = 6; | |||
| // The iteration of the tensor. Supported: "prev" or leave empty. | |||
| optional string iter = 7; | |||
| // If the tensor name should be truncated. | |||
| optional bool truncate = 8; | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,289 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """This file is used to define the basic graph.""" | |||
| from collections import deque | |||
| from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph | |||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||
| from mindinsight.debugger.common.exceptions.exceptions import \ | |||
| DebuggerNodeNotInGraphError, DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from .node import NodeTree | |||
| class DebuggerGraph(MSGraph): | |||
| """The `DebuggerGraph` object provides interfaces to describe a debugger graph.""" | |||
| def __init__(self): | |||
| super(DebuggerGraph, self).__init__() | |||
| self._node_tree = None | |||
| def get_node_name_by_full_name(self, full_name): | |||
| """Get node name by full names.""" | |||
| inner_name = self._full_name_map_name.get(full_name, '') | |||
| if not inner_name: | |||
| log.warning("Node %s does not find the relative inner node name.", full_name) | |||
| return inner_name | |||
| def get_full_name_by_node_name(self, node_name): | |||
| """Get full name by node name for leaf nodes.""" | |||
| node = self._normal_node_map.get(node_name) | |||
| if not node: | |||
| log.warning("Node %s is not leaf node.", node_name) | |||
| return node.full_name if node else '' | |||
| def get_nodes(self, searched_node_list): | |||
| """ | |||
| Search node names by a given pattern. | |||
| Args: | |||
| searched_node_list (list[Node]): A list of leaf nodes that | |||
| matches the given search pattern. | |||
| Returns: | |||
| A list of dict including the searched nodes. | |||
| [{ | |||
| "name": "Default", | |||
| "type": "name_scope", | |||
| "nodes": [{ | |||
| "name": "Default/Conv2D1", | |||
| "type": "name_scope", | |||
| "nodes": [{ | |||
| ... | |||
| }] | |||
| }] | |||
| }, | |||
| { | |||
| "name": "Gradients", | |||
| "type": "name_scope", | |||
| "nodes": [{ | |||
| "name": "Gradients/Default", | |||
| "type": "name_scope", | |||
| "nodes": [{ | |||
| ... | |||
| }] | |||
| }] | |||
| """ | |||
| # save the node in the NodeTree | |||
| self._node_tree = NodeTree() | |||
| for node in searched_node_list: | |||
| self._build_node_tree(node.name, node.type) | |||
| # get the searched nodes in the NodeTree and reorganize them | |||
| searched_list = [] | |||
| self._traverse_node_tree(self._node_tree, searched_list) | |||
| return searched_list | |||
| def search_nodes_by_pattern(self, pattern): | |||
| """ | |||
| Search node names by a given pattern. | |||
| Args: | |||
| pattern (Union[str, None]): The pattern of the node to search, | |||
| if None, return all node names. | |||
| Returns: | |||
| list[(str, str)], a list of tuple (node name, node type). | |||
| """ | |||
| if pattern is not None: | |||
| pattern = pattern.lower() | |||
| searched_nodes = [ | |||
| node for name, node in self._leaf_nodes.items() | |||
| if pattern in name.lower() | |||
| ] | |||
| else: | |||
| searched_nodes = [node for name, node in self._leaf_nodes.items()] | |||
| return searched_nodes | |||
| def _build_node_tree(self, node_name, node_type): | |||
| """Build node tree.""" | |||
| scope_names = node_name.split('/') | |||
| cur_node = self._node_tree | |||
| for scope_name in scope_names[:-1]: | |||
| sub_node = cur_node.get(scope_name) | |||
| if not sub_node: | |||
| sub_node = cur_node.add(scope_name) | |||
| cur_node = sub_node | |||
| cur_node.add(scope_names[-1], node_type) | |||
| def _traverse_node_tree(self, cur_node, search_node_list): | |||
| """Traverse the watch nodes and update the total watched node list.""" | |||
| if not cur_node.get_children(): | |||
| return | |||
| for _, sub_node in cur_node.get_children(): | |||
| sub_nodes = [] | |||
| self._traverse_node_tree(sub_node, sub_nodes) | |||
| sub_node_dict = { | |||
| 'name': sub_node.node_name, | |||
| 'type': sub_node.node_type, | |||
| 'nodes': sub_nodes | |||
| } | |||
| search_node_list.append(sub_node_dict) | |||
| def get_node_type(self, node_name): | |||
| """ | |||
| Get the type of the node. | |||
| Args: | |||
| node_name (str): The full name of the node with its scope. | |||
| Returns: | |||
| A string, leaf or name_scope. | |||
| """ | |||
| if node_name and not self.exist_node(name=node_name): | |||
| raise DebuggerNodeNotInGraphError(node_name=node_name) | |||
| node = self._leaf_nodes.get(node_name) | |||
| if node is not None: | |||
| node_type = node.type | |||
| else: | |||
| node_type = NodeTypeEnum.NAME_SCOPE.value | |||
| return node_type | |||
| def get_tensor_history(self, node_name, depth=0): | |||
| """ | |||
| Get the tensor history of a specified node. | |||
| Args: | |||
| node_name (str): The debug name of the node. | |||
| depth (int): The number of layers the user wants to trace. Default is 0. | |||
| Returns: | |||
| list, a list of the traced tensors' name and node type, | |||
| arranged in order from leaf node to root node. | |||
| int, the number of output tensors. | |||
| """ | |||
| node = self._leaf_nodes.get(node_name) | |||
| tensor_history = self._get_tensor_infos_of_node(node) | |||
| cur_outputs_nums = len(tensor_history) | |||
| cur_depth = 0 | |||
| trace_list = deque([(node, cur_depth)]) | |||
| while trace_list: | |||
| cur_node, cur_depth = trace_list.popleft() | |||
| tensors_info = self._get_input_tensors_of_node(cur_node) | |||
| if tensors_info: | |||
| tensor_history.extend(tensors_info) | |||
| if cur_depth < depth: | |||
| for name in cur_node.input.keys(): | |||
| trace_list.append((self._leaf_nodes[name], cur_depth + 1)) | |||
| return tensor_history, cur_outputs_nums | |||
| @staticmethod | |||
| def _get_tensor_infos_of_node(cur_node, slot=None): | |||
| """Get tensors info of specified node.""" | |||
| tensors_info = [] | |||
| if slot is None: | |||
| slots = range(cur_node.output_nums) | |||
| elif slot >= 0: | |||
| slots = [slot] | |||
| else: | |||
| log.info("Skip get tensor info for %s:%s.", cur_node.name, slot) | |||
| return tensors_info | |||
| for num in slots: | |||
| tensor_info = { | |||
| 'name': cur_node.name + ':' + str(num), | |||
| 'full_name': cur_node.full_name + ':' + str(num), | |||
| 'node_type': cur_node.type | |||
| } | |||
| tensors_info.append(tensor_info) | |||
| return tensors_info | |||
| def _get_input_tensors_of_node(self, cur_node): | |||
| """Get input tensors of node.""" | |||
| tensors_info = [] | |||
| for name in cur_node.input.keys(): | |||
| node = self._leaf_nodes.get(name) | |||
| tensor_info = self._get_tensor_infos_of_node(node) | |||
| tensors_info.extend(tensor_info) | |||
| return tensors_info | |||
| def get_bfs_order(self): | |||
| """ | |||
| Traverse the graph in order of breath-first search. | |||
| Returns: | |||
| list, including the leaf nodes arranged in BFS order. | |||
| """ | |||
| root = self.get_default_root() | |||
| log.info('Randomly choose node %s as root to do BFS.', root.name) | |||
| bfs_order = [] | |||
| self.get_bfs_graph(root.name, bfs_order) | |||
| length = len(self._leaf_nodes.keys()) | |||
| # Find rest un-traversed nodes | |||
| for node_name, _ in self._leaf_nodes.items(): | |||
| if node_name not in bfs_order: | |||
| self.get_bfs_graph(node_name, bfs_order) | |||
| if len(bfs_order) != length: | |||
| log.error("The length of bfs and leaf nodes are not equal.") | |||
| msg = "Not all nodes are traversed!" | |||
| raise DebuggerParamValueError(msg) | |||
| return bfs_order | |||
| def get_bfs_graph(self, node_name, bfs_order): | |||
| """ | |||
| Traverse the graph in order of breath-first search. | |||
| Returns: | |||
| list, including the leaf nodes arranged in BFS order. | |||
| """ | |||
| temp_list = deque() | |||
| temp_list.append(node_name) | |||
| while temp_list: | |||
| node_name = temp_list.popleft() | |||
| node = self._leaf_nodes.get(node_name) | |||
| if not node: | |||
| log.warning('Cannot find node %s in graph. Ignored.', node_name) | |||
| continue | |||
| bfs_order.append(node_name) | |||
| if node.input: | |||
| for name in node.input.keys(): | |||
| if name not in temp_list and name not in bfs_order: | |||
| temp_list.append(name) | |||
| if node.output: | |||
| for name in node.output.keys(): | |||
| if name not in temp_list and name not in bfs_order: | |||
| temp_list.append(name) | |||
| def get_default_root(self): | |||
| """ | |||
| Get a node as default root for BFS in graph. Using the | |||
| leaf node with the smallest node id as the default root. | |||
| Returns: | |||
| str, the name of the default root. | |||
| """ | |||
| default_root = None | |||
| for _, item in self._leaf_nodes.items(): | |||
| if item.node_id == '1': | |||
| default_root = item | |||
| break | |||
| if default_root is None: | |||
| log.error("Abnormal graph. Invalid node for BFS.") | |||
| msg = 'Abnormal graph. Invalid node for BFS.' | |||
| raise DebuggerParamValueError(msg) | |||
| return default_root | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| This file is used to define the node of graph and associated base types. | |||
| """ | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| class NodeTree: | |||
| """A class for building a node tree.""" | |||
| def __init__(self, node_name='', node_type=None): | |||
| self.node_name = node_name | |||
| self._node_type = node_type | |||
| self._children = {} | |||
| @property | |||
| def node_type(self): | |||
| """The property of node type.""" | |||
| return self._node_type | |||
| @node_type.setter | |||
| def node_type(self, value): | |||
| """Set the node type.""" | |||
| self._node_type = value | |||
| def add(self, name, node_type=None): | |||
| """Add sub node.""" | |||
| sub_name = '/'.join([self.node_name, name]) if self.node_name else name | |||
| sub_node = NodeTree(sub_name, node_type) | |||
| self._children[name] = sub_node | |||
| return sub_node | |||
| def get(self, sub_name): | |||
| """Get sub node.""" | |||
| return self._children.get(sub_name) | |||
| def get_children(self): | |||
| """Get all childrens.""" | |||
| for name_scope, sub_node in self._children.items(): | |||
| yield name_scope, sub_node | |||
| def remove(self, sub_name): | |||
| """Remove sub node.""" | |||
| try: | |||
| self._children.pop(sub_name) | |||
| except KeyError as err: | |||
| log.error("Failed to find node %s. %s", sub_name, err) | |||
| raise DebuggerParamValueError("Failed to find node {}".format(sub_name)) | |||
| @@ -0,0 +1,233 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """The definition of tensor stream.""" | |||
| from abc import abstractmethod, ABC | |||
| import numpy as np | |||
| from mindinsight.utils.tensor import TensorUtils | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.common.utils import NUMPY_TYPE_MAP | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import DataType | |||
| class BaseTensor(ABC): | |||
| """Tensor data structure.""" | |||
| def __init__(self, step=0): | |||
| self._step = step | |||
| @property | |||
| @abstractmethod | |||
| def name(self): | |||
| """The property of tensor name.""" | |||
| @property | |||
| @abstractmethod | |||
| def dtype(self): | |||
| """The property of tensor dtype.""" | |||
| @property | |||
| @abstractmethod | |||
| def shape(self): | |||
| """The property of tensor shape.""" | |||
| @property | |||
| @abstractmethod | |||
| def value(self): | |||
| """The property of tensor shape.""" | |||
| @abstractmethod | |||
| def get_tensor_value_by_shape(self, shape=None): | |||
| """Get tensor value by shape.""" | |||
| def _to_dict(self): | |||
| """Get tensor info in dict format.""" | |||
| res = { | |||
| 'full_name': self.name, | |||
| 'step': self._step, | |||
| 'dtype': self.dtype, | |||
| 'shape': self.shape, | |||
| 'has_prev_step': False | |||
| } | |||
| return res | |||
| def get_basic_info(self): | |||
| """Return basic info about tensor info.""" | |||
| if not self.shape: | |||
| value = self.value | |||
| else: | |||
| value = 'click to view' | |||
| res = self._to_dict() | |||
| res['value'] = value | |||
| return res | |||
| def get_full_info(self, shape=None): | |||
| """Get tensor info with value.""" | |||
| res = self._to_dict() | |||
| value_info = self.get_tensor_serializable_value_by_shape(shape) | |||
| res.update(value_info) | |||
| return res | |||
| class OpTensor(BaseTensor): | |||
| """Tensor data structure for operator Node.""" | |||
| max_number_data_show_on_ui = 100000 | |||
| def __init__(self, tensor_proto, step=0): | |||
| # the type of tensor_proto is TensorProto | |||
| super(OpTensor, self).__init__(step) | |||
| self._tensor_proto = tensor_proto | |||
| self._value = self.generate_value(tensor_proto) | |||
| @property | |||
| def name(self): | |||
| """The property of tensor name.""" | |||
| node_name = self._tensor_proto.node_name | |||
| slot = self._tensor_proto.slot | |||
| return ':'.join([node_name, slot]) | |||
| @property | |||
| def dtype(self): | |||
| """The property of tensor dtype.""" | |||
| tensor_type = DataType.Name(self._tensor_proto.data_type) | |||
| return tensor_type | |||
| @property | |||
| def shape(self): | |||
| """The property of tensor shape.""" | |||
| return list(self._tensor_proto.dims) | |||
| @property | |||
| def value(self): | |||
| """The property of tensor value.""" | |||
| tensor_value = None | |||
| if self._value is not None: | |||
| tensor_value = self._value.tolist() | |||
| return tensor_value | |||
| @property | |||
| def numpy_value(self): | |||
| """The property of tensor value in numpy type.""" | |||
| return self._value | |||
| def generate_value(self, tensor_proto): | |||
| """Generate tensor value from proto.""" | |||
| tensor_value = None | |||
| if tensor_proto.tensor_content: | |||
| tensor_value = tensor_proto.tensor_content | |||
| np_type = NUMPY_TYPE_MAP.get(self.dtype) | |||
| tensor_value = np.frombuffer(tensor_value, dtype=np_type) | |||
| tensor_value = tensor_value.reshape(self.shape) | |||
| return tensor_value | |||
| def get_tensor_serializable_value_by_shape(self, shape=None): | |||
| """ | |||
| Get tensor value info by shape. | |||
| Args: | |||
| shape (tuple): The specified range of tensor value. | |||
| Returns: | |||
| dict, the specified tensor value and value statistics. | |||
| """ | |||
| tensor_value = self.get_tensor_value_by_shape(shape) | |||
| res = {} | |||
| if isinstance(tensor_value, np.ndarray): | |||
| statistics = TensorUtils.get_statistics_from_tensor(tensor_value) | |||
| res['statistics'] = TensorUtils.get_statistics_dict(statistics) | |||
| res['value'] = tensor_value.tolist() | |||
| return res | |||
| return res | |||
| def get_tensor_value_by_shape(self, shape=None): | |||
| """ | |||
| Get tensor value by shape. | |||
| Args: | |||
| shape (tuple): The specified shape. | |||
| Returns: | |||
| Union[None, str, numpy.ndarray], the sub-tensor. | |||
| """ | |||
| if self._value is None: | |||
| log.warning("%s has no value yet.", self.name) | |||
| return None | |||
| if shape is None or not isinstance(shape, tuple): | |||
| log.info("Get the whole tensor value with shape is %s", shape) | |||
| return self._value | |||
| if len(shape) != len(self.shape): | |||
| log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) | |||
| raise DebuggerParamValueError("Invalid shape. Shape unmatched.") | |||
| try: | |||
| value = self._value[shape] | |||
| except IndexError as err: | |||
| log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) | |||
| log.exception(err) | |||
| raise DebuggerParamValueError("Invalid shape. Shape unmatched.") | |||
| if isinstance(value, np.ndarray): | |||
| if value.size > self.max_number_data_show_on_ui: | |||
| value = "Too large to show." | |||
| log.info("The tensor size is %s, which is too large to show on UI.") | |||
| else: | |||
| value = np.asarray(value) | |||
| return value | |||
| class ConstTensor(BaseTensor): | |||
| """Tensor data structure for Const Node.""" | |||
| def __init__(self, const_proto): | |||
| # the type of const_proto is NamedValueProto | |||
| super(ConstTensor, self).__init__() | |||
| self._const_proto = const_proto | |||
| def set_step(self, step): | |||
| """Set step value.""" | |||
| self._step = step | |||
| @property | |||
| def name(self): | |||
| """The property of tensor name.""" | |||
| return self._const_proto.key + ':0' | |||
| @property | |||
| def dtype(self): | |||
| """The property of tensor dtype.""" | |||
| return DataType.Name(self._const_proto.value.dtype) | |||
| @property | |||
| def shape(self): | |||
| """The property of tensor shape.""" | |||
| return [] | |||
| @property | |||
| def value(self): | |||
| """The property of tensor shape.""" | |||
| fields = self._const_proto.value.ListFields() | |||
| if len(fields) != 2: | |||
| log.warning("Unexpected const proto <%s>.\n Please check offline.", self._const_proto) | |||
| for field_name, field_value in fields: | |||
| if field_name != 'dtype': | |||
| return field_value | |||
| return None | |||
| def get_tensor_value_by_shape(self, shape=None): | |||
| """Get tensor info with value.""" | |||
| if shape is not None: | |||
| log.warning("Invalid shape for const value.") | |||
| return self.value | |||
| @@ -0,0 +1,300 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the watchpoint stream.""" | |||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition | |||
| WATCHPOINT_CONDITION_MAPPING = { | |||
| 'INF': WatchCondition.Condition.inf, | |||
| 'NAN': WatchCondition.Condition.nan, | |||
| 'OVERFLOW': WatchCondition.Condition.overflow, | |||
| 'MAX_GT': WatchCondition.Condition.max_gt, | |||
| 'MAX_LT': WatchCondition.Condition.max_lt, | |||
| 'MIN_GT': WatchCondition.Condition.min_gt, | |||
| 'MIN_LT': WatchCondition.Condition.min_lt, | |||
| 'MAX_MIN_GT': WatchCondition.Condition.max_min_gt, | |||
| 'MAX_MIN_LT': WatchCondition.Condition.max_min_lt, | |||
| 'MEAN_GT': WatchCondition.Condition.mean_gt, | |||
| 'MEAN_LT': WatchCondition.Condition.mean_lt | |||
| } | |||
| class WatchNodeTree: | |||
| """The WatchNode Node Structure.""" | |||
| NOT_WATCH = 0 # the scope node and the nodes below are not watched | |||
| PARTIAL_WATCH = 1 # at least one node under the scope node is not watched | |||
| TOTAL_WATCH = 2 # the scope node and the nodes below are all watched | |||
| def __init__(self, node_name='', node_type=None, full_name='', watch_status=1): | |||
| self._node_name = node_name | |||
| self._full_name = full_name | |||
| self._node_type = self._translate_node_type(node_type) | |||
| self._watch_status = watch_status | |||
| self._children = {} | |||
| @property | |||
| def node_name(self): | |||
| """The property of node name.""" | |||
| return self._node_name | |||
| @property | |||
| def full_name(self): | |||
| """The property of node name.""" | |||
| return self._full_name | |||
| @property | |||
| def node_type(self): | |||
| """The property of node type.""" | |||
| return self._node_type | |||
| @node_type.setter | |||
| def node_type(self, value): | |||
| """Set the node type.""" | |||
| self._node_type = self._translate_node_type(value) | |||
| @property | |||
| def watch_status(self): | |||
| """The property of watch status about current node.""" | |||
| return self._watch_status | |||
| def enable_watch_status(self): | |||
| """The property of watch status about current node.""" | |||
| self._watch_status = WatchNodeTree.TOTAL_WATCH | |||
| @staticmethod | |||
| def _translate_node_type(node_type): | |||
| """Translate node type to watch node type.""" | |||
| if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value or \ | |||
| node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | |||
| return 'scope' | |||
| return 'leaf' | |||
| def get(self, sub_name): | |||
| """Get sub node.""" | |||
| return self._children.get(sub_name) | |||
| def get_children(self): | |||
| """Get all childrens.""" | |||
| for name_scope, sub_watch_node in self._children.items(): | |||
| yield name_scope, sub_watch_node | |||
| def add_node(self, node_name, node_type, full_name=''): | |||
| """ | |||
| Add watch node to watch node tree. | |||
| Args: | |||
| node_name (str): The node name. | |||
| node_type (str): The node type. | |||
| full_name (str): The full name of node. | |||
| """ | |||
| log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) | |||
| scope_names = node_name.split('/', 1) | |||
| if len(scope_names) == 1: | |||
| if not self.get(node_name): | |||
| self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) | |||
| else: | |||
| self.get(node_name).enable_watch_status() | |||
| return | |||
| scope_name, sub_names = scope_names | |||
| sub_tree = self.get(scope_name) | |||
| if not sub_tree: | |||
| sub_tree = self.add(scope_name, watch_status=1) | |||
| sub_tree.add_node(sub_names, node_type, full_name) | |||
| def add(self, name, node_type=None, full_name='', watch_status=1): | |||
| """Add sub WatchPointTree.""" | |||
| sub_name = '/'.join([self._node_name, name]) if self._node_name else name | |||
| sub_tree = WatchNodeTree(sub_name, node_type, full_name, watch_status) | |||
| self._children[name] = sub_tree | |||
| return sub_tree | |||
| def remove_node(self, node_name): | |||
| """Remove sub node from current tree.""" | |||
| log.debug("Remove %s", node_name) | |||
| scope_names = node_name.split('/', 1) | |||
| sub_tree_name = scope_names[0] | |||
| sub_tree = self._children.get(sub_tree_name) | |||
| if not sub_tree: | |||
| log.error("Failed to find node %s in WatchNodeTree.", sub_tree_name) | |||
| raise DebuggerParamValueError("Failed to find node {}".format(sub_tree_name)) | |||
| if len(scope_names) > 1: | |||
| sub_tree.remove_node(scope_names[1]) | |||
| if sub_tree.watch_status == WatchNodeTree.NOT_WATCH or len(scope_names) == 1: | |||
| self._children.pop(sub_tree_name) | |||
| self._watch_status = WatchNodeTree.PARTIAL_WATCH if self._children else \ | |||
| WatchNodeTree.NOT_WATCH | |||
| class Watchpoint: | |||
| """ | |||
| The class of watchpoint stream. | |||
| Args: | |||
| watchpoint_id (int): The id of Watchpoint. | |||
| watch_condition (dict): The condition of Watchpoint. | |||
| - condition (str): Accept `INF` or `NAN`. | |||
| - param (list[float]): Not defined yet. | |||
| """ | |||
| def __init__(self, watchpoint_id, watch_condition): | |||
| self._id = watchpoint_id | |||
| self._condition = watch_condition | |||
| self._watch_node = WatchNodeTree() | |||
| @property | |||
| def watchpoint_id(self): | |||
| """The property of watchpoint id.""" | |||
| return self._id | |||
| @property | |||
| def nodes(self): | |||
| """The property of watch nodes.""" | |||
| return self._watch_node | |||
| @property | |||
| def condition(self): | |||
| """The property of watch condition.""" | |||
| return self._condition | |||
| def copy_nodes_from(self, other_watchpoint): | |||
| """ | |||
| Copy nodes from other watchpoint. | |||
| Args: | |||
| other_watchpoint (Watchpoint): Other watchpoint. | |||
| """ | |||
| self._watch_node = other_watchpoint.nodes | |||
| def add_nodes(self, nodes): | |||
| """Add node into watchcpoint.""" | |||
| if not nodes: | |||
| log.warning("Add empty nodes.") | |||
| return | |||
| if not isinstance(nodes, list): | |||
| nodes = [nodes] | |||
| for node in nodes: | |||
| self._watch_node.add_node(node.name, node.type, node.full_name) | |||
| def remove_nodes(self, nodes): | |||
| """Remove nodes from watchpoint.""" | |||
| if not nodes: | |||
| return | |||
| if not isinstance(nodes, list): | |||
| nodes = [nodes] | |||
| for node in nodes: | |||
| node_name = node.split(':')[0] | |||
| self._watch_node.remove_node(node_name) | |||
| def get_node_status(self, node_name, node_type, full_name): | |||
| """Judge if the node is in watch nodes.""" | |||
| scope_names = node_name.split('/') | |||
| cur_node = self._watch_node | |||
| status = 1 | |||
| for scope_name in scope_names: | |||
| cur_node = cur_node.get(scope_name) | |||
| if cur_node is None: | |||
| status = WatchNodeTree.NOT_WATCH | |||
| break | |||
| if cur_node.watch_status == WatchNodeTree.TOTAL_WATCH: | |||
| status = WatchNodeTree.TOTAL_WATCH | |||
| break | |||
| if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name: | |||
| self._watch_node.add_node(node_name, node_type, full_name) | |||
| return status | |||
| def get_watch_node(self, cur_watch_node, watch_node_list): | |||
| """ | |||
| Traverse the watch nodes and add total watched node list to `watch_node_list`. | |||
| Args: | |||
| cur_watch_node (WatchNodeTree): The current watch node. | |||
| watch_node_list (list[WatchNodeTree]): The list of total watched node. | |||
| """ | |||
| if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH: | |||
| watch_node_list.append(cur_watch_node) | |||
| return | |||
| for _, watch_node in cur_watch_node.get_children(): | |||
| self.get_watch_node(watch_node, watch_node_list) | |||
| def get_set_cmd(self): | |||
| """Return the watchpoint in proto format.""" | |||
| # get watch nodes. | |||
| watch_nodes = [] | |||
| self.get_watch_node(self._watch_node, watch_nodes) | |||
| # construct SetCMD | |||
| set_cmd = SetCMD() | |||
| set_cmd.id = self._id | |||
| set_cmd.delete = False | |||
| set_cmd.watch_condition.condition = WATCHPOINT_CONDITION_MAPPING.get( | |||
| self._condition.get('condition')) | |||
| if self._condition.get('param'): | |||
| # at most one param is provided | |||
| set_cmd.watch_condition.value = self._condition.get('param') | |||
| for watch_node in watch_nodes: | |||
| event_node = set_cmd.watch_nodes.add() | |||
| event_node.node_name = watch_node.full_name | |||
| event_node.node_type = watch_node.node_type | |||
| return set_cmd | |||
| def get_watch_condition_info(self): | |||
| """Get watch condition info.""" | |||
| watchpoint_info = { | |||
| 'id': self._id, | |||
| 'watch_condition': self._condition | |||
| } | |||
| return watchpoint_info | |||
| class WatchpointHit: | |||
| """The watchpoint hit structure.""" | |||
| def __init__(self, tensor_proto, watchpoint, node_name): | |||
| self._node_name = node_name | |||
| self._full_name = tensor_proto.node_name | |||
| self._slot = tensor_proto.slot | |||
| self._watchpoint = watchpoint | |||
| @property | |||
| def tensor_full_name(self): | |||
| """The property of tensor_name.""" | |||
| tensor_name = ':'.join([self._full_name, self._slot]) | |||
| return tensor_name | |||
| @property | |||
| def tensor_name(self): | |||
| """The property of tensor_name.""" | |||
| return ':'.join([self._node_name, self._slot]) | |||
| @property | |||
| def watchpoint(self): | |||
| """The property of watchpoint.""" | |||
| watchpoint = self._watchpoint.get_watch_condition_info() | |||
| return watchpoint | |||
| def __eq__(self, other): | |||
| """Define the equal condition.""" | |||
| flag = self.tensor_full_name == other.tensor_full_name and self.watchpoint == other.watchpoint | |||
| return flag | |||
| @@ -0,0 +1,23 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Import the streams handlers.""" | |||
| from .event_handler import EventHandler | |||
| from .metadata_handler import MetadataHandler | |||
| from .graph_handler import GraphHandler | |||
| from .tensor_handler import TensorHandler | |||
| from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler | |||
| __all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', | |||
| 'WatchpointHandler', 'WatchpointHitHandler'] | |||
| @@ -0,0 +1,34 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the stream handler base.""" | |||
| from abc import abstractmethod | |||
| class StreamHandlerBase: | |||
| """The stream handler base.""" | |||
| @abstractmethod | |||
| def put(self, value): | |||
| """Abstract method of set data.""" | |||
| return NotImplementedError | |||
| @abstractmethod | |||
| def get(self, filter_condition): | |||
| """Abstract method of get data.""" | |||
| return NotImplementedError | |||
| def clean(self): | |||
| """Clean cache.""" | |||
| self.__init__() | |||
| @@ -0,0 +1,159 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the message handler.""" | |||
| import uuid | |||
| from queue import Queue, Empty | |||
| from threading import Lock | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| class EventHandler(StreamHandlerBase): | |||
| """Message Handler.""" | |||
| max_limit = 1000 # the max number of items in cache | |||
| def __init__(self): | |||
| self._prev_flag = str(uuid.uuid4()) | |||
| self._cur_flag = str(uuid.uuid4()) | |||
| self._next_idx = 0 | |||
| self._event_cache = [None] * self.max_limit | |||
| self._pending_requests = {} | |||
| self._lock = Lock() | |||
| @property | |||
| def next_pos(self): | |||
| """The next pos to be updated in cache.""" | |||
| return ':'.join([self._cur_flag, str(self._next_idx)]) | |||
| def has_pos(self, pos): | |||
| """Get the event according to pos.""" | |||
| cur_flag, cur_idx = self._parse_pos(pos) | |||
| event = self._event_cache[cur_idx] | |||
| if event is not None: | |||
| if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \ | |||
| (cur_flag == self._prev_flag and cur_idx >= self._next_idx): | |||
| return event | |||
| return None | |||
| def clean(self): | |||
| """Clean event cache.""" | |||
| with self._lock: | |||
| self._prev_flag = str(uuid.uuid4()) | |||
| self._cur_flag = str(uuid.uuid4()) | |||
| self._next_idx = 0 | |||
| self._event_cache = [None] * self.max_limit | |||
| value = {'metadata': {'pos': '0'}} | |||
| self.clean_pending_requests(value) | |||
| log.debug("Clean event cache.") | |||
| def put(self, value): | |||
| """ | |||
| Put value into event_cache. | |||
| Args: | |||
| value (dict): The event to be put into cache. | |||
| """ | |||
| if not isinstance(value, dict): | |||
| log.error("Dict type required when put event message.") | |||
| raise DebuggerParamValueError("Dict type required when put event message.") | |||
| with self._lock: | |||
| log.debug("Put the %d-th message into queue. \n %d requests is waiting.", | |||
| self._next_idx, len(self._pending_requests)) | |||
| cur_pos = self._next_idx | |||
| # update next pos | |||
| self._next_idx += 1 | |||
| if self._next_idx >= self.max_limit: | |||
| self._next_idx = 0 | |||
| self._prev_flag = self._cur_flag | |||
| self._cur_flag = str(uuid.uuid4()) | |||
| # set next pos | |||
| if not value.get('metadata'): | |||
| value['metadata'] = {} | |||
| value['metadata']['pos'] = self.next_pos | |||
| self._event_cache[cur_pos] = value | |||
| # feed the value for pending requests | |||
| self.clean_pending_requests(value) | |||
| def clean_pending_requests(self, value): | |||
| """Clean pending requests.""" | |||
| for _, request in self._pending_requests.items(): | |||
| request.put(value) | |||
| self._pending_requests = {} | |||
| def get(self, filter_condition=None): | |||
| """ | |||
| Get the pos-th value from event_cache according to filter_condition. | |||
| Args: | |||
| filter_condition (str): The index of event in cache. Default: None. | |||
| Returns: | |||
| object, the pos-th event. | |||
| """ | |||
| flag, idx = self._parse_pos(filter_condition) | |||
| cur_id = str(uuid.uuid4()) | |||
| with self._lock: | |||
| # reset the pos after the cache is re-initialized. | |||
| if not flag or flag not in [self._cur_flag, self._prev_flag]: | |||
| idx = 0 | |||
| # get event from cache immediately | |||
| if idx != self._next_idx and self._event_cache[idx]: | |||
| return self._event_cache[idx] | |||
| # wait for the event | |||
| cur_queue = Queue(maxsize=1) | |||
| self._pending_requests[cur_id] = cur_queue | |||
| # block until event has been received | |||
| event = self._wait_for_event(cur_id, cur_queue, filter_condition) | |||
| return event | |||
| def _parse_pos(self, pos): | |||
| """Get next pos according to input position.""" | |||
| elements = pos.split(':') | |||
| try: | |||
| idx = int(elements[-1]) | |||
| except ValueError: | |||
| log.error("Invalid index. The index in pos should be digit but get pos:%s", pos) | |||
| raise DebuggerParamValueError("Invalid pos.") | |||
| if idx < 0 or idx >= self.max_limit: | |||
| log.error("Invalid index. The index in pos should between [0, %d)", self.max_limit) | |||
| raise DebuggerParamValueError(f"Invalid pos. {idx}") | |||
| flag = elements[0] if len(elements) == 2 else '' | |||
| return flag, idx | |||
| def _wait_for_event(self, cur_id, cur_queue, pos): | |||
| """Wait for the pos-th event.""" | |||
| try: | |||
| # set the timeout to 25 seconds which is less the the timeout limit from UI | |||
| event = cur_queue.get(timeout=25) | |||
| except Empty: | |||
| event = None | |||
| if event is None: | |||
| with self._lock: | |||
| if self._pending_requests.get(cur_id): | |||
| self._pending_requests.pop(cur_id) | |||
| log.debug("Clean timeout request. Left pending requests: %d", | |||
| len(self._pending_requests)) | |||
| event = {'metadata': {'pos': pos}} | |||
| return event | |||
| @@ -0,0 +1,314 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the graph stream handler.""" | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerNodeNotInGraphError, DebuggerGraphNotExistError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.stream_cache.debugger_graph import DebuggerGraph | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| class GraphHandler(StreamHandlerBase): | |||
| """Metadata Handler.""" | |||
| def __init__(self): | |||
| self._graph_proto = None | |||
| self._graph = None | |||
| self._searched_node_list = [] | |||
| self.bfs_order = [] | |||
| @property | |||
| def graph(self): | |||
| """The property of graph.""" | |||
| return self._graph_proto | |||
| def put(self, value): | |||
| """ | |||
| Put value into graph cache. Called by grpc server. | |||
| Args: | |||
| value (GraphProto): The Graph proto message. | |||
| """ | |||
| self._graph_proto = value | |||
| log.info("Put graph into cache.") | |||
| # build graph | |||
| graph = DebuggerGraph() | |||
| graph.build_graph(value) | |||
| self._graph = graph | |||
| self.bfs_order = self._graph.get_bfs_order() | |||
| def get(self, filter_condition=None): | |||
| """ | |||
| Get the graph of specific node. | |||
| Args: | |||
| filter_condition (dict): | |||
| - name (str): The full debug node name. | |||
| - single_node (bool): If True, return the graph from root | |||
| to the specific node; else, return the sublayer of the | |||
| graph. Default: False. | |||
| Returns: | |||
| dict, the metadata. | |||
| """ | |||
| try: | |||
| self._graph_exists() | |||
| except DebuggerGraphNotExistError: | |||
| log.warning('The graph is empty. To view a graph, ' | |||
| 'please start the training script first.') | |||
| return {'graph': {}} | |||
| if filter_condition is None: | |||
| filter_condition = {} | |||
| single_node = filter_condition.get('single_node', False) | |||
| name = filter_condition.get('name') | |||
| graph = {} | |||
| if single_node is True: | |||
| nodes = self.get_single_node(name) | |||
| else: | |||
| nodes = self.list_nodes(name) | |||
| graph.update(nodes) | |||
| return {'graph': graph} | |||
| def get_tensor_history(self, node_name, depth=0): | |||
| """ | |||
| Get the tensor history of a specified node. | |||
| Args: | |||
| node_name (str): The debug name of the node. | |||
| depth (int): The number of layers the user | |||
| wants to trace. Default is 0. | |||
| Returns: | |||
| dict, basic tensor history, only including tensor name and tensor type and node type. | |||
| """ | |||
| self._graph_exists() | |||
| if not self._graph.exist_node(node_name): | |||
| raise DebuggerNodeNotInGraphError(node_name) | |||
| tensor_history, cur_outputs_nums = self._graph.get_tensor_history( | |||
| node_name, depth | |||
| ) | |||
| # add the tensor type for tensor history | |||
| self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output') | |||
| self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input') | |||
| log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name) | |||
| return {'tensor_history': tensor_history} | |||
| @staticmethod | |||
| def _update_tensor_history(tensor_history, tensor_type): | |||
| """ | |||
| Add tensor source type for tensor history. | |||
| Args: | |||
| tensor_history (list[dict]): Tensor history from Graph stream. Each element has two | |||
| keys: `node_type` and `name`. `node_type` refers to the type of the node which | |||
| the tensor come from. `name` refers to the tensor name. | |||
| tensor_type (str): The source type of the tensor. `input` or `output`. | |||
| """ | |||
| for single_tensor_info in tensor_history: | |||
| single_tensor_info['type'] = tensor_type | |||
| def search_nodes(self, pattern): | |||
| """ | |||
| Search nodes by given pattern. | |||
| Args: | |||
| pattern (Union[str, None]): The pattern of the node to search, | |||
| if None, return all node names. | |||
| Returns: | |||
| dict, the searched node. | |||
| """ | |||
| self._graph_exists() | |||
| self._searched_node_list = self._graph.search_nodes_by_pattern(pattern) | |||
| nodes = self._graph.get_nodes(self._searched_node_list) | |||
| return {'nodes': nodes} | |||
| def get_node_names(self, pattern=None): | |||
| """Get graph nodes according to pattern.""" | |||
| return self._graph.search_nodes_by_pattern(pattern) | |||
| def get_searched_node_list(self): | |||
| """Get searched node list.""" | |||
| return self._searched_node_list | |||
| def get_node_type(self, node_name): | |||
| """ | |||
| Get the type of the specified node. | |||
| Args: | |||
| node_name (str): The debug name of the node. | |||
| Returns: | |||
| A string of the node type, name_scope or leaf. | |||
| """ | |||
| self._graph_exists() | |||
| node_type = self._graph.get_node_type(node_name) | |||
| return node_type | |||
| def get_full_name(self, node_name): | |||
| """Get full name according to ui node name.""" | |||
| full_name = self._graph.get_full_name_by_node_name(node_name) if node_name else '' | |||
| return full_name | |||
| def get_node_name_by_full_name(self, full_name): | |||
| """Get UI node name by full name.""" | |||
| if self._graph: | |||
| node_name = self._graph.get_node_name_by_full_name(full_name) | |||
| else: | |||
| node_name = '' | |||
| log.info("No graph received yet.") | |||
| return node_name | |||
| def list_nodes(self, scope): | |||
| """ | |||
| Get the nodes of every layer in graph. | |||
| Args: | |||
| scope (str): The name of a scope. | |||
| Returns: | |||
| TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': [<Node object>]}. | |||
| example: | |||
| { | |||
| "nodes" : [ | |||
| { | |||
| "attr" : | |||
| { | |||
| "index" : "i: 0\n" | |||
| }, | |||
| "input" : {}, | |||
| "name" : "input_tensor", | |||
| "output" : | |||
| { | |||
| "Default/TensorAdd-op17" : | |||
| { | |||
| "edge_type" : "data", | |||
| "scope" : "name_scope", | |||
| "shape" : [1, 16, 128, 128] | |||
| } | |||
| }, | |||
| "output_i" : -1, | |||
| "proxy_input" : {}, | |||
| "proxy_output" : {}, | |||
| "independent_layout" : False, | |||
| "subnode_count" : 0, | |||
| "type" : "Data" | |||
| } | |||
| ] | |||
| } | |||
| """ | |||
| if scope and not self._graph.exist_node(scope): | |||
| raise DebuggerNodeNotInGraphError(node_name=scope) | |||
| nodes = self._graph.list_node_by_scope(scope=scope) | |||
| return {'nodes': nodes} | |||
| def get_node_by_bfs_order(self, node_name=None, ascend=True): | |||
| """ | |||
| Traverse the graph in order of breath-first search by given node. | |||
| Args: | |||
| node_name (str): The node name which will be regarded | |||
| as the start node in graph. | |||
| ascend (bool): If True, traverse the input nodes; | |||
| If False, traverse the output nodes. Default is True. | |||
| Returns: | |||
| dict, including the searched node and its tensor value. | |||
| """ | |||
| self._graph_exists() | |||
| bfs_order = self.bfs_order | |||
| length = len(bfs_order) | |||
| if not bfs_order: | |||
| log.error('Cannot get the BFS order of the graph!') | |||
| msg = 'Cannot get the BFS order of the graph!' | |||
| raise DebuggerParamValueError(msg) | |||
| if node_name is None: | |||
| if ascend is False: | |||
| next_node = bfs_order[-1] | |||
| else: | |||
| next_node = bfs_order[0] | |||
| else: | |||
| try: | |||
| index = bfs_order.index(node_name) | |||
| log.debug("The index of the node in BFS list is: %d", index) | |||
| except ValueError as err: | |||
| log.error('Cannot find the node: %s. Please check ' | |||
| 'the node name: %s', node_name, err) | |||
| msg = f'Cannot find the node: {node_name}. ' \ | |||
| f'Please check the node name {err}.' | |||
| raise DebuggerParamValueError(msg) | |||
| next_node = self.get_next_node_in_bfs(index, length, ascend) | |||
| return next_node | |||
| def get_next_node_in_bfs(self, index, length, ascend): | |||
| """ | |||
| Get the next node in bfs order. | |||
| Args: | |||
| index (int): The current index. | |||
| length (int): The number of all leaf nodes. | |||
| ascend (bool): Whether get the node in ascend order or not. | |||
| Returns: | |||
| Union[None, dict], the next node object in dict type or None. | |||
| """ | |||
| next_node = None | |||
| if 0 <= index < length: | |||
| if ascend is True and index < length - 1: | |||
| next_node = self.bfs_order[index + 1] | |||
| elif ascend is False and index > 0: | |||
| next_node = self.bfs_order[index - 1] | |||
| return next_node | |||
| def get_single_node(self, name): | |||
| """ | |||
| Search node, and return every layer nodes until this node. | |||
| Args: | |||
| name (str): The name of node. | |||
| Returns: | |||
| dict, every layer nodes until this node. | |||
| """ | |||
| nodes = self._graph.search_single_node(name) | |||
| return nodes | |||
| def _graph_exists(self): | |||
| """ | |||
| Check if the graph has been loaded in the debugger cache. | |||
| Raises: | |||
| DebuggerGraphNotExistError: If the graph does not exist. | |||
| """ | |||
| if self._graph is None: | |||
| log.error('The graph does not exist. Please start the ' | |||
| 'training script and try again.') | |||
| raise DebuggerGraphNotExistError | |||
| @@ -0,0 +1,131 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the metadata stream handler.""" | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.common.utils import ServerStatus | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| class MetadataHandler(StreamHandlerBase): | |||
| """Metadata Handler.""" | |||
| def __init__(self): | |||
| self._state = ServerStatus.PENDING | |||
| self._device_name = "" | |||
| self._step = 0 | |||
| self._client_ip = "" | |||
| self._cur_node_name = "" | |||
| self._cur_full_name = "" | |||
| self._backend = "" | |||
| @property | |||
| def device_name(self): | |||
| """The property of device name.""" | |||
| return self._device_name | |||
| @property | |||
| def step(self): | |||
| """The property of current step.""" | |||
| return self._step | |||
| @property | |||
| def node_name(self): | |||
| """The property of current node name.""" | |||
| return self._cur_node_name | |||
| @node_name.setter | |||
| def node_name(self, node_name): | |||
| """The property of current node name.""" | |||
| self._cur_node_name = node_name | |||
| @property | |||
| def full_name(self): | |||
| """The property of current node name.""" | |||
| return self._cur_full_name | |||
| @property | |||
| def backend(self): | |||
| """The property of current backend.""" | |||
| return self._backend | |||
| @property | |||
| def state(self): | |||
| """The property of state.""" | |||
| return self._state.value | |||
| @state.setter | |||
| def state(self, value): | |||
| """ | |||
| Set the property of state. | |||
| Args: | |||
| value (str): The new state. | |||
| """ | |||
| self._state = ServerStatus(value) | |||
| @property | |||
| def client_ip(self): | |||
| """The property of client ip.""" | |||
| return self._client_ip | |||
| @client_ip.setter | |||
| def client_ip(self, value): | |||
| """ | |||
| Set the property of client ip. | |||
| Args: | |||
| value (str): The new ip. | |||
| """ | |||
| self._client_ip = str(value) | |||
| def put(self, value): | |||
| """ | |||
| Put value into metadata cache. Called by grpc server. | |||
| Args: | |||
| value (MetadataProto): The Metadata proto message. | |||
| """ | |||
| self._device_name = value.device_name.split(':')[0] | |||
| self._step = value.cur_step | |||
| self._cur_full_name = value.cur_node | |||
| self._backend = value.backend if value.backend else "Ascend" | |||
| log.debug("Put metadata into cache at the %d-th step.", self._step) | |||
| def get(self, filter_condition=None): | |||
| """ | |||
| Get updated value. Called by main server. | |||
| Args: | |||
| filter_condition (str): The filter property. | |||
| Returns: | |||
| dict, the metadata. | |||
| """ | |||
| metadata = {} | |||
| if filter_condition is None: | |||
| metadata = { | |||
| 'state': self.state, | |||
| 'step': self.step, | |||
| 'device_name': self.device_name, | |||
| 'pos': '0', | |||
| 'ip': self.client_ip, | |||
| 'node_name': self.node_name, | |||
| 'backend': self.backend | |||
| } | |||
| else: | |||
| metadata[filter_condition] = getattr(self, filter_condition) if \ | |||
| hasattr(self, filter_condition) else '' | |||
| return {'metadata': metadata} | |||
| @@ -0,0 +1,298 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the tensor stream handler.""" | |||
| import numpy as np | |||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import DataType | |||
| from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| from mindinsight.utils.tensor import TensorUtils | |||
| class TensorHandler(StreamHandlerBase): | |||
| """Metadata Handler.""" | |||
| def __init__(self): | |||
| self._const_vals = {} | |||
| self._tensors = {} | |||
| self._cur_step = 0 | |||
| def put(self, value): | |||
| """ | |||
| Put value into tensor cache. Called by grpc server. | |||
| Args: | |||
| value (dict): The Tensor proto message. | |||
| - step (int): The current step of tensor. | |||
| - tensor_protos (list[TensorProto]): The tensor proto. | |||
| """ | |||
| tensor_protos = value.get('tensor_protos') | |||
| merged_tensor = self._get_merged_tensor(tensor_protos) | |||
| step = value.get('step', 0) | |||
| if merged_tensor.iter and step > 0: | |||
| log.debug("Received previous tensor.") | |||
| step -= 1 | |||
| tensor = OpTensor(merged_tensor, step) | |||
| self._put_tensor_into_cache(tensor, step) | |||
| log.debug("Put tensor %s of step: %d, into cache", tensor.name, step) | |||
| @staticmethod | |||
| def _get_merged_tensor(tensor_protos): | |||
| """ | |||
| Merged list of parsed tensor value into one. | |||
| Args: | |||
| tensor_protos (list[TensorProto]): List of tensor proto. | |||
| Returns: | |||
| TensorProto, merged tensor proto. | |||
| """ | |||
| merged_tensor = tensor_protos[-1] | |||
| if len(tensor_protos) > 1: | |||
| tensor_value = bytes() | |||
| for tensor_proto in tensor_protos: | |||
| if not tensor_proto.tensor_content: | |||
| log.warning("Doesn't find tensor value for %s:%s", | |||
| tensor_proto.node_name, tensor_proto.slot) | |||
| break | |||
| tensor_value += tensor_proto.tensor_content | |||
| merged_tensor.tensor_content = tensor_value | |||
| log.debug("Merge multi tensor values into one.") | |||
| return merged_tensor | |||
| def _put_tensor_into_cache(self, tensor, step): | |||
| """ | |||
| Put tensor into cache. | |||
| Args: | |||
| tensor (OpTensor): The tensor value. | |||
| """ | |||
| cache_tensor = self._tensors.get(tensor.name) | |||
| if cache_tensor is None: | |||
| cache_tensor = {} | |||
| self._tensors[tensor.name] = cache_tensor | |||
| cache_tensor[step] = tensor | |||
| def put_const_vals(self, const_vals): | |||
| """ | |||
| Put const value into tensor cache. | |||
| Args: | |||
| const_vals (list[NamedValueProto]): List of const values. | |||
| """ | |||
| for const_val in const_vals: | |||
| if not (const_val.value and const_val.key): | |||
| continue | |||
| if DataType.Name(const_val.value.dtype) == "DT_TENSOR": | |||
| tensor_proto = const_val.value.tensor_val | |||
| tensor_proto.node_name = const_val.key | |||
| tensor_proto.slot = '0' | |||
| const_tensor = OpTensor(tensor_proto) | |||
| else: | |||
| const_tensor = ConstTensor(const_val) | |||
| self._const_vals[const_tensor.name] = const_tensor | |||
| def get(self, filter_condition=None): | |||
| """ | |||
| Get full tensor value. | |||
| Args: | |||
| filter_condition (dict): Filter condition. | |||
| - name (str): The name of tensor. | |||
| - node_type (str): The type of the node. | |||
| Returns: | |||
| dict, the tensor_value. | |||
| """ | |||
| name = filter_condition.get('name') | |||
| node_type = filter_condition.get('node_type') | |||
| shape = filter_condition.get('shape') | |||
| tensor = self._get_tensor(name, node_type) | |||
| if not tensor: | |||
| log.error("No tensor named %s", name) | |||
| raise DebuggerParamValueError("No tensor named {}".format(name)) | |||
| tensor_info = tensor.get_full_info(shape) | |||
| self._update_has_prev_step_field(tensor_info, name, node_type) | |||
| return {'tensor_value': tensor_info} | |||
| def _get_tensor(self, tensor_name, node_type=None, step=None): | |||
| """ | |||
| Get tensor according to tensor name and node_type. | |||
| Args: | |||
| tensor_name (str): Tensor name, format like `node_name:slot`. | |||
| node_type (str): Node type. | |||
| step (int): The step of tensor info. Default: None. Noe | |||
| Returns: | |||
| Union[OPTensor, ConstTensor], the tensor object. | |||
| """ | |||
| if step is None: | |||
| step = self._cur_step | |||
| tensor = self._tensors.get(tensor_name, {}).get(step) | |||
| if not tensor and node_type == NodeTypeEnum.CONST.value: | |||
| const_name = tensor_name.rsplit('/', 1)[-1] | |||
| tensor = self._const_vals.get(const_name) | |||
| self._tensors[tensor_name] = {step: tensor} | |||
| return tensor | |||
| def _get_basic_info(self, tensor_name, node_type=None): | |||
| """Get the latest basic tensor info by tensor name.""" | |||
| tensor = self._get_tensor(tensor_name, node_type) | |||
| if tensor: | |||
| return tensor.get_basic_info() | |||
| return None | |||
| def get_tensor_history(self, tensor_names): | |||
| """Get tensor history for tensor names.""" | |||
| # only used by grpc server, could be remove later | |||
| tensor_infos = [] | |||
| for tensor_name in tensor_names: | |||
| tensor_info = self._get_basic_info(tensor_name) | |||
| tensor_infos.append(tensor_info) | |||
| return {'tensor_history': tensor_infos} | |||
| def update_tensor_history(self, tensor_history): | |||
| """ | |||
| Add tensor basic info in tensor_history. | |||
| Args: | |||
| tensor_history (dict): Tensor history, including a list of tensor name and type. | |||
| Returns: | |||
| list[dict], the list of tensor basic info cache. | |||
| """ | |||
| missed_tensors = [] | |||
| for tensor_info in tensor_history.get('tensor_history'): | |||
| tensor_name = tensor_info.get('full_name') | |||
| node_type = tensor_info.get('node_type') | |||
| basic_info = self._get_basic_info(tensor_name, node_type) | |||
| flag = self._update_has_prev_step_field(basic_info, tensor_name, node_type) | |||
| if flag is False: | |||
| missed_tensor = tensor_info.copy() | |||
| missed_tensor['iter'] = 'prev' | |||
| missed_tensors.append(missed_tensor) | |||
| log.debug("Add previous view cmd for %s", tensor_name) | |||
| # add `has_prev_step` field to tensor basic info. | |||
| if basic_info: | |||
| tensor_info.update(basic_info) | |||
| else: | |||
| missed_tensors.append(tensor_info) | |||
| log.debug("Add view cmd for %s", tensor_name) | |||
| return missed_tensors | |||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): | |||
| """Update has_prev_step field in tensor info.""" | |||
| flag = None | |||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||
| flag = self._has_prev_tensor_value(tensor_name) | |||
| if flag and tensor_info: | |||
| tensor_info['has_prev_step'] = True | |||
| return flag | |||
| def _has_prev_tensor_value(self, tensor_name): | |||
| """ | |||
| Check if the tensor has valid value of previous step. | |||
| Args: | |||
| tensor_name (str): Tensor name. | |||
| Returns: | |||
| bool, whether the tensor has valid tensor value. | |||
| """ | |||
| flag = None | |||
| # check if the tensor has previous step value. | |||
| prev_step = self._cur_step - 1 | |||
| if prev_step < 0: | |||
| return flag | |||
| tensor = self._get_tensor(tensor_name, step=prev_step) | |||
| flag = bool(tensor and tensor.value) | |||
| return flag | |||
| def get_tensor_value_by_name(self, tensor_name, prev=False): | |||
| """Get tensor value by name in numpy type.""" | |||
| cur_step = self._cur_step | |||
| step = cur_step - 1 if prev else cur_step | |||
| if step < 0: | |||
| log.warning("%d step has no previous value for tensor: %s", cur_step, tensor_name) | |||
| return None | |||
| tensor = self._get_tensor(tensor_name, step=step) | |||
| return tensor | |||
| def clean_tensors(self, cur_step): | |||
| """Clean the tensor cache.""" | |||
| self._cur_step = cur_step | |||
| expired_tensor = [] | |||
| for tensor_name, tensor in self._tensors.items(): | |||
| expired_step = [step for step in tensor.keys() if step <= cur_step - 2] | |||
| for step in expired_step: | |||
| tensor.pop(step) | |||
| if not tensor: | |||
| expired_tensor.append(tensor_name) | |||
| for tensor_name in expired_tensor: | |||
| self._tensors.pop(tensor_name) | |||
| self._tensors = {} | |||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0): | |||
| """ | |||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||
| Args: | |||
| tensor_name (str): The name of tensor for cache. | |||
| shape (tuple): Specify concrete dimensions of shape. | |||
| tolerance (str): Specify tolerance of difference between current step tensor and previous | |||
| step tensor. Default value is 0. Its is a percentage. The boundary value is equal to | |||
| max(abs(min),abs(max)) * tolerance. The function of min and max is being used to | |||
| calculate the min value and max value of the result of the current step tensor subtract | |||
| the previous step tensor. If the absolute value of result is less than or equal to | |||
| boundary value, the result will set to be zero. | |||
| Raises: | |||
| DebuggerParamValueError, If get current step node and previous step node failed. | |||
| Returns: | |||
| dict, the retrieved data. | |||
| """ | |||
| curr_tensor = self.get_tensor_value_by_name(tensor_name) | |||
| prev_tensor = self.get_tensor_value_by_name(tensor_name, prev=True) | |||
| if not (curr_tensor and prev_tensor): | |||
| log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) | |||
| raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " | |||
| f"{tensor_name} failed.") | |||
| curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape) | |||
| prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape) | |||
| tensor_info = curr_tensor.get_basic_info() | |||
| if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray): | |||
| diff_tensor = TensorUtils.calc_diff_between_two_tensor(curr_tensor_slice, prev_tensor_slice, tolerance) | |||
| result = np.stack([prev_tensor_slice, curr_tensor_slice, diff_tensor], axis=-1) | |||
| tensor_info['diff'] = result.tolist() | |||
| stats = TensorUtils.get_statistics_from_tensor(diff_tensor) | |||
| tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats) | |||
| del tensor_info['has_prev_step'] | |||
| del tensor_info['value'] | |||
| reply = {'tensor_value': tensor_info} | |||
| return reply | |||
| @@ -0,0 +1,333 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Define the watchpoint stream handler.""" | |||
| import numpy as np | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerParamTypeError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD | |||
| from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \ | |||
| WATCHPOINT_CONDITION_MAPPING | |||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||
| class WatchpointHandler(StreamHandlerBase): | |||
| """watchpoint Handler.""" | |||
| def __init__(self): | |||
| self._watchpoints = {} | |||
| self._deleted_watchpoints = [] | |||
| self._updated_watchpoints = {} | |||
| self._latest_id = 0 | |||
| def put(self, value): | |||
| """ | |||
| Put Watchpoint into watchpoint handler. | |||
| Args: | |||
| value (Watchpoint): The name of nodes that have been chosen. | |||
| """ | |||
| new_id = value.watchpoint_id | |||
| self._watchpoints[new_id] = value | |||
| self._updated_watchpoints[new_id] = value | |||
| self._latest_id = new_id | |||
| log.debug("Put watchpoint %d into cache.", new_id) | |||
| def sync_set_cmd(self): | |||
| """Clean temp watchpoints.""" | |||
| self._deleted_watchpoints = [] | |||
| self._updated_watchpoints = {} | |||
| def get_watchpoint_by_id(self, watchpoint_id): | |||
| """Get watchpoint by watchpoint id.""" | |||
| watchpoint = self._watchpoints.get(watchpoint_id) | |||
| if not watchpoint: | |||
| log.error("Invalid watchpoint id %d", watchpoint_id) | |||
| raise DebuggerParamValueError("Invalid watchpoint id {}".format(watchpoint_id)) | |||
| return watchpoint | |||
| def get(self, filter_condition=False): | |||
| """ | |||
| Get the watchpoints. | |||
| Args: | |||
| filter_condition (bool): If True, get all watchpoints without nodes. If False, | |||
| get updated watchpoints in SetCMD proto format. Default: False. | |||
| Returns: | |||
| dict, the watchpoints. | |||
| """ | |||
| reply = [] | |||
| if not filter_condition: | |||
| # get watch condition list | |||
| for _, watchpoint in self._watchpoints.items(): | |||
| watchpoint_info = watchpoint.get_watch_condition_info() | |||
| reply.append(watchpoint_info) | |||
| else: | |||
| # get updated watchpoint list | |||
| for _, watchpoint in self._updated_watchpoints.items(): | |||
| set_cmd = watchpoint.get_set_cmd() | |||
| reply.append(set_cmd) | |||
| reply.extend(self._deleted_watchpoints) | |||
| log.debug("get the watch points with filter_condition:%s", filter_condition) | |||
| return {'watch_points': reply} | |||
| def set_watch_nodes(self, graph, watch_point_id): | |||
| """ | |||
| set watch nodes for graph. | |||
| Args: | |||
| graph (dict): The graph with list of nodes. | |||
| watch_point_id (int): The id of watchpoint. | |||
| """ | |||
| if not (watch_point_id and graph): | |||
| return | |||
| self._validate_watchpoint_id(watch_point_id) | |||
| log.debug("add watch flags") | |||
| watchpoint = self._watchpoints.get(watch_point_id) | |||
| self._set_watch_status_recursively(graph, watchpoint) | |||
| def _set_watch_status_recursively(self, graph, watchpoint): | |||
| """Set watch status to graph.""" | |||
| if not isinstance(graph, dict): | |||
| log.warning("The graph is not dict.") | |||
| return | |||
| if graph.get('children'): | |||
| self._set_watch_status_recursively(graph.get('children'), watchpoint) | |||
| for node in graph.get('nodes', []): | |||
| if not isinstance(node, dict): | |||
| log.warning("The node is not dict.") | |||
| return | |||
| node_name = node.get('name') | |||
| if not node_name: | |||
| continue | |||
| flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name')) | |||
| node['watched'] = flag | |||
| if node.get('nodes'): | |||
| self._set_watch_status_recursively(node, watchpoint) | |||
| def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None): | |||
| """ | |||
| Create watchpoint. | |||
| Args: | |||
| watch_condition (dict): The watch condition. | |||
| - condition (str): Accept `INF` or `NAN`. | |||
| - param (list[float]): Not defined yet. | |||
| watch_nodes (list[NodeBasicInfo]): The list of node basic info. | |||
| watch_point_id (int): The id of watchpoint. | |||
| Returns: | |||
| int, the new id of watchpoint. | |||
| """ | |||
| validate_watch_condition(watch_condition) | |||
| new_id = self._latest_id + 1 | |||
| watchpoint = Watchpoint(new_id, watch_condition) | |||
| if watch_nodes: | |||
| watchpoint.add_nodes(watch_nodes) | |||
| elif watch_point_id: | |||
| self._validate_watchpoint_id(watch_point_id) | |||
| watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) | |||
| self.put(watchpoint) | |||
| return new_id | |||
| def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): | |||
| """ | |||
| Update watchpoint. | |||
| Args: | |||
| watch_point_id (int): The id of watchpoint. | |||
| watch_nodes (list[str]): The list of node names. | |||
| watched (bool): The update operator on nodes. If False, remove nodes from watch nodes. | |||
| If True, add nodes to watch nodes. Default: False. | |||
| Returns: | |||
| dict, empty response. | |||
| """ | |||
| self._validate_watchpoint_id(watch_point_id) | |||
| watchpoint = self._watchpoints.get(watch_point_id) | |||
| if watched: | |||
| watchpoint.add_nodes(watch_nodes) | |||
| else: | |||
| watchpoint.remove_nodes(watch_nodes) | |||
| self._updated_watchpoints[watch_point_id] = watchpoint | |||
| log.debug("Update watchpoint %d in cache.", watch_point_id) | |||
| def delete_watchpoint(self, watch_point_id): | |||
| """ | |||
| Delete watchpoint. | |||
| Args: | |||
| watch_point_id (int): The id of watchpoint. | |||
| Returns: | |||
| dict, empty response. | |||
| """ | |||
| self._validate_watchpoint_id(watch_point_id) | |||
| self._watchpoints.pop(watch_point_id) | |||
| set_cmd = SetCMD() | |||
| set_cmd.id = watch_point_id | |||
| set_cmd.delete = True | |||
| self._deleted_watchpoints.append(set_cmd) | |||
| log.debug("Delete watchpoint %d in cache.", watch_point_id) | |||
| def _validate_watchpoint_id(self, watch_point_id): | |||
| """Validate watchpoint id.""" | |||
| if watch_point_id and watch_point_id not in self._watchpoints: | |||
| log.error("Invalid watchpoint id: %d.", watch_point_id) | |||
| raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) | |||
| class WatchpointHitHandler(StreamHandlerBase): | |||
| """Watchpoint hit handler.""" | |||
| def __init__(self): | |||
| self._hits = {} | |||
| def put(self, value): | |||
| """ | |||
| Put value into watchpoint hit cache. Called by grpc server. | |||
| Args: | |||
| value (dict): The watchpoint hit info. | |||
| - tensor_proto (TensorProto): The message about hit tensor. | |||
| - watchpoint (Watchpoint): The Watchpoint that a node hit. | |||
| """ | |||
| watchpoint_hit = WatchpointHit( | |||
| tensor_proto=value.get('tensor_proto'), | |||
| watchpoint=value.get('watchpoint'), | |||
| node_name=value.get('node_name') | |||
| ) | |||
| node_name = value.get('node_name') | |||
| hit_tensors = self._hits.get(node_name) | |||
| if hit_tensors is None: | |||
| hit_tensors = [] | |||
| self._hits[node_name] = hit_tensors | |||
| if watchpoint_hit not in hit_tensors: | |||
| hit_tensors.append(watchpoint_hit) | |||
| def get(self, filter_condition=None): | |||
| """ | |||
| Get watchpoint hit list. | |||
| Args: | |||
| filter_condition (str): Get the watchpoint hit according to specifiled node name. | |||
| If not given, get all watchpoint hits. Default: None. | |||
| Returns: | |||
| dict, the watchpoint hit list. | |||
| """ | |||
| if filter_condition is None: | |||
| log.debug("Get all watchpoint hit list.") | |||
| reply = self.get_watchpoint_hits() | |||
| else: | |||
| log.debug("Get the watchpoint for node: <%s>.", filter_condition) | |||
| reply = self._hits.get(filter_condition) | |||
| return reply | |||
| def get_watchpoint_hits(self): | |||
| """Return the list of watchpoint hits.""" | |||
| watch_point_hits = [] | |||
| for node_name, watchpoint_hits in self._hits.items(): | |||
| watch_points = [watchpoint_hit.watchpoint for watchpoint_hit in watchpoint_hits] | |||
| watch_point_hits.append({ | |||
| 'node_name': node_name, | |||
| 'watch_points': watch_points | |||
| }) | |||
| return {'watch_point_hits': watch_point_hits} | |||
| def _is_tensor_hit(self, tensor_name): | |||
| """Check if the tensor is record in hit cache.""" | |||
| node_name = tensor_name.split(':')[0] | |||
| watchpoint_hits = self.get(node_name) | |||
| if watchpoint_hits is None: | |||
| return False | |||
| for watchpoint_hit in watchpoint_hits: | |||
| if tensor_name == watchpoint_hit.tensor_name: | |||
| return True | |||
| return False | |||
| def update_tensor_history(self, tensor_history): | |||
| """ | |||
| Add hit flag to tensor history. | |||
| Args: | |||
| tensor_history (dict): The tensor history. | |||
| """ | |||
| if not self._hits: | |||
| return | |||
| # add hit tensor names to `tensor_names` | |||
| for tensor_info in tensor_history.get('tensor_history'): | |||
| tensor_name = tensor_info['full_name'] | |||
| hit_flag = self._is_tensor_hit(tensor_name) | |||
| tensor_info['is_hit'] = hit_flag | |||
| def validate_watch_condition(watch_condition): | |||
| """Validate watch condition.""" | |||
| if not isinstance(watch_condition, dict): | |||
| log.error("<watch_condition> should be dict. %s received.", watch_condition) | |||
| raise DebuggerParamTypeError("<watch_condition> should be dict.") | |||
| # validate condition | |||
| condition = watch_condition.get('condition') | |||
| if condition not in WATCHPOINT_CONDITION_MAPPING.keys(): | |||
| log.error("Invalid watch condition. Acceptable values are <%s>.", | |||
| str(WATCHPOINT_CONDITION_MAPPING.keys())) | |||
| raise DebuggerParamValueError("Invalid watch condition value.") | |||
| # validate param | |||
| validate_watch_condition_params(watch_condition) | |||
| def validate_watch_condition_params(watch_condition): | |||
| """ | |||
| Validate watch condition parameters. | |||
| Args: | |||
| watch_condition (dict): Watch condition. | |||
| - condition (str): Condition type. Should be in WATCHPOINT_CONDITION_MAPPING. | |||
| - param (list): Condition value. Should be given for comparison condition. The value will | |||
| be translated to np.float32. | |||
| """ | |||
| condition = watch_condition.get('condition') | |||
| param = watch_condition.get('param') | |||
| if condition in ['NAN', 'INF', 'OVERFLOW']: | |||
| if param: | |||
| log.error("No param is expected for %s condition.", condition) | |||
| raise DebuggerParamValueError("No param is expected.") | |||
| else: | |||
| if not isinstance(param, (float, int)): | |||
| log.error("Number param should be given for condition <%s>.", | |||
| condition) | |||
| raise DebuggerParamValueError("Number param should be given.") | |||
| if np.isinf(np.float32(param)): | |||
| log.error("Condition param should be float32.") | |||
| raise DebuggerParamValueError("The value of condition param should be within float32.") | |||
| @@ -14,19 +14,28 @@ | |||
| # ============================================================================ | |||
| """Start mindinsight service.""" | |||
| import argparse | |||
| import os | |||
| import sys | |||
| import re | |||
| import argparse | |||
| import sys | |||
| from importlib import import_module | |||
| import psutil | |||
| from mindinsight.conf import settings | |||
| from mindinsight.utils.command import BaseCommand | |||
| from mindinsight.utils.exceptions import PortNotAvailableError | |||
| from mindinsight.utils.hook import HookUtils | |||
| from mindinsight.utils.hook import init | |||
| from mindinsight.utils.exceptions import PortNotAvailableError | |||
| def str2bool(string): | |||
| """Convert str to bool""" | |||
| if string.lower() == 'false': | |||
| return False | |||
| if string.lower() == 'true': | |||
| return True | |||
| raise ValueError | |||
| class ConfigAction(argparse.Action): | |||
| @@ -146,6 +155,23 @@ class UrlPathPrefixAction(argparse.Action): | |||
| setattr(namespace, self.dest, prefix) | |||
| class EnableDebuggerAction(argparse.Action): | |||
| """SSL certificate action class definition.""" | |||
| def __call__(self, parser, namespace, values, option_string=None): | |||
| """ | |||
| Inherited __call__ method from argparse.Action. | |||
| Args: | |||
| parser (ArgumentParser): Passed-in argument parser. | |||
| namespace (Namespace): Namespace object to hold arguments. | |||
| values (object): Argument values with type depending on argument definition. | |||
| option_string (str): Optional string for specific argument name. Default: None. | |||
| """ | |||
| enable_debugger = values | |||
| setattr(namespace, self.dest, enable_debugger) | |||
| class Command(BaseCommand): | |||
| """ | |||
| Start mindinsight service. | |||
| @@ -186,6 +212,14 @@ class Command(BaseCommand): | |||
| Custom port ranging from %s to %s. Default value is %s. | |||
| """ % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.PORT)) | |||
| parser.add_argument( | |||
| '--debugger_port', | |||
| type=int, | |||
| action=PortAction, | |||
| help=""" | |||
| Debugger port ranging from %s to %s. Default value is %s. | |||
| """ % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.DEBUGGER_PORT)) | |||
| parser.add_argument( | |||
| '--url-path-prefix', | |||
| type=str, | |||
| @@ -197,6 +231,14 @@ class Command(BaseCommand): | |||
| dot or double dots. Default value is ''. | |||
| """) | |||
| parser.add_argument( | |||
| '--enable_debugger', | |||
| type=str2bool, | |||
| action=EnableDebuggerAction, | |||
| default=False, | |||
| help=""" | |||
| Enable debugger or not. | |||
| Dfault is False.""") | |||
| for hook in HookUtils.instance().hooks(): | |||
| hook.register_startup_arguments(parser) | |||
| @@ -33,6 +33,7 @@ class MindInsightModules(Enum): | |||
| SCRIPTCONVERTER = 7 | |||
| WIZARD = 9 | |||
| OPTIMIZER = 10 | |||
| DEBUGGER = 11 | |||
| class GeneralErrors(Enum): | |||
| @@ -56,6 +57,10 @@ class LineageMgrErrors(Enum): | |||
| """Enum definition for lineage errors.""" | |||
| class DebuggerErrors(Enum): | |||
| """Enum definition for debugger errors.""" | |||
| class DataVisualErrors(Enum): | |||
| """Enum definition for datavisual errors.""" | |||
| RESTFUL_API_NOT_EXIST = 1 | |||
| @@ -0,0 +1,298 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Tensor utils.""" | |||
| import numpy as np | |||
| from mindinsight.datavisual.utils.tools import to_int | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.utils.exceptions import ParamTypeError | |||
| from mindinsight.utils.log import utils_logger as logger | |||
| F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max | |||
| class Statistics: | |||
| """Statistics data class. | |||
| Args: | |||
| max_value (float): max value of tensor data. | |||
| min_value (float): min value of tensor data. | |||
| avg_value (float): avg value of tensor data. | |||
| count (int): total count of tensor data. | |||
| nan_count (int): count of NAN. | |||
| neg_inf_count (int): count of negative INF. | |||
| pos_inf_count (int): count of positive INF. | |||
| """ | |||
| def __init__(self, max_value=0, min_value=0, avg_value=0, | |||
| count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0): | |||
| self._max = max_value | |||
| self._min = min_value | |||
| self._avg = avg_value | |||
| self._count = count | |||
| self._nan_count = nan_count | |||
| self._neg_inf_count = neg_inf_count | |||
| self._pos_inf_count = pos_inf_count | |||
| @property | |||
| def max(self): | |||
| """Get max value of tensor.""" | |||
| return self._max | |||
| @property | |||
| def min(self): | |||
| """Get min value of tensor.""" | |||
| return self._min | |||
| @property | |||
| def avg(self): | |||
| """Get avg value of tensor.""" | |||
| return self._avg | |||
| @property | |||
| def count(self): | |||
| """Get total count of tensor.""" | |||
| return self._count | |||
| @property | |||
| def nan_count(self): | |||
| """Get count of NAN.""" | |||
| return self._nan_count | |||
| @property | |||
| def neg_inf_count(self): | |||
| """Get count of negative INF.""" | |||
| return self._neg_inf_count | |||
| @property | |||
| def pos_inf_count(self): | |||
| """Get count of positive INF.""" | |||
| return self._pos_inf_count | |||
| class TensorUtils: | |||
| """Tensor Utils class.""" | |||
| @staticmethod | |||
| def validate_dims_format(dims): | |||
| """ | |||
| Validate correct of format of dimension parameter. | |||
| Args: | |||
| dims (str): Dims of tensor. Its format is something like this "[0, 0, :, :]". | |||
| Raises: | |||
| ParamValueError: If format of dims is not correct. | |||
| """ | |||
| if dims is not None: | |||
| if not isinstance(dims, str): | |||
| raise ParamTypeError(dims, str) | |||
| dims = dims.strip() | |||
| if not (dims.startswith('[') and dims.endswith(']')): | |||
| raise ParamValueError('The value: {} of dims must be ' | |||
| 'start with `[` and end with `]`.'.format(dims)) | |||
| for dim in dims[1:-1].split(','): | |||
| dim = dim.strip() | |||
| if dim == ":": | |||
| continue | |||
| if dim.startswith('-'): | |||
| dim = dim[1:] | |||
| if not dim.isdigit(): | |||
| raise ParamValueError('The value: {} of dims in the square brackets ' | |||
| 'must be int or `:`.'.format(dims)) | |||
| @staticmethod | |||
| def convert_array_from_str_dims(dims, limit=0): | |||
| """ | |||
| Convert string of dims data to array. | |||
| Args: | |||
| dims (str): Specify dims of tensor. | |||
| limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation. | |||
| Returns: | |||
| list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None]. | |||
| Raises: | |||
| ParamValueError, If flexible dimensions exceed limit value. | |||
| """ | |||
| dims = dims.strip().lstrip('[').rstrip(']') | |||
| dims_list = [] | |||
| count = 0 | |||
| for dim in dims.split(','): | |||
| dim = dim.strip() | |||
| if dim == ':': | |||
| dims_list.append(None) | |||
| count += 1 | |||
| else: | |||
| dims_list.append(to_int(dim, "dim")) | |||
| if limit and count > limit: | |||
| raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}" | |||
| .format(limit, count)) | |||
| return dims_list | |||
| @staticmethod | |||
| def get_specific_dims_data(ndarray, dims, tensor_dims): | |||
| """ | |||
| Get specific dims data. | |||
| Args: | |||
| ndarray (numpy.ndarray): An ndarray of numpy. | |||
| dims (list): A list of specific dims. | |||
| tensor_dims (list): A list of tensor dims. | |||
| Returns: | |||
| numpy.ndarray, an ndarray of specific dims tensor data. | |||
| Raises: | |||
| ParamValueError, If the length of param dims is not equal to the length of tensor dims or | |||
| the index of param dims out of range. | |||
| """ | |||
| if len(dims) != len(tensor_dims): | |||
| raise ParamValueError("The length of param dims: {}, is not equal to the " | |||
| "length of tensor dims: {}.".format(len(dims), len(tensor_dims))) | |||
| indices = [] | |||
| for k, d in enumerate(dims): | |||
| if d is not None: | |||
| if d >= tensor_dims[k]: | |||
| raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k])) | |||
| indices.append(d) | |||
| else: | |||
| indices.append(slice(0, tensor_dims[k])) | |||
| result = ndarray[tuple(indices)] | |||
| # Make sure the return type is numpy.ndarray. | |||
| if not isinstance(result, np.ndarray): | |||
| result = np.array(result) | |||
| return result | |||
| @staticmethod | |||
| def get_statistics_from_tensor(tensors): | |||
| """ | |||
| Calculates statistics data of tensor. | |||
| Args: | |||
| tensors (numpy.ndarray): An numpy.ndarray of tensor data. | |||
| Returns: | |||
| an instance of Statistics. | |||
| """ | |||
| ma_value = np.ma.masked_invalid(tensors) | |||
| total, valid = tensors.size, ma_value.count() | |||
| invalids = [] | |||
| for isfn in np.isnan, np.isposinf, np.isneginf: | |||
| if total - valid > sum(invalids): | |||
| count = np.count_nonzero(isfn(tensors)) | |||
| invalids.append(count) | |||
| else: | |||
| invalids.append(0) | |||
| nan_count, pos_inf_count, neg_inf_count = invalids | |||
| if not valid: | |||
| logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape) | |||
| statistics = Statistics(max_value=0, | |||
| min_value=0, | |||
| avg_value=0, | |||
| count=total, | |||
| nan_count=nan_count, | |||
| neg_inf_count=neg_inf_count, | |||
| pos_inf_count=pos_inf_count) | |||
| return statistics | |||
| # BUG: max of a masked array with dtype np.float16 returns inf | |||
| # See numpy issue#15077 | |||
| if issubclass(tensors.dtype.type, np.floating): | |||
| tensor_min = ma_value.min(fill_value=np.PINF) | |||
| tensor_max = ma_value.max(fill_value=np.NINF) | |||
| if tensor_min < F32_MIN or tensor_max > F32_MAX: | |||
| logger.warning('Values(%f, %f) are too large, you may encounter some undefined ' | |||
| 'behaviours hereafter.', tensor_min, tensor_max) | |||
| else: | |||
| tensor_min = ma_value.min() | |||
| tensor_max = ma_value.max() | |||
| tensor_sum = ma_value.sum(dtype=np.float64) | |||
| statistics = Statistics(max_value=tensor_max, | |||
| min_value=tensor_min, | |||
| avg_value=tensor_sum / valid, | |||
| count=total, | |||
| nan_count=nan_count, | |||
| neg_inf_count=neg_inf_count, | |||
| pos_inf_count=pos_inf_count) | |||
| return statistics | |||
| @staticmethod | |||
| def get_statistics_dict(stats): | |||
| """ | |||
| Get statistics dict according to statistics value. | |||
| Args: | |||
| stats (Statistics): An instance of Statistics. | |||
| Returns: | |||
| dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'. | |||
| """ | |||
| statistics = { | |||
| "max": float(stats.max), | |||
| "min": float(stats.min), | |||
| "avg": float(stats.avg), | |||
| "count": stats.count, | |||
| "nan_count": stats.nan_count, | |||
| "neg_inf_count": stats.neg_inf_count, | |||
| "pos_inf_count": stats.pos_inf_count | |||
| } | |||
| return statistics | |||
| @staticmethod | |||
| def calc_diff_between_two_tensor(first_tensor, second_tensor, tolerance): | |||
| """ | |||
| Calculate the difference between the first tensor and the second tensor. | |||
| Args: | |||
| first_tensor (numpy.ndarray): Specify the first tensor. | |||
| second_tensor (numpy.ndarray): Specify the second tensor. | |||
| tolerance (float): The tolerance of difference between the first tensor and the second tensor. | |||
| Its is a percentage. The boundary value is equal to max(abs(min),abs(max)) * tolerance. | |||
| The function of min and max is being used to calculate the min value and max value of | |||
| the result of the first tensor subtract the second tensor. If the absolute value of | |||
| result is less than or equal to boundary value, the result will set to be zero. | |||
| Returns: | |||
| tuple[numpy.ndarray, OverallDiffMetric], numpy.ndarray indicates the value of the first tensor | |||
| subtract the second tensor and set the value to be zero when its less than or equal to tolerance. | |||
| Raises: | |||
| ParamTypeError: If the type of these two tensors is not the numpy.ndarray. | |||
| ParamValueError: If the shape or dtype is not the same of these two tensors. | |||
| """ | |||
| if not isinstance(first_tensor, np.ndarray): | |||
| raise ParamTypeError('first_tensor', np.ndarray) | |||
| if not isinstance(second_tensor, np.ndarray): | |||
| raise ParamTypeError('second_tensor', np.ndarray) | |||
| if first_tensor.shape != second_tensor.shape: | |||
| raise ParamValueError("the shape: {} of first tensor is not equal to shape: {} of second tensor." | |||
| .format(first_tensor.shape, second_tensor.shape)) | |||
| if first_tensor.dtype != second_tensor.dtype: | |||
| raise ParamValueError("the dtype: {} of first tensor is not equal to dtype: {} of second tensor." | |||
| .format(first_tensor.dtype, second_tensor.dtype)) | |||
| diff_tensor = np.subtract(first_tensor, second_tensor) | |||
| stats = TensorUtils.get_statistics_from_tensor(diff_tensor) | |||
| boundary_value = max(abs(stats.max), abs(stats.min)) * tolerance | |||
| is_close = np.isclose(first_tensor, second_tensor, atol=boundary_value, rtol=0) | |||
| result = np.multiply(diff_tensor, ~is_close) | |||
| return result | |||
| @@ -16,4 +16,5 @@ Werkzeug>=1.0.0 | |||
| tabulate>=0.8.6 | |||
| pandas>=1.0.4 | |||
| yapf>=0.30.0 | |||
| treelib>=1.6.1 | |||
| treelib>=1.6.1 | |||
| grpcio>=1.29.0 | |||
| @@ -14,9 +14,11 @@ | |||
| # ============================================================================ | |||
| """Test tensor container.""" | |||
| import unittest.mock as mock | |||
| import numpy as np | |||
| from mindinsight.datavisual.data_transform import tensor_container as tensor | |||
| from mindinsight.utils.tensor import TensorUtils | |||
| class TestTensorContainer: | |||
| @@ -34,8 +36,9 @@ class TestTensorContainer: | |||
| def test_get_statistics_from_tensor(self): | |||
| """Tests get statistics from tensor.""" | |||
| ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape([2, 2, 2]) | |||
| statistics = tensor.get_statistics_from_tensor(ndarray) | |||
| ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape( | |||
| [2, 2, 2]) | |||
| statistics = TensorUtils.get_statistics_from_tensor(ndarray) | |||
| assert (statistics.max, statistics.min, statistics.avg, statistics.count, | |||
| statistics.nan_count, statistics.neg_inf_count, statistics.pos_inf_count) == \ | |||
| (5, 1, 3, 8, | |||
| @@ -43,8 +46,9 @@ class TestTensorContainer: | |||
| def test_calc_original_buckets(self): | |||
| """Tests calculate original buckets.""" | |||
| ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape([2, 2, 2]) | |||
| statistics = tensor.get_statistics_from_tensor(ndarray) | |||
| ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape( | |||
| [2, 2, 2]) | |||
| statistics = TensorUtils.get_statistics_from_tensor(ndarray) | |||
| buckets = tensor.calc_original_buckets(ndarray, statistics) | |||
| assert (buckets[0].left, buckets[0].width, buckets[0].count) == (1, 2, 2) | |||
| @@ -29,10 +29,8 @@ from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.datavisual.common.exceptions import TensorNotExistError | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| from mindinsight.datavisual.data_transform.tensor_container import calc_original_buckets | |||
| from mindinsight.datavisual.data_transform.tensor_container import get_statistics_from_tensor | |||
| from mindinsight.datavisual.processors.tensor_processor import TensorProcessor | |||
| from mindinsight.datavisual.processors.tensor_processor import get_specific_dims_data | |||
| from mindinsight.datavisual.processors.tensor_processor import get_statistics_dict | |||
| from mindinsight.utils.tensor import TensorUtils | |||
| from mindinsight.datavisual.utils import crc32 | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| @@ -187,7 +185,7 @@ class TestTensorProcessor: | |||
| dims = expected_values.get('value').get("dims") | |||
| expected_data = np.array(expected_values.get('value').get("float_data")).reshape(dims) | |||
| recv_tensor = np.array(recv_values.get('value').get("data")) | |||
| expected_tensor = get_specific_dims_data(expected_data, [0, 0, None, None], dims) | |||
| expected_tensor = TensorUtils.get_specific_dims_data(expected_data, [0, 0, None, None], dims) | |||
| assert np.sum(np.isclose(recv_tensor, expected_tensor, rtol=1e-6) == 0) == 0 | |||
| @pytest.mark.usefixtures('load_tensor_record') | |||
| @@ -204,7 +202,8 @@ class TestTensorProcessor: | |||
| assert recv_values.get('wall_time') == expected_values.get('wall_time') | |||
| assert recv_values.get('step') == expected_values.get('step') | |||
| expected_data = expected_values.get('value').get("float_data") | |||
| expected_statistic = get_statistics_dict(get_statistics_from_tensor(expected_data)) | |||
| expected_statistic_instance = TensorUtils.get_statistics_from_tensor(expected_data) | |||
| expected_statistic = TensorUtils.get_statistics_dict(expected_statistic_instance) | |||
| recv_statistic = recv_values.get('value').get("statistics") | |||
| assert recv_statistic.get("max") - expected_statistic.get("max") < 1e-6 | |||
| assert recv_statistic.get("min") - expected_statistic.get("min") < 1e-6 | |||
| @@ -225,7 +224,7 @@ class TestTensorProcessor: | |||
| assert recv_values.get('wall_time') == expected_values.get('wall_time') | |||
| assert recv_values.get('step') == expected_values.get('step') | |||
| expected_data = expected_values.get('value').get("float_data") | |||
| expected_statistic = get_statistics_from_tensor(expected_data) | |||
| expected_statistic = TensorUtils.get_statistics_from_tensor(expected_data) | |||
| expected_buckets = calc_original_buckets(expected_data, expected_statistic) | |||
| recv_buckets = recv_values.get('value').get("histogram_buckets") | |||