“update” “update” “update” “update” “update” “update” “update” “update” “update” “update” “update” update update update update update update update update updatepull/427/head
@@ -0,0 +1,119 @@ | |||||
# 人脸识别物理对抗攻击 | |||||
## 描述 | |||||
本项目是基于MindSpore框架对人脸识别模型的物理对抗攻击,通过生成对抗口罩,使人脸佩戴后实现有目标攻击和非目标攻击。 | |||||
## 模型结构 | |||||
采用华为MindSpore官方训练的FaceRecognition模型 | |||||
https://www.mindspore.cn/resources/hub/details?MindSpore/1.7/facerecognition_ms1mv2 | |||||
## 环境要求 | |||||
mindspore>=1.7,硬件平台为GPU。 | |||||
## 脚本说明 | |||||
```markdown | |||||
├── readme.md | |||||
├── photos | |||||
│ ├── adv_input //对抗图像 | |||||
│ ├── input //输入图像 | |||||
│ └── target //目标图像 | |||||
├── outputs //训练后的图像 | |||||
├── adversarial_attack.py //训练脚本 | |||||
│── example_non_target_attack.py //无目标攻击训练 | |||||
│── example_target_attack.py //有目标攻击训练 | |||||
│── loss_design.py //训练优化设置 | |||||
└── test.py //评估攻击效果 | |||||
``` | |||||
## 模型调用 | |||||
方法一: | |||||
```python | |||||
#基于mindspore_hub库调用FaceRecognition模型 | |||||
import mindspore_hub as mshub | |||||
from mindspore import context | |||||
def get_model(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=0) | |||||
model = "mindspore/1.7/facerecognition_ms1mv2" | |||||
network = mshub.load(model) | |||||
network.set_train(False) | |||||
return network | |||||
``` | |||||
方法二: | |||||
```text | |||||
利用MindSpore代码仓中的https://gitee.com/mindspore/models/blob/master/research/cv/FaceRecognition/eval.py的get_model函数加载模型 | |||||
``` | |||||
## 训练过程 | |||||
有目标攻击: | |||||
```shell | |||||
cd face_adversarial_attack/ | |||||
python example_target_attack.py | |||||
``` | |||||
非目标攻击: | |||||
```shell | |||||
cd face_adversarial_attack/ | |||||
python example_non_target_attack.py | |||||
``` | |||||
## 默认训练参数 | |||||
optimizer=adam, learning rate=0.01, weight_decay=0.0001, epoch=2000 | |||||
## 评估过程 | |||||
评估方法一: | |||||
```shell | |||||
adversarial_attack.FaceAdversarialAttack.test_non_target_attack() | |||||
adversarial_attack.FaceAdversarialAttack.test_target_attack() | |||||
``` | |||||
评估方法二: | |||||
```shell | |||||
cd face_adversarial_attack/ | |||||
python test.py | |||||
``` | |||||
## 实验结果 | |||||
有目标攻击: | |||||
```text | |||||
input_label: 60 | |||||
target_label: 345 | |||||
The confidence of the input image on the input label: 26.67 | |||||
The confidence of the input image on the target label: 0.95 | |||||
================================ | |||||
adversarial_label: 345 | |||||
The confidence of the adversarial sample on the correct label: 1.82 | |||||
The confidence of the adversarial sample on the target label: 10.96 | |||||
input_label: 60, target_label: 345, adversarial_label: 345 | |||||
photos中是有目标攻击的实验结果 | |||||
``` | |||||
非目标攻击: | |||||
```text | |||||
input_label: 60 | |||||
The confidence of the input image on the input label: 25.16 | |||||
================================ | |||||
adversarial_label: 251 | |||||
The confidence of the adversarial sample on the correct label: 9.52 | |||||
The confidence of the adversarial sample on the adversarial label: 60.96 | |||||
input_label: 60, adversarial_label: 251 | |||||
``` |
@@ -0,0 +1,275 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""Train set""" | |||||
import os | |||||
import re | |||||
import numpy as np | |||||
import face_recognition as fr | |||||
import face_recognition_models as frm | |||||
import dlib | |||||
from PIL import Image, ImageDraw | |||||
import mindspore | |||||
import mindspore.dataset.vision.py_transforms as P | |||||
from mindspore.dataset.vision.py_transforms import ToPIL as ToPILImage | |||||
from mindspore.dataset.vision.py_transforms import ToTensor | |||||
from mindspore import Parameter, ops, nn, Tensor | |||||
from loss_design import MyTrainOneStepCell, MyWithLossCellTargetAttack, \ | |||||
MyWithLossCellNonTargetAttack, FaceLossTargetAttack, FaceLossNoTargetAttack | |||||
class FaceAdversarialAttack(): | |||||
""" | |||||
Class used to create adversarial facial recognition attacks. | |||||
Args: | |||||
input_img (numpy.ndarray): The input image. | |||||
target_img (numpy.ndarray): The target image. | |||||
seed (int): optional Sets custom seed for reproducibility. Default is generated randomly. | |||||
net (mindspore.Model): face recognition model. | |||||
""" | |||||
def __init__(self, input_img, target_img, net, seed=None): | |||||
if seed is not None: | |||||
np.random.seed(seed) | |||||
self.mean = Tensor([0.485, 0.456, 0.406]) | |||||
self.std = Tensor([0.229, 0.224, 0.225]) | |||||
self.expand_dims = mindspore.ops.ExpandDims() | |||||
self.imageize = ToPILImage() | |||||
self.tensorize = ToTensor() | |||||
self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||||
self.resnet = net | |||||
self.input_tensor = Tensor(self.normalize(self.tensorize(input_img))) | |||||
self.target_tensor = Tensor(self.normalize(self.tensorize(target_img))) | |||||
self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0)) | |||||
self.target_emb = self.resnet(self.expand_dims(self.target_tensor, 0)) | |||||
self.adversarial_emb = None | |||||
self.mask_tensor = create_mask(input_img) | |||||
self.ref = self.mask_tensor | |||||
self.pm = Parameter(self.mask_tensor) | |||||
self.opt = nn.Adam([self.pm], learning_rate=0.01, weight_decay=0.0001) | |||||
def train(self, attack_method): | |||||
""" | |||||
Optimized adversarial image. | |||||
Args: | |||||
attack_method (String) : Including target attack and non_target attack. | |||||
Returns: | |||||
Tensor, adversarial image. | |||||
Tensor, mask image. | |||||
""" | |||||
if attack_method == "non_target_attack": | |||||
loss = FaceLossNoTargetAttack() | |||||
net_with_criterion = MyWithLossCellNonTargetAttack(self.resnet, loss, self.input_tensor) | |||||
if attack_method == "target_attack": | |||||
loss = FaceLossTargetAttack(self.target_emb) | |||||
net_with_criterion = MyWithLossCellTargetAttack(self.resnet, loss, self.input_tensor) | |||||
train_net = MyTrainOneStepCell(net_with_criterion, self.opt) | |||||
for i in range(2000): | |||||
self.mask_tensor = Tensor(self.pm) | |||||
loss = train_net(self.mask_tensor) | |||||
print("epoch %d ,loss: %f \n " % (i, loss.asnumpy().item())) | |||||
self.mask_tensor = ops.clip_by_value( | |||||
self.mask_tensor, Tensor(0, mindspore.float32), Tensor(1, mindspore.float32)) | |||||
adversarial_tensor = apply( | |||||
self.input_tensor, | |||||
(self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None], | |||||
self.ref) | |||||
adversarial_tensor = self._reverse_norm(adversarial_tensor) | |||||
processed_input_tensor = self._reverse_norm(self.input_tensor) | |||||
processed_target_tensor = self._reverse_norm(self.target_tensor) | |||||
return { | |||||
"adversarial_tensor": adversarial_tensor, | |||||
"mask_tensor": self.mask_tensor, | |||||
"processed_input_tensor": processed_input_tensor, | |||||
"processed_target_tensor": processed_target_tensor | |||||
} | |||||
def test_target_attack(self): | |||||
""" | |||||
The model is used to test the recognition ability of adversarial images under target attack. | |||||
""" | |||||
adversarial_tensor = apply( | |||||
self.input_tensor, | |||||
(self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None], | |||||
self.ref) | |||||
self.adversarial_emb = self.resnet(self.expand_dims(adversarial_tensor, 0)) | |||||
self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0)) | |||||
self.target_emb = self.resnet(self.expand_dims(self.target_tensor, 0)) | |||||
adversarial_index = np.argmax(self.adversarial_emb.asnumpy()) | |||||
target_index = np.argmax(self.target_emb.asnumpy()) | |||||
input_index = np.argmax(self.input_emb.asnumpy()) | |||||
print("input_label:", input_index) | |||||
print("target_label:", target_index) | |||||
print("The confidence of the input image on the input label:", self.input_emb.asnumpy()[0][input_index]) | |||||
print("The confidence of the input image on the target label:", self.input_emb.asnumpy()[0][target_index]) | |||||
print("================================") | |||||
print("adversarial_label:", adversarial_index) | |||||
print("The confidence of the adversarial sample on the correct label:", | |||||
self.adversarial_emb.asnumpy()[0][input_index]) | |||||
print("The confidence of the adversarial sample on the target label:", | |||||
self.adversarial_emb.asnumpy()[0][target_index]) | |||||
print("input_label: %d, target_label: %d, adversarial_label: %d" | |||||
% (input_index, target_index, adversarial_index)) | |||||
def test_non_target_attack(self): | |||||
""" | |||||
The model is used to test the recognition ability of adversarial images under non_target attack. | |||||
""" | |||||
adversarial_tensor = apply( | |||||
self.input_tensor, | |||||
(self.mask_tensor - self.mean[:, None, None]) / self.std[:, None, None], | |||||
self.ref) | |||||
self.adversarial_emb = self.resnet(self.expand_dims(adversarial_tensor, 0)) | |||||
self.input_emb = self.resnet(self.expand_dims(self.input_tensor, 0)) | |||||
adversarial_index = np.argmax(self.adversarial_emb.asnumpy()) | |||||
input_index = np.argmax(self.input_emb.asnumpy()) | |||||
print("input_label:", input_index) | |||||
print("The confidence of the input image on the input label:", self.input_emb.asnumpy()[0][input_index]) | |||||
print("================================") | |||||
print("adversarial_label:", adversarial_index) | |||||
print("The confidence of the adversarial sample on the correct label:", | |||||
self.adversarial_emb.asnumpy()[0][input_index]) | |||||
print("The confidence of the adversarial sample on the adversarial label:", | |||||
self.adversarial_emb.asnumpy()[0][adversarial_index]) | |||||
print( | |||||
"input_label: %d, adversarial_label: %d" % (input_index, adversarial_index)) | |||||
def _reverse_norm(self, image_tensor): | |||||
""" | |||||
Reverses normalization for a given image_tensor. | |||||
Args: | |||||
image_tensor (Tensor): Tensor. | |||||
Returns: | |||||
Tensor, image. | |||||
""" | |||||
tensor = image_tensor * self.std[:, None, None] + self.mean[:, None, None] | |||||
return tensor | |||||
def apply(image_tensor, mask_tensor, reference_tensor): | |||||
""" | |||||
Apply a mask over an image. | |||||
Args: | |||||
image_tensor (Tensor): Canvas to be used to apply mask on. | |||||
mask_tensor (Tensor): Mask to apply over the image. | |||||
reference_tensor (Tensor): Used to reference mask boundaries | |||||
Returns: | |||||
Tensor, image. | |||||
""" | |||||
tensor = mindspore.numpy.where((reference_tensor == 0), image_tensor, mask_tensor) | |||||
return tensor | |||||
def create_mask(face_image): | |||||
""" | |||||
Create mask image. | |||||
Args: | |||||
face_image (PIL.Image): image of a detected face. | |||||
Returns: | |||||
mask_tensor : a mask image. | |||||
""" | |||||
mask = Image.new('RGB', face_image.size, color=(0, 0, 0)) | |||||
d = ImageDraw.Draw(mask) | |||||
landmarks = fr.face_landmarks(np.array(face_image)) | |||||
area = [landmark | |||||
for landmark in landmarks[0]['chin'] | |||||
if landmark[1] > max(landmarks[0]['nose_tip'])[1]] | |||||
area.append(landmarks[0]['nose_bridge'][1]) | |||||
d.polygon(area, fill=(255, 255, 255)) | |||||
mask = np.array(mask) | |||||
mask = mask.astype(np.float32) | |||||
for i in range(mask.shape[0]): | |||||
for j in range(mask.shape[1]): | |||||
for k in range(mask.shape[2]): | |||||
if mask[i][j][k] == 255.: | |||||
mask[i][j][k] = 0.5 | |||||
else: | |||||
mask[i][j][k] = 0 | |||||
mask_tensor = Tensor(mask) | |||||
mask_tensor = mask_tensor.swapaxes(0, 2).swapaxes(1, 2) | |||||
mask_tensor.requires_grad = True | |||||
return mask_tensor | |||||
def detect_face(image): | |||||
""" | |||||
Face detection and alignment process using dlib library. | |||||
Args: | |||||
image (numpy.ndarray): image file location. | |||||
Returns: | |||||
face_image : Resized face image. | |||||
""" | |||||
dlib_detector = dlib.get_frontal_face_detector() | |||||
dlib_shape_predictor = dlib.shape_predictor(frm.pose_predictor_model_location()) | |||||
dlib_image = dlib.load_rgb_image(image) | |||||
detections = dlib_detector(dlib_image, 1) | |||||
dlib_faces = dlib.full_object_detections() | |||||
for det in detections: | |||||
dlib_faces.append(dlib_shape_predictor(dlib_image, det)) | |||||
face_image = Image.fromarray(dlib.get_face_chip(dlib_image, dlib_faces[0], size=112)) | |||||
return face_image | |||||
def load_data(data): | |||||
""" | |||||
An auxiliary function that loads image data. | |||||
Args: | |||||
data (String): The path to the given data. | |||||
Returns: | |||||
list : Resize list of face images. | |||||
""" | |||||
image_files = [f for f in os.listdir(data) if re.search(r'.*\.(jpe?g|png)', f)] | |||||
image_files_locs = [os.path.join(data, f) for f in image_files] | |||||
image_list = [] | |||||
for img in image_files_locs: | |||||
image_list.append(detect_face(img)) | |||||
return image_list |
@@ -0,0 +1,45 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""non target attack""" | |||||
import numpy as np | |||||
import matplotlib.image as mp | |||||
from mindspore import context | |||||
import adversarial_attack | |||||
from FaceRecognition.eval import get_model | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
if __name__ == '__main__': | |||||
inputs = adversarial_attack.load_data('photos/input/') | |||||
targets = adversarial_attack.load_data('photos/target/') | |||||
net = get_model() | |||||
adversarial = adversarial_attack.FaceAdversarialAttack(inputs[0], targets[0], net) | |||||
ATTACK_METHOD = "non_target_attack" | |||||
tensor_dict = adversarial.train(attack_method=ATTACK_METHOD) | |||||
mp.imsave('./outputs/adversarial_example.jpg', | |||||
np.transpose(tensor_dict.get("adversarial_tensor").asnumpy(), (1, 2, 0))) | |||||
mp.imsave('./outputs/mask.jpg', | |||||
np.transpose(tensor_dict.get("mask_tensor").asnumpy(), (1, 2, 0))) | |||||
mp.imsave('./outputs/input_image.jpg', | |||||
np.transpose(tensor_dict.get("processed_input_tensor").asnumpy(), (1, 2, 0))) | |||||
mp.imsave('./outputs/target_image.jpg', | |||||
np.transpose(tensor_dict.get("processed_target_tensor").asnumpy(), (1, 2, 0))) | |||||
adversarial.test_non_target_attack() |
@@ -0,0 +1,46 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""target attack""" | |||||
import numpy as np | |||||
import matplotlib.image as mp | |||||
from mindspore import context | |||||
import adversarial_attack | |||||
from FaceRecognition.eval import get_model | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
if __name__ == '__main__': | |||||
inputs = adversarial_attack.load_data('photos/input/') | |||||
targets = adversarial_attack.load_data('photos/target/') | |||||
net = get_model() | |||||
adversarial = adversarial_attack.FaceAdversarialAttack(inputs[0], targets[0], net) | |||||
ATTACK_METHOD = "target_attack" | |||||
tensor_dict = adversarial.train(attack_method=ATTACK_METHOD) | |||||
mp.imsave('./outputs/adversarial_example.jpg', | |||||
np.transpose(tensor_dict.get("adversarial_tensor").asnumpy(), (1, 2, 0))) | |||||
mp.imsave('./outputs/mask.jpg', | |||||
np.transpose(tensor_dict.get("mask_tensor").asnumpy(), (1, 2, 0))) | |||||
mp.imsave('./outputs/input_image.jpg', | |||||
np.transpose(tensor_dict.get("processed_input_tensor").asnumpy(), (1, 2, 0))) | |||||
mp.imsave('./outputs/target_image.jpg', | |||||
np.transpose(tensor_dict.get("processed_target_tensor").asnumpy(), (1, 2, 0))) | |||||
adversarial.test_target_attack() |
@@ -0,0 +1,154 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""optimization Settings""" | |||||
import mindspore | |||||
from mindspore import ops, nn, Tensor | |||||
from mindspore.dataset.vision.py_transforms import ToTensor | |||||
import mindspore.dataset.vision.py_transforms as P | |||||
class MyTrainOneStepCell(nn.TrainOneStepCell): | |||||
""" | |||||
Encapsulation class of network training. | |||||
Append an optimizer to the training network after that the construct | |||||
function can be called to create the backward graph. | |||||
Args: | |||||
network (Cell): The training network. Note that loss function should have been added. | |||||
optimizer (Optimizer): Optimizer for updating the weights. | |||||
sens (Number): The adjust parameter. Default: 1.0. | |||||
""" | |||||
def __init__(self, network, optimizer, sens=1.0): | |||||
super(MyTrainOneStepCell, self).__init__(network, optimizer, sens) | |||||
self.grad = ops.composite.GradOperation(get_all=True, sens_param=False) | |||||
def construct(self, *inputs): | |||||
"""Defines the computation performed.""" | |||||
loss = self.network(*inputs) | |||||
grads = self.grad(self.network)(*inputs) | |||||
self.optimizer(grads) | |||||
return loss | |||||
class MyWithLossCellTargetAttack(nn.Cell): | |||||
"""The loss function defined by the target attack""" | |||||
def __init__(self, net, loss_fn, input_tensor): | |||||
super(MyWithLossCellTargetAttack, self).__init__(auto_prefix=False) | |||||
self.net = net | |||||
self._loss_fn = loss_fn | |||||
self.std = Tensor([0.229, 0.224, 0.225]) | |||||
self.mean = Tensor([0.485, 0.456, 0.406]) | |||||
self.expand_dims = mindspore.ops.ExpandDims() | |||||
self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||||
self.tensorize = ToTensor() | |||||
self.input_tensor = input_tensor | |||||
self.input_emb = self.net(self.expand_dims(self.input_tensor, 0)) | |||||
@property | |||||
def backbone_network(self): | |||||
return self.net | |||||
def construct(self, mask_tensor): | |||||
ref = mask_tensor | |||||
adversarial_tensor = mindspore.numpy.where( | |||||
(ref == 0), | |||||
self.input_tensor, | |||||
(mask_tensor - self.mean[:, None, None]) / self.std[:, None, None]) | |||||
adversarial_emb = self.net(self.expand_dims(adversarial_tensor, 0)) | |||||
loss = self._loss_fn(adversarial_emb) | |||||
return loss | |||||
class MyWithLossCellNonTargetAttack(nn.Cell): | |||||
"""The loss function defined by the non target attack""" | |||||
def __init__(self, net, loss_fn, input_tensor): | |||||
super(MyWithLossCellNonTargetAttack, self).__init__(auto_prefix=False) | |||||
self.net = net | |||||
self._loss_fn = loss_fn | |||||
self.std = Tensor([0.229, 0.224, 0.225]) | |||||
self.mean = Tensor([0.485, 0.456, 0.406]) | |||||
self.expand_dims = mindspore.ops.ExpandDims() | |||||
self.normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||||
self.tensorize = ToTensor() | |||||
self.input_tensor = input_tensor | |||||
self.input_emb = self.net(self.expand_dims(self.input_tensor, 0)) | |||||
@property | |||||
def backbone_network(self): | |||||
return self.net | |||||
def construct(self, mask_tensor): | |||||
ref = mask_tensor | |||||
adversarial_tensor = mindspore.numpy.where( | |||||
(ref == 0), | |||||
self.input_tensor, | |||||
(mask_tensor - self.mean[:, None, None]) / self.std[:, None, None]) | |||||
adversarial_emb = self.net(self.expand_dims(adversarial_tensor, 0)) | |||||
loss = self._loss_fn(adversarial_emb, self.input_emb) | |||||
return loss | |||||
class FaceLossTargetAttack(nn.Cell): | |||||
"""The loss function of the target attack""" | |||||
def __init__(self, target_emb): | |||||
super(FaceLossTargetAttack, self).__init__() | |||||
self.uniformreal = ops.UniformReal(seed=2) | |||||
self.sum = ops.ReduceSum(keep_dims=False) | |||||
self.norm = nn.Norm(keep_dims=True) | |||||
self.zeroslike = ops.ZerosLike() | |||||
self.concat_op1 = ops.Concat(1) | |||||
self.concat_op2 = ops.Concat(2) | |||||
self.pow = ops.Pow() | |||||
self.reduce_sum = ops.operations.ReduceSum() | |||||
self.target_emb = target_emb | |||||
self.abs = ops.Abs() | |||||
self.reduce_mean = ops.ReduceMean() | |||||
def construct(self, adversarial_emb): | |||||
prod_sum = self.reduce_sum(adversarial_emb * self.target_emb, (1,)) | |||||
square1 = self.reduce_sum(ops.functional.square(adversarial_emb), (1,)) | |||||
square2 = self.reduce_sum(ops.functional.square(self.target_emb), (1,)) | |||||
denom = ops.functional.sqrt(square1) * ops.functional.sqrt(square2) | |||||
loss = -(prod_sum / denom) | |||||
return loss | |||||
class FaceLossNoTargetAttack(nn.Cell): | |||||
"""The loss function of the non-target attack""" | |||||
def __init__(self): | |||||
"""Initialization""" | |||||
super(FaceLossNoTargetAttack, self).__init__() | |||||
self.uniformreal = ops.UniformReal(seed=2) | |||||
self.sum = ops.ReduceSum(keep_dims=False) | |||||
self.norm = nn.Norm(keep_dims=True) | |||||
self.zeroslike = ops.ZerosLike() | |||||
self.concat_op1 = ops.Concat(1) | |||||
self.concat_op2 = ops.Concat(2) | |||||
self.pow = ops.Pow() | |||||
self.reduce_sum = ops.operations.ReduceSum() | |||||
self.abs = ops.Abs() | |||||
self.reduce_mean = ops.ReduceMean() | |||||
def construct(self, adversarial_emb, input_emb): | |||||
prod_sum = self.reduce_sum(adversarial_emb * input_emb, (1,)) | |||||
square1 = self.reduce_sum(ops.functional.square(adversarial_emb), (1,)) | |||||
square2 = self.reduce_sum(ops.functional.square(input_emb), (1,)) | |||||
denom = ops.functional.sqrt(square1) * ops.functional.sqrt(square2) | |||||
loss = prod_sum / denom | |||||
return loss |
@@ -0,0 +1,59 @@ | |||||
# Copyright 2022 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
"""test""" | |||||
import numpy as np | |||||
from mindspore import context, Tensor | |||||
import mindspore | |||||
from mindspore.dataset.vision.py_transforms import ToTensor | |||||
import mindspore.dataset.vision.py_transforms as P | |||||
from FaceRecognition.eval import get_model | |||||
import adversarial_attack | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
if __name__ == '__main__': | |||||
image = adversarial_attack.load_data('photos/adv_input/') | |||||
inputs = adversarial_attack.load_data('photos/input/') | |||||
targets = adversarial_attack.load_data('photos/target/') | |||||
tensorize = ToTensor() | |||||
normalize = P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||||
expand_dims = mindspore.ops.ExpandDims() | |||||
mean = Tensor([0.485, 0.456, 0.406]) | |||||
std = Tensor([0.229, 0.224, 0.225]) | |||||
resnet = get_model() | |||||
adv = Tensor(normalize(tensorize(image[0]))) | |||||
input_tensor = Tensor(normalize(tensorize(inputs[0]))) | |||||
target_tensor = Tensor(normalize(tensorize(targets[0]))) | |||||
adversarial_emb = resnet(expand_dims(adv, 0)) | |||||
input_emb = resnet(expand_dims(input_tensor, 0)) | |||||
target_emb = resnet(expand_dims(target_tensor, 0)) | |||||
adversarial_index = np.argmax(adversarial_emb.asnumpy()) | |||||
target_index = np.argmax(target_emb.asnumpy()) | |||||
input_index = np.argmax(input_emb.asnumpy()) | |||||
print("input_label:", input_index) | |||||
print("The confidence of the input image on the input label:", input_emb.asnumpy()[0][input_index]) | |||||
print("================================") | |||||
print("adversarial_label:", adversarial_index) | |||||
print("The confidence of the adversarial sample on the correct label:", adversarial_emb.asnumpy()[0][input_index]) | |||||
print("The confidence of the adversarial sample on the adversarial label:", | |||||
adversarial_emb.asnumpy()[0][adversarial_index]) | |||||
print("input_label:%d, adversarial_label:%d" % (input_index, adversarial_index)) |