You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

constant.py 4.1 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Constant definition."""
  16. from enum import Enum, unique
  17. import numpy as np
  18. SEPARATOR_IN_ONNX_OP = "::"
  19. SEPARATOR_IN_SCOPE = "/"
  20. SEPARATOR_BTW_NAME_AND_ID = "_"
  21. SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "="
  22. LINK_IN_SCOPE = "-"
  23. LINK_IN_WEIGHT_NAME = "."
  24. LEFT_BUCKET = "["
  25. RIGHT_BUCKET = "]"
  26. BLANK_SYM = " "
  27. FIRST_LEVEL_INDENT = BLANK_SYM * 4
  28. SECOND_LEVEL_INDENT = BLANK_SYM * 8
  29. NEW_LINE = "\n"
  30. ONNX_TYPE_INT = 2
  31. ONNX_TYPE_INTS = 7
  32. ONNX_TYPE_FLOAT = 1
  33. ONNX_TYPE_FLOATS = 6
  34. ONNX_TYPE_STRING = 3
  35. DYNAMIC_SHAPE = -1
  36. SCALAR_WITHOUT_SHAPE = 0
  37. UNKNOWN_DIM_VAL = "unk__001"
  38. ONNX_MIN_VER = "1.8.0"
  39. TF2ONNX_MIN_VER = "1.7.1"
  40. ONNXRUNTIME_MIN_VER = "1.5.2"
  41. ONNXOPTIMIZER_MIN_VER = "0.1.2"
  42. ONNXOPTIMIZER_MAX_VER = "0.1.2"
  43. DTYPE_MAP = {
  44. 1: np.float32,
  45. 2: np.uint8,
  46. 3: np.int8,
  47. 4: np.uint16,
  48. 5: np.int16,
  49. 6: np.int32,
  50. 7: np.int64,
  51. 8: str,
  52. 9: bool,
  53. 10: np.float16,
  54. 11: np.double,
  55. 12: np.uint32,
  56. 13: np.uint64,
  57. 14: np.complex64,
  58. 15: np.complex128,
  59. 16: None
  60. }
  61. @unique
  62. class TemplateKeywords(Enum):
  63. """Define keywords in template message."""
  64. INIT = "init"
  65. CONSTRUCT = "construct"
  66. @unique
  67. class ExchangeMessageKeywords(Enum):
  68. """Define keywords in exchange message."""
  69. METADATA = "metadata"
  70. @unique
  71. class MetadataScope(Enum):
  72. """Define metadata scope keywords in exchange message."""
  73. SOURCE = "source"
  74. OPERATION = "operation"
  75. INPUTS = "inputs"
  76. INPUTS_SHAPE = "inputs_shape"
  77. OUTPUTS = "outputs"
  78. OUTPUTS_SHAPE = "outputs_shape"
  79. PRECURSOR = "precursor_nodes"
  80. SUCCESSOR = "successor_nodes"
  81. ATTRS = "attributes"
  82. SCOPE = "scope"
  83. @unique
  84. class VariableScope(Enum):
  85. """Define variable scope keywords in exchange message."""
  86. OPERATION = "operation"
  87. VARIABLE_NAME = "variable_name"
  88. OUTPUT_TYPE = "output_type"
  89. TSR_TYPE = "tensor"
  90. ARR_TYPE = "array"
  91. INPUTS = "inputs"
  92. ARGS = "args"
  93. WEIGHTS = "weights"
  94. TRAINABLE_PARAMS = "trainable_params"
  95. PARAMETERS_DECLARED = "parameters"
  96. GROUP_INPUTS = "group_inputs"
  97. ONNX_MODEL_SUFFIX = "onnx"
  98. TENSORFLOW_MODEL_SUFFIX = "pb"
  99. BINARY_HEADER_PYTORCH_BITS = 32
  100. ARGUMENT_LENGTH_LIMIT = 128
  101. ARGUMENT_NUM_LIMIT = 32
  102. ARGUMENT_LEN_LIMIT = 64
  103. EXPECTED_NUMBER = 1
  104. MIN_SCOPE_LENGTH = 2
  105. ONNX_OPSET_VERSION = 11
  106. NO_CONVERTED_OPERATORS = [
  107. "onnx::Constant",
  108. "Constant"
  109. ]
  110. THIRD_PART_VERSION = {
  111. "onnx": (ONNX_MIN_VER,),
  112. "onnxruntime": (ONNXRUNTIME_MIN_VER,),
  113. "onnxoptimizer": (ONNXOPTIMIZER_MIN_VER,),
  114. "tf2onnx": (TF2ONNX_MIN_VER,)
  115. }
  116. @unique
  117. class NodeType(Enum):
  118. MODULE = "module"
  119. OPERATION = "operation"
  120. CLASS = "class"
  121. FUNC = "func"
  122. INPUTS = "DataInput"
  123. @unique
  124. class InputType(Enum):
  125. TENSOR = "tensor"
  126. LIST = "list"
  127. @unique
  128. class FrameworkType(Enum):
  129. ONNX = 0
  130. TENSORFLOW = 1
  131. UNKNOWN = 2
  132. @unique
  133. class WeightType(Enum):
  134. PARAMETER = 0
  135. COMMON = 1
  136. def get_imported_module():
  137. """
  138. Generate imported module header.
  139. Returns:
  140. str, imported module.
  141. """
  142. return f"import numpy as np{NEW_LINE}" \
  143. f"import mindspore{NEW_LINE}" \
  144. f"import mindspore.ops as P{NEW_LINE}" \
  145. f"from mindspore import nn{NEW_LINE}" \
  146. f"from mindspore import Tensor, Parameter{NEW_LINE * 3}"