Browse Source

!135 Fix two issues and add unit tests for PSOAttack and GeneticAttack (detection model).

From: @jxlang910
Reviewed-by: @pkuliuliu,@liu_luobin
Signed-off-by: @pkuliuliu
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
025c3d2283
6 changed files with 138 additions and 26 deletions
  1. +2
    -2
      mindarmour/adv_robustness/attacks/attack.py
  2. +19
    -11
      mindarmour/adv_robustness/attacks/black/genetic_attack.py
  3. +18
    -11
      mindarmour/adv_robustness/attacks/black/pso_attack.py
  4. +2
    -0
      mindarmour/utils/_check_param.py
  5. +51
    -2
      tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py
  6. +46
    -0
      tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py

+ 2
- 2
mindarmour/adv_robustness/attacks/attack.py View File

@@ -182,11 +182,11 @@ class Attack:
best_position = check_numpy_param('best_position', best_position) best_position = check_numpy_param('best_position', best_position)
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position)
x_shape = best_position.shape x_shape = best_position.shape
reduction_iters = 10000 # recover 0.01% each step
reduction_iters = 1000 # recover 0.1% each step
_, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) _, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
for _ in range(reduction_iters): for _ in range(reduction_iters):
diff = x_ori - best_position diff = x_ori - best_position
res = 0.5*diff*(np.random.random(x_shape) < 0.0001)
res = 0.5*diff*(np.random.random(x_shape) < 0.001)
best_position += res best_position += res
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) _, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
q_times += 1 q_times += 1


+ 19
- 11
mindarmour/adv_robustness/attacks/black/genetic_attack.py View File

@@ -46,16 +46,16 @@ class GeneticAttack(Attack):
targeted (bool): If True, turns on the targeted attack. If False, targeted (bool): If True, turns on the targeted attack. If False,
turns on untargeted attack. It should be noted that only untargeted attack turns on untargeted attack. It should be noted that only untargeted attack
is supproted for model_type='detection', Default: False. is supproted for model_type='detection', Default: False.
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for
model_type='detection'. Default: 0.3.
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks,
specifically for model_type='detection'. Default: 0.3.
pop_size (int): The number of particles, which should be greater than pop_size (int): The number of particles, which should be greater than
zero. Default: 6. zero. Default: 6.
mutation_rate (float): The probability of mutations. Default: 0.005.
per_bounds (float): Maximum L_inf distance.
mutation_rate (Union[int, float]): The probability of mutations. Default: 0.005.
per_bounds (Union[int, float]): Maximum L_inf distance.
max_steps (int): The maximum round of iteration for each adversarial max_steps (int): The maximum round of iteration for each adversarial
example. Default: 1000. example. Default: 1000.
step_size (float): Attack step size. Default: 0.2.
temp (float): Sampling temperature for selection. Default: 0.3.
step_size (Union[int, float]): Attack step size. Default: 0.2.
temp (Union[int, float]): Sampling temperature for selection. Default: 0.3.
The greater the temp, the greater the differences between individuals' The greater the temp, the greater the differences between individuals'
selecting probabilities. selecting probabilities.
bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form bounds (Union[tuple, list, None]): Upper and lower bounds of data. In form
@@ -65,7 +65,7 @@ class GeneticAttack(Attack):
Default: False. Default: False.
sparse (bool): If True, input labels are sparse-encoded. If False, sparse (bool): If True, input labels are sparse-encoded. If False,
input labels are one-hot-encoded. Default: True. input labels are one-hot-encoded. Default: True.
c (float): Weight of perturbation loss. Default: 0.1.
c (Union[int, float]): Weight of perturbation loss. Default: 0.1.


Examples: Examples:
>>> attack = GeneticAttack(model) >>> attack = GeneticAttack(model)
@@ -76,6 +76,10 @@ class GeneticAttack(Attack):
super(GeneticAttack, self).__init__() super(GeneticAttack, self).__init__()
self._model = check_model('model', model, BlackModel) self._model = check_model('model', model, BlackModel)
self._model_type = check_param_type('model_type', model_type, str) self._model_type = check_param_type('model_type', model_type, str)
if self._model_type not in ('classification', 'detection'):
msg = "Only 'classification' or 'detection' is supported now, but got {}.".format(self._model_type)
LOGGER.error(TAG, msg)
raise ValueError(msg)
self._targeted = check_param_type('targeted', targeted, bool) self._targeted = check_param_type('targeted', targeted, bool)
self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio) self._reserve_ratio = check_value_non_negative('reserve_ratio', reserve_ratio)
if self._reserve_ratio > 1: if self._reserve_ratio > 1:
@@ -153,10 +157,14 @@ class GeneticAttack(Attack):
if self._model_type == 'classification': if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs, inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels) 'labels', labels)
if not self._sparse:
if labels.ndim != 2:
raise ValueError('labels must be 2 dims, '
'but got {} dims.'.format(labels.ndim))
if self._sparse:
label_squ = np.squeeze(labels)
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]:
msg = "The parameter 'sparse' of GeneticAttack is True, but the input labels is not sparse style " \
"and got its shape as {}.".format(labels.shape)
LOGGER.error(TAG, msg)
raise ValueError
else:
labels = np.argmax(labels, axis=1) labels = np.argmax(labels, axis=1)
images = inputs images = inputs
elif self._model_type == 'detection': elif self._model_type == 'detection':


