|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Common configurations for debugger unit testing."""
- import json
- import os
-
- from google.protobuf import json_format
-
- from mindinsight.debugger.proto import ms_graph_pb2
- from mindinsight.debugger.stream_handler.graph_handler import GraphHandler
- from mindinsight.debugger.stream_handler.watchpoint_handler import WatchpointHitHandler
- from tests.utils.tools import compare_result_with_file
-
- GRAPH_PROTO_FILE = os.path.join(
- os.path.dirname(__file__), '../../utils/resource/graph_pb/lenet.pb'
- )
- DEBUGGER_EXPECTED_RESULTS = os.path.join(os.path.dirname(__file__), 'expected_results')
-
-
- def get_graph_proto():
- """Get graph proto."""
- with open(GRAPH_PROTO_FILE, 'rb') as f:
- content = f.read()
-
- graph = ms_graph_pb2.GraphProto()
- graph.ParseFromString(content)
-
- return graph
-
-
- def init_graph_handler():
- """Init GraphHandler."""
- graph = get_graph_proto()
- graph_handler = GraphHandler()
- graph_handler.put({graph.name: graph})
-
- return graph_handler
-
-
- def init_watchpoint_hit_handler(value):
- """Init WatchpointHitHandler."""
- wph_handler = WatchpointHitHandler()
- wph_handler.put(value)
-
- return wph_handler
-
-
- def get_node_basic_infos(node_names):
- """Get node info according to node names."""
- if not node_names:
- return []
- graph_stream = init_graph_handler()
- graph_name = graph_stream.graph_names[0]
- node_infos = []
- for node_name in node_names:
- node_infos.append(graph_stream.get_node_basic_info(node_name, graph_name))
- return node_infos
-
-
- def get_watch_nodes_by_search(watch_nodes):
- """Get watched leaf nodes by search name."""
- watched_leaf_nodes = []
- graph_stream = init_graph_handler()
- graph_name = graph_stream.graph_names[0]
- for search_name in watch_nodes:
- search_node_info = graph_stream.get_node_basic_info_by_scope(search_name, graph_name)
- watched_leaf_nodes.extend(search_node_info)
-
- return watched_leaf_nodes
-
-
- def mock_tensor_proto():
- """Mock tensor proto."""
- tensor_dict = {
- "node_name":
- "Default/network-WithLossCell/_backbone-LeNet5/relu-ReLU/gradReLU/ReluGradV2-op92",
- "slot": "0"
- }
- tensor_proto = json_format.Parse(json.dumps(tensor_dict), ms_graph_pb2.TensorProto())
-
- return tensor_proto
-
-
- def mock_tensor_history():
- """Mock tensor history."""
- tensor_history = {
- "tensor_history": [
- {"name": "Default/TransData-op99:0",
- "full_name": "Default/TransData-op99:0",
- "node_type": "TransData",
- "type": "output",
- "step": 0,
- "dtype": "DT_FLOAT32",
- "shape": [2, 3],
- "has_prev_step": False,
- "value": "click to view"},
- {"name": "Default/args0:0",
- "full_name": "Default/args0:0",
- "node_type": "Parameter",
- "type": "input",
- "step": 0,
- "dtype": "DT_FLOAT32",
- "shape": [2, 3],
- "has_prev_step": False,
- "value": "click to view"}
- ],
- "metadata": {
- "state": "waiting",
- "step": 0,
- "device_name": "0",
- "pos": "0",
- "ip": "127.0.0.1:57492",
- "node_name": "",
- "backend": "Ascend"
- }
- }
-
- return tensor_history
-
-
- def compare_debugger_result_with_file(res, expect_file, save=False):
- """
- Compare debugger result with file.
-
- Args:
- res (dict): The debugger result in dict type.
- expect_file: The expected file name.
- """
- real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, expect_file)
- if save:
- with open(real_path, 'w') as file_handler:
- json.dump(res, file_handler)
- else:
- compare_result_with_file(res, real_path)
|