Browse Source

!427 Face recognition physical adversarial attack

Merge pull request !427 from 君君臣臣君/master
pull/428/head
i-robot Gitee 2 years ago
parent
commit
db6b315af2
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 698 additions and 0 deletions
  1. +119
    -0
      examples/community/face_adversarial_attack/README.md
  2. +275
    -0
      examples/community/face_adversarial_attack/adversarial_attack.py
  3. +45
    -0
      examples/community/face_adversarial_attack/example_non_target_attack.py
  4. +46
    -0
      examples/community/face_adversarial_attack/example_target_attack.py
  5. +154
    -0
      examples/community/face_adversarial_attack/loss_design.py
  6. BIN
      examples/community/face_adversarial_attack/photos/adv_input/adv.png
  7. BIN
      examples/community/face_adversarial_attack/photos/input/input1.jpg
  8. BIN
      examples/community/face_adversarial_attack/photos/target/target1.jpg
  9. +59
    -0
      examples/community/face_adversarial_attack/test.py

+ 119
- 0
examples/community/face_adversarial_attack/README.md View File

@@ -0,0 +1,119 @@

# 人脸识别物理对抗攻击

## 描述

本项目是基于MindSpore框架对人脸识别模型的物理对抗攻击,通过生成对抗口罩,使人脸佩戴后实现有目标攻击和非目标攻击。

## 模型结构

采用华为MindSpore官方训练的FaceRecognition模型
https://www.mindspore.cn/resources/hub/details?MindSpore/1.7/facerecognition_ms1mv2

## 环境要求

mindspore>=1.7,硬件平台为GPU。

## 脚本说明

```markdown
├── readme.md
├── photos
│ ├── adv_input //对抗图像
│ ├── input //输入图像
│ └── target //目标图像
├── outputs //训练后的图像
├── adversarial_attack.py //训练脚本
│── example_non_target_attack.py //无目标攻击训练
│── example_target_attack.py //有目标攻击训练
│── loss_design.py //训练优化设置
└── test.py //评估攻击效果
```

## 模型调用

方法一:

```python
#基于mindspore_hub库调用FaceRecognition模型
import mindspore_hub as mshub
from mindspore import context
def get_model():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=0)
model = "mindspore/1.7/facerecognition_ms1mv2"
network = mshub.load(model)
network.set_train(False)
return network
```

方法二:

```text
利用MindSpore代码仓中的https://gitee.com/mindspore/models/blob/master/research/cv/FaceRecognition/eval.py的get_model函数加载模型
```

## 训练过程

有目标攻击:

```shell
cd face_adversarial_attack/
python example_target_attack.py
```

非目标攻击:

```shell
cd face_adversarial_attack/
python example_non_target_attack.py
```

## 默认训练参数

optimizer=adam, learning rate=0.01, weight_decay=0.0001, epoch=2000

## 评估过程

评估方法一:

```shell
adversarial_attack.FaceAdversarialAttack.test_non_target_attack()
adversarial_attack.FaceAdversarialAttack.test_target_attack()
```

评估方法二:

```shell
cd face_adversarial_attack/
python test.py
```

## 实验结果

有目标攻击:

```text
input_label: 60
target_label: 345
The confidence of the input image on the input label: 26.67
The confidence of the input image on the target label: 0.95
================================
adversarial_label: 345
The confidence of the adversarial sample on the correct label: 1.82
The confidence of the adversarial sample on the target label: 10.96
input_label: 60, target_label: 345, adversarial_label: 345

photos中是有目标攻击的实验结果
```

非目标攻击:

```text
input_label: 60
The confidence of the input image on the input label: 25.16
================================
adversarial_label: 251
The confidence of the adversarial sample on the correct label: 9.52
The confidence of the adversarial sample on the adversarial label: 60.96
input_label: 60, adversarial_label: 251
```

+ 275
- 0
examples/community/face_adversarial_attack/adversarial_attack.py View File

@@ -0,0 +1,275 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train set"""
import os
import re
import numpy as np
import face_recognition as fr
import face_recognition_models as frm
import dlib
from PIL import Image, ImageDraw
import mindspore
import mindspore.dataset.vision.py_transforms as P
from mindspore.dataset.vision.py_transforms import ToPIL as ToPILImage
from mindspore.dataset.vision.py_transforms import ToTensor
from mindspore import Parameter, ops, nn, Tensor
from loss_design import MyTrainOneStepCell, MyWithLossCellTargetAttack, \
MyWithLossCellNonTargetAttack, FaceLossTargetAttack, FaceLossNoTargetAttack


class FaceAdversarialAttack():
"""
Class used to create adversarial facial recognition attacks.

