You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

constant.py 4.3 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Constant definition."""
  16. from enum import Enum, unique
  17. import numpy as np
  18. SEPARATOR_IN_ONNX_OP = "::"
  19. SEPARATOR_IN_SCOPE = "/"
  20. SEPARATOR_BTW_NAME_AND_ID = "_"
  21. SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "="
  22. LINK_IN_SCOPE = "-"
  23. LINK_IN_WEIGHT_NAME = "."
  24. LEFT_BUCKET = "["
  25. RIGHT_BUCKET = "]"
  26. BLANK_SYM = " "
  27. FIRST_LEVEL_INDENT = BLANK_SYM * 4
  28. SECOND_LEVEL_INDENT = BLANK_SYM * 8
  29. NEW_LINE = "\n"
  30. ONNX_TYPE_INT = 2
  31. ONNX_TYPE_INTS = 7
  32. ONNX_TYPE_FLOAT = 1
  33. ONNX_TYPE_FLOATS = 6
  34. ONNX_TYPE_STRING = 3
  35. DYNAMIC_SHAPE = -1
  36. SCALAR_WITHOUT_SHAPE = 0
  37. UNKNOWN_DIM_VAL = "unk__001"
  38. ONNX_MIN_VER = "1.8.0"
  39. TF2ONNX_MIN_VER = "1.7.1"
  40. ONNXRUNTIME_MIN_VER = "1.5.2"
  41. ONNXOPTIMIZER_MIN_VER = "0.1.2"
  42. ONNXOPTIMIZER_MAX_VER = "0.1.2"
  43. TORCH_MIN_VER = "1.5.0"
  44. DTYPE_MAP = {
  45. 1: np.float32,
  46. 2: np.uint8,
  47. 3: np.int8,
  48. 4: np.uint16,
  49. 5: np.int16,
  50. 6: np.int32,
  51. 7: np.int64,
  52. 8: str,
  53. 9: bool,
  54. 10: np.float16,
  55. 11: np.double,
  56. 12: np.uint32,
  57. 13: np.uint64,
  58. 14: np.complex64,
  59. 15: np.complex128,
  60. 16: None
  61. }
  62. @unique
  63. class TemplateKeywords(Enum):
  64. """Define keywords in template message."""
  65. INIT = "init"
  66. CONSTRUCT = "construct"
  67. @unique
  68. class ExchangeMessageKeywords(Enum):
  69. """Define keywords in exchange message."""
  70. METADATA = "metadata"
  71. @unique
  72. class MetadataScope(Enum):
  73. """Define metadata scope keywords in exchange message."""
  74. SOURCE = "source"
  75. OPERATION = "operation"
  76. INPUTS = "inputs"
  77. INPUTS_SHAPE = "inputs_shape"
  78. OUTPUTS = "outputs"
  79. OUTPUTS_SHAPE = "outputs_shape"
  80. PRECURSOR = "precursor_nodes"
  81. SUCCESSOR = "successor_nodes"
  82. ATTRS = "attributes"
  83. SCOPE = "scope"
  84. @unique
  85. class VariableScope(Enum):
  86. """Define variable scope keywords in exchange message."""
  87. OPERATION = "operation"
  88. VARIABLE_NAME = "variable_name"
  89. OUTPUT_TYPE = "output_type"
  90. TSR_TYPE = "tensor"
  91. ARR_TYPE = "array"
  92. INPUTS = "inputs"
  93. ARGS = "args"
  94. WEIGHTS = "weights"
  95. TRAINABLE_PARAMS = "trainable_params"
  96. PARAMETERS_DECLARED = "parameters"
  97. GROUP_INPUTS = "group_inputs"
  98. BINARY_HEADER_PYTORCH_FILE = \
  99. b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00'
  100. TENSORFLOW_MODEL_SUFFIX = "pb"
  101. BINARY_HEADER_PYTORCH_BITS = 32
  102. ARGUMENT_LENGTH_LIMIT = 128
  103. ARGUMENT_NUM_LIMIT = 32
  104. ARGUMENT_LEN_LIMIT = 64
  105. EXPECTED_NUMBER = 1
  106. MIN_SCOPE_LENGTH = 2
  107. ONNX_OPSET_VERSION = 11
  108. MODEL_INPUT_NAME = 'input.1'
  109. NO_CONVERTED_OPERATORS = [
  110. "onnx::Constant",
  111. "Constant"
  112. ]
  113. THIRD_PART_VERSION = {
  114. "torch": (TORCH_MIN_VER,),
  115. "onnx": (ONNX_MIN_VER,),
  116. "onnxruntime": (ONNXRUNTIME_MIN_VER,),
  117. "onnxoptimizer": (ONNXOPTIMIZER_MIN_VER,),
  118. "tf2onnx": (TF2ONNX_MIN_VER,)
  119. }
  120. @unique
  121. class NodeType(Enum):
  122. MODULE = "module"
  123. OPERATION = "operation"
  124. CLASS = "class"
  125. FUNC = "func"
  126. INPUTS = "DataInput"
  127. @unique
  128. class InputType(Enum):
  129. TENSOR = "tensor"
  130. LIST = "list"
  131. @unique
  132. class FrameworkType(Enum):
  133. PYTORCH = 0
  134. TENSORFLOW = 1
  135. UNKNOWN = 2
  136. @unique
  137. class WeightType(Enum):
  138. PARAMETER = 0
  139. COMMON = 1
  140. def get_imported_module():
  141. """
  142. Generate imported module header.
  143. Returns:
  144. str, imported module.
  145. """
  146. return f"import numpy as np{NEW_LINE}" \
  147. f"import mindspore{NEW_LINE}" \
  148. f"import mindspore.ops as P{NEW_LINE}" \
  149. f"from mindspore import nn{NEW_LINE}" \
  150. f"from mindspore import Tensor, Parameter{NEW_LINE * 3}"