| @@ -15,7 +15,7 @@ | |||||
| """grad impl.""" | """grad impl.""" | ||||
| from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ | from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ | ||||
| grad_math_ops, grad_nn_ops, grad_other_ops | |||||
| grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops | |||||
| from .grad_base import get_bprop_fn | from .grad_base import get_bprop_fn | ||||
| __all__ = ['get_bprop_fn'] | __all__ = ['get_bprop_fn'] | ||||
| @@ -223,8 +223,8 @@ class BatchNormFold(PrimitiveWithInfer): | |||||
| Args: | Args: | ||||
| momentum (float): Momentum value should be [0, 1]. Default: 0.1. | momentum (float): Momentum value should be [0, 1]. Default: 0.1. | ||||
| epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in | |||||
| float32 else 1e-3. Default: 1e-12. | |||||
| epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in | |||||
| float32 else 1e-3. Default: 1e-5. | |||||
| is_training (bool): In training mode set True, else set False. Default: True. | is_training (bool): In training mode set True, else set False. Default: True. | ||||
| freeze_bn (int): Delay in steps at which computation switches from regular batch | freeze_bn (int): Delay in steps at which computation switches from regular batch | ||||
| norm to frozen mean and std. Default: 0. | norm to frozen mean and std. Default: 0. | ||||
| @@ -247,7 +247,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||||
| channel = 1 | channel = 1 | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0): | |||||
| def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): | |||||
| """init batch norm fold layer""" | """init batch norm fold layer""" | ||||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | ||||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | ||||
| @@ -277,7 +277,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): | |||||
| channel = 1 | channel = 1 | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0): | |||||
| def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): | |||||
| """init BatchNormGrad layer""" | """init BatchNormGrad layer""" | ||||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | ||||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | ||||
| @@ -32,6 +32,7 @@ __all__ = ["build_train_network"] | |||||
| class OutputTo16(nn.Cell): | class OutputTo16(nn.Cell): | ||||
| "Wrap cell for amp. Cast network output back to float16" | "Wrap cell for amp. Cast network output back to float16" | ||||
| def __init__(self, op): | def __init__(self, op): | ||||
| super(OutputTo16, self).__init__(auto_prefix=False) | super(OutputTo16, self).__init__(auto_prefix=False) | ||||
| self._op = op | self._op = op | ||||
| @@ -53,7 +54,7 @@ def _do_keep_batchnorm_fp32(network): | |||||
| change = True | change = True | ||||
| else: | else: | ||||
| _do_keep_batchnorm_fp32(subcell) | _do_keep_batchnorm_fp32(subcell) | ||||
| if isinstance(network, nn.SequentialCell) and change: | |||||
| if isinstance(network, nn.SequentialCell) and change: | |||||
| network.cell_list = list(network.cells()) | network.cell_list = list(network.cells()) | ||||
| @@ -72,7 +73,7 @@ def _check_kwargs(key_words): | |||||
| """Check kwargs.""" | """Check kwargs.""" | ||||
| for arg in key_words: | for arg in key_words: | ||||
| if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']: | if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']: | ||||
| raise ValueError(f"Unsupported arg '{arg}'") | |||||
| raise ValueError(f"Unsupported arg '{arg}'") | |||||
| if 'cast_model_type' in key_words: | if 'cast_model_type' in key_words: | ||||
| validator.check_type_name('cast_model_type', key_words['cast_model_type'], | validator.check_type_name('cast_model_type', key_words['cast_model_type'], | ||||
| @@ -18,4 +18,16 @@ set -e | |||||
| # Usage : get_shape_from_ir.sh ir_file | # Usage : get_shape_from_ir.sh ir_file | ||||
| cat "$1" | perl -p -e 's/\n/NEWLINE/' | sed 's/NEWLINE :/:/g' | sed 's/Tensor NEWLINEshape//g' | perl -p -e 's/NEWLINE/\n/g' | perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' | perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' | perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' | perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' | tr -d '()' | awk '/subgraph/{p=1;next}{if(p){print}}'| awk '/return/{p=1;next}{if(!p){print}}' | sed '/^$/d' | awk -F'\t' '{print $1"\t"$2"\t"$4"\t"$3}' | |||||
| cat "$1" | perl -p -e 's/\n/NEWLINE/' \ | |||||
| | sed 's/NEWLINE :/:/g' \ | |||||
| | sed 's/Tensor NEWLINEshape//g' \ | |||||
| | perl -p -e 's/NEWLINE/\n/g' \ | |||||
| | perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' \ | |||||
| | perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' \ | |||||
| | perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' \ | |||||
| | perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' \ | |||||
| | tr -d '()' \ | |||||
| | awk '/subgraph/{p=1;next}{if(p){print}}'\ | |||||
| | awk '/return/{p=1;next}{if(!p){print}}' \ | |||||
| | sed '/^$/d' \ | |||||
| | awk -F'\t' '{print $1"\t"$2"\t"$4}' | |||||