+ 18
- 11
mindarmour/adv_robustness/attacks/black/pso_attack.py View File

@@ -41,17 +41,17 @@ class PSOAttack(Attack):


Args: Args:
model (BlackModel): Target model. model (BlackModel): Target model.
step_size (float): Attack step size. Default: 0.5.
per_bounds (float): Relative variation range of perturbations. Default: 0.6.
c1 (float): Weight coefficient. Default: 2.
c2 (float): Weight coefficient. Default: 2.
c (float): Weight of perturbation loss. Default: 2.
step_size (Union[int, float]): Attack step size. Default: 0.5.
per_bounds (Union[int, float]): Relative variation range of perturbations. Default: 0.6.
c1 (Union[int, float]): Weight coefficient. Default: 2.
c2 (Union[int, float]): Weight coefficient. Default: 2.
c (Union[int, float]): Weight of perturbation loss. Default: 2.
pop_size (int): The number of particles, which should be greater pop_size (int): The number of particles, which should be greater
than zero. Default: 6. than zero. Default: 6.
t_max (int): The maximum round of iteration for each adversarial example, t_max (int): The maximum round of iteration for each adversarial example,
which should be greater than zero. Default: 1000. which should be greater than zero. Default: 1000.
pm (float): The probability of mutations. Default: 0.5.
bounds (tuple): Upper and lower bounds of data. In form of (clip_min,
pm (Union[int, float]): The probability of mutations. Default: 0.5.
bounds (Union[list, tuple, None]): Upper and lower bounds of data. In form of (clip_min,
clip_max). Default: None. clip_max). Default: None.
targeted (bool): If True, turns on the targeted attack. If False, targeted (bool): If True, turns on the targeted attack. If False,
turns on untargeted attack. It should be noted that only untargeted attack turns on untargeted attack. It should be noted that only untargeted attack
@@ -60,8 +60,8 @@ class PSOAttack(Attack):
input labels are one-hot-encoded. Default: True. input labels are one-hot-encoded. Default: True.
model_type (str): The type of targeted model. 'classification' and 'detection' are supported now. model_type (str): The type of targeted model. 'classification' and 'detection' are supported now.
default: 'classification'. default: 'classification'.
reserve_ratio (float): The percentage of objects that can be detected after attacks, specifically for
model_type='detection'. Default: 0.3.
reserve_ratio (Union[int, float]): The percentage of objects that can be detected after attacks,
specifically for model_type='detection'. Default: 0.3.


Examples: Examples:
>>> attack = PSOAttack(model) >>> attack = PSOAttack(model)
@@ -161,7 +161,7 @@ class PSOAttack(Attack):
pixel_deep = self._bounds[1] - self._bounds[0] pixel_deep = self._bounds[1] - self._bounds[0]
cur_pop = check_numpy_param('cur_pop', cur_pop) cur_pop = check_numpy_param('cur_pop', cur_pop)
perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep perturb_noise = (np.random.random(cur_pop.shape) - 0.5)*pixel_deep
mutated_pop = perturb_noise + cur_pop
mutated_pop = perturb_noise*(np.random.random(cur_pop.shape) < self._pm) + cur_pop
if self._model_type == 'classification': if self._model_type == 'classification':
mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop), mutated_pop = np.clip(np.clip(mutated_pop, cur_pop - self._per_bounds*np.abs(cur_pop),
cur_pop + self._per_bounds*np.abs(cur_pop)), cur_pop + self._per_bounds*np.abs(cur_pop)),
@@ -194,7 +194,14 @@ class PSOAttack(Attack):
if self._model_type == 'classification': if self._model_type == 'classification':
inputs, labels = check_pair_numpy_param('inputs', inputs, inputs, labels = check_pair_numpy_param('inputs', inputs,
'labels', labels) 'labels', labels)
if not self._sparse:
if self._sparse:
label_squ = np.squeeze(labels)
if len(label_squ.shape) >= 2 or label_squ.shape[0] != inputs.shape[0]:
msg = "The parameter 'sparse' of PSOAttack is True, but the input labels is not sparse style and " \
"got its shape as {}.".format(labels.shape)
LOGGER.error(TAG, msg)
raise ValueError
else:
labels = np.argmax(labels, axis=1) labels = np.argmax(labels, axis=1)
images = inputs images = inputs
elif self._model_type == 'detection': elif self._model_type == 'detection':


+ 2
- 0
mindarmour/utils/_check_param.py View File

@@ -302,6 +302,8 @@ def check_detection_inputs(inputs, labels):
raise ValueError(msg) raise ValueError(msg)
else: else:
check_numpy_param('inputs', inputs) check_numpy_param('inputs', inputs)
images = inputs
auxiliary_inputs = ()


check_param_type('labels', labels, tuple) check_param_type('labels', labels, tuple)
if len(labels) != 2: if len(labels) != 2:


+ 51
- 2
tests/ut/python/adv_robustness/attacks/black/test_genetic_attack.py View File

