Browse Source

!62 fix issue check params type.

Merge pull request !62 from zheng-huanhuan/master
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
cb2825e36b
2 changed files with 8 additions and 4 deletions
  1. +7
    -2
      mindarmour/diff_privacy/mechanisms/mechanisms.py
  2. +1
    -2
      mindarmour/diff_privacy/train/model.py

+ 7
- 2
mindarmour/diff_privacy/mechanisms/mechanisms.py View File

@@ -26,6 +26,7 @@ from mindspore.common import dtype as mstype
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_param_in_range
from mindarmour.utils._check_param import check_value_non_negative
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
@@ -204,8 +205,10 @@ class NoiseGaussianRandom(_Mechanisms):

def __init__(self, norm_bound=1.0, initial_noise_multiplier=1.0, seed=0, decay_policy=None):
super(NoiseGaussianRandom, self).__init__()
norm_bound = check_param_type('norm_bound', norm_bound, float)
self._norm_bound = check_value_positive('norm_bound', norm_bound)
self._norm_bound = Tensor(norm_bound, mstype.float32)
initial_noise_multiplier = check_param_type('initial_noise_multiplier', initial_noise_multiplier, float)
self._initial_noise_multiplier = check_value_positive('initial_noise_multiplier',
initial_noise_multiplier)
self._initial_noise_multiplier = Tensor(initial_noise_multiplier, mstype.float32)
@@ -213,7 +216,8 @@ class NoiseGaussianRandom(_Mechanisms):
if decay_policy is not None:
raise ValueError('decay_policy must be None in GaussianRandom class, but got {}.'.format(decay_policy))
self._decay_policy = decay_policy
self._seed = seed
seed = check_param_type('seed', seed, int)
self._seed = check_value_non_negative('seed', seed)

def construct(self, gradients):
"""
@@ -400,7 +404,8 @@ class AdaClippingWithGaussianRandom(Cell):
self._sub = P.Sub()
self._mul = P.Mul()
self._exp = P.Exp()
self._seed = seed
seed = check_param_type('seed', seed, int)
self._seed = check_value_non_negative('seed', seed)

def construct(self, empirical_fraction, norm_bound):
"""


+ 1
- 2
mindarmour/diff_privacy/train/model.py View File

@@ -50,8 +50,7 @@ from mindspore import ParameterTuple
from mindarmour.utils.logger import LogUtil
from mindarmour.diff_privacy.mechanisms.mechanisms import \
_MechanismsParamsUpdater
from mindarmour.utils._check_param import check_param_type
from mindarmour.utils._check_param import check_value_positive
from mindarmour.utils._check_param import check_value_positive, check_param_type
from mindarmour.utils._check_param import check_int_positive

LOGGER = LogUtil.get_instance()


Loading…
Cancel
Save