You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test debugger server utils."""
  16. import json
  17. import os
  18. import time
  19. import shutil
  20. import tempfile
  21. from mindinsight.debugger.proto import ms_graph_pb2
  22. from tests.st.func.debugger.conftest import DEBUGGER_EXPECTED_RESULTS, DEBUGGER_BASE_URL, GRAPH_PROTO_FILE
  23. from tests.utils.tools import compare_result_with_file, get_url
  24. def check_state(app_client, server_state='waiting'):
  25. """Check if the Server is ready."""
  26. url = 'retrieve'
  27. body_data = {'mode': 'all'}
  28. max_try_times = 30
  29. count = 0
  30. flag = False
  31. while count < max_try_times:
  32. res = get_request_result(app_client, url, body_data)
  33. state = res.get('metadata', {}).get('state')
  34. if state == server_state:
  35. flag = True
  36. break
  37. count += 1
  38. time.sleep(0.1)
  39. assert flag is True
  40. def get_request_result(app_client, url, body_data, method='post', expect_code=200, full_url=False):
  41. """Get request result."""
  42. if not full_url:
  43. real_url = os.path.join(DEBUGGER_BASE_URL, url)
  44. else:
  45. real_url = url
  46. if method == 'post':
  47. response = app_client.post(real_url, data=json.dumps(body_data))
  48. else:
  49. real_url = get_url(real_url, body_data)
  50. response = app_client.get(real_url)
  51. assert response.status_code == expect_code
  52. res = response.get_json()
  53. return res
  54. def send_and_compare_result(app_client, url, body_data, expect_file=None, method='post', full_url=False):
  55. """Send and compare result."""
  56. res = get_request_result(app_client, url, body_data, method=method, full_url=full_url)
  57. delete_random_items(res)
  58. if expect_file:
  59. real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, 'restful_results', expect_file)
  60. compare_result_with_file(res, real_path)
  61. def send_and_save_result(app_client, url, body_data, file_path, method='post'):
  62. """Send and save result."""
  63. res = get_request_result(app_client, url, body_data, method=method)
  64. delete_random_items(res)
  65. real_path = os.path.join(DEBUGGER_EXPECTED_RESULTS, 'restful_results', file_path)
  66. json.dump(res, open(real_path, 'w'))
  67. def delete_random_items(res):
  68. """delete the random items in metadata."""
  69. if isinstance(res, dict):
  70. if res.get('metadata'):
  71. if res['metadata'].get('ip'):
  72. res['metadata'].pop('ip')
  73. if res['metadata'].get('pos'):
  74. res['metadata'].pop('pos')
  75. if res['metadata'].get('debugger_version') and res['metadata']['debugger_version'].get('mi'):
  76. res['metadata']['debugger_version'].pop('mi')
  77. res['metadata']['debugger_version'].pop('ms')
  78. if res.get('devices'):
  79. for device in res.get('devices'):
  80. if device.get('server_ip'):
  81. device.pop('server_ip')
  82. def build_dump_file_structure():
  83. """Build the dump file structure."""
  84. async_file_structure = {
  85. "Ascend/async/device_0/Lenet_graph_1/1": 3,
  86. "Ascend/async/device_1/Lenet_graph_1/1": 3
  87. }
  88. sync_file_structure = {
  89. "Ascend/sync/Lenet/device_0": 4,
  90. "Ascend/sync/Lenet/device_1": 4,
  91. "GPU/sync/Lenet/device_0": 3,
  92. "GPU/sync/Lenet/device_1": 3
  93. }
  94. debugger_tmp_dir = tempfile.mkdtemp(suffix='debugger_tmp')
  95. dump_files_dir = os.path.join(debugger_tmp_dir, 'dump_files')
  96. shutil.copytree(os.path.join(os.path.dirname(__file__), 'dump_files'), dump_files_dir)
  97. for sub_dir, steps in async_file_structure.items():
  98. for step in range(1, steps + 1):
  99. os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), str(step)), exist_ok=True)
  100. for sub_dir, steps in sync_file_structure.items():
  101. for step in range(1, steps + 1):
  102. os.makedirs(os.path.join(os.path.join(dump_files_dir, sub_dir), 'iteration_' + str(step)),
  103. exist_ok=True)
  104. graph_dir_path = os.path.join(os.path.join(dump_files_dir, sub_dir), 'graphs')
  105. os.makedirs(graph_dir_path, exist_ok=True)
  106. graph_path = os.path.join(graph_dir_path, 'ms_output_trace_code_graph_0.pb')
  107. with open(GRAPH_PROTO_FILE, 'rb') as file_handler:
  108. content = file_handler.read()
  109. model = ms_graph_pb2.ModelProto()
  110. model.graph.ParseFromString(content)
  111. model_str = model.SerializeToString()
  112. with open(graph_path, 'wb') as file_handler:
  113. file_handler.write(model_str)
  114. return debugger_tmp_dir, dump_files_dir