You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define common utils."""
  16. import os
  17. import stat
  18. from importlib import import_module
  19. from typing import List, Tuple, Mapping
  20. from mindinsight.mindconverter.common.log import logger as log
  21. from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP
  22. def is_converted(operation: str):
  23. """
  24. Whether convert successful.
  25. Args:
  26. operation (str): Operation name.
  27. Returns:
  28. bool, true or false.
  29. """
  30. return operation and SEPARATOR_IN_ONNX_OP not in operation
  31. def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]):
  32. """
  33. Fetch specific nodes output from onnx model.
  34. Notes:
  35. Only support to get output without batch dimension.
  36. Args:
  37. model (ModelProto): ONNX model.
  38. feed_dict (dict): Feed forward inputs.
  39. output_nodes (list[str]): Output nodes list.
  40. Returns:
  41. dict, nodes' output value.
  42. """
  43. if not isinstance(feed_dict, dict) or not isinstance(output_nodes, list):
  44. raise TypeError("`feed_dict` should be type of dict, and `output_nodes` "
  45. "should be type of List[str].")
  46. ort = import_module("onnxruntime")
  47. input_nodes = list(feed_dict.keys())
  48. extractor = getattr(import_module("onnx.utils"), "Extractor")(model)
  49. extracted_model = extractor.extract_model(input_nodes, output_nodes)
  50. sess = ort.InferenceSession(path_or_bytes=bytes(extracted_model.SerializeToString()))
  51. fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict)
  52. run_result = dict()
  53. for idx, opt in enumerate(output_nodes):
  54. run_result[opt] = fetched_res[idx]
  55. return run_result
  56. def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
  57. out_folder: str, report_folder: str):
  58. """
  59. Save code file and report.
  60. Args:
  61. model_name (str): Model name.
  62. code_lines (dict): Code lines.
  63. out_folder (str): Output folder.
  64. report_folder (str): Report output folder.
  65. """
  66. flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
  67. modes = stat.S_IRUSR | stat.S_IWUSR
  68. modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR
  69. out_folder = os.path.realpath(out_folder)
  70. if not report_folder:
  71. report_folder = out_folder
  72. else:
  73. report_folder = os.path.realpath(report_folder)
  74. if not os.path.exists(out_folder):
  75. os.makedirs(out_folder, modes_usr)
  76. if not os.path.exists(report_folder):
  77. os.makedirs(report_folder, modes_usr)
  78. for file_name in code_lines:
  79. code, report = code_lines[file_name]
  80. try:
  81. with os.fdopen(os.open(os.path.realpath(os.path.join(out_folder, f"{model_name}.py")),
  82. flags, modes), 'w') as file:
  83. file.write(code)
  84. with os.fdopen(os.open(os.path.realpath(os.path.join(report_folder,
  85. f"report_of_{model_name}.txt")),
  86. flags, stat.S_IRUSR), "w") as rpt_f:
  87. rpt_f.write(report)
  88. except IOError as error:
  89. log.error(str(error))
  90. log.exception(error)
  91. raise error
  92. def lib_version_satisfied(current_ver: str, mini_ver_limited: str,
  93. newest_ver_limited: str = ""):
  94. """
  95. Check python lib version whether is satisfied.
  96. Notes:
  97. Version number must be format of x.x.x, e.g. 1.1.0.
  98. Args:
  99. current_ver (str): Current lib version.
  100. mini_ver_limited (str): Mini lib version.
  101. newest_ver_limited (str): Newest lib version.
  102. Returns:
  103. bool, true or false.
  104. """
  105. required_version_number_len = 3
  106. if len(list(current_ver.split("."))) != required_version_number_len or \
  107. len(list(mini_ver_limited.split("."))) != required_version_number_len or \
  108. (newest_ver_limited and len(newest_ver_limited.split(".")) != required_version_number_len):
  109. raise ValueError("Version number must be format of x.x.x.")
  110. if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited):
  111. return False
  112. return True
  113. def get_dict_key_by_value(val, dic):
  114. """
  115. Return the first appeared key of a dictionary by given value.
  116. Args:
  117. val (Any): Value of the key.
  118. dic (dict): Dictionary to be checked.
  119. Returns:
  120. Any, key of the given value.
  121. """
  122. for d_key, d_val in dic.items():
  123. if d_val == val:
  124. return d_key
  125. return None