Browse Source

solve DI: [MS][MindArmour][Doc]EnsembleAdversarialDefense format is defferent with others

https://gitee.com/mindspore/dashboard/issues?id=I1GSW0
tags/v0.3.0-alpha
ZhidanLiu 5 years ago
parent
commit
8da8f9525c
1 changed files with 34 additions and 2 deletions
  1. +34
    -2
      mindarmour/defenses/adversarial_defense.py

+ 34
- 2
mindarmour/defenses/adversarial_defense.py View File

@@ -23,7 +23,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.nn import WithLossCell, TrainOneStepCell from mindspore.nn import WithLossCell, TrainOneStepCell


from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \
check_param_in_range, check_param_type, check_param_multi_types
check_param_in_range, check_param_type, check_param_multi_types
from mindarmour.defenses.defense import Defense from mindarmour.defenses.defense import Defense




@@ -166,4 +166,36 @@ class AdversarialDefenseWithAttacks(AdversarialDefense):
return loss.asnumpy() return loss.asnumpy()




EnsembleAdversarialDefense = AdversarialDefenseWithAttacks
class EnsembleAdversarialDefense(AdversarialDefenseWithAttacks):
"""
Ensemble adversarial defense.

Args:
network (Cell): A MindSpore network to be defensed.
attacks (list[Attack]): List of attack method.
loss_fn (Functions): Loss function. Default: None.
optimizer (Cell): Optimizer used to train the network. Default: None.
bounds (tuple): Upper and lower bounds of data. In form of (clip_min,
clip_max). Default: (0.0, 1.0).
replace_ratio (float): Ratio of replacing original samples with
adversarial, which must be between 0 and 1. Default: 0.5.

Raises:
ValueError: If replace_ratio is not between 0 and 1.

Examples:
>>> net = Net()
>>> fgsm = FastGradientSignMethod(net)
>>> pgd = ProjectedGradientDescent(net)
>>> ead = EnsembleAdversarialDefense(net, [fgsm, pgd])
>>> ead.defense(inputs, labels)
"""

def __init__(self, network, attacks, loss_fn=None, optimizer=None,
bounds=(0.0, 1.0), replace_ratio=0.5):
super(EnsembleAdversarialDefense, self).__init__(network,
attacks,
loss_fn,
optimizer,
bounds,
replace_ratio)

Loading…
Cancel
Save