Browse Source

fix some bug in quant debug

tags/v0.3.0-alpha
chenzomi 5 years ago
parent
commit
2a695cfe24
4 changed files with 21 additions and 8 deletions
  1. +1
    -1
      mindspore/ops/_grad/__init__.py
  2. +4
    -4
      mindspore/ops/operations/_quant_ops.py
  3. +3
    -2
      mindspore/train/amp.py
  4. +13
    -1
      scripts/get_shape_from_ir.sh

+ 1
- 1
mindspore/ops/_grad/__init__.py View File

@@ -15,7 +15,7 @@


"""grad impl.""" """grad impl."""
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
grad_math_ops, grad_nn_ops, grad_other_ops
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops
from .grad_base import get_bprop_fn from .grad_base import get_bprop_fn


__all__ = ['get_bprop_fn'] __all__ = ['get_bprop_fn']

+ 4
- 4
mindspore/ops/operations/_quant_ops.py View File

@@ -223,8 +223,8 @@ class BatchNormFold(PrimitiveWithInfer):


Args: Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1. momentum (float): Momentum value should be [0, 1]. Default: 0.1.
epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in
float32 else 1e-3. Default: 1e-12.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5.
is_training (bool): In training mode set True, else set False. Default: True. is_training (bool): In training mode set True, else set False. Default: True.
freeze_bn (int): Delay in steps at which computation switches from regular batch freeze_bn (int): Delay in steps at which computation switches from regular batch
norm to frozen mean and std. Default: 0. norm to frozen mean and std. Default: 0.
@@ -247,7 +247,7 @@ class BatchNormFold(PrimitiveWithInfer):
channel = 1 channel = 1


@prim_attr_register @prim_attr_register
def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0):
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init batch norm fold layer""" """init batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
@@ -277,7 +277,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
channel = 1 channel = 1


@prim_attr_register @prim_attr_register
def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0):
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init BatchNormGrad layer""" """init BatchNormGrad layer"""
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)


+ 3
- 2
mindspore/train/amp.py View File

@@ -32,6 +32,7 @@ __all__ = ["build_train_network"]


class OutputTo16(nn.Cell): class OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16" "Wrap cell for amp. Cast network output back to float16"

def __init__(self, op): def __init__(self, op):
super(OutputTo16, self).__init__(auto_prefix=False) super(OutputTo16, self).__init__(auto_prefix=False)
self._op = op self._op = op
@@ -53,7 +54,7 @@ def _do_keep_batchnorm_fp32(network):
change = True change = True
else: else:
_do_keep_batchnorm_fp32(subcell) _do_keep_batchnorm_fp32(subcell)
if isinstance(network, nn.SequentialCell) and change:
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells()) network.cell_list = list(network.cells())




@@ -72,7 +73,7 @@ def _check_kwargs(key_words):
"""Check kwargs.""" """Check kwargs."""
for arg in key_words: for arg in key_words:
if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']: if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
raise ValueError(f"Unsupported arg '{arg}'")
raise ValueError(f"Unsupported arg '{arg}'")


if 'cast_model_type' in key_words: if 'cast_model_type' in key_words:
validator.check_type_name('cast_model_type', key_words['cast_model_type'], validator.check_type_name('cast_model_type', key_words['cast_model_type'],


+ 13
- 1
scripts/get_shape_from_ir.sh View File

@@ -18,4 +18,16 @@ set -e


# Usage : get_shape_from_ir.sh ir_file # Usage : get_shape_from_ir.sh ir_file


cat "$1" | perl -p -e 's/\n/NEWLINE/' | sed 's/NEWLINE :/:/g' | sed 's/Tensor NEWLINEshape//g' | perl -p -e 's/NEWLINE/\n/g' | perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' | perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' | perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' | perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' | tr -d '()' | awk '/subgraph/{p=1;next}{if(p){print}}'| awk '/return/{p=1;next}{if(!p){print}}' | sed '/^$/d' | awk -F'\t' '{print $1"\t"$2"\t"$4"\t"$3}'
cat "$1" | perl -p -e 's/\n/NEWLINE/' \
| sed 's/NEWLINE :/:/g' \
| sed 's/Tensor NEWLINEshape//g' \
| perl -p -e 's/NEWLINE/\n/g' \
| perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' \
| perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' \
| perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' \
| perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' \
| tr -d '()' \
| awk '/subgraph/{p=1;next}{if(p){print}}'\
| awk '/return/{p=1;next}{if(!p){print}}' \
| sed '/^$/d' \
| awk -F'\t' '{print $1"\t"$2"\t"$4}'

Loading…
Cancel
Save