Args:
input_img (numpy.ndarray): The input image.
target_img (numpy.ndarray): The target image.
seed (int): optional Sets custom seed for reproducibility. Default is generated randomly.
net (mindspore.Model): face recognition model.
"""
def __init__(self, input_img, target_img, net, seed=None):
if seed is not None:
np.random.seed(seed)
self.mean = Tensor([0.485, 0.456, 0.406])
self.std = Tensor([0.229, 0.224, 0.225])
self.expand_dims = mindspore.ops.ExpandDims()
self.imageize = ToPILImage()
self.tensorize = ToTensor()
self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.resnet = net
self.input_tensor = Tensor(self.normalize(self.tensorize(input_img)))
self.target_tensor = Tensor(self.normalize(self.tensorize(target_img)))
self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0))
self.target_emb = self.resnet(self.expand_dims(self.target_tensor, 0))
self.adversarial_emb = None
self.mask_tensor = create_mask(input_img)
self.ref = self.mask_tensor
self.pm = Parameter(self.mask_tensor)
self.opt = nn.Adam([self.pm], learning_rate=0.01, weight_decay=0.0001)

def train(self, attack_method):
"""
Optimized adversarial image.

Args:
attack_method (String) : Including target attack and non_target attack.

Returns:
Tensor, adversarial image.
Tensor, mask image.
"""

if attack_method == "non_target_attack":
loss = FaceLossNoTargetAttack()
net_with_criterion = MyWithLossCellNonTargetAttack(self.resnet, loss, self.input_tensor)
if attack_method == "target_attack":
loss = FaceLossTargetAttack(self.target_emb)
net_with_criterion = MyWithLossCellTargetAttack(self.resnet, loss, self.input_tensor)

train_net = MyTrainOneStepCell(net_with_criterion, self.opt)

for i in range(2000):

self.mask_tensor = Tensor(self.pm)

loss = train_net(self.mask_tensor)

print("epoch %d ,loss: %f \n " % (i, loss.asnumpy().item()))

self.mask_tensor = ops.clip_by_value(
self.mask_tensor, Tensor(0, mindspore.float32), Tensor(1, mindspore.float32))

adversarial_tensor = apply(
self.input_tensor,
(self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None],
self.ref)

adversarial_tensor = self._reverse_norm(adversarial_tensor)
processed_input_tensor = self._reverse_norm(self.input_tensor)
processed_target_tensor = self._reverse_norm(self.target_tensor)

return {
"adversarial_tensor": adversarial_tensor,
"mask_tensor": self.mask_tensor,
"processed_input_tensor": processed_input_tensor,
"processed_target_tensor": processed_target_tensor
}

def test_target_attack(self):
"""
The model is used to test the recognition ability of adversarial images under target attack.
"""

adversarial_tensor = apply(
self.input_tensor,
(self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None],
self.ref)

self.adversarial_emb = self.resnet(self.expand_dims(adversarial_tensor, 0))
self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0))
self.target_emb = self.resnet(self.expand_dims(self.target_tensor, 0))

adversarial_index = np.argmax(self.adversarial_emb.asnumpy())
target_index = np.argmax(self.target_emb.asnumpy())
input_index = np.argmax(self.input_emb.asnumpy())

print("input_label:", input_index)
print("target_label:", target_index)
print("The confidence of the input image on the input label:", self.input_emb.asnumpy()[0][input_index])
print("The confidence of the input image on the target label:", self.input_emb.asnumpy()[0][target_index])
print("================================")
print("adversarial_label:", adversarial_index)
print("The confidence of the adversarial sample on the correct label:",
self.adversarial_emb.asnumpy()[0][input_index])
print("The confidence of the adversarial sample on the target label:",
self.adversarial_emb.asnumpy()[0][target_index])
print("input_label: %d, target_label: %d, adversarial_label: %d"
% (input_index, target_index, adversarial_index))

def test_non_target_attack(self):
"""
The model is used to test the recognition ability of adversarial images under non_target attack.
"""

adversarial_tensor = apply(
self.input_tensor,
(self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None],
self.ref)

self.adversarial_emb = self.resnet(self.expand_dims(adversarial_tensor, 0))
self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0))

adversarial_index = np.argmax(self.adversarial_emb.asnumpy())
input_index = np.argmax(self.input_emb.asnumpy())

print("input_label:", input_index)
print("The confidence of the input image on the input label:", self.input_emb.asnumpy()[0][input_index])
print("================================")
print("adversarial_label:", adversarial_index)
print("The confidence of the adversarial sample on the correct label:",
self.adversarial_emb.asnumpy()[0][input_index])
print("The confidence of the adversarial sample on the adversarial label:",
self.adversarial_emb.asnumpy()[0][adversarial_index])
print(
"input_label: %d, adversarial_label: %d" % (input_index, adversarial_index))

def _reverse_norm(self, image_tensor):
"""
Reverses normalization for a given image_tensor.