@@ -24,8 +24,6 @@ from mindspore.nn import Cell
from mindarmour import BlackModel from mindarmour import BlackModel
from mindarmour.adv_robustness.attacks import GeneticAttack from mindarmour.adv_robustness.attacks import GeneticAttack


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")



# for user # for user
class ModelToBeAttacked(BlackModel): class ModelToBeAttacked(BlackModel):
@@ -41,6 +39,28 @@ class ModelToBeAttacked(BlackModel):
return result.asnumpy() return result.asnumpy()




class DetectionModel(BlackModel):
"""model to be attack"""

def predict(self, inputs):
"""predict"""
# Adapt to the input shape requirements of the target network if inputs is only one image.
if len(inputs.shape) == 3:
inputs_num = 1
else:
inputs_num = inputs.shape[0]
box_and_confi = []
pred_labels = []
gt_number = np.random.randint(1, 128)

for _ in range(inputs_num):
boxes_i = np.random.random((gt_number, 5))
labels_i = np.random.randint(0, 10, gt_number)
box_and_confi.append(boxes_i)
pred_labels.append(labels_i)
return np.array(box_and_confi), np.array(pred_labels)


class SimpleNet(Cell): class SimpleNet(Cell):
""" """
Construct the network of target model. Construct the network of target model.
@@ -76,6 +96,7 @@ def test_genetic_attack():
""" """
Genetic_Attack test Genetic_Attack test
""" """
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
batch_size = 6 batch_size = 6


net = SimpleNet() net = SimpleNet()
@@ -98,6 +119,7 @@ def test_genetic_attack():
@pytest.mark.env_card @pytest.mark.env_card
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_supplement(): def test_supplement():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
batch_size = 6 batch_size = 6


net = SimpleNet() net = SimpleNet()
@@ -123,6 +145,7 @@ def test_supplement():
@pytest.mark.component_mindarmour @pytest.mark.component_mindarmour
def test_value_error(): def test_value_error():
"""test that exception is raised for invalid labels""" """test that exception is raised for invalid labels"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
batch_size = 6 batch_size = 6


net = SimpleNet() net = SimpleNet()
@@ -140,3 +163,29 @@ def test_value_error():
# raise error # raise error
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert attack.generate(inputs, labels) assert attack.generate(inputs, labels)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_genetic_attack_detection_cpu():
"""
Genetic_Attack test
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
batch_size = 2
inputs = np.random.random((batch_size, 3, 28, 28))
model = DetectionModel()
attack = GeneticAttack(model, model_type='detection', pop_size=6, mutation_rate=0.05,
per_bounds=0.1, step_size=0.25, temp=0.1,
sparse=False, max_steps=50)

# generate adversarial samples
adv_imgs = []
for i in range(batch_size):
img_data = np.expand_dims(inputs[i], axis=0)
pre_gt_boxes, pre_gt_labels = model.predict(inputs)
_, adv_img, _ = attack.generate(img_data, (pre_gt_boxes, pre_gt_labels))
adv_imgs.append(adv_img)
assert np.any(inputs != np.array(adv_imgs))

+ 46
- 0
tests/ut/python/adv_robustness/attacks/black/test_pso_attack.py View File

@@ -43,6 +43,28 @@ class ModelToBeAttacked(BlackModel):
return result.asnumpy() return result.asnumpy()




class DetectionModel(BlackModel):
"""model to be attack"""

def predict(self, inputs):
"""predict"""
# Adapt to the input shape requirements of the target network if inputs is only one image.
if len(inputs.shape) == 3:
inputs_num = 1
else:
inputs_num = inputs.shape[0]
box_and_confi = []
pred_labels = []
gt_number = np.random.randint(1, 128)

for _ in range(inputs_num):
boxes_i = np.random.random((gt_number, 5))
labels_i = np.random.randint(0, 10, gt_number)
box_and_confi.append(boxes_i)
pred_labels.append(labels_i)
return np.array(box_and_confi), np.array(pred_labels)


class SimpleNet(Cell): class SimpleNet(Cell):
""" """
Construct the network of target model. Construct the network of target model.
@@ -167,3 +189,27 @@ def test_pso_attack_cpu():
attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False) attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=False)
_, adv_data, _ = attack.generate(inputs, labels) _, adv_data, _ = attack.generate(inputs, labels)
assert np.any(inputs != adv_data) assert np.any(inputs != adv_data)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_pso_attack_detection_cpu():
"""
PSO_Attack test
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
batch_size = 2
inputs = np.random.random((batch_size, 3, 28, 28))
model = DetectionModel()
attack = PSOAttack(model, t_max=30, pm=0.5, model_type='detection', reserve_ratio=0.5)

# generate adversarial samples
adv_imgs = []
for i in range(batch_size):
img_data = np.expand_dims(inputs[i], axis=0)
pre_gt_boxes, pre_gt_labels = model.predict(inputs)
_, adv_img, _ = attack.generate(img_data, (pre_gt_boxes, pre_gt_labels))
adv_imgs.append(adv_img)
assert np.any(inputs != np.array(adv_imgs))

Loading…
Cancel
Save