# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ Management of all conditions. This module is used to register all conditions, as well as their parameters. This module also provide the available conditions to condition_collections api. """ from enum import Enum from mindinsight.debugger.conditionmgr.log import logger class ConditionIdEnum(Enum): """Condition ids.""" WEIGHT_INITIALIZATION = "weight_initialization" WEIGHT_OVERFLOW = "weight_overflow" WEIGHT_TOO_LARGE = "weight_too_large" WEIGHT_TOO_SMALL = "weight_too_small" GRADIENT_VANISHING = "gradient_vanishing" GRADIENT_TOO_LARGE = "gradient_too_large" GRADIENT_EXPLODING = "gradient_exploding" TENSOR_OVERFLOW = "tensor_overflow" OPERATOR_OVERFLOW = "operator_overflow" TENSOR_INITIALIZATION = "tensor_initialization" TENSOR_TOO_LARGE = "tensor_too_large" TENSOR_TOO_SMALL = "tensor_too_small" TENSOR_ALL_ZERO = "tensor_all_zero" WEIGHT_NOT_CHANGED = "weight_not_changed" WEIGHT_CHANGE_TOO_LARGE = "weight_change_too_large" WEIGHT_CHANGE_TOO_SMALL = "weight_change_too_small" ACTIVATION_RANGE = "activation_range" TENSOR_RANGE = "tensor_range" class OptimizePhaseEnum(Enum): """Optimize phases.""" TENSOR_CHECK = 400 OPERATOR_CHECK = 100 LOSS_CHECK = 300 INPUT_DATA_CHECK = 200 class ValueTypeEnum(Enum): """Value types.""" FLOAT64 = 1 INT64 = 2 BOOL = 3 class PlatformEnum(Enum): """Platform types.""" GPU = "GPU" ASCEND = "Ascend" class TargetTypeEnum(Enum): """Target types.""" TENSOR = 'tensor' WEIGHT = 'weight' ACTIVATION = 'activation' GRADIENT = 'gradient' class ParamTypeEnum(Enum): """Param types.""" CHECK_PARAM = "CHECK_PARAM" SUPPORT_PARAM = "SUPPORT_PARAM" class ActivationFuncEnum(Enum): """Activation functions.""" TANH = 'Tanh' SIGMOID = 'Sigmoid' RELU = 'ReLU' class ConditionContext: """ The class for condition context. Args: backend (str): parameter name. step (int): the type of value. debugger_capability (tuple): whether the param support no assignment. """ def __init__(self, backend, step=0, debugger_capability=(1, 0)): self._backend = backend self._step = step self._debugger_capability = debugger_capability @property def backend(self): """Get backend.""" return self._backend @property def step(self): """Get _step.""" return self._step @property def debugger_capability(self): """Get debugger_capability.""" return self._debugger_capability class ConditionParameter: """ The class for parameters of conditions. Args: name (str): parameter name. value_type (ValueTypeEnum): the type of value. valid_test_func (func): the function used to test whether the param is valid. support_disable (bool): whether the param support no assignment. default_value (float): default value. visible_on_ui (bool): whether the param visible on ui. param_type (ParamTypeEnum): parameters type. required_params (list): the list of required parameters. """ def __init__(self, name, value_type: ValueTypeEnum, valid_test_func=None, support_disable=True, default_value=None, visible_on_ui=True, param_type=ParamTypeEnum.CHECK_PARAM, required_params=None): self._name = name self._type = value_type self._valid_test_func = valid_test_func self._support_disable = support_disable self._default_value = default_value self._visible_on_ui = visible_on_ui self._param_type = param_type.value self._required_params = required_params @property def name(self): """Get name of parameter.""" return self._name @property def type(self): """Get type of parameter.""" return self._type @property def support_disable(self): """Get support_disable of parameter.""" return self._support_disable @property def default_value(self): """Get default_value of parameter.""" return self._default_value @property def visible_on_ui(self): """Get visible_on_ui of parameter.""" return self._visible_on_ui @property def param_type(self): """Get param_type of parameter.""" return self._param_type @property def required_params(self): """Get required_param of parameter.""" return self._required_params def is_valid(self, value): """Check is the parameter valid.""" if self._valid_test_func is None: return True return self._valid_test_func(value) class Condition: """ The class for parameters of conditions. Args: condition_id (ConditionIdEnum): condition id. abbr (str): the abbreviation of condition id. optimize_phase (OptimizePhaseEnum): optimize phase. parameters (List[ConditionParameter]): parameters. supported_target_type (TargetTypeEnum): the supported target type. supported_platforms (tuple[PlatformEnum, PlatformEnum]): the supported platforms. minimum_debugger_capability (tuple): the minimum debugger capability required. availability_test_func (func): the function used to test whether the condition is available. """ def __init__(self, condition_id, abbr, optimize_phase, parameters, supported_target_type, supported_platforms, minimum_debugger_capability, availability_test_func=None): self.id = condition_id.value self._abbr = abbr self.optimize_phase = optimize_phase self._parameters = { parameter.name: parameter for parameter in parameters } self.ordered_parameter_names = [parameter.name for parameter in parameters] self._supported_target_type = supported_target_type self.supported_platforms = supported_platforms self.minimum_debugger_capability = minimum_debugger_capability self.availability_test_func = availability_test_func def get_parameter_definition(self, name): """Return parameter definition by the name""" return self._parameters[name] def is_available(self, condition_context): """Check is the condition available.""" backend = condition_context.backend debugger_capability = condition_context.debugger_capability if debugger_capability < self.minimum_debugger_capability: logger.debug("The debugger capability is lower than the minimum debugger capability.") return False if backend not in [platform.value for platform in self.supported_platforms]: logger.debug("The condition %s is not supported on the platform.", self.id) return False if self.availability_test_func is None: return True return self.availability_test_func(condition_context) @property def abbr(self): """The abbreviation of condition""" return self._abbr @property def names(self): """The name of condition""" return self._parameters.keys() @property def parameters(self): """The parameters of condition""" return self._parameters.values() @property def supported_target_type(self): """The supported target type of condition""" return self._supported_target_type def check_initialization_available(condition_context): """Check if initialization is available at this step""" if condition_context.step == 0: return True return False def check_percentage_param_range(value): if 0 <= value <= 100: return True return False def check_normal_param_range(value): if float("-inf") < value < float("inf"): return True return False def check_abs_param_range(value): if 0 <= value < float("inf"): return True return False