Args:
image_tensor (Tensor): Tensor.

Returns:
Tensor, image.
"""
tensor = image_tensor * self.std[:, None, None] + self.mean[:, None, None]
return tensor


def apply(image_tensor, mask_tensor, reference_tensor):
"""
Apply a mask over an image.

Args:
image_tensor (Tensor): Canvas to be used to apply mask on.
mask_tensor (Tensor): Mask to apply over the image.
reference_tensor (Tensor): Used to reference mask boundaries

Returns:
Tensor, image.
"""
tensor = mindspore.numpy.where((reference_tensor == 0), image_tensor, mask_tensor)

return tensor


def create_mask(face_image):
"""
Create mask image.

Args:
face_image (PIL.Image): image of a detected face.

Returns:
mask_tensor : a mask image.
"""

mask = Image.new('RGB', face_image.size, color=(0, 0, 0))
d = ImageDraw.Draw(mask)
landmarks = fr.face_landmarks(np.array(face_image))
area = [landmark
for landmark in landmarks[0]['chin']
if landmark[1] > max(landmarks[0]['nose_tip'])[1]]
area.append(landmarks[0]['nose_bridge'][1])
d.polygon(area, fill=(255, 255, 255))
mask = np.array(mask)
mask = mask.astype(np.float32)

for i in range(mask.shape[0]):
for j in range(mask.shape[1]):
for k in range(mask.shape[2]):
if mask[i][j][k] == 255.:
mask[i][j][k] = 0.5
else:
mask[i][j][k] = 0

mask_tensor = Tensor(mask)
mask_tensor = mask_tensor.swapaxes(0, 2).swapaxes(1, 2)
mask_tensor.requires_grad = True
return mask_tensor


def detect_face(image):
"""
Face detection and alignment process using dlib library.

Args:
image (numpy.ndarray): image file location.

Returns:
face_image : Resized face image.
"""
dlib_detector = dlib.get_frontal_face_detector()
dlib_shape_predictor = dlib.shape_predictor(frm.pose_predictor_model_location())
dlib_image = dlib.load_rgb_image(image)
detections = dlib_detector(dlib_image, 1)

dlib_faces = dlib.full_object_detections()
for det in detections:
dlib_faces.append(dlib_shape_predictor(dlib_image, det))
face_image = Image.fromarray(dlib.get_face_chip(dlib_image, dlib_faces[0], size=112))

return face_image


def load_data(data):
"""
An auxiliary function that loads image data.

Args:
data (String): The path to the given data.

Returns:
list : Resize list of face images.
"""
image_files = [f for f in os.listdir(data) if re.search(r'.*\.(jpe?g|png)', f)]
image_files_locs = [os.path.join(data, f) for f in image_files]

image_list = []
for img in image_files_locs:
image_list.append(detect_face(img))

return image_list

+ 45
- 0
examples/community/face_adversarial_attack/example_non_target_attack.py View File

@@ -0,0 +1,45 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""non target attack"""
import numpy as np
import matplotlib.image as mp
from mindspore import context
import adversarial_attack
from FaceRecognition.eval import get_model

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


if __name__ == '__main__':

inputs = adversarial_attack.load_data('photos/input/')
targets = adversarial_attack.load_data('photos/target/')

net = get_model()
adversarial = adversarial_attack.FaceAdversarialAttack(inputs[0], targets[0], net)
ATTACK_METHOD = "non_target_attack"

tensor_dict = adversarial.train(attack_method=ATTACK_METHOD)

mp.imsave('./outputs/adversarial_example.jpg',
np.transpose(tensor_dict.get("adversarial_tensor").asnumpy(), (1, 2, 0)))
mp.imsave('./outputs/mask.jpg',
np.transpose(tensor_dict.get("mask_tensor").asnumpy(), (1, 2, 0)))
mp.imsave('./outputs/input_image.jpg',
np.transpose(tensor_dict.get("processed_input_tensor").asnumpy(), (1, 2, 0)))
mp.imsave('./outputs/target_image.jpg',
np.transpose(tensor_dict.get("processed_target_tensor").asnumpy(), (1, 2, 0)))

