You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

condition.py 9.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Management of all conditions.
  17. This module is used to register all conditions, as well as their parameters.
  18. This module also provide the available conditions to condition_collections api.
  19. """
  20. from enum import Enum
  21. from mindinsight.debugger.conditionmgr.log import logger
  22. class ParamNameEnum(Enum):
  23. """Param names."""
  24. ABS_MEAN_GT = "abs_mean_gt"
  25. ABS_MEAN_LT = "abs_mean_lt"
  26. ABS_MEAN_UPDATE_RATIO_GT = "abs_mean_update_ratio_gt"
  27. ABS_MEAN_UPDATE_RATIO_LT = "abs_mean_update_ratio_lt"
  28. ATOL = "atol"
  29. EQUAL_NAN = "equal_nan"
  30. EPSILON = "epsilon"
  31. MAX_GT = "max_gt"
  32. MAX_LT = "max_lt"
  33. MIN_GT = "min_gt"
  34. MIN_LT = "min_lt"
  35. MEAN_GT = "mean_gt"
  36. MEAN_LT = "mean_lt"
  37. MAX_MIN_GT = "max_min_gt"
  38. MAX_MIN_LT = "max_min_lt"
  39. PARAM = "param"
  40. RANGE_START_INCLUSIVE = "range_start_inclusive"
  41. RANGE_END_INCLUSIVE = "range_end_inclusive"
  42. RANGE_PERCENTAGE_GT = "range_percentage_gt"
  43. RANGE_PERCENTAGE_LT = "range_percentage_lt"
  44. RTOL = "rtol"
  45. ZERO_PERCENTAGE_GE = "zero_percentage_ge"
  46. class ConditionIdEnum(Enum):
  47. """Condition ids."""
  48. WEIGHT_INITIALIZATION = "weight_initialization"
  49. WEIGHT_OVERFLOW = "weight_overflow"
  50. WEIGHT_TOO_LARGE = "weight_too_large"
  51. WEIGHT_TOO_SMALL = "weight_too_small"
  52. GRADIENT_VANISHING = "gradient_vanishing"
  53. GRADIENT_TOO_LARGE = "gradient_too_large"
  54. GRADIENT_EXPLODING = "gradient_exploding"
  55. TENSOR_OVERFLOW = "tensor_overflow"
  56. OPERATOR_OVERFLOW = "operator_overflow"
  57. TENSOR_TOO_LARGE = "tensor_too_large"
  58. TENSOR_TOO_SMALL = "tensor_too_small"
  59. TENSOR_ALL_ZERO = "tensor_all_zero"
  60. WEIGHT_NOT_CHANGED = "weight_not_changed"
  61. WEIGHT_CHANGE_TOO_LARGE = "weight_change_too_large"
  62. WEIGHT_CHANGE_TOO_SMALL = "weight_change_too_small"
  63. ACTIVATION_RANGE = "activation_range"
  64. TENSOR_RANGE = "tensor_range"
  65. class OptimizePhaseEnum(Enum):
  66. """Optimize phases."""
  67. TENSOR_CHECK = 400
  68. OPERATOR_CHECK = 100
  69. LOSS_CHECK = 300
  70. INPUT_DATA_CHECK = 200
  71. class ValueTypeEnum(Enum):
  72. """Value types."""
  73. FLOAT64 = 1
  74. INT64 = 2
  75. BOOL = 3
  76. class PlatformEnum(Enum):
  77. """Platform types."""
  78. GPU = "GPU"
  79. ASCEND = "Ascend"
  80. class TargetTypeEnum(Enum):
  81. """Target types."""
  82. TENSOR = 'tensor'
  83. ACTIVATION = 'activation'
  84. GRADIENT = 'gradient'
  85. PARAMETER = 'parameter'
  86. WEIGHT = 'weight'
  87. class ParamTypeEnum(Enum):
  88. """Param types."""
  89. CHECK_PARAM = "CHECK_PARAM"
  90. SUPPORT_PARAM = "SUPPORT_PARAM"
  91. class ActivationFuncEnum(Enum):
  92. """Activation functions."""
  93. TANH = 'tanh'
  94. SIGMOID = 'sigmoid'
  95. RELU = 'relu'
  96. RELUV2 = 'reluv2'
  97. class ConditionContext:
  98. """
  99. The class for condition context.
  100. Args:
  101. backend (str): parameter name.
  102. step (int): the type of value.
  103. debugger_capability (tuple): whether the param support no assignment.
  104. """
  105. def __init__(self, backend, step=0, debugger_capability=(1, 1)):
  106. self._backend = backend
  107. self._step = step
  108. self._debugger_capability = debugger_capability
  109. @property
  110. def backend(self):
  111. """Get backend."""
  112. return self._backend
  113. @property
  114. def step(self):
  115. """Get _step."""
  116. return self._step
  117. @property
  118. def debugger_capability(self):
  119. """Get debugger_capability."""
  120. return self._debugger_capability
  121. class ConditionParameter:
  122. """
  123. The class for parameters of conditions.
  124. Args:
  125. name (ParamNameEnum): parameter name.
  126. value_type (ValueTypeEnum): the type of value.
  127. valid_test_func (func): the function used to test whether the param is valid.
  128. support_disable (bool): whether the param support no assignment.
  129. default_value (float): default value.
  130. visible_on_ui (bool): whether the param visible on ui.
  131. param_type (ParamTypeEnum): parameters type.
  132. required_params (list): the list of required parameters.
  133. """
  134. def __init__(self, name, value_type: ValueTypeEnum, valid_test_func=None, support_disable=True, default_value=None,
  135. visible_on_ui=True, param_type=ParamTypeEnum.CHECK_PARAM, required_params=None):
  136. self._name = name.value
  137. self._type = value_type
  138. self._valid_test_func = valid_test_func
  139. self._support_disable = support_disable
  140. self._default_value = default_value
  141. self._visible_on_ui = visible_on_ui
  142. self._param_type = param_type.value
  143. self._required_params = required_params
  144. @property
  145. def name(self):
  146. """Get name of parameter."""
  147. return self._name
  148. @property
  149. def type(self):
  150. """Get type of parameter."""
  151. return self._type
  152. @property
  153. def support_disable(self):
  154. """Get support_disable of parameter."""
  155. return self._support_disable
  156. @property
  157. def default_value(self):
  158. """Get default_value of parameter."""
  159. return self._default_value
  160. @property
  161. def visible_on_ui(self):
  162. """Get visible_on_ui of parameter."""
  163. return self._visible_on_ui
  164. @property
  165. def param_type(self):
  166. """Get param_type of parameter."""
  167. return self._param_type
  168. @property
  169. def required_params(self):
  170. """Get required_param of parameter."""
  171. return self._required_params
  172. def is_valid(self, value):
  173. """Check is the parameter valid."""
  174. if self._valid_test_func is None:
  175. return True
  176. return self._valid_test_func(value)
  177. class Condition:
  178. """
  179. The class for parameters of conditions.
  180. Args:
  181. condition_id (ConditionIdEnum): condition id.
  182. abbr (str): the abbreviation of condition id.
  183. optimize_phase (OptimizePhaseEnum): optimize phase.
  184. parameters (List[ConditionParameter]): parameters.
  185. supported_target_type (TargetTypeEnum): the supported target type.
  186. supported_platforms (tuple[PlatformEnum, PlatformEnum]): the supported platforms.
  187. minimum_debugger_capability (tuple): the minimum debugger capability required.
  188. availability_test_func (func): the function used to test whether the condition is available.
  189. """
  190. def __init__(self, condition_id, abbr, optimize_phase, parameters, supported_target_type, supported_platforms,
  191. minimum_debugger_capability, availability_test_func=None):
  192. self.id = condition_id.value
  193. self._abbr = abbr
  194. self.optimize_phase = optimize_phase
  195. self._parameters = {
  196. parameter.name: parameter for parameter in parameters
  197. }
  198. self.ordered_parameter_names = [parameter.name for parameter in parameters]
  199. self._supported_target_type = supported_target_type
  200. self.supported_platforms = supported_platforms
  201. self.minimum_debugger_capability = minimum_debugger_capability
  202. self.availability_test_func = availability_test_func
  203. def get_parameter_definition(self, name):
  204. """Return parameter definition by the name"""
  205. return self._parameters[name]
  206. def is_available(self, condition_context):
  207. """Check is the condition available."""
  208. backend = condition_context.backend
  209. debugger_capability = condition_context.debugger_capability
  210. if debugger_capability < self.minimum_debugger_capability:
  211. logger.debug("The debugger capability is lower than the minimum debugger capability.")
  212. return False
  213. if backend not in [platform.value for platform in self.supported_platforms]:
  214. logger.debug("The condition %s is not supported on the platform.", self.id)
  215. return False
  216. if self.availability_test_func is None:
  217. return True
  218. return self.availability_test_func(condition_context)
  219. @property
  220. def abbr(self):
  221. """The abbreviation of condition"""
  222. return self._abbr
  223. @property
  224. def names(self):
  225. """The name of condition"""
  226. return self._parameters.keys()
  227. @property
  228. def parameters(self):
  229. """The parameters of condition"""
  230. return self._parameters.values()
  231. @property
  232. def supported_target_type(self):
  233. """The supported target type of condition"""
  234. return self._supported_target_type
  235. def check_initialization_available(condition_context):
  236. """Check if initialization is available at this step"""
  237. if condition_context.step == 0:
  238. return True
  239. return False
  240. def check_percentage_param_range(value):
  241. if 0 <= value <= 100:
  242. return True
  243. return False
  244. def check_normal_param_range(value):
  245. if float("-inf") < value < float("inf"):
  246. return True
  247. return False
  248. def check_abs_param_range(value):
  249. if 0 <= value < float("inf"):
  250. return True
  251. return False
  252. def check_positive_param_range(value):
  253. if 0 < value < float("inf"):
  254. return True
  255. return False