You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

constant.py 4.2 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Constant definition."""
  16. from enum import Enum, unique
  17. import numpy as np
  18. SEPARATOR_IN_ONNX_OP = "::"
  19. SEPARATOR_IN_SCOPE = "/"
  20. SEPARATOR_BTW_NAME_AND_ID = "_"
  21. SEPARATOR_TITLE_AND_CONTENT_IN_CONSTRUCT = "="
  22. LINK_IN_SCOPE = "-"
  23. LINK_IN_WEIGHT_NAME = "."
  24. LEFT_BUCKET = "["
  25. RIGHT_BUCKET = "]"
  26. BLANK_SYM = " "
  27. FIRST_LEVEL_INDENT = BLANK_SYM * 4
  28. SECOND_LEVEL_INDENT = BLANK_SYM * 8
  29. NEW_LINE = "\n"
  30. ONNX_TYPE_INT = 2
  31. ONNX_TYPE_INTS = 7
  32. ONNX_TYPE_FLOAT = 1
  33. ONNX_TYPE_FLOATS = 6
  34. ONNX_TYPE_STRING = 3
  35. DYNAMIC_SHAPE = -1
  36. SCALAR_WITHOUT_SHAPE = 0
  37. UNKNOWN_DIM_VAL = "unk__001"
  38. ONNX_MIN_VER = "1.8.0"
  39. TF2ONNX_MIN_VER = "1.7.1"
  40. ONNXRUNTIME_MIN_VER = "1.5.2"
  41. ONNXOPTIMIZER_MIN_VER = "0.1.2"
  42. ONNXOPTIMIZER_MAX_VER = "0.1.2"
  43. CHECKPOINT_SEGMENT_SIZE = 2040109465 # 1.9GB, no more than 2GB
  44. DTYPE_MAP = {
  45. 1: np.float32,
  46. 2: np.uint8,
  47. 3: np.int8,
  48. 4: np.uint16,
  49. 5: np.int16,
  50. 6: np.int32,
  51. 7: np.int64,
  52. 8: str,
  53. 9: bool,
  54. 10: np.float16,
  55. 11: np.double,
  56. 12: np.uint32,
  57. 13: np.uint64,
  58. 14: np.complex64,
  59. 15: np.complex128,
  60. 16: None
  61. }
  62. @unique
  63. class TemplateKeywords(Enum):
  64. """Define keywords in template message."""
  65. INIT = "init"
  66. CONSTRUCT = "construct"
  67. @unique
  68. class ExchangeMessageKeywords(Enum):
  69. """Define keywords in exchange message."""
  70. METADATA = "metadata"
  71. @unique
  72. class MetadataScope(Enum):
  73. """Define metadata scope keywords in exchange message."""
  74. SOURCE = "source"
  75. OPERATION = "operation"
  76. INPUTS = "inputs"
  77. INPUTS_SHAPE = "inputs_shape"
  78. OUTPUTS = "outputs"
  79. OUTPUTS_SHAPE = "outputs_shape"
  80. PRECURSOR = "precursor_nodes"
  81. SUCCESSOR = "successor_nodes"
  82. ATTRS = "attributes"
  83. SCOPE = "scope"
  84. @unique
  85. class VariableScope(Enum):
  86. """Define variable scope keywords in exchange message."""
  87. OPERATION = "operation"
  88. VARIABLE_NAME = "variable_name"
  89. OUTPUT_TYPE = "output_type"
  90. TSR_TYPE = "tensor"
  91. ARR_TYPE = "array"
  92. INPUTS = "inputs"
  93. ARGS = "args"
  94. WEIGHTS = "weights"
  95. TRAINABLE_PARAMS = "trainable_params"
  96. PARAMETERS_DECLARED = "parameters"
  97. GROUP_INPUTS = "group_inputs"
  98. ONNX_MODEL_SUFFIX = "onnx"
  99. TENSORFLOW_MODEL_SUFFIX = "pb"
  100. BINARY_HEADER_PYTORCH_BITS = 32
  101. ARGUMENT_LENGTH_LIMIT = 128
  102. ARGUMENT_NUM_LIMIT = 32
  103. ARGUMENT_LEN_LIMIT = 64
  104. EXPECTED_NUMBER = 1
  105. MIN_SCOPE_LENGTH = 2
  106. ONNX_OPSET_VERSION = 11
  107. NO_CONVERTED_OPERATORS = [
  108. "onnx::Constant",
  109. "Constant"
  110. ]
  111. THIRD_PART_VERSION = {
  112. "onnx": (ONNX_MIN_VER,),
  113. "onnxruntime": (ONNXRUNTIME_MIN_VER,),
  114. "onnxoptimizer": (ONNXOPTIMIZER_MIN_VER,),
  115. "tf2onnx": (TF2ONNX_MIN_VER,)
  116. }
  117. @unique
  118. class NodeType(Enum):
  119. MODULE = "module"
  120. OPERATION = "operation"
  121. CLASS = "class"
  122. FUNC = "func"
  123. INPUTS = "DataInput"
  124. @unique
  125. class InputType(Enum):
  126. TENSOR = "tensor"
  127. LIST = "list"
  128. @unique
  129. class FrameworkType(Enum):
  130. ONNX = 0
  131. TENSORFLOW = 1
  132. UNKNOWN = 2
  133. @unique
  134. class WeightType(Enum):
  135. PARAMETER = 0
  136. COMMON = 1
  137. def get_imported_module():
  138. """
  139. Generate imported module header.
  140. Returns:
  141. str, imported module.
  142. """
  143. return f"import numpy as np{NEW_LINE}" \
  144. f"import mindspore{NEW_LINE}" \
  145. f"import mindspore.ops as P{NEW_LINE}" \
  146. f"from mindspore import nn{NEW_LINE}" \
  147. f"from mindspore import Tensor, Parameter{NEW_LINE * 3}"