adversarial.test_non_target_attack()

+ 46
- 0
examples/community/face_adversarial_attack/example_target_attack.py View File

@@ -0,0 +1,46 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""target attack"""
import numpy as np
import matplotlib.image as mp
from mindspore import context
import adversarial_attack
from FaceRecognition.eval import get_model

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


if __name__ == '__main__':

inputs = adversarial_attack.load_data('photos/input/')
targets = adversarial_attack.load_data('photos/target/')

net = get_model()
adversarial = adversarial_attack.FaceAdversarialAttack(inputs[0], targets[0], net)

ATTACK_METHOD = "target_attack"

tensor_dict = adversarial.train(attack_method=ATTACK_METHOD)

mp.imsave('./outputs/adversarial_example.jpg',
np.transpose(tensor_dict.get("adversarial_tensor").asnumpy(), (1, 2, 0)))
mp.imsave('./outputs/mask.jpg',
np.transpose(tensor_dict.get("mask_tensor").asnumpy(), (1, 2, 0)))
mp.imsave('./outputs/input_image.jpg',
np.transpose(tensor_dict.get("processed_input_tensor").asnumpy(), (1, 2, 0)))
mp.imsave('./outputs/target_image.jpg',
np.transpose(tensor_dict.get("processed_target_tensor").asnumpy(), (1, 2, 0)))

adversarial.test_target_attack()

+ 154
- 0
examples/community/face_adversarial_attack/loss_design.py View File

@@ -0,0 +1,154 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""optimization Settings"""
import mindspore
from mindspore import ops, nn, Tensor
from mindspore.dataset.vision.py_transforms import ToTensor
import mindspore.dataset.vision.py_transforms as P


class MyTrainOneStepCell(nn.TrainOneStepCell):
"""
Encapsulation class of network training.

Append an optimizer to the training network after that the construct
function can be called to create the backward graph.

Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default: 1.0.
"""

def __init__(self, network, optimizer, sens=1.0):
super(MyTrainOneStepCell, self).__init__(network, optimizer, sens)
self.grad = ops.composite.GradOperation(get_all=True, sens_param=False)

def construct(self, *inputs):
"""Defines the computation performed."""
loss = self.network(*inputs)
grads = self.grad(self.network)(*inputs)
self.optimizer(grads)
return loss


class MyWithLossCellTargetAttack(nn.Cell):
"""The loss function defined by the target attack"""
def __init__(self, net, loss_fn, input_tensor):
super(MyWithLossCellTargetAttack, self).__init__(auto_prefix=False)
self.net = net
self._loss_fn = loss_fn
self.std = Tensor([0.229, 0.224, 0.225])
self.mean = Tensor([0.485, 0.456, 0.406])
self.expand_dims = mindspore.ops.ExpandDims()
self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.tensorize = ToTensor()
self.input_tensor = input_tensor
self.input_emb = self.net(self.expand_dims(self.input_tensor, 0))

@property
def backbone_network(self):
return self.net

def construct(self, mask_tensor):
ref = mask_tensor
adversarial_tensor = mindspore.numpy.where(
(ref == 0),
self.input_tensor,
(mask_tensor - self.mean[:, None, None]) / self.std[:, None, None])
adversarial_emb = self.net(self.expand_dims(adversarial_tensor, 0))
loss = self._loss_fn(adversarial_emb)
return loss


class MyWithLossCellNonTargetAttack(nn.Cell):
"""The loss function defined by the non target attack"""
def __init__(self, net, loss_fn, input_tensor):
super(MyWithLossCellNonTargetAttack, self).__init__(auto_prefix=False)
self.net = net
self._loss_fn = loss_fn
self.std = Tensor([0.229, 0.224, 0.225])
self.mean = Tensor([0.485, 0.456, 0.406])
self.expand_dims = mindspore.ops.ExpandDims()
self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.tensorize = ToTensor()
self.input_tensor = input_tensor
self.input_emb = self.net(self.expand_dims(self.input_tensor, 0))

@property
def backbone_network(self):
return self.net

def construct(self, mask_tensor):
ref = mask_tensor
adversarial_tensor = mindspore.numpy.where(
(ref == 0),
self.input_tensor,
(mask_tensor - self.mean[:, None, None]) / self.std[:, None, None])
adversarial_emb = self.net(self.expand_dims(adversarial_tensor, 0))
loss = self._loss_fn(adversarial_emb, self.input_emb)
return loss


class FaceLossTargetAttack(nn.Cell):
"""The loss function of the target attack"""

