From: @jxlang910 Reviewed-by: @liu_luobin,@pkuliuliu Signed-off-by: @pkuliuliutags/v1.1.0
@@ -183,12 +183,12 @@ class Attack: | |||||
x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) | x_ori, best_position = check_equal_shape('x_ori', x_ori, 'best_position', best_position) | ||||
x_shape = best_position.shape | x_shape = best_position.shape | ||||
reduction_iters = 10000 # recover 0.01% each step | reduction_iters = 10000 # recover 0.01% each step | ||||
_, original_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||||
_, original_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||||
for _ in range(reduction_iters): | for _ in range(reduction_iters): | ||||
diff = x_ori - best_position | diff = x_ori - best_position | ||||
res = 0.5*diff*(np.random.random(x_shape) < 0.0001) | res = 0.5*diff*(np.random.random(x_shape) < 0.0001) | ||||
best_position += res | best_position += res | ||||
_, correct_num = self.detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||||
_, correct_num = self._detection_scores((best_position,) + auxiliary_inputs, gt_boxes, gt_labels, model) | |||||
q_times += 1 | q_times += 1 | ||||
if correct_num > original_num: | if correct_num > original_num: | ||||
best_position -= res | best_position -= res | ||||
@@ -313,15 +313,15 @@ def check_detection_inputs(inputs, labels): | |||||
has_labels = False | has_labels = False | ||||
for item in labels: | for item in labels: | ||||
check_numpy_param('item', item) | check_numpy_param('item', item) | ||||
if len(item.shape) == 3 and item.shape[2] == 5: | |||||
if len(item.shape) == 3: | |||||
gt_boxes = item | gt_boxes = item | ||||
has_boxes = True | has_boxes = True | ||||
elif len(item.shape) == 2: | elif len(item.shape) == 2: | ||||
gt_labels = item | gt_labels = item | ||||
has_labels = True | has_labels = True | ||||
if (not has_boxes) or (not has_labels): | if (not has_boxes) or (not has_labels): | ||||
msg = 'The shape of boxes array and ground-truth labels array should be (N, M, 5) and (N, M), respectively. ' \ | |||||
'But got {} and {}.'.format(labels[0].shape, labels[1].shape) | |||||
msg = 'The shape of boxes array should be (N, M, 5) or (N, M, 4), and the shape of ground-truth' \ | |||||
'labels array should be (N, M). But got {} and {}.'.format(labels[0].shape, labels[1].shape) | |||||
LOGGER.error(TAG, msg) | LOGGER.error(TAG, msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
return images, auxiliary_inputs, gt_boxes, gt_labels | return images, auxiliary_inputs, gt_boxes, gt_labels |