Browse Source

fix parameter check

tags/v0.5.0-beta
ZhidanLiu 5 years ago
parent
commit
ec9d46a7e1
4 changed files with 11 additions and 8 deletions
  1. +4
    -1
      mindarmour/diff_privacy/mechanisms/mechanisms.py
  2. +3
    -3
      mindarmour/utils/_check_param.py
  3. +1
    -1
      tests/ut/python/attacks/test_iterative_gradient_method.py
  4. +3
    -3
      tests/ut/python/detectors/test_region_based_detector.py

+ 4
- 1
mindarmour/diff_privacy/mechanisms/mechanisms.py View File

@@ -159,7 +159,10 @@ class AdaGaussianRandom(Mechanisms):
alpha = check_param_type('alpha', alpha, float) alpha = check_param_type('alpha', alpha, float)
self._alpha = Tensor(np.array(alpha, np.float32)) self._alpha = Tensor(np.array(alpha, np.float32))


self._decay_policy = check_param_type('decay_policy', decay_policy, str)
if decay_policy not in ['Time', 'Step']:
raise NameError("The decay_policy must be in ['Time', 'Step'], but "
"get {}".format(decay_policy))
self._decay_policy = decay_policy
self._mean = 0.0 self._mean = 0.0
self._sub = P.Sub() self._sub = P.Sub()
self._mul = P.Mul() self._mul = P.Mul()


+ 3
- 3
mindarmour/utils/_check_param.py View File

@@ -43,7 +43,7 @@ def check_param_type(arg_name, arg_value, valid_type):
valid_type, valid_type,
type(arg_value).__name__) type(arg_value).__name__)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)


return arg_value return arg_value


@@ -54,7 +54,7 @@ def check_param_multi_types(arg_name, arg_value, valid_types):
msg = 'type of {} must be in {}, but got {}' \ msg = 'type of {} must be in {}, but got {}' \
.format(arg_name, valid_types, type(arg_value).__name__) .format(arg_name, valid_types, type(arg_value).__name__)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)


return arg_value return arg_value


@@ -157,7 +157,7 @@ def check_numpy_param(arg_name, arg_value):
msg = 'type of {} must be in (list, tuple, numpy.ndarray)'.format( msg = 'type of {} must be in (list, tuple, numpy.ndarray)'.format(
arg_name) arg_name)
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)
return arg_value return arg_value






+ 1
- 1
tests/ut/python/attacks/test_iterative_gradient_method.py View File

@@ -167,7 +167,7 @@ def test_momentum_diverse_input_iterative_method():
@pytest.mark.env_card @pytest.mark.env_card
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_error(): def test_error():
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# check_param_multi_types # check_param_multi_types
assert IterativeGradientMethod(Net(), bounds=None) assert IterativeGradientMethod(Net(), bounds=None)
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0)) attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0))


+ 3
- 3
tests/ut/python/detectors/test_region_based_detector.py View File

@@ -100,16 +100,16 @@ def test_value_error():
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert RegionBasedDetector(model, search_step=0) assert RegionBasedDetector(model, search_step=0)


with pytest.raises(ValueError):
with pytest.raises(TypeError):
assert RegionBasedDetector(model, sparse='False') assert RegionBasedDetector(model, sparse='False')


detector = RegionBasedDetector(model) detector = RegionBasedDetector(model)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# radius must not empty # radius must not empty
assert detector.detect(adv) assert detector.detect(adv)


radius = detector.fit(ori, labels) radius = detector.fit(ori, labels)
detector.set_radius(radius) detector.set_radius(radius)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# adv type should be in (list, tuple, numpy.ndarray) # adv type should be in (list, tuple, numpy.ndarray)
assert detector.detect(adv.tostring()) assert detector.detect(adv.tostring())

Loading…
Cancel
Save