def __init__(self, target_emb):
super(FaceLossTargetAttack, self).__init__()
self.uniformreal = ops.UniformReal(seed=2)
self.sum = ops.ReduceSum(keep_dims=False)
self.norm = nn.Norm(keep_dims=True)
self.zeroslike = ops.ZerosLike()
self.concat_op1 = ops.Concat(1)
self.concat_op2 = ops.Concat(2)
self.pow = ops.Pow()
self.reduce_sum = ops.operations.ReduceSum()
self.target_emb = target_emb
self.abs = ops.Abs()
self.reduce_mean = ops.ReduceMean()

def construct(self, adversarial_emb):
prod_sum = self.reduce_sum(adversarial_emb * self.target_emb, (1,))
square1 = self.reduce_sum(ops.functional.square(adversarial_emb), (1,))
square2 = self.reduce_sum(ops.functional.square(self.target_emb), (1,))
denom = ops.functional.sqrt(square1) * ops.functional.sqrt(square2)
loss = -(prod_sum / denom)
return loss


class FaceLossNoTargetAttack(nn.Cell):
"""The loss function of the non-target attack"""

def __init__(self):
"""Initialization"""
super(FaceLossNoTargetAttack, self).__init__()
self.uniformreal = ops.UniformReal(seed=2)
self.sum = ops.ReduceSum(keep_dims=False)
self.norm = nn.Norm(keep_dims=True)
self.zeroslike = ops.ZerosLike()
self.concat_op1 = ops.Concat(1)
self.concat_op2 = ops.Concat(2)
self.pow = ops.Pow()
self.reduce_sum = ops.operations.ReduceSum()
self.abs = ops.Abs()
self.reduce_mean = ops.ReduceMean()

def construct(self, adversarial_emb, input_emb):
prod_sum = self.reduce_sum(adversarial_emb * input_emb, (1,))
square1 = self.reduce_sum(ops.functional.square(adversarial_emb), (1,))
square2 = self.reduce_sum(ops.functional.square(input_emb), (1,))
denom = ops.functional.sqrt(square1) * ops.functional.sqrt(square2)
loss = prod_sum / denom
return loss

BIN
examples/community/face_adversarial_attack/photos/adv_input/adv.png View File

Before After
Width: 318  |  Height: 318  |  Size: 267 kB

BIN
examples/community/face_adversarial_attack/photos/input/input1.jpg View File

Before After
Width: 1024  |  Height: 774  |  Size: 88 kB

BIN
examples/community/face_adversarial_attack/photos/target/target1.jpg View File

Before After
Width: 1523  |  Height: 2048  |  Size: 653 kB

+ 59
- 0
examples/community/face_adversarial_attack/test.py View File

@@ -0,0 +1,59 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test"""
import numpy as np
from mindspore import context, Tensor
import mindspore
from mindspore.dataset.vision.py_transforms import ToTensor
import mindspore.dataset.vision.py_transforms as P
from FaceRecognition.eval import get_model
import adversarial_attack

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

if __name__ == '__main__':

image = adversarial_attack.load_data('photos/adv_input/')
inputs = adversarial_attack.load_data('photos/input/')
targets = adversarial_attack.load_data('photos/target/')

tensorize = ToTensor()
normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
expand_dims = mindspore.ops.ExpandDims()
mean = Tensor([0.485, 0.456, 0.406])
std = Tensor([0.229, 0.224, 0.225])

resnet = get_model()

adv = Tensor(normalize(tensorize(image[0])))
input_tensor = Tensor(normalize(tensorize(inputs[0])))
target_tensor = Tensor(normalize(tensorize(targets[0])))

adversarial_emb = resnet(expand_dims(adv, 0))
input_emb = resnet(expand_dims(input_tensor, 0))
target_emb = resnet(expand_dims(target_tensor, 0))

adversarial_index = np.argmax(adversarial_emb.asnumpy())
target_index = np.argmax(target_emb.asnumpy())
input_index = np.argmax(input_emb.asnumpy())

print("input_label:", input_index)
print("The confidence of the input image on the input label:", input_emb.asnumpy()[0][input_index])
print("================================")
print("adversarial_label:", adversarial_index)
print("The confidence of the adversarial sample on the correct label:", adversarial_emb.asnumpy()[0][input_index])
print("The confidence of the adversarial sample on the adversarial label:",
adversarial_emb.asnumpy()[0][adversarial_index])
print("input_label:%d, adversarial_label:%d" % (input_index, adversarial_index))

Loading…
Cancel
Save