Browse Source

!133 Fix a bug in attack.py

From: @jxlang910
Reviewed-by: @liu_luobin,@pkuliuliu
Signed-off-by: @pkuliuliu
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
1437d3df1c
2 changed files with 5 additions and 5 deletions
  1. +2
    -2
      mindarmour/adv_robustness/attacks/attack.py
  2. +3
    -3
      mindarmour/utils/_check_param.py

+ 2
- 2
mindarmour/adv_robustness/attacks/attack.py View File

@@ -183,12 +183,12 @@ class Attack:
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position)
x_shape = best_position.shape
reduction_iters = 10000 # recover 0.01% each step
_, original_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
_, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
for _ in range(reduction_iters):
diff = x_ori - best_position
res = 0.5*diff*(np.random.random(x_shape) < 0.0001)
best_position += res
_, correct_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model)
q_times += 1
if correct_num > original_num:
best_position -= res


+ 3
- 3
mindarmour/utils/_check_param.py View File

@@ -313,15 +313,15 @@ def check_detection_inputs(inputs, labels):
has_labels = False
for item in labels:
check_numpy_param('item', item)
if len(item.shape) == 3 and item.shape[2] == 5:
if len(item.shape) == 3:
gt_boxes = item
has_boxes = True
elif len(item.shape) == 2:
gt_labels = item
has_labels = True
if (not has_boxes) or (not has_labels):
msg = 'The shape of boxes array and ground-truth labels array should be (N, M, 5) and (N, M), respectively. ' \
'But got {} and {}.'.format(labels[0].shape, labels[1].shape)
msg = 'The shape of boxes array should be (N, M, 5) or (N, M, 4), and the shape of ground-truth' \
'labels array should be (N, M). But got {} and {}.'.format(labels[0].shape, labels[1].shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return images, auxiliary_inputs, gt_boxes, gt_labels

Loading…
Cancel
Save