# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Test debugger server utils.""" import json import os import time import shutil import tempfile from mindinsight.debugger.proto import ms_graph_pb2 from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL, GRAPH_PROTO_FILE from tests.utils.tools import compare_result_with_file, get_url def check_state(app_client, server_state='waiting'): """Check if the Server is ready.""" url = 'retrieve' body_data = {'mode': 'all'} max_try_times = 30 count = 0 flag = False while count < max_try_times: res = get_request_result(app_client, url, body_data) state = res.get('metadata', {}).get('state') if state == server_state: flag = True break count += 1 time.sleep(0.1) assert flag is True def get_request_result(app_client, url, body_data, method='post', expect_code=200, full_url=False): """Get request result.""" if not full_url: real_url = os.path.join(DEBUGGER_BASE_URL, url) else: real_url = url if method == 'post': response = app_client.post(real_url, data=json.dumps(body_data)) else: real_url = get_url(real_url, body_data) response = app_client.get(real_url) assert response.status_code == expect_code res = response.get_json() return res def send_and_compare_result(app_client, url, body_data, expect_file=None, method='post', full_url=False): """Send and compare result.""" res = get_request_result(app_client, url, body_data, method=method, full_url=full_url) delete_random_items(res) if expect_file: real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, 'restful_results', expect_file) compare_result_with_file(res, real_path) def send_and_save_result(app_client, url, body_data, file_path, method='post'): """Send and save result.""" res = get_request_result(app_client, url, body_data, method=method) delete_random_items(res) real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, 'restful_results', file_path) json.dump(res, open(real_path, 'w')) def delete_random_items(res): """delete the random items in metadata.""" if isinstance(res, dict): if res.get('metadata'): if res['metadata'].get('ip'): res['metadata'].pop('ip') if res['metadata'].get('pos'): res['metadata'].pop('pos') if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'): res['metadata']['debugger_version'].pop('mi') res['metadata']['debugger_version'].pop('ms') if res.get('devices'): for device in res.get('devices'): if device.get('server_ip'): device.pop('server_ip') def build_dump_file_structure(): """Build the dump file structure.""" async_file_structure = { "Ascend/async/device_0/Lenet_graph_1/1": 3, "Ascend/async/device_1/Lenet_graph_1/1": 3 } sync_file_structure = { "Ascend/sync/Lenet/device_0": 4, "Ascend/sync/Lenet/device_1": 4, "GPU/sync/Lenet/device_0": 3, "GPU/sync/Lenet/device_1": 3 } debugger_tmp_dir = tempfile.mkdtemp(suffix='debugger_tmp') dump_files_dir = os.path.join(debugger_tmp_dir, 'dump_files') shutil.copytree(os.path.join(os.path.dirname(__file__), 'dump_files'), dump_files_dir) for sub_dir, steps in async_file_structure.items(): for step in range(1, steps + 1): os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), str(step)), exist_ok=True) for sub_dir, steps in sync_file_structure.items(): for step in range(1, steps + 1): os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), 'iteration_' + str(step)), exist_ok=True) graph_dir_path = os.path.join(os.path.join(dump_files_dir, sub_dir), 'graphs') os.makedirs(graph_dir_path, exist_ok=True) graph_path = os.path.join(graph_dir_path, 'ms_output_trace_code_graph_0.pb') with open(GRAPH_PROTO_FILE, 'rb') as file_handler: content = file_handler.read() model = ms_graph_pb2.ModelProto() model.graph.ParseFromString(content) model_str = model.SerializeToString() with open(graph_path, 'wb') as file_handler: file_handler.write(model_str) return debugger_tmp_dir, dump